Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions src/specify_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,89 @@ def preset_set_priority(
console.print("\n[dim]Lower priority = higher precedence in template resolution[/dim]")


@preset_app.command("enable")
def preset_enable(
pack_id: str = typer.Argument(help="Preset ID to enable"),
):
"""Enable a disabled preset."""
from .presets import PresetManager

project_root = Path.cwd()

# Check if we're in a spec-kit project
specify_dir = project_root / ".specify"
if not specify_dir.exists():
console.print("[red]Error:[/red] Not a spec-kit project (no .specify/ directory)")
console.print("Run this command from a spec-kit project root")
raise typer.Exit(1)

manager = PresetManager(project_root)

# Check if preset is installed
if not manager.registry.is_installed(pack_id):
console.print(f"[red]Error:[/red] Preset '{pack_id}' is not installed")
raise typer.Exit(1)

# Get current metadata
metadata = manager.registry.get(pack_id)
if metadata is None or not isinstance(metadata, dict):
console.print(f"[red]Error:[/red] Preset '{pack_id}' not found in registry (corrupted state)")
raise typer.Exit(1)

if metadata.get("enabled", True):
console.print(f"[yellow]Preset '{pack_id}' is already enabled[/yellow]")
raise typer.Exit(0)

# Enable the preset
manager.registry.update(pack_id, {"enabled": True})

console.print(f"[green]✓[/green] Preset '{pack_id}' enabled")
console.print("\nTemplates from this preset will now be included in resolution.")
console.print("[dim]Note: Previously registered commands/skills remain active.[/dim]")


@preset_app.command("disable")
def preset_disable(
pack_id: str = typer.Argument(help="Preset ID to disable"),
):
"""Disable a preset without removing it."""
from .presets import PresetManager

project_root = Path.cwd()

# Check if we're in a spec-kit project
specify_dir = project_root / ".specify"
if not specify_dir.exists():
console.print("[red]Error:[/red] Not a spec-kit project (no .specify/ directory)")
console.print("Run this command from a spec-kit project root")
raise typer.Exit(1)

manager = PresetManager(project_root)

# Check if preset is installed
if not manager.registry.is_installed(pack_id):
console.print(f"[red]Error:[/red] Preset '{pack_id}' is not installed")
raise typer.Exit(1)

# Get current metadata
metadata = manager.registry.get(pack_id)
if metadata is None or not isinstance(metadata, dict):
console.print(f"[red]Error:[/red] Preset '{pack_id}' not found in registry (corrupted state)")
raise typer.Exit(1)

if not metadata.get("enabled", True):
console.print(f"[yellow]Preset '{pack_id}' is already disabled[/yellow]")
raise typer.Exit(0)

# Disable the preset
manager.registry.update(pack_id, {"enabled": False})

console.print(f"[green]✓[/green] Preset '{pack_id}' disabled")
console.print("\nTemplates from this preset will be skipped during resolution.")
console.print("[dim]Note: Previously registered commands/skills remain active until preset removal.[/dim]")
console.print(f"To re-enable: specify preset enable {pack_id}")


# ===== Preset Catalog Commands =====


Expand Down Expand Up @@ -3841,8 +3924,7 @@ def extension_enable(
console.print(f"[yellow]Extension '{display_name}' is already enabled[/yellow]")
raise typer.Exit(0)

metadata["enabled"] = True
manager.registry.update(extension_id, metadata)
manager.registry.update(extension_id, {"enabled": True})

# Enable hooks in extensions.yml
config = hook_executor.get_project_config()
Expand Down Expand Up @@ -3889,8 +3971,7 @@ def extension_disable(
console.print(f"[yellow]Extension '{display_name}' is already disabled[/yellow]")
raise typer.Exit(0)

metadata["enabled"] = False
manager.registry.update(extension_id, metadata)
manager.registry.update(extension_id, {"enabled": False})

# Disable hooks in extensions.yml
config = hook_executor.get_project_config()
Expand Down
99 changes: 79 additions & 20 deletions src/specify_cli/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,17 @@ def _load(self) -> dict:

try:
with open(self.registry_path, 'r') as f:
return json.load(f)
data = json.load(f)
# Validate loaded data is a dict (handles corrupted registry files)
if not isinstance(data, dict):
return {
"schema_version": self.SCHEMA_VERSION,
"extensions": {}
}
# Normalize extensions field (handles corrupted extensions value)
if not isinstance(data.get("extensions"), dict):
data["extensions"] = {}
return data
except (json.JSONDecodeError, FileNotFoundError):
# Corrupted or missing registry, start fresh
return {
Expand All @@ -244,7 +254,7 @@ def add(self, extension_id: str, metadata: dict):
metadata: Extension metadata (version, source, etc.)
"""
self.data["extensions"][extension_id] = {
**metadata,
**copy.deepcopy(metadata),
"installed_at": datetime.now(timezone.utc).isoformat()
}
self._save()
Expand All @@ -267,10 +277,11 @@ def update(self, extension_id: str, metadata: dict):
Raises:
KeyError: If extension is not installed
"""
if extension_id not in self.data["extensions"]:
extensions = self.data.get("extensions")
if not isinstance(extensions, dict) or extension_id not in extensions:
raise KeyError(f"Extension '{extension_id}' is not installed")
# Merge new metadata with existing, preserving original installed_at
existing = self.data["extensions"][extension_id]
existing = extensions[extension_id]
# Handle corrupted registry entries (e.g., string/list instead of dict)
if not isinstance(existing, dict):
existing = {}
Expand All @@ -283,7 +294,7 @@ def update(self, extension_id: str, metadata: dict):
else:
# If not present in existing, explicitly remove from merged if caller provided it
merged.pop("installed_at", None)
self.data["extensions"][extension_id] = merged
extensions[extension_id] = merged
self._save()

def restore(self, extension_id: str, metadata: dict):
Expand All @@ -296,8 +307,16 @@ def restore(self, extension_id: str, metadata: dict):
Args:
extension_id: Extension ID
metadata: Complete extension metadata including installed_at

Raises:
ValueError: If metadata is None or not a dict
"""
self.data["extensions"][extension_id] = dict(metadata)
if metadata is None or not isinstance(metadata, dict):
raise ValueError(f"Cannot restore '{extension_id}': metadata must be a dict")
# Ensure extensions dict exists (handle corrupted registry)
if not isinstance(self.data.get("extensions"), dict):
self.data["extensions"] = {}
self.data["extensions"][extension_id] = copy.deepcopy(metadata)
self._save()

def remove(self, extension_id: str):
Expand All @@ -306,8 +325,11 @@ def remove(self, extension_id: str):
Args:
extension_id: Extension ID
"""
if extension_id in self.data["extensions"]:
del self.data["extensions"][extension_id]
extensions = self.data.get("extensions")
if not isinstance(extensions, dict):
return
if extension_id in extensions:
del extensions[extension_id]
self._save()

def get(self, extension_id: str) -> Optional[dict]:
Expand All @@ -320,21 +342,49 @@ def get(self, extension_id: str) -> Optional[dict]:
extension_id: Extension ID

Returns:
Deep copy of extension metadata, or None if not found
Deep copy of extension metadata, or None if not found or corrupted
"""
entry = self.data["extensions"].get(extension_id)
return copy.deepcopy(entry) if entry is not None else None
extensions = self.data.get("extensions")
if not isinstance(extensions, dict):
return None
entry = extensions.get(extension_id)
# Return None for missing or corrupted (non-dict) entries
if entry is None or not isinstance(entry, dict):
return None
return copy.deepcopy(entry)

def list(self) -> Dict[str, dict]:
"""Get all installed extensions.
"""Get all installed extensions with valid metadata.

Returns a deep copy of extensions with dict metadata only.
Corrupted entries (non-dict values) are filtered out.

Returns:
Dictionary of extension_id -> metadata (deep copies), empty dict if corrupted
"""
extensions = self.data.get("extensions", {}) or {}
if not isinstance(extensions, dict):
return {}
# Filter to only valid dict entries to match type contract
return {
ext_id: copy.deepcopy(meta)
for ext_id, meta in extensions.items()
if isinstance(meta, dict)
}

def keys(self) -> set:
"""Get all extension IDs including corrupted entries.

Returns a deep copy of the extensions mapping to prevent callers
from accidentally mutating nested internal registry state.
Lightweight method that returns IDs without deep-copying metadata.
Use this when you only need to check which extensions are tracked.

Returns:
Dictionary of extension_id -> metadata (deep copies)
Set of extension IDs (includes corrupted entries)
"""
return copy.deepcopy(self.data["extensions"])
extensions = self.data.get("extensions", {}) or {}
if not isinstance(extensions, dict):
return set()
return set(extensions.keys())

def is_installed(self, extension_id: str) -> bool:
"""Check if extension is installed.
Expand All @@ -343,17 +393,23 @@ def is_installed(self, extension_id: str) -> bool:
extension_id: Extension ID

Returns:
True if extension is installed
True if extension is installed, False if not or registry corrupted
"""
return extension_id in self.data["extensions"]
extensions = self.data.get("extensions")
if not isinstance(extensions, dict):
return False
return extension_id in extensions

def list_by_priority(self) -> List[tuple]:
def list_by_priority(self, include_disabled: bool = False) -> List[tuple]:
"""Get all installed extensions sorted by priority.

Lower priority number = higher precedence (checked first).
Extensions with equal priority are sorted alphabetically by ID
for deterministic ordering.

Args:
include_disabled: If True, include disabled extensions. Default False.

Returns:
List of (extension_id, metadata_copy) tuples sorted by priority.
Metadata is deep-copied to prevent accidental mutation.
Expand All @@ -365,6 +421,9 @@ def list_by_priority(self) -> List[tuple]:
for ext_id, meta in extensions.items():
if not isinstance(meta, dict):
continue
# Skip disabled extensions unless explicitly requested
if not include_disabled and not meta.get("enabled", True):
continue
metadata_copy = copy.deepcopy(meta)
metadata_copy["priority"] = normalize_priority(metadata_copy.get("priority", 10))
sortable_extensions.append((ext_id, metadata_copy))
Expand Down Expand Up @@ -633,7 +692,7 @@ def remove(self, extension_id: str, keep_config: bool = False) -> bool:

# Get registered commands before removal
metadata = self.registry.get(extension_id)
registered_commands = metadata.get("registered_commands", {})
registered_commands = metadata.get("registered_commands", {}) if metadata else {}

extension_dir = self.extensions_dir / extension_id

Expand Down
Loading
Loading