diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index ff2364d29..9bcf11b75 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -2405,6 +2405,89 @@ def preset_set_priority( console.print("\n[dim]Lower priority = higher precedence in template resolution[/dim]") +@preset_app.command("enable") +def preset_enable( + pack_id: str = typer.Argument(help="Preset ID to enable"), +): + """Enable a disabled preset.""" + from .presets import PresetManager + + project_root = Path.cwd() + + # Check if we're in a spec-kit project + specify_dir = project_root / ".specify" + if not specify_dir.exists(): + console.print("[red]Error:[/red] Not a spec-kit project (no .specify/ directory)") + console.print("Run this command from a spec-kit project root") + raise typer.Exit(1) + + manager = PresetManager(project_root) + + # Check if preset is installed + if not manager.registry.is_installed(pack_id): + console.print(f"[red]Error:[/red] Preset '{pack_id}' is not installed") + raise typer.Exit(1) + + # Get current metadata + metadata = manager.registry.get(pack_id) + if metadata is None or not isinstance(metadata, dict): + console.print(f"[red]Error:[/red] Preset '{pack_id}' not found in registry (corrupted state)") + raise typer.Exit(1) + + if metadata.get("enabled", True): + console.print(f"[yellow]Preset '{pack_id}' is already enabled[/yellow]") + raise typer.Exit(0) + + # Enable the preset + manager.registry.update(pack_id, {"enabled": True}) + + console.print(f"[green]✓[/green] Preset '{pack_id}' enabled") + console.print("\nTemplates from this preset will now be included in resolution.") + console.print("[dim]Note: Previously registered commands/skills remain active.[/dim]") + + +@preset_app.command("disable") +def preset_disable( + pack_id: str = typer.Argument(help="Preset ID to disable"), +): + """Disable a preset without removing it.""" + from .presets import PresetManager + + project_root = Path.cwd() + + # Check if we're in a spec-kit project + specify_dir = project_root / ".specify" + if not specify_dir.exists(): + console.print("[red]Error:[/red] Not a spec-kit project (no .specify/ directory)") + console.print("Run this command from a spec-kit project root") + raise typer.Exit(1) + + manager = PresetManager(project_root) + + # Check if preset is installed + if not manager.registry.is_installed(pack_id): + console.print(f"[red]Error:[/red] Preset '{pack_id}' is not installed") + raise typer.Exit(1) + + # Get current metadata + metadata = manager.registry.get(pack_id) + if metadata is None or not isinstance(metadata, dict): + console.print(f"[red]Error:[/red] Preset '{pack_id}' not found in registry (corrupted state)") + raise typer.Exit(1) + + if not metadata.get("enabled", True): + console.print(f"[yellow]Preset '{pack_id}' is already disabled[/yellow]") + raise typer.Exit(0) + + # Disable the preset + manager.registry.update(pack_id, {"enabled": False}) + + console.print(f"[green]✓[/green] Preset '{pack_id}' disabled") + console.print("\nTemplates from this preset will be skipped during resolution.") + console.print("[dim]Note: Previously registered commands/skills remain active until preset removal.[/dim]") + console.print(f"To re-enable: specify preset enable {pack_id}") + + # ===== Preset Catalog Commands ===== @@ -3841,8 +3924,7 @@ def extension_enable( console.print(f"[yellow]Extension '{display_name}' is already enabled[/yellow]") raise typer.Exit(0) - metadata["enabled"] = True - manager.registry.update(extension_id, metadata) + manager.registry.update(extension_id, {"enabled": True}) # Enable hooks in extensions.yml config = hook_executor.get_project_config() @@ -3889,8 +3971,7 @@ def extension_disable( console.print(f"[yellow]Extension '{display_name}' is already disabled[/yellow]") raise typer.Exit(0) - metadata["enabled"] = False - manager.registry.update(extension_id, metadata) + manager.registry.update(extension_id, {"enabled": False}) # Disable hooks in extensions.yml config = hook_executor.get_project_config() diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index 984ca83d6..0dca39a0c 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -222,7 +222,17 @@ def _load(self) -> dict: try: with open(self.registry_path, 'r') as f: - return json.load(f) + data = json.load(f) + # Validate loaded data is a dict (handles corrupted registry files) + if not isinstance(data, dict): + return { + "schema_version": self.SCHEMA_VERSION, + "extensions": {} + } + # Normalize extensions field (handles corrupted extensions value) + if not isinstance(data.get("extensions"), dict): + data["extensions"] = {} + return data except (json.JSONDecodeError, FileNotFoundError): # Corrupted or missing registry, start fresh return { @@ -244,7 +254,7 @@ def add(self, extension_id: str, metadata: dict): metadata: Extension metadata (version, source, etc.) """ self.data["extensions"][extension_id] = { - **metadata, + **copy.deepcopy(metadata), "installed_at": datetime.now(timezone.utc).isoformat() } self._save() @@ -267,15 +277,16 @@ def update(self, extension_id: str, metadata: dict): Raises: KeyError: If extension is not installed """ - if extension_id not in self.data["extensions"]: + extensions = self.data.get("extensions") + if not isinstance(extensions, dict) or extension_id not in extensions: raise KeyError(f"Extension '{extension_id}' is not installed") # Merge new metadata with existing, preserving original installed_at - existing = self.data["extensions"][extension_id] + existing = extensions[extension_id] # Handle corrupted registry entries (e.g., string/list instead of dict) if not isinstance(existing, dict): existing = {} - # Merge: existing fields preserved, new fields override - merged = {**existing, **metadata} + # Merge: existing fields preserved, new fields override (deep copy to prevent caller mutation) + merged = {**existing, **copy.deepcopy(metadata)} # Always preserve original installed_at based on key existence, not truthiness, # to handle cases where the field exists but may be falsy (legacy/corruption) if "installed_at" in existing: @@ -283,7 +294,7 @@ def update(self, extension_id: str, metadata: dict): else: # If not present in existing, explicitly remove from merged if caller provided it merged.pop("installed_at", None) - self.data["extensions"][extension_id] = merged + extensions[extension_id] = merged self._save() def restore(self, extension_id: str, metadata: dict): @@ -296,8 +307,16 @@ def restore(self, extension_id: str, metadata: dict): Args: extension_id: Extension ID metadata: Complete extension metadata including installed_at + + Raises: + ValueError: If metadata is None or not a dict """ - self.data["extensions"][extension_id] = dict(metadata) + if metadata is None or not isinstance(metadata, dict): + raise ValueError(f"Cannot restore '{extension_id}': metadata must be a dict") + # Ensure extensions dict exists (handle corrupted registry) + if not isinstance(self.data.get("extensions"), dict): + self.data["extensions"] = {} + self.data["extensions"][extension_id] = copy.deepcopy(metadata) self._save() def remove(self, extension_id: str): @@ -306,8 +325,11 @@ def remove(self, extension_id: str): Args: extension_id: Extension ID """ - if extension_id in self.data["extensions"]: - del self.data["extensions"][extension_id] + extensions = self.data.get("extensions") + if not isinstance(extensions, dict): + return + if extension_id in extensions: + del extensions[extension_id] self._save() def get(self, extension_id: str) -> Optional[dict]: @@ -320,21 +342,49 @@ def get(self, extension_id: str) -> Optional[dict]: extension_id: Extension ID Returns: - Deep copy of extension metadata, or None if not found + Deep copy of extension metadata, or None if not found or corrupted """ - entry = self.data["extensions"].get(extension_id) - return copy.deepcopy(entry) if entry is not None else None + extensions = self.data.get("extensions") + if not isinstance(extensions, dict): + return None + entry = extensions.get(extension_id) + # Return None for missing or corrupted (non-dict) entries + if entry is None or not isinstance(entry, dict): + return None + return copy.deepcopy(entry) def list(self) -> Dict[str, dict]: - """Get all installed extensions. + """Get all installed extensions with valid metadata. + + Returns a deep copy of extensions with dict metadata only. + Corrupted entries (non-dict values) are filtered out. + + Returns: + Dictionary of extension_id -> metadata (deep copies), empty dict if corrupted + """ + extensions = self.data.get("extensions", {}) or {} + if not isinstance(extensions, dict): + return {} + # Filter to only valid dict entries to match type contract + return { + ext_id: copy.deepcopy(meta) + for ext_id, meta in extensions.items() + if isinstance(meta, dict) + } + + def keys(self) -> set: + """Get all extension IDs including corrupted entries. - Returns a deep copy of the extensions mapping to prevent callers - from accidentally mutating nested internal registry state. + Lightweight method that returns IDs without deep-copying metadata. + Use this when you only need to check which extensions are tracked. Returns: - Dictionary of extension_id -> metadata (deep copies) + Set of extension IDs (includes corrupted entries) """ - return copy.deepcopy(self.data["extensions"]) + extensions = self.data.get("extensions", {}) or {} + if not isinstance(extensions, dict): + return set() + return set(extensions.keys()) def is_installed(self, extension_id: str) -> bool: """Check if extension is installed. @@ -343,17 +393,23 @@ def is_installed(self, extension_id: str) -> bool: extension_id: Extension ID Returns: - True if extension is installed + True if extension is installed, False if not or registry corrupted """ - return extension_id in self.data["extensions"] + extensions = self.data.get("extensions") + if not isinstance(extensions, dict): + return False + return extension_id in extensions - def list_by_priority(self) -> List[tuple]: + def list_by_priority(self, include_disabled: bool = False) -> List[tuple]: """Get all installed extensions sorted by priority. Lower priority number = higher precedence (checked first). Extensions with equal priority are sorted alphabetically by ID for deterministic ordering. + Args: + include_disabled: If True, include disabled extensions. Default False. + Returns: List of (extension_id, metadata_copy) tuples sorted by priority. Metadata is deep-copied to prevent accidental mutation. @@ -365,6 +421,9 @@ def list_by_priority(self) -> List[tuple]: for ext_id, meta in extensions.items(): if not isinstance(meta, dict): continue + # Skip disabled extensions unless explicitly requested + if not include_disabled and not meta.get("enabled", True): + continue metadata_copy = copy.deepcopy(meta) metadata_copy["priority"] = normalize_priority(metadata_copy.get("priority", 10)) sortable_extensions.append((ext_id, metadata_copy)) @@ -633,7 +692,7 @@ def remove(self, extension_id: str, keep_config: bool = False) -> bool: # Get registered commands before removal metadata = self.registry.get(extension_id) - registered_commands = metadata.get("registered_commands", {}) + registered_commands = metadata.get("registered_commands", {}) if metadata else {} extension_dir = self.extensions_dir / extension_id diff --git a/src/specify_cli/presets.py b/src/specify_cli/presets.py index 121d59617..aaa6e52e5 100644 --- a/src/specify_cli/presets.py +++ b/src/specify_cli/presets.py @@ -238,7 +238,17 @@ def _load(self) -> dict: try: with open(self.registry_path, 'r') as f: - return json.load(f) + data = json.load(f) + # Validate loaded data is a dict (handles corrupted registry files) + if not isinstance(data, dict): + return { + "schema_version": self.SCHEMA_VERSION, + "presets": {} + } + # Normalize presets field (handles corrupted presets value) + if not isinstance(data.get("presets"), dict): + data["presets"] = {} + return data except (json.JSONDecodeError, FileNotFoundError): return { "schema_version": self.SCHEMA_VERSION, @@ -259,7 +269,7 @@ def add(self, pack_id: str, metadata: dict): metadata: Pack metadata (version, source, etc.) """ self.data["presets"][pack_id] = { - **metadata, + **copy.deepcopy(metadata), "installed_at": datetime.now(timezone.utc).isoformat() } self._save() @@ -270,8 +280,11 @@ def remove(self, pack_id: str): Args: pack_id: Preset ID """ - if pack_id in self.data["presets"]: - del self.data["presets"][pack_id] + packs = self.data.get("presets") + if not isinstance(packs, dict): + return + if pack_id in packs: + del packs[pack_id] self._save() def update(self, pack_id: str, updates: dict): @@ -288,14 +301,15 @@ def update(self, pack_id: str, updates: dict): Raises: KeyError: If preset is not installed """ - if pack_id not in self.data["presets"]: + packs = self.data.get("presets") + if not isinstance(packs, dict) or pack_id not in packs: raise KeyError(f"Preset '{pack_id}' not found in registry") - existing = self.data["presets"][pack_id] + existing = packs[pack_id] # Handle corrupted registry entries (e.g., string/list instead of dict) if not isinstance(existing, dict): existing = {} - # Merge: existing fields preserved, new fields override - merged = {**existing, **updates} + # Merge: existing fields preserved, new fields override (deep copy to prevent caller mutation) + merged = {**existing, **copy.deepcopy(updates)} # Always preserve original installed_at based on key existence, not truthiness, # to handle cases where the field exists but may be falsy (legacy/corruption) if "installed_at" in existing: @@ -303,35 +317,95 @@ def update(self, pack_id: str, updates: dict): else: # If not present in existing, explicitly remove from merged if caller provided it merged.pop("installed_at", None) - self.data["presets"][pack_id] = merged + packs[pack_id] = merged + self._save() + + def restore(self, pack_id: str, metadata: dict): + """Restore preset metadata to registry without modifying timestamps. + + Use this method for rollback scenarios where you have a complete backup + of the registry entry (including installed_at) and want to restore it + exactly as it was. + + Args: + pack_id: Preset ID + metadata: Complete preset metadata including installed_at + + Raises: + ValueError: If metadata is None or not a dict + """ + if metadata is None or not isinstance(metadata, dict): + raise ValueError(f"Cannot restore '{pack_id}': metadata must be a dict") + # Ensure presets dict exists (handle corrupted registry) + if not isinstance(self.data.get("presets"), dict): + self.data["presets"] = {} + self.data["presets"][pack_id] = copy.deepcopy(metadata) self._save() def get(self, pack_id: str) -> Optional[dict]: """Get preset metadata from registry. + Returns a deep copy to prevent callers from accidentally mutating + nested internal registry state without going through the write path. + Args: pack_id: Preset ID Returns: - Pack metadata or None if not found + Deep copy of preset metadata, or None if not found or corrupted """ - return self.data["presets"].get(pack_id) + packs = self.data.get("presets") + if not isinstance(packs, dict): + return None + entry = packs.get(pack_id) + # Return None for missing or corrupted (non-dict) entries + if entry is None or not isinstance(entry, dict): + return None + return copy.deepcopy(entry) def list(self) -> Dict[str, dict]: - """Get all installed presets. + """Get all installed presets with valid metadata. + + Returns a deep copy of presets with dict metadata only. + Corrupted entries (non-dict values) are filtered out. Returns: - Dictionary of pack_id -> metadata + Dictionary of pack_id -> metadata (deep copies), empty dict if corrupted """ - return self.data["presets"] + packs = self.data.get("presets", {}) or {} + if not isinstance(packs, dict): + return {} + # Filter to only valid dict entries to match type contract + return { + pack_id: copy.deepcopy(meta) + for pack_id, meta in packs.items() + if isinstance(meta, dict) + } + + def keys(self) -> set: + """Get all preset IDs including corrupted entries. + + Lightweight method that returns IDs without deep-copying metadata. + Use this when you only need to check which presets are tracked. + + Returns: + Set of preset IDs (includes corrupted entries) + """ + packs = self.data.get("presets", {}) or {} + if not isinstance(packs, dict): + return set() + return set(packs.keys()) - def list_by_priority(self) -> List[tuple]: + def list_by_priority(self, include_disabled: bool = False) -> List[tuple]: """Get all installed presets sorted by priority. Lower priority number = higher precedence (checked first). Presets with equal priority are sorted alphabetically by ID for deterministic ordering. + Args: + include_disabled: If True, include disabled presets. Default False. + Returns: List of (pack_id, metadata_copy) tuples sorted by priority. Metadata is deep-copied to prevent accidental mutation. @@ -343,6 +417,9 @@ def list_by_priority(self) -> List[tuple]: for pack_id, meta in packs.items(): if not isinstance(meta, dict): continue + # Skip disabled presets unless explicitly requested + if not include_disabled and not meta.get("enabled", True): + continue metadata_copy = copy.deepcopy(meta) metadata_copy["priority"] = normalize_priority(metadata_copy.get("priority", 10)) sortable_packs.append((pack_id, metadata_copy)) @@ -358,9 +435,12 @@ def is_installed(self, pack_id: str) -> bool: pack_id: Preset ID Returns: - True if pack is installed + True if pack is installed, False if not or registry corrupted """ - return pack_id in self.data["presets"] + packs = self.data.get("presets") + if not isinstance(packs, dict): + return False + return pack_id in packs class PresetManager: @@ -1466,12 +1546,20 @@ def _get_all_extensions_by_priority(self) -> list[tuple[int, str, dict | None]]: return [] registry = ExtensionRegistry(self.extensions_dir) - registered_extensions = registry.list_by_priority() - registered_extension_ids = {ext_id for ext_id, _ in registered_extensions} + # Use keys() to track ALL extensions (including corrupted entries) without deep copy + # This prevents corrupted entries from being picked up as "unregistered" dirs + registered_extension_ids = registry.keys() + + # Get all registered extensions including disabled; we filter disabled manually below + all_registered = registry.list_by_priority(include_disabled=True) all_extensions: list[tuple[int, str, dict | None]] = [] - for ext_id, metadata in registered_extensions: + # Only include enabled extensions in the result + for ext_id, metadata in all_registered: + # Skip disabled extensions + if not metadata.get("enabled", True): + continue priority = normalize_priority(metadata.get("priority") if metadata else None) all_extensions.append((priority, ext_id, metadata)) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index c87ba5b53..d99295ea8 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -420,6 +420,48 @@ def test_restore_can_recreate_removed_entry(self, temp_dir): assert registry.is_installed("test-ext") assert registry.get("test-ext")["version"] == "1.0.0" + def test_restore_rejects_none_metadata(self, temp_dir): + """Test restore() raises ValueError for None metadata.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + registry = ExtensionRegistry(extensions_dir) + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-ext", None) + + def test_restore_rejects_non_dict_metadata(self, temp_dir): + """Test restore() raises ValueError for non-dict metadata.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + registry = ExtensionRegistry(extensions_dir) + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-ext", "not-a-dict") + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-ext", ["list", "not", "dict"]) + + def test_restore_uses_deep_copy(self, temp_dir): + """Test restore() deep copies metadata to prevent mutation.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + registry = ExtensionRegistry(extensions_dir) + + original_metadata = { + "version": "1.0.0", + "nested": {"key": "original"}, + } + registry.restore("test-ext", original_metadata) + + # Mutate the original metadata after restore + original_metadata["version"] = "MUTATED" + original_metadata["nested"]["key"] = "MUTATED" + + # Registry should have the original values + stored = registry.get("test-ext") + assert stored["version"] == "1.0.0" + assert stored["nested"]["key"] == "original" + def test_get_returns_deep_copy(self, temp_dir): """Test that get() returns deep copies for nested structures.""" extensions_dir = temp_dir / "extensions" @@ -439,6 +481,26 @@ def test_get_returns_deep_copy(self, temp_dir): internal = registry.data["extensions"]["test-ext"] assert internal["registered_commands"] == {"claude": ["cmd1"]} + def test_get_returns_none_for_corrupted_entry(self, temp_dir): + """Test that get() returns None for corrupted (non-dict) entries.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + + registry = ExtensionRegistry(extensions_dir) + + # Directly corrupt the registry with non-dict entries + registry.data["extensions"]["corrupted-string"] = "not a dict" + registry.data["extensions"]["corrupted-list"] = ["not", "a", "dict"] + registry.data["extensions"]["corrupted-int"] = 42 + registry._save() + + # All corrupted entries should return None + assert registry.get("corrupted-string") is None + assert registry.get("corrupted-list") is None + assert registry.get("corrupted-int") is None + # Non-existent should also return None + assert registry.get("nonexistent") is None + def test_list_returns_deep_copy(self, temp_dir): """Test that list() returns deep copies for nested structures.""" extensions_dir = temp_dir / "extensions" @@ -458,6 +520,20 @@ def test_list_returns_deep_copy(self, temp_dir): internal = registry.data["extensions"]["test-ext"] assert internal["registered_commands"] == {"claude": ["cmd1"]} + def test_list_returns_empty_dict_for_corrupted_registry(self, temp_dir): + """Test that list() returns empty dict when extensions is not a dict.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + registry = ExtensionRegistry(extensions_dir) + + # Corrupt the registry - extensions is a list instead of dict + registry.data["extensions"] = ["not", "a", "dict"] + registry._save() + + # list() should return empty dict, not crash + result = registry.list() + assert result == {} + # ===== ExtensionManager Tests ===== @@ -2500,6 +2576,40 @@ def test_list_by_priority_invalid_priority_defaults(self, temp_dir): assert [item[0] for item in result] == ["ext-high", "ext-invalid"] assert result[1][1]["priority"] == 10 + def test_list_by_priority_excludes_disabled(self, temp_dir): + """Test that list_by_priority excludes disabled extensions by default.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + + registry = ExtensionRegistry(extensions_dir) + registry.add("ext-enabled", {"version": "1.0.0", "enabled": True, "priority": 5}) + registry.add("ext-disabled", {"version": "1.0.0", "enabled": False, "priority": 1}) + registry.add("ext-default", {"version": "1.0.0", "priority": 10}) # no enabled field = True + + # Default: exclude disabled + by_priority = registry.list_by_priority() + ext_ids = [p[0] for p in by_priority] + assert "ext-enabled" in ext_ids + assert "ext-default" in ext_ids + assert "ext-disabled" not in ext_ids + + def test_list_by_priority_includes_disabled_when_requested(self, temp_dir): + """Test that list_by_priority includes disabled extensions when requested.""" + extensions_dir = temp_dir / "extensions" + extensions_dir.mkdir() + + registry = ExtensionRegistry(extensions_dir) + registry.add("ext-enabled", {"version": "1.0.0", "enabled": True, "priority": 5}) + registry.add("ext-disabled", {"version": "1.0.0", "enabled": False, "priority": 1}) + + # Include disabled + by_priority = registry.list_by_priority(include_disabled=True) + ext_ids = [p[0] for p in by_priority] + assert "ext-enabled" in ext_ids + assert "ext-disabled" in ext_ids + # Disabled ext has lower priority number, so it comes first when included + assert ext_ids[0] == "ext-disabled" + def test_install_with_priority(self, extension_dir, project_dir): """Test that install_from_directory stores priority.""" manager = ExtensionManager(project_dir) @@ -2541,8 +2651,8 @@ def test_priority_preserved_on_update(self, temp_dir): assert updated["priority"] == 5 # Preserved assert updated["enabled"] is False # Updated - def test_resolve_uses_unregistered_extension_dirs_when_registry_partially_corrupted(self, project_dir): - """Resolution scans unregistered extension dirs after valid registry entries.""" + def test_corrupted_extension_entry_not_picked_up_as_unregistered(self, project_dir): + """Corrupted registry entries are still tracked and NOT picked up as unregistered.""" extensions_dir = project_dir / ".specify" / "extensions" valid_dir = extensions_dir / "valid-ext" / "templates" @@ -2555,20 +2665,21 @@ def test_resolve_uses_unregistered_extension_dirs_when_registry_partially_corrup registry = ExtensionRegistry(extensions_dir) registry.add("valid-ext", {"version": "1.0.0", "priority": 10}) + # Corrupt the entry - should still be tracked, not picked up as unregistered registry.data["extensions"]["broken-ext"] = "corrupted" registry._save() from specify_cli.presets import PresetResolver resolver = PresetResolver(project_dir) + # Corrupted extension templates should NOT be resolved resolved = resolver.resolve("target-template") - sourced = resolver.resolve_with_source("target-template") + assert resolved is None - assert resolved is not None - assert resolved.name == "target-template.md" - assert "Broken Target" in resolved.read_text() - assert sourced is not None - assert sourced["source"] == "extension:broken-ext (unregistered)" + # Valid extension template should still resolve + valid_resolved = resolver.resolve("other-template") + assert valid_resolved is not None + assert "Valid" in valid_resolved.read_text() class TestExtensionPriorityCLI: diff --git a/tests/test_presets.py b/tests/test_presets.py index b6fe81d5b..2716b73dc 100644 --- a/tests/test_presets.py +++ b/tests/test_presets.py @@ -369,6 +369,172 @@ def test_get_nonexistent(self, temp_dir): registry = PresetRegistry(packs_dir) assert registry.get("nonexistent") is None + def test_restore(self, temp_dir): + """Test restore() preserves timestamps exactly.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + # Create original entry with a specific timestamp + original_metadata = { + "version": "1.0.0", + "source": "local", + "installed_at": "2025-01-15T10:30:00+00:00", + "enabled": True, + } + registry.restore("test-pack", original_metadata) + + # Verify exact restoration + restored = registry.get("test-pack") + assert restored["installed_at"] == "2025-01-15T10:30:00+00:00" + assert restored["version"] == "1.0.0" + assert restored["enabled"] is True + + def test_restore_rejects_none_metadata(self, temp_dir): + """Test restore() raises ValueError for None metadata.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-pack", None) + + def test_restore_rejects_non_dict_metadata(self, temp_dir): + """Test restore() raises ValueError for non-dict metadata.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-pack", "not-a-dict") + + with pytest.raises(ValueError, match="metadata must be a dict"): + registry.restore("test-pack", ["list", "not", "dict"]) + + def test_restore_uses_deep_copy(self, temp_dir): + """Test restore() deep copies metadata to prevent mutation.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + original_metadata = { + "version": "1.0.0", + "nested": {"key": "original"}, + } + registry.restore("test-pack", original_metadata) + + # Mutate the original metadata after restore + original_metadata["version"] = "MUTATED" + original_metadata["nested"]["key"] = "MUTATED" + + # Registry should have the original values + stored = registry.get("test-pack") + assert stored["version"] == "1.0.0" + assert stored["nested"]["key"] == "original" + + def test_get_returns_deep_copy(self, temp_dir): + """Test that get() returns a deep copy to prevent mutation.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + registry.add("test-pack", {"version": "1.0.0", "nested": {"key": "original"}}) + + # Get and mutate the returned copy + metadata = registry.get("test-pack") + metadata["version"] = "MUTATED" + metadata["nested"]["key"] = "MUTATED" + + # Original should be unchanged + fresh = registry.get("test-pack") + assert fresh["version"] == "1.0.0" + assert fresh["nested"]["key"] == "original" + + def test_get_returns_none_for_corrupted_entry(self, temp_dir): + """Test that get() returns None for corrupted (non-dict) entries.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + # Directly corrupt the registry with non-dict entries + registry.data["presets"]["corrupted-string"] = "not a dict" + registry.data["presets"]["corrupted-list"] = ["not", "a", "dict"] + registry.data["presets"]["corrupted-int"] = 42 + registry._save() + + # All corrupted entries should return None + assert registry.get("corrupted-string") is None + assert registry.get("corrupted-list") is None + assert registry.get("corrupted-int") is None + # Non-existent should also return None + assert registry.get("nonexistent") is None + + def test_list_returns_deep_copy(self, temp_dir): + """Test that list() returns deep copies to prevent mutation.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + registry.add("test-pack", {"version": "1.0.0", "nested": {"key": "original"}}) + + # Get list and mutate + all_packs = registry.list() + all_packs["test-pack"]["version"] = "MUTATED" + all_packs["test-pack"]["nested"]["key"] = "MUTATED" + + # Original should be unchanged + fresh = registry.get("test-pack") + assert fresh["version"] == "1.0.0" + assert fresh["nested"]["key"] == "original" + + def test_list_returns_empty_dict_for_corrupted_registry(self, temp_dir): + """Test that list() returns empty dict when presets is not a dict.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + # Corrupt the registry - presets is a list instead of dict + registry.data["presets"] = ["not", "a", "dict"] + registry._save() + + # list() should return empty dict, not crash + result = registry.list() + assert result == {} + + def test_list_by_priority_excludes_disabled(self, temp_dir): + """Test that list_by_priority excludes disabled presets by default.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + registry.add("pack-enabled", {"version": "1.0.0", "enabled": True, "priority": 5}) + registry.add("pack-disabled", {"version": "1.0.0", "enabled": False, "priority": 1}) + registry.add("pack-default", {"version": "1.0.0", "priority": 10}) # no enabled field = True + + # Default: exclude disabled + by_priority = registry.list_by_priority() + pack_ids = [p[0] for p in by_priority] + assert "pack-enabled" in pack_ids + assert "pack-default" in pack_ids + assert "pack-disabled" not in pack_ids + + def test_list_by_priority_includes_disabled_when_requested(self, temp_dir): + """Test that list_by_priority includes disabled presets when requested.""" + packs_dir = temp_dir / "packs" + packs_dir.mkdir() + registry = PresetRegistry(packs_dir) + + registry.add("pack-enabled", {"version": "1.0.0", "enabled": True, "priority": 5}) + registry.add("pack-disabled", {"version": "1.0.0", "enabled": False, "priority": 1}) + + # Include disabled + by_priority = registry.list_by_priority(include_disabled=True) + pack_ids = [p[0] for p in by_priority] + assert "pack-enabled" in pack_ids + assert "pack-disabled" in pack_ids + # Disabled pack has lower priority number, so it comes first when included + assert pack_ids[0] == "pack-disabled" + # ===== PresetManager Tests ===== @@ -707,6 +873,44 @@ def test_resolve_extension_provided_templates(self, project_dir): assert result is not None assert "Extension Custom Template" in result.read_text() + def test_resolve_disabled_extension_templates_skipped(self, project_dir): + """Test that disabled extension templates are not resolved.""" + # Create extension with templates + ext_dir = project_dir / ".specify" / "extensions" / "disabled-ext" + ext_templates_dir = ext_dir / "templates" + ext_templates_dir.mkdir(parents=True) + ext_template = ext_templates_dir / "disabled-template.md" + ext_template.write_text("# Disabled Extension Template\n") + + # Register extension as disabled + extensions_dir = project_dir / ".specify" / "extensions" + ext_registry = ExtensionRegistry(extensions_dir) + ext_registry.add("disabled-ext", {"version": "1.0.0", "priority": 1, "enabled": False}) + + # Template should NOT be resolved because extension is disabled + resolver = PresetResolver(project_dir) + result = resolver.resolve("disabled-template") + assert result is None, "Disabled extension template should not be resolved" + + def test_resolve_disabled_extension_not_picked_up_as_unregistered(self, project_dir): + """Test that disabled extensions are not picked up via unregistered dir scan.""" + # Create extension directory with templates + ext_dir = project_dir / ".specify" / "extensions" / "test-disabled-ext" + ext_templates_dir = ext_dir / "templates" + ext_templates_dir.mkdir(parents=True) + ext_template = ext_templates_dir / "unique-disabled-template.md" + ext_template.write_text("# Should Not Resolve\n") + + # Register the extension but disable it + extensions_dir = project_dir / ".specify" / "extensions" + ext_registry = ExtensionRegistry(extensions_dir) + ext_registry.add("test-disabled-ext", {"version": "1.0.0", "enabled": False}) + + # Verify the template is NOT resolved (even though the directory exists) + resolver = PresetResolver(project_dir) + result = resolver.resolve("unique-disabled-template") + assert result is None, "Disabled extension should not be picked up as unregistered" + def test_resolve_pack_over_extension(self, project_dir, pack_dir, temp_dir, valid_pack_data): """Test that pack templates take priority over extension templates.""" # Create extension with templates @@ -2001,3 +2205,189 @@ def test_mixed_legacy_and_new_presets_ordering(self, temp_dir): "legacy-pack", "low-priority-pack", ] + + +class TestPresetEnableDisable: + """Test preset enable/disable CLI commands.""" + + def test_disable_preset(self, project_dir, pack_dir): + """Test disable command sets enabled=False.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + + # Verify initially enabled + assert manager.registry.get("test-pack").get("enabled", True) is True + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "disable", "test-pack"]) + + assert result.exit_code == 0, result.output + assert "disabled" in result.output.lower() + + # Reload registry to see updated value + manager2 = PresetManager(project_dir) + assert manager2.registry.get("test-pack")["enabled"] is False + + def test_enable_preset(self, project_dir, pack_dir): + """Test enable command sets enabled=True.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset and disable it + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + manager.registry.update("test-pack", {"enabled": False}) + + # Verify disabled + assert manager.registry.get("test-pack")["enabled"] is False + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "enable", "test-pack"]) + + assert result.exit_code == 0, result.output + assert "enabled" in result.output.lower() + + # Reload registry to see updated value + manager2 = PresetManager(project_dir) + assert manager2.registry.get("test-pack")["enabled"] is True + + def test_disable_already_disabled(self, project_dir, pack_dir): + """Test disable on already disabled preset shows warning.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset and disable it + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + manager.registry.update("test-pack", {"enabled": False}) + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "disable", "test-pack"]) + + assert result.exit_code == 0, result.output + assert "already disabled" in result.output.lower() + + def test_enable_already_enabled(self, project_dir, pack_dir): + """Test enable on already enabled preset shows warning.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset (enabled by default) + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "enable", "test-pack"]) + + assert result.exit_code == 0, result.output + assert "already enabled" in result.output.lower() + + def test_disable_not_installed(self, project_dir): + """Test disable fails for non-installed preset.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "disable", "nonexistent"]) + + assert result.exit_code == 1, result.output + assert "not installed" in result.output.lower() + + def test_enable_not_installed(self, project_dir): + """Test enable fails for non-installed preset.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "enable", "nonexistent"]) + + assert result.exit_code == 1, result.output + assert "not installed" in result.output.lower() + + def test_disabled_preset_excluded_from_resolution(self, project_dir, pack_dir): + """Test that disabled presets are excluded from template resolution.""" + # Install preset with a template + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + + # Create a template in the preset directory + preset_template = project_dir / ".specify" / "presets" / "test-pack" / "templates" / "test-template.md" + preset_template.parent.mkdir(parents=True, exist_ok=True) + preset_template.write_text("# Template from test-pack") + + resolver = PresetResolver(project_dir) + + # Template should be found when enabled + result = resolver.resolve("test-template", "template") + assert result is not None + assert "test-pack" in str(result) + + # Disable the preset + manager.registry.update("test-pack", {"enabled": False}) + + # Template should NOT be found when disabled + resolver2 = PresetResolver(project_dir) + result2 = resolver2.resolve("test-template", "template") + assert result2 is None + + def test_enable_corrupted_registry_entry(self, project_dir, pack_dir): + """Test enable fails gracefully for corrupted registry entry.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset then corrupt the registry entry + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + manager.registry.data["presets"]["test-pack"] = "corrupted-string" + manager.registry._save() + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "enable", "test-pack"]) + + assert result.exit_code == 1 + assert "corrupted state" in result.output.lower() + + def test_disable_corrupted_registry_entry(self, project_dir, pack_dir): + """Test disable fails gracefully for corrupted registry entry.""" + from typer.testing import CliRunner + from unittest.mock import patch + from specify_cli import app + + runner = CliRunner() + + # Install preset then corrupt the registry entry + manager = PresetManager(project_dir) + manager.install_from_directory(pack_dir, "0.1.5") + manager.registry.data["presets"]["test-pack"] = "corrupted-string" + manager.registry._save() + + with patch.object(Path, "cwd", return_value=project_dir): + result = runner.invoke(app, ["preset", "disable", "test-pack"]) + + assert result.exit_code == 1 + assert "corrupted state" in result.output.lower()