From 59d4e867c0b43603fb054a7e401d9ed0f5bd5b0c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 30 Jan 2026 23:45:05 +0000 Subject: [PATCH 01/15] tests and base classes --- .pyrit_conf_example | 70 ++++ pyrit/common/path.py | 28 +- pyrit/setup/__init__.py | 3 + pyrit/setup/configuration_loader.py | 319 +++++++++++++++ tests/unit/setup/test_configuration_loader.py | 364 ++++++++++++++++++ 5 files changed, 776 insertions(+), 8 deletions(-) create mode 100644 .pyrit_conf_example create mode 100644 pyrit/setup/configuration_loader.py create mode 100644 tests/unit/setup/test_configuration_loader.py diff --git a/.pyrit_conf_example b/.pyrit_conf_example new file mode 100644 index 0000000000..70c60f74fd --- /dev/null +++ b/.pyrit_conf_example @@ -0,0 +1,70 @@ +# PyRIT Configuration File Example +# ================================ +# Copy this file to ~/.pyrit/.pyrit_conf or specify a custom path when loading. +# +# For documentation on configuration options, see: +# https://github.com/Azure/PyRIT/blob/main/doc/setup/configuration.md + +# Memory Database Type +# -------------------- +# Specifies which database backend to use for storing prompts and results. +# Options: in_memory, sqlite, azure_sql (case-insensitive) +# - in_memory: Temporary in-memory database (data lost on exit) +# - sqlite: Persistent local SQLite database (default) +# - azure_sql: Azure SQL database (requires connection string in env vars) +memory_db_type: sqlite + +# Initializers +# ------------ +# List of built-in initializers to run during PyRIT initialization. +# Initializers configure default values for converters, scorers, and targets. +# Names are normalized to snake_case (e.g., "SimpleInitializer" -> "simple"). +# +# Available initializers: +# - simple: Basic OpenAI configuration (requires OPENAI_CHAT_* env vars) +# - airt: AI Red Team setup with Azure OpenAI (requires AZURE_OPENAI_* env vars) +# - load_default_datasets: Loads default datasets for all registered scenarios +# - objective_list: Sets default objectives for scenarios +# - openai_objective_target: Sets up OpenAI target for scenarios +# +# Each initializer can be specified as: +# - A simple string (name only) +# - A dictionary with 'name' and optional 'args' for constructor arguments +# +# Example: +# initializers: +# - simple +# - name: airt +# args: +# some_param: value +initializers: + - simple + +# Initialization Scripts +# ---------------------- +# List of paths to custom Python scripts containing PyRITInitializer subclasses. +# Paths can be absolute or relative to the current working directory. +# +# Example: +# initialization_scripts: +# - /path/to/my_custom_initializer.py +# - ./local_initializer.py +initialization_scripts: [] + +# Environment Files +# ----------------- +# List of .env file paths to load during initialization. +# Later files override values from earlier files. +# If not specified, PyRIT loads ~/.pyrit/.env and ~/.pyrit/.env.local by default. +# +# Example: +# env_files: +# - /path/to/.env +# - /path/to/.env.local +env_files: [] + +# Silent Mode +# ----------- +# If true, suppresses print statements during initialization. +# Useful for non-interactive environments or when embedding PyRIT in other tools. +silent: false diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 4094ba8a4b..b61eb09d91 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -33,6 +33,10 @@ def in_git_repo() -> bool: CONFIGURATION_DIRECTORY_PATH = pathlib.Path.home() / ".pyrit" +# Default configuration file name and path +DEFAULT_CONFIG_FILENAME = ".pyrit_conf" +DEFAULT_CONFIG_PATH = CONFIGURATION_DIRECTORY_PATH / DEFAULT_CONFIG_FILENAME + # Points to the root of the project HOME_PATH = pathlib.Path(PYRIT_PATH, "..").resolve() @@ -54,22 +58,30 @@ def in_git_repo() -> bool: DATASETS_PATH = pathlib.Path(PYRIT_PATH, "datasets").resolve() EXECUTOR_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "executors").resolve() -EXECUTOR_RED_TEAM_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() -EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() -CONVERTER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "prompt_converters").resolve() +EXECUTOR_RED_TEAM_PATH = pathlib.Path( + EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() +EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path( + EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() +CONVERTER_SEED_PROMPT_PATH = pathlib.Path( + DATASETS_PATH, "prompt_converters").resolve() SCORER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "score").resolve() -SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() +SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path( + SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() SCORER_LIKERT_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "likert").resolve() SCORER_SCALES_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "scales").resolve() HARM_DEFINITION_PATH = pathlib.Path(DATASETS_PATH, "harm_definition").resolve() -JAILBREAK_TEMPLATES_PATH = pathlib.Path(DATASETS_PATH, "jailbreak", "templates").resolve() +JAILBREAK_TEMPLATES_PATH = pathlib.Path( + DATASETS_PATH, "jailbreak", "templates").resolve() SCORER_EVALS_PATH = pathlib.Path(DATASETS_PATH, "scorer_evals").resolve() SCORER_EVALS_HARM_PATH = pathlib.Path(SCORER_EVALS_PATH, "harm").resolve() -SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path(SCORER_EVALS_PATH, "objective").resolve() -SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path(SCORER_EVALS_PATH, "refusal_scorer").resolve() -SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path(SCORER_EVALS_PATH, "true_false").resolve() +SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path( + SCORER_EVALS_PATH, "objective").resolve() +SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path( + SCORER_EVALS_PATH, "refusal_scorer").resolve() +SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path( + SCORER_EVALS_PATH, "true_false").resolve() SCORER_EVALS_LIKERT_PATH = pathlib.Path(SCORER_EVALS_PATH, "likert").resolve() diff --git a/pyrit/setup/__init__.py b/pyrit/setup/__init__.py index 4ecdbd9d43..2929a59ea3 100644 --- a/pyrit/setup/__init__.py +++ b/pyrit/setup/__init__.py @@ -3,6 +3,7 @@ """Module containing initialization PyRIT.""" +from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async __all__ = [ @@ -10,5 +11,7 @@ "SQLITE", "IN_MEMORY", "initialize_pyrit_async", + "initialize_from_config_async", "MemoryDatabaseType", + "ConfigurationLoader", ] diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py new file mode 100644 index 0000000000..2da6eda6bc --- /dev/null +++ b/pyrit/setup/configuration_loader.py @@ -0,0 +1,319 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Configuration loader for PyRIT initialization. + +This module provides the ConfigurationLoader class that loads PyRIT configuration +from YAML files and initializes PyRIT accordingly. +""" + +import pathlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union + +from pyrit.common.path import DEFAULT_CONFIG_PATH +from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.setup.initialization import ( + AZURE_SQL, + IN_MEMORY, + SQLITE, + initialize_pyrit_async, +) + +if TYPE_CHECKING: + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + +# Type alias for YAML-serializable values that can be passed as initializer args +# This matches what YAML can represent: primitives, lists, and nested dicts +YamlPrimitive = Union[str, int, float, bool, None] +YamlValue = Union[YamlPrimitive, List["YamlValue"], Dict[str, "YamlValue"]] + +# Mapping from snake_case config values to internal constants +_MEMORY_DB_TYPE_MAP: Dict[str, str] = { + "in_memory": IN_MEMORY, + "sqlite": SQLITE, + "azure_sql": AZURE_SQL, +} + + +@dataclass +class InitializerConfig: + """ + Configuration for a single initializer. + + Attributes: + name: The name of the initializer (must be registered in InitializerRegistry). + args: Optional dictionary of YAML-serializable arguments to pass to the initializer constructor. + """ + + name: str + args: Optional[Dict[str, YamlValue]] = None + + +@dataclass +class ConfigurationLoader(YamlLoadable): + """ + Loader for PyRIT configuration from YAML files. + + This class loads configuration from a YAML file and provides methods to + initialize PyRIT with the loaded configuration. + + Attributes: + memory_db_type: The type of memory database (in_memory, sqlite, azure_sql). + initializers: List of initializer configurations (name + optional args). + initialization_scripts: List of paths to custom initialization scripts. + env_files: List of environment file paths to load. + silent: Whether to suppress initialization messages. + + Example YAML configuration: + memory_db_type: sqlite + + initializers: + - simple + - name: airt + args: + some_param: value + + initialization_scripts: + - /path/to/custom_initializer.py + + env_files: + - /path/to/.env + - /path/to/.env.local + + silent: false + """ + + memory_db_type: str = "sqlite" + initializers: List[Union[str, Dict[str, Any]] + ] = field(default_factory=list) + initialization_scripts: List[str] = field(default_factory=list) + env_files: List[str] = field(default_factory=list) + silent: bool = False + + def __post_init__(self) -> None: + """Validate and normalize the configuration after loading.""" + self._normalize_memory_db_type() + self._normalize_initializers() + + def _normalize_memory_db_type(self) -> None: + """ + Normalize and validate memory_db_type. + + Converts the input to lowercase snake_case and validates against known types. + Stores the normalized snake_case value for config consistency, but maps + to internal constants when initializing. + """ + # Normalize to lowercase + normalized = self.memory_db_type.lower().replace("-", "_") + + # Also handle PascalCase inputs (e.g., "InMemory" -> "in_memory") + if normalized not in _MEMORY_DB_TYPE_MAP: + # Try converting from PascalCase + normalized = class_name_to_snake_case(self.memory_db_type) + + if normalized not in _MEMORY_DB_TYPE_MAP: + valid_types = list(_MEMORY_DB_TYPE_MAP.keys()) + raise ValueError( + f"Invalid memory_db_type '{self.memory_db_type}'. " + f"Must be one of: {', '.join(valid_types)}" + ) + + # Store normalized snake_case value + self.memory_db_type = normalized + + def _normalize_initializers(self) -> None: + """ + Normalize initializer entries to InitializerConfig objects. + + Converts initializer names to snake_case for consistent registry lookup. + """ + normalized: List[InitializerConfig] = [] + for entry in self.initializers: + if isinstance(entry, str): + # Simple string entry: normalize name to snake_case + name = class_name_to_snake_case(entry) + normalized.append(InitializerConfig(name=name)) + elif isinstance(entry, dict): + # Dict entry: name and optional args + if "name" not in entry: + raise ValueError( + f"Initializer configuration must have a 'name' field. Got: {entry}" + ) + name = class_name_to_snake_case(entry["name"]) + normalized.append( + InitializerConfig( + name=name, + args=entry.get("args"), + ) + ) + else: + raise ValueError( + f"Initializer entry must be a string or dict, got: {type(entry).__name__}" + ) + self._initializer_configs = normalized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": + """ + Create a ConfigurationLoader from a dictionary. + + Args: + data: Dictionary containing configuration values. + + Returns: + A new ConfigurationLoader instance. + """ + # Filter out None values and empty lists to use defaults + filtered_data = { + k: v for k, v in data.items() + if v is not None and v != [] + } + return cls(**filtered_data) + + @classmethod + def get_default_config_path(cls) -> pathlib.Path: + """ + Get the default configuration file path. + + Returns: + Path to the default config file in ~/.pyrit/.pyrit_conf + """ + return DEFAULT_CONFIG_PATH + + def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: + """ + Resolve initializer names to PyRITInitializer instances. + + Uses the InitializerRegistry to look up initializer classes by name + and instantiate them with optional arguments. + + Returns: + Sequence of PyRITInitializer instances. + + Raises: + ValueError: If an initializer name is not found in the registry. + """ + from pyrit.registry import InitializerRegistry + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + if not self._initializer_configs: + return [] + + registry = InitializerRegistry() + resolved: List[PyRITInitializer] = [] + + for config in self._initializer_configs: + initializer_class = registry.get_class(config.name) + if initializer_class is None: + available = ", ".join(sorted(registry.get_names())) + raise ValueError( + f"Initializer '{config.name}' not found in registry.\n" + f"Available initializers: {available}" + ) + + # Instantiate with args if provided + if config.args: + instance = initializer_class(**config.args) + else: + instance = initializer_class() + + resolved.append(instance) + + return resolved + + def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve initialization script paths. + + Returns: + Sequence of Path objects, or None if no scripts configured. + """ + if not self.initialization_scripts: + return None + + resolved: List[pathlib.Path] = [] + for script_str in self.initialization_scripts: + script_path = pathlib.Path(script_str) + if not script_path.is_absolute(): + script_path = pathlib.Path.cwd() / script_path + resolved.append(script_path) + + return resolved + + def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve environment file paths. + + Returns: + Sequence of Path objects, or None if no env files configured. + """ + if not self.env_files: + return None + + resolved: List[pathlib.Path] = [] + for env_str in self.env_files: + env_path = pathlib.Path(env_str) + if not env_path.is_absolute(): + env_path = pathlib.Path.cwd() / env_path + resolved.append(env_path) + + return resolved + + async def initialize_pyrit_async(self) -> None: + """ + Initialize PyRIT with the loaded configuration. + + This method resolves all initializer names to instances and calls + the core initialize_pyrit_async function. + + Raises: + ValueError: If configuration is invalid or initializers cannot be resolved. + """ + resolved_initializers = self._resolve_initializers() + resolved_scripts = self._resolve_initialization_scripts() + resolved_env_files = self._resolve_env_files() + + # Map snake_case memory_db_type to internal constant + internal_memory_db_type = _MEMORY_DB_TYPE_MAP[self.memory_db_type] + + await initialize_pyrit_async( + memory_db_type=internal_memory_db_type, + initialization_scripts=resolved_scripts, + initializers=resolved_initializers if resolved_initializers else None, + env_files=resolved_env_files, + silent=self.silent, + ) + + +async def initialize_from_config_async( + config_path: Optional[Union[str, pathlib.Path]] = None, +) -> ConfigurationLoader: + """ + Initialize PyRIT from a configuration file. + + This is a convenience function that loads a ConfigurationLoader from + a YAML file and initializes PyRIT. + + Args: + config_path: Path to the configuration file. If None, uses the default + path (~/.pyrit/.pyrit_conf). Can be a string or pathlib.Path. + + Returns: + The loaded ConfigurationLoader instance. + + Raises: + FileNotFoundError: If the configuration file does not exist. + ValueError: If the configuration is invalid. + """ + if config_path is None: + config_path = ConfigurationLoader.get_default_config_path() + elif isinstance(config_path, str): + config_path = pathlib.Path(config_path) + + config = ConfigurationLoader.from_yaml_file(config_path) + await config.initialize_pyrit_async() + return config diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py new file mode 100644 index 0000000000..d652969e36 --- /dev/null +++ b/tests/unit/setup/test_configuration_loader.py @@ -0,0 +1,364 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pathlib +import tempfile +from unittest import mock + +import pytest + +from pyrit.setup.configuration_loader import ( + ConfigurationLoader, + InitializerConfig, + initialize_from_config_async, +) + + +class TestInitializerConfig: + """Tests for InitializerConfig dataclass.""" + + def test_initializer_config_with_name_only(self): + """Test creating InitializerConfig with just a name.""" + config = InitializerConfig(name="simple") + assert config.name == "simple" + assert config.args is None + + def test_initializer_config_with_args(self): + """Test creating InitializerConfig with name and args.""" + config = InitializerConfig(name="custom", args={"param1": "value1"}) + assert config.name == "custom" + assert config.args == {"param1": "value1"} + + +class TestConfigurationLoader: + """Tests for ConfigurationLoader class.""" + + def test_default_values(self): + """Test default configuration values.""" + config = ConfigurationLoader() + assert config.memory_db_type == "sqlite" + assert config.initializers == [] + assert config.initialization_scripts == [] + assert config.env_files == [] + assert config.silent is False + + def test_valid_memory_db_types_snake_case(self): + """Test all valid memory database types in snake_case.""" + for db_type in ["in_memory", "sqlite", "azure_sql"]: + config = ConfigurationLoader(memory_db_type=db_type) + assert config.memory_db_type == db_type + + def test_memory_db_type_normalization_from_pascal_case(self): + """Test that PascalCase memory_db_type is normalized to snake_case.""" + config = ConfigurationLoader(memory_db_type="InMemory") + assert config.memory_db_type == "in_memory" + + config = ConfigurationLoader(memory_db_type="SQLite") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="AzureSQL") + assert config.memory_db_type == "azure_sql" + + def test_memory_db_type_normalization_case_insensitive(self): + """Test that memory_db_type normalization is case-insensitive.""" + config = ConfigurationLoader(memory_db_type="SQLITE") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="In_Memory") + assert config.memory_db_type == "in_memory" + + def test_invalid_memory_db_type_raises_error(self): + """Test that invalid memory_db_type raises ValueError.""" + with pytest.raises(ValueError, match="Invalid memory_db_type"): + ConfigurationLoader(memory_db_type="InvalidType") + + def test_initializer_as_string(self): + """Test initializers specified as simple strings.""" + config = ConfigurationLoader(initializers=["simple", "airt"]) + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + assert config._initializer_configs[1].name == "airt" + + def test_initializer_as_dict_with_name_only(self): + """Test initializers specified as dicts with only name.""" + config = ConfigurationLoader(initializers=[{"name": "simple"}]) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + + def test_initializer_as_dict_with_args(self): + """Test initializers specified as dicts with name and args.""" + config = ConfigurationLoader( + initializers=[{"name": "custom", "args": { + "param1": "value1", "param2": 42}}] + ) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "custom" + assert config._initializer_configs[0].args == { + "param1": "value1", "param2": 42} + + def test_mixed_initializer_formats(self): + """Test initializers with mixed string and dict formats.""" + config = ConfigurationLoader( + initializers=[ + "simple", + {"name": "airt"}, + {"name": "custom", "args": {"key": "value"}}, + ] + ) + assert len(config._initializer_configs) == 3 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[2].name == "custom" + assert config._initializer_configs[2].args == {"key": "value"} + + def test_initializer_name_normalization_from_pascal_case(self): + """Test that PascalCase initializer names are normalized to snake_case.""" + config = ConfigurationLoader( + initializers=["SimpleInitializer", "AIRTInitializer"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_initializer" + + def test_initializer_name_normalization_preserves_snake_case(self): + """Test that snake_case names are preserved.""" + config = ConfigurationLoader( + initializers=["simple_initializer", "airt_init"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_init" + + def test_initializer_name_already_snake_case(self): + """Test that snake_case names remain unchanged.""" + config = ConfigurationLoader( + initializers=["load_default_datasets", "objective_list"]) + assert config._initializer_configs[0].name == "load_default_datasets" + assert config._initializer_configs[1].name == "objective_list" + + def test_initializer_dict_without_name_raises_error(self): + """Test that dict initializer without 'name' raises ValueError.""" + with pytest.raises(ValueError, match="must have a 'name' field"): + ConfigurationLoader(initializers=[{"args": {"key": "value"}}]) + + def test_initializer_invalid_type_raises_error(self): + """Test that invalid initializer type raises ValueError.""" + with pytest.raises(ValueError, match="must be a string or dict"): + ConfigurationLoader(initializers=[123]) # type: ignore + + def test_from_dict_with_all_fields(self): + """Test from_dict with all configuration fields.""" + data = { + "memory_db_type": "InMemory", + "initializers": ["simple"], + "initialization_scripts": ["/path/to/script.py"], + "env_files": ["/path/to/.env"], + "silent": True, + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "in_memory" # Normalized to snake_case + assert config.initializers == ["simple"] + assert config.initialization_scripts == ["/path/to/script.py"] + assert config.env_files == ["/path/to/.env"] + assert config.silent is True + + def test_from_dict_filters_none_values(self): + """Test that from_dict filters out None values.""" + data = { + "memory_db_type": "SQLite", + "initializers": None, + "env_files": [], + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "sqlite" # Normalized to snake_case + assert config.initializers == [] # Uses default, not None + + def test_from_yaml_file(self): + """Test loading configuration from a YAML file.""" + yaml_content = """ +memory_db_type: in_memory +initializers: + - simple + - name: airt + args: + key: value +silent: true +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + config = ConfigurationLoader.from_yaml_file(yaml_path) + assert config.memory_db_type == "in_memory" + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[1].args == {"key": "value"} + assert config.silent is True + finally: + pathlib.Path(yaml_path).unlink() + + def test_get_default_config_path(self): + """Test get_default_config_path returns expected path.""" + default_path = ConfigurationLoader.get_default_config_path() + assert default_path.name == ".pyrit_conf" + assert ".pyrit" in str(default_path) + + +class TestConfigurationLoaderResolvers: + """Tests for ConfigurationLoader path resolution methods.""" + + def test_resolve_initialization_scripts_empty(self): + """Test that empty scripts returns None.""" + config = ConfigurationLoader() + assert config._resolve_initialization_scripts() is None + + def test_resolve_initialization_scripts_absolute_path(self): + """Test resolving absolute script paths.""" + config = ConfigurationLoader(initialization_scripts=[ + "/absolute/path/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0] == pathlib.Path("/absolute/path/script.py") + + def test_resolve_initialization_scripts_relative_path(self): + """Test resolving relative script paths (converted to absolute).""" + config = ConfigurationLoader( + initialization_scripts=["relative/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].is_absolute() + assert str(resolved[0]).endswith("relative/script.py") + + def test_resolve_env_files_empty(self): + """Test that empty env files returns None.""" + config = ConfigurationLoader() + assert config._resolve_env_files() is None + + def test_resolve_env_files_absolute_path(self): + """Test resolving absolute env file paths.""" + config = ConfigurationLoader(env_files=["/path/to/.env"]) + resolved = config._resolve_env_files() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0] == pathlib.Path("/path/to/.env") + + +@pytest.mark.usefixtures("patch_central_database") +class TestConfigurationLoaderInitialization: + """Tests for ConfigurationLoader.initialize_pyrit_async method.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + async def test_initialize_pyrit_async_basic(self, mock_init): + """Test basic initialization with minimal configuration.""" + config = ConfigurationLoader(memory_db_type="in_memory") + await config.initialize_pyrit_async() + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + # Should map snake_case to internal constant + assert call_kwargs["memory_db_type"] == "InMemory" + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["initializers"] is None + assert call_kwargs["env_files"] is None + assert call_kwargs["silent"] is False + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, mock_init): + """Test initialization with initializers resolved from registry.""" + # Setup mock registry + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + + # Mock an initializer class + mock_initializer_class = mock.MagicMock() + mock_initializer_instance = mock.MagicMock() + mock_initializer_class.return_value = mock_initializer_instance + mock_registry.get_class.return_value = mock_initializer_class + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["simple"], + ) + await config.initialize_pyrit_async() + + # Verify registry was used to resolve initializer + mock_registry.get_class.assert_called_once_with("simple") + mock_initializer_class.assert_called_once_with() + + # Verify initialize was called with resolved initializers + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + assert call_kwargs["initializers"] == [mock_initializer_instance] + + @pytest.mark.asyncio + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_unknown_initializer_raises_error(self, mock_registry_cls): + """Test that unknown initializer name raises ValueError.""" + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + mock_registry.get_class.return_value = None + mock_registry.get_names.return_value = ["simple", "airt"] + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["unknown_initializer"], + ) + + with pytest.raises(ValueError, match="not found in registry"): + await config.initialize_pyrit_async() + + +@pytest.mark.usefixtures("patch_central_database") +class TestInitializeFromConfigAsync: + """Tests for initialize_from_config_async function.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with explicit path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + mock_from_yaml.assert_called_once_with( + pathlib.Path("/path/to/config.yaml")) + mock_init.assert_called_once() + assert result is mock_config + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_string_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with string path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + # Should convert string to Path + call_args = mock_from_yaml.call_args[0][0] + assert isinstance(call_args, pathlib.Path) + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.get_default_config_path") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_default_path(self, mock_init, mock_from_yaml, mock_default_path): + """Test initialize_from_config_async uses default path when none specified.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + mock_default_path.return_value = pathlib.Path( + "/default/path/.pyrit_conf") + + await initialize_from_config_async() + + mock_default_path.assert_called_once() + mock_from_yaml.assert_called_once_with( + pathlib.Path("/default/path/.pyrit_conf")) From 6518e397fb05afb1738e8c74ae353dd7c1570875 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 31 Jan 2026 00:45:35 +0000 Subject: [PATCH 02/15] tweaks --- pyrit/cli/frontend_core.py | 138 ++++++++++++++++++++++++++++++++++--- pyrit/cli/pyrit_scan.py | 17 ++++- pyrit/cli/pyrit_shell.py | 7 ++ 3 files changed, 152 insertions(+), 10 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 6cfc8b5a45..b8d4b9b956 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -46,6 +46,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i ScenarioMetadata, ScenarioRegistry, ) + from pyrit.setup import ConfigurationLoader logger = logging.getLogger(__name__) @@ -66,16 +67,23 @@ class FrontendCore: def __init__( self, *, - database: str = SQLITE, + config_file: Optional[Path] = None, + database: Optional[str] = None, initialization_scripts: Optional[list[Path]] = None, initializer_names: Optional[list[str]] = None, env_files: Optional[list[Path]] = None, - log_level: str = "WARNING", + log_level: Optional[str] = None, ): """ Initialize PyRIT context. + Configuration is loaded in the following order (later values override earlier): + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual CLI arguments (database, initializers, etc.) + Args: + config_file: Optional path to a YAML configuration file. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. @@ -83,14 +91,34 @@ def __init__( log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. Raises: - ValueError: If database or log_level are invalid. + ValueError: If database or log_level are invalid, or if config file is invalid. + FileNotFoundError: If an explicitly specified config_file does not exist. """ - # Validate inputs - self._database = validate_database(database=database) - self._initialization_scripts = initialization_scripts - self._initializer_names = initializer_names - self._env_files = env_files - self._log_level = validate_log_level(log_level=log_level) + from pyrit.setup import ConfigurationLoader + + # Load configuration from files and merge with CLI arguments + config = self._load_and_merge_config( + config_file=config_file, + database=database, + initialization_scripts=initialization_scripts, + initializer_names=initializer_names, + env_files=env_files, + ) + + # Store the merged configuration + self._config = config + + # Extract values from config for internal use + # Map snake_case db type back to PascalCase for backward compatibility + db_type_map = {"in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL} + self._database = db_type_map[config.memory_db_type] + self._initialization_scripts = config._resolve_initialization_scripts() + self._initializer_names = [ic.name for ic in config._initializer_configs] if config._initializer_configs else None + self._env_files = config._resolve_env_files() + + # Log level comes from CLI arg (not in config file), default to WARNING + effective_log_level = log_level if log_level is not None else "WARNING" + self._log_level = validate_log_level(log_level=effective_log_level) # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -100,6 +128,93 @@ def __init__( # Configure logging logging.basicConfig(level=getattr(logging, self._log_level)) + def _load_and_merge_config( + self, + *, + config_file: Optional[Path], + database: Optional[str], + initialization_scripts: Optional[list[Path]], + initializer_names: Optional[list[str]], + env_files: Optional[list[Path]], + ) -> "ConfigurationLoader": + """ + Load configuration from files and merge with CLI arguments. + + Precedence (later overrides earlier): + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual CLI arguments + + Args: + config_file: Optional explicit config file path. + database: Optional database type from CLI. + initialization_scripts: Optional scripts from CLI. + initializer_names: Optional initializer names from CLI. + env_files: Optional env files from CLI. + + Returns: + Merged ConfigurationLoader instance. + """ + from pyrit.setup import ConfigurationLoader + + # Start with defaults + config_data: dict = { + "memory_db_type": "sqlite", + "initializers": [], + "initialization_scripts": [], + "env_files": [], + } + + # 1. Try loading default config file if it exists + default_config_path = ConfigurationLoader.get_default_config_path() + if default_config_path.exists(): + try: + default_config = ConfigurationLoader.from_yaml_file(default_config_path) + config_data["memory_db_type"] = default_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in default_config._initializer_configs + ] + config_data["initialization_scripts"] = default_config.initialization_scripts + config_data["env_files"] = default_config.env_files + except Exception as e: + logger.warning(f"Failed to load default config file {default_config_path}: {e}") + + # 2. Load explicit config file if provided (overrides default) + if config_file is not None: + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + explicit_config = ConfigurationLoader.from_yaml_file(config_file) + config_data["memory_db_type"] = explicit_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in explicit_config._initializer_configs + ] + config_data["initialization_scripts"] = explicit_config.initialization_scripts + config_data["env_files"] = explicit_config.env_files + + # 3. Apply CLI overrides (non-None values take precedence) + if database is not None: + # Normalize to snake_case for ConfigurationLoader + normalized_db = database.lower().replace("-", "_") + # Handle PascalCase inputs + if normalized_db == "inmemory": + normalized_db = "in_memory" + elif normalized_db == "azuresql": + normalized_db = "azure_sql" + config_data["memory_db_type"] = normalized_db + + if initialization_scripts is not None: + config_data["initialization_scripts"] = [str(p) for p in initialization_scripts] + + if initializer_names is not None: + config_data["initializers"] = initializer_names + + if env_files is not None: + config_data["env_files"] = [str(p) for p in env_files] + + return ConfigurationLoader.from_dict(config_data) + async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" if self._initialized: @@ -734,6 +849,11 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path # Shared argument help text ARG_HELP = { + "config_file": ( + "Path to a YAML configuration file. Allows specifying database, initializers (with args), " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." + ), "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index df342ce0cd..ba025c61d7 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -10,6 +10,7 @@ import asyncio import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path from typing import Optional from pyrit.cli import frontend_core @@ -34,6 +35,9 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: # Run a scenario with built-in initializers pyrit_scan foundry --initializers openai_objective_target load_default_datasets + # Run with a configuration file (recommended for complex setups) + pyrit_scan foundry --config-file ./my_config.yaml + # Run with custom initialization scripts pyrit_scan garak.encoding --initialization-scripts ./my_config.py @@ -45,6 +49,12 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: formatter_class=RawDescriptionHelpFormatter, ) + parser.add_argument( + "--config-file", + type=Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--log-level", type=frontend_core.validate_log_level_argparse, @@ -182,6 +192,7 @@ def main(args: Optional[list[str]] = None) -> int: return 1 context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, env_files=env_files, @@ -194,7 +205,10 @@ def main(args: Optional[list[str]] = None) -> int: # Discover from scenarios directory scenarios_path = frontend_core.get_default_initializer_discovery_path() - context = frontend_core.FrontendCore(log_level=parsed_args.log_level) + context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, + log_level=parsed_args.log_level, + ) return asyncio.run(frontend_core.print_initializers_list_async(context=context, discovery_path=scenarios_path)) # Verify scenario was provided @@ -218,6 +232,7 @@ def main(args: Optional[list[str]] = None) -> int: # Create context with initializers context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index bcb0743425..46ccaab73e 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -453,6 +453,12 @@ def main() -> int: description="PyRIT Interactive Shell - Load modules once, run commands instantly", ) + parser.add_argument( + "--config-file", + type=frontend_core.Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--database", choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], @@ -488,6 +494,7 @@ def main() -> int: # Create context (initializers are specified per-run, not at startup) context = frontend_core.FrontendCore( + config_file=args.config_file, database=args.database, initialization_scripts=None, initializer_names=None, From f012424b160ec3316715b568f3e6a8f5622c350d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 31 Jan 2026 00:52:55 +0000 Subject: [PATCH 03/15] config-file arg --- pyrit/cli/pyrit_scan.py | 9 ++++-- pyrit/cli/pyrit_shell.py | 60 ++++++++++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index ba025c61d7..cf27df6840 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -186,7 +186,8 @@ def main(args: Optional[list[str]] = None) -> int: env_files = None if parsed_args.env_files: try: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=parsed_args.env_files) except ValueError as e: print(f"Error: {e}") return 1 @@ -228,7 +229,8 @@ def main(args: Optional[list[str]] = None) -> int: # Collect environment files env_files = None if parsed_args.env_files: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=parsed_args.env_files) # Create context with initializers context = frontend_core.FrontendCore( @@ -243,7 +245,8 @@ def main(args: Optional[list[str]] = None) -> int: # Parse memory labels if provided memory_labels = None if parsed_args.memory_labels: - memory_labels = frontend_core.parse_memory_labels(json_string=parsed_args.memory_labels) + memory_labels = frontend_core.parse_memory_labels( + json_string=parsed_args.memory_labels) # Run scenario asyncio.run( diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 46ccaab73e..a18b4ab145 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -105,7 +105,8 @@ def __init__( self._scenario_history: list[tuple[str, ScenarioResult]] = [] # Initialize PyRIT in background thread for faster startup - self._init_thread = threading.Thread(target=self._background_init, daemon=True) + self._init_thread = threading.Thread( + target=self._background_init, daemon=True) self._init_complete = threading.Event() self._init_thread.start() @@ -125,7 +126,8 @@ def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" self._ensure_initialized() try: - asyncio.run(frontend_core.print_scenarios_list_async(context=self.context)) + asyncio.run(frontend_core.print_scenarios_list_async( + context=self.context)) except Exception as e: print(f"Error listing scenarios: {e}") @@ -136,7 +138,8 @@ def do_list_initializers(self, arg: str) -> None: # Discover from scenarios directory by default (same as scan) discovery_path = frontend_core.get_default_initializer_discovery_path() asyncio.run( - frontend_core.print_initializers_list_async(context=self.context, discovery_path=discovery_path) + frontend_core.print_initializers_list_async( + context=self.context, discovery_path=discovery_path) ) except Exception as e: print(f"Error listing initializers: {e}") @@ -179,14 +182,19 @@ def do_run(self, line: str) -> None: print("\nUsage: run [options]") print("\nNote: Every scenario requires an initializer.") print("\nOptions:") - print(f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") + print( + f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") print( f" --initialization-scripts <...> {frontend_core.ARG_HELP['initialization_scripts']} (alternative to --initializers)" ) - print(f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") - print(f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") - print(f" --max-retries {frontend_core.ARG_HELP['max_retries']}") - print(f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") + print( + f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") + print( + f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") + print( + f" --max-retries {frontend_core.ARG_HELP['max_retries']}") + print( + f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") print( f" --database Override default database ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" ) @@ -194,7 +202,8 @@ def do_run(self, line: str) -> None: f" --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" ) print("\nExample:") - print(" run foundry --initializers openai_objective_target load_default_datasets") + print( + " run foundry --initializers openai_objective_target load_default_datasets") print("\nType 'help run' for more details and examples") return @@ -220,7 +229,8 @@ def do_run(self, line: str) -> None: resolved_env_files = None if args["env_files"]: try: - resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) + resolved_env_files = frontend_core.resolve_env_files( + env_file_paths=args["env_files"]) except ValueError as e: print(f"Error: {e}") return @@ -283,7 +293,8 @@ def do_scenario_history(self, arg: str) -> None: print(f"{idx}) {command}") print("=" * 80) print(f"\nTotal runs: {len(self._scenario_history)}") - print("\nUse 'print-scenario ' to view detailed results for a specific run.") + print( + "\nUse 'print-scenario ' to view detailed results for a specific run.") print("Use 'print-scenario' to view detailed results for all runs.") def do_print_scenario(self, arg: str) -> None: @@ -325,7 +336,8 @@ def do_print_scenario(self, arg: str) -> None: try: scenario_num = int(arg) if scenario_num < 1 or scenario_num > len(self._scenario_history): - print(f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") + print( + f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") return command, result = self._scenario_history[scenario_num - 1] @@ -338,7 +350,8 @@ def do_print_scenario(self, arg: str) -> None: printer = ConsoleScenarioResultPrinter() asyncio.run(printer.print_summary_async(result)) except ValueError: - print(f"Error: Invalid scenario number '{arg}'. Must be an integer.") + print( + f"Error: Invalid scenario number '{arg}'. Must be an integer.") def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" @@ -351,12 +364,14 @@ def do_help(self, arg: str) -> None: print(" --database ") print(" Default database type: InMemory, SQLite, or AzureSQL") print(" Default: SQLite") - print(" Can be overridden per-run with 'run --database '") + print( + " Can be overridden per-run with 'run --database '") print() print(" --log-level ") print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") print(" Default: WARNING") - print(" Can be overridden per-run with 'run --log-level '") + print( + " Can be overridden per-run with 'run --log-level '") print() print("=" * 70) print("Run Command Options (specified when running scenarios):") @@ -364,9 +379,11 @@ def do_help(self, arg: str) -> None: print(" --initializers [ ...] (REQUIRED)") print(f" {frontend_core.ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") - print(" Example: run foundry --initializers openai_objective_target load_default_datasets") + print( + " Example: run foundry --initializers openai_objective_target load_default_datasets") print() - print(" --initialization-scripts [ ...] (Alternative to --initializers)") + print( + " --initialization-scripts [ ...] (Alternative to --initializers)") print(f" {frontend_core.ARG_HELP['initialization_scripts']}") print(" Example: run foundry --initialization-scripts ./my_init.py") print() @@ -382,7 +399,8 @@ def do_help(self, arg: str) -> None: print() print(" --memory-labels ") print(f" {frontend_core.ARG_HELP['memory_labels']}") - print(' Example: run foundry --memory-labels \'{"env":"test"}\'') + print( + ' Example: run foundry --memory-labels \'{"env":"test"}\'') print() print("Start the shell like:") print(" pyrit_shell") @@ -461,7 +479,8 @@ def main() -> int: parser.add_argument( "--database", - choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], + choices=[frontend_core.IN_MEMORY, + frontend_core.SQLITE, frontend_core.AZURE_SQL], default=frontend_core.SQLITE, help=f"Default database type to use ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE}, can be overridden per-run)", ) @@ -487,7 +506,8 @@ def main() -> int: env_files = None if args.env_files: try: - env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=args.env_files) except ValueError as e: print(f"Error: {e}") return 1 From 1753576c7fefb779c85dfc9a4958d104da9a2797 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 5 Feb 2026 19:50:30 +0000 Subject: [PATCH 04/15] precedence --- pyrit/cli/frontend_core.py | 81 +++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b8d4b9b956..2637264be2 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -110,10 +110,12 @@ def __init__( # Extract values from config for internal use # Map snake_case db type back to PascalCase for backward compatibility - db_type_map = {"in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL} + db_type_map = {"in_memory": IN_MEMORY, + "sqlite": SQLITE, "azure_sql": AZURE_SQL} self._database = db_type_map[config.memory_db_type] self._initialization_scripts = config._resolve_initialization_scripts() - self._initializer_names = [ic.name for ic in config._initializer_configs] if config._initializer_configs else None + self._initializer_names = [ + ic.name for ic in config._initializer_configs] if config._initializer_configs else None self._env_files = config._resolve_env_files() # Log level comes from CLI arg (not in config file), default to WARNING @@ -169,7 +171,8 @@ def _load_and_merge_config( default_config_path = ConfigurationLoader.get_default_config_path() if default_config_path.exists(): try: - default_config = ConfigurationLoader.from_yaml_file(default_config_path) + default_config = ConfigurationLoader.from_yaml_file( + default_config_path) config_data["memory_db_type"] = default_config.memory_db_type config_data["initializers"] = [ {"name": ic.name, "args": ic.args} if ic.args else ic.name @@ -178,12 +181,14 @@ def _load_and_merge_config( config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files except Exception as e: - logger.warning(f"Failed to load default config file {default_config_path}: {e}") + logger.warning( + f"Failed to load default config file {default_config_path}: {e}") # 2. Load explicit config file if provided (overrides default) if config_file is not None: if not config_file.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_file}") + raise FileNotFoundError( + f"Configuration file not found: {config_file}") explicit_config = ConfigurationLoader.from_yaml_file(config_file) config_data["memory_db_type"] = explicit_config.memory_db_type config_data["initializers"] = [ @@ -205,7 +210,8 @@ def _load_and_merge_config( config_data["memory_db_type"] = normalized_db if initialization_scripts is not None: - config_data["initialization_scripts"] = [str(p) for p in initialization_scripts] + config_data["initialization_scripts"] = [ + str(p) for p in initialization_scripts] if initializer_names is not None: config_data["initializers"] = initializer_names @@ -213,7 +219,17 @@ def _load_and_merge_config( if env_files is not None: config_data["env_files"] = [str(p) for p in env_files] - return ConfigurationLoader.from_dict(config_data) + try: + return ConfigurationLoader.from_dict(config_data) + except ValueError as e: + # Re-raise with user-friendly message for CLI users + error_msg = str(e) + if "memory_db_type" in error_msg: + raise ValueError( + f"Invalid database type '{database}'. " + f"Must be one of: InMemory, SQLite, AzureSQL" + ) from e + raise async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" @@ -385,7 +401,8 @@ async def run_scenario_async( if scenario_class is None: available = ", ".join(context.scenario_registry.get_names()) - raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") + raise ValueError( + f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") # Build initialization kwargs (these go to initialize_async, not __init__) init_kwargs: dict[str, Any] = {} @@ -510,13 +527,15 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if scenario_metadata.aggregate_strategies: agg_strategies = scenario_metadata.aggregate_strategies print(" Aggregate Strategies:") - formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") + formatted = _format_wrapped_text( + text=", ".join(agg_strategies), indent=" - ") print(formatted) if scenario_metadata.all_strategies: strategies = scenario_metadata.all_strategies print(f" Available Strategies ({len(strategies)}):") - formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") + formatted = _format_wrapped_text( + text=", ".join(strategies), indent=" ") print(formatted) if scenario_metadata.default_strategy: @@ -528,7 +547,8 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if datasets: size_suffix = f", max {max_size} per dataset" if max_size else "" print(f" Default Datasets ({len(datasets)}{size_suffix}):") - formatted = _format_wrapped_text(text=", ".join(datasets), indent=" ") + formatted = _format_wrapped_text( + text=", ".join(datasets), indent=" ") print(formatted) else: print(" Default Datasets: None") @@ -555,7 +575,8 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) + print(_format_wrapped_text( + text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: @@ -573,7 +594,8 @@ def validate_database(*, database: str) -> str: """ valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] if database not in valid_databases: - raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") + raise ValueError( + f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") return database @@ -593,7 +615,8 @@ def validate_log_level(*, log_level: str) -> str: valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] level_upper = log_level.upper() if level_upper not in valid_levels: - raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + raise ValueError( + f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") return level_upper @@ -618,11 +641,13 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int """ # Reject boolean types explicitly (int(True) == 1, int(False) == 0) if isinstance(value, bool): - raise ValueError(f"{name} must be an integer string, got boolean: {value}") + raise ValueError( + f"{name} must be an integer string, got boolean: {value}") # Ensure value is a string if not isinstance(value, str): - raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") + raise ValueError( + f"{name} must be a string, got {type(value).__name__}: {value}") # Strip whitespace and validate it looks like an integer value = value.strip() @@ -635,7 +660,8 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int raise ValueError(f"{name} must be an integer, got: {value}") from e if min_value is not None and int_value < min_value: - raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") + raise ValueError( + f"{name} must be at least {min_value}, got: {int_value}") return int_value @@ -679,7 +705,8 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) if not params: - raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") + raise ValueError( + f"Validator function {validator_func.__name__} must have at least one parameter") first_param = params[0] def wrapper(value: Any) -> Any: @@ -692,7 +719,8 @@ def wrapper(value: Any) -> Any: raise ap.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging - wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") + wrapper.__name__ = getattr( + validator_func, "__name__", "argparse_validator") wrapper.__doc__ = getattr(validator_func, "__doc__", None) return wrapper @@ -752,7 +780,8 @@ def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: validate_database_argparse = _argparse_validator(validate_database) validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) -non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) +non_negative_int = _argparse_validator( + lambda v: validate_integer(v, min_value=0)) resolve_env_files_argparse = _argparse_validator(resolve_env_files) @@ -780,7 +809,8 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: # Validate all keys and values are strings for key, value in labels.items(): if not isinstance(key, str) or not isinstance(value, str): - raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") + raise ValueError( + f"All label keys and values must be strings. Got: {key}={value}") return labels @@ -949,13 +979,15 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) + result["max_concurrency"] = validate_integer( + parts[i], name="--max-concurrency", min_value=1) i += 1 elif parts[i] == "--max-retries": i += 1 if i >= len(parts): raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) + result["max_retries"] = validate_integer( + parts[i], name="--max-retries", min_value=0) i += 1 elif parts[i] == "--memory-labels": i += 1 @@ -986,7 +1018,8 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) + result["max_dataset_size"] = validate_integer( + parts[i], name="--max-dataset-size", min_value=1) i += 1 else: logger.warning(f"Unknown argument: {parts[i]}") From 3d1ff3ee3bc7e2ce1026aa5126ea0ed97640d95c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 5 Feb 2026 21:40:14 +0000 Subject: [PATCH 05/15] precommit --- pyrit/cli/frontend_core.py | 80 +++++++------------ pyrit/cli/pyrit_scan.py | 9 +-- pyrit/cli/pyrit_shell.py | 60 +++++--------- pyrit/common/path.py | 24 ++---- pyrit/setup/configuration_loader.py | 28 +++---- tests/unit/setup/test_configuration_loader.py | 32 +++----- 6 files changed, 83 insertions(+), 150 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 2637264be2..ef77818870 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -94,8 +94,6 @@ def __init__( ValueError: If database or log_level are invalid, or if config file is invalid. FileNotFoundError: If an explicitly specified config_file does not exist. """ - from pyrit.setup import ConfigurationLoader - # Load configuration from files and merge with CLI arguments config = self._load_and_merge_config( config_file=config_file, @@ -110,12 +108,12 @@ def __init__( # Extract values from config for internal use # Map snake_case db type back to PascalCase for backward compatibility - db_type_map = {"in_memory": IN_MEMORY, - "sqlite": SQLITE, "azure_sql": AZURE_SQL} + db_type_map = {"in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL} self._database = db_type_map[config.memory_db_type] self._initialization_scripts = config._resolve_initialization_scripts() - self._initializer_names = [ - ic.name for ic in config._initializer_configs] if config._initializer_configs else None + self._initializer_names = ( + [ic.name for ic in config._initializer_configs] if config._initializer_configs else None + ) self._env_files = config._resolve_env_files() # Log level comes from CLI arg (not in config file), default to WARNING @@ -156,6 +154,10 @@ def _load_and_merge_config( Returns: Merged ConfigurationLoader instance. + + Raises: + FileNotFoundError: If an explicitly specified config_file does not exist. + ValueError: If the database type is invalid. """ from pyrit.setup import ConfigurationLoader @@ -171,8 +173,7 @@ def _load_and_merge_config( default_config_path = ConfigurationLoader.get_default_config_path() if default_config_path.exists(): try: - default_config = ConfigurationLoader.from_yaml_file( - default_config_path) + default_config = ConfigurationLoader.from_yaml_file(default_config_path) config_data["memory_db_type"] = default_config.memory_db_type config_data["initializers"] = [ {"name": ic.name, "args": ic.args} if ic.args else ic.name @@ -181,14 +182,12 @@ def _load_and_merge_config( config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files except Exception as e: - logger.warning( - f"Failed to load default config file {default_config_path}: {e}") + logger.warning(f"Failed to load default config file {default_config_path}: {e}") # 2. Load explicit config file if provided (overrides default) if config_file is not None: if not config_file.exists(): - raise FileNotFoundError( - f"Configuration file not found: {config_file}") + raise FileNotFoundError(f"Configuration file not found: {config_file}") explicit_config = ConfigurationLoader.from_yaml_file(config_file) config_data["memory_db_type"] = explicit_config.memory_db_type config_data["initializers"] = [ @@ -210,8 +209,7 @@ def _load_and_merge_config( config_data["memory_db_type"] = normalized_db if initialization_scripts is not None: - config_data["initialization_scripts"] = [ - str(p) for p in initialization_scripts] + config_data["initialization_scripts"] = [str(p) for p in initialization_scripts] if initializer_names is not None: config_data["initializers"] = initializer_names @@ -226,8 +224,7 @@ def _load_and_merge_config( error_msg = str(e) if "memory_db_type" in error_msg: raise ValueError( - f"Invalid database type '{database}'. " - f"Must be one of: InMemory, SQLite, AzureSQL" + f"Invalid database type '{database}'. Must be one of: InMemory, SQLite, AzureSQL" ) from e raise @@ -401,8 +398,7 @@ async def run_scenario_async( if scenario_class is None: available = ", ".join(context.scenario_registry.get_names()) - raise ValueError( - f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") + raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") # Build initialization kwargs (these go to initialize_async, not __init__) init_kwargs: dict[str, Any] = {} @@ -527,15 +523,13 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if scenario_metadata.aggregate_strategies: agg_strategies = scenario_metadata.aggregate_strategies print(" Aggregate Strategies:") - formatted = _format_wrapped_text( - text=", ".join(agg_strategies), indent=" - ") + formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") print(formatted) if scenario_metadata.all_strategies: strategies = scenario_metadata.all_strategies print(f" Available Strategies ({len(strategies)}):") - formatted = _format_wrapped_text( - text=", ".join(strategies), indent=" ") + formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") print(formatted) if scenario_metadata.default_strategy: @@ -547,8 +541,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if datasets: size_suffix = f", max {max_size} per dataset" if max_size else "" print(f" Default Datasets ({len(datasets)}{size_suffix}):") - formatted = _format_wrapped_text( - text=", ".join(datasets), indent=" ") + formatted = _format_wrapped_text(text=", ".join(datasets), indent=" ") print(formatted) else: print(" Default Datasets: None") @@ -575,8 +568,7 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text( - text=initializer_metadata.class_description, indent=" ")) + print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: @@ -594,8 +586,7 @@ def validate_database(*, database: str) -> str: """ valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] if database not in valid_databases: - raise ValueError( - f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") + raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") return database @@ -615,8 +606,7 @@ def validate_log_level(*, log_level: str) -> str: valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] level_upper = log_level.upper() if level_upper not in valid_levels: - raise ValueError( - f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") return level_upper @@ -641,13 +631,11 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int """ # Reject boolean types explicitly (int(True) == 1, int(False) == 0) if isinstance(value, bool): - raise ValueError( - f"{name} must be an integer string, got boolean: {value}") + raise ValueError(f"{name} must be an integer string, got boolean: {value}") # Ensure value is a string if not isinstance(value, str): - raise ValueError( - f"{name} must be a string, got {type(value).__name__}: {value}") + raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") # Strip whitespace and validate it looks like an integer value = value.strip() @@ -660,8 +648,7 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int raise ValueError(f"{name} must be an integer, got: {value}") from e if min_value is not None and int_value < min_value: - raise ValueError( - f"{name} must be at least {min_value}, got: {int_value}") + raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") return int_value @@ -705,8 +692,7 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) if not params: - raise ValueError( - f"Validator function {validator_func.__name__} must have at least one parameter") + raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") first_param = params[0] def wrapper(value: Any) -> Any: @@ -719,8 +705,7 @@ def wrapper(value: Any) -> Any: raise ap.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging - wrapper.__name__ = getattr( - validator_func, "__name__", "argparse_validator") + wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") wrapper.__doc__ = getattr(validator_func, "__doc__", None) return wrapper @@ -780,8 +765,7 @@ def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: validate_database_argparse = _argparse_validator(validate_database) validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) -non_negative_int = _argparse_validator( - lambda v: validate_integer(v, min_value=0)) +non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) resolve_env_files_argparse = _argparse_validator(resolve_env_files) @@ -809,8 +793,7 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: # Validate all keys and values are strings for key, value in labels.items(): if not isinstance(key, str) or not isinstance(value, str): - raise ValueError( - f"All label keys and values must be strings. Got: {key}={value}") + raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") return labels @@ -979,15 +962,13 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer( - parts[i], name="--max-concurrency", min_value=1) + result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) i += 1 elif parts[i] == "--max-retries": i += 1 if i >= len(parts): raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer( - parts[i], name="--max-retries", min_value=0) + result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) i += 1 elif parts[i] == "--memory-labels": i += 1 @@ -1018,8 +999,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer( - parts[i], name="--max-dataset-size", min_value=1) + result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) i += 1 else: logger.warning(f"Unknown argument: {parts[i]}") diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index cf27df6840..ba025c61d7 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -186,8 +186,7 @@ def main(args: Optional[list[str]] = None) -> int: env_files = None if parsed_args.env_files: try: - env_files = frontend_core.resolve_env_files( - env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) except ValueError as e: print(f"Error: {e}") return 1 @@ -229,8 +228,7 @@ def main(args: Optional[list[str]] = None) -> int: # Collect environment files env_files = None if parsed_args.env_files: - env_files = frontend_core.resolve_env_files( - env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) # Create context with initializers context = frontend_core.FrontendCore( @@ -245,8 +243,7 @@ def main(args: Optional[list[str]] = None) -> int: # Parse memory labels if provided memory_labels = None if parsed_args.memory_labels: - memory_labels = frontend_core.parse_memory_labels( - json_string=parsed_args.memory_labels) + memory_labels = frontend_core.parse_memory_labels(json_string=parsed_args.memory_labels) # Run scenario asyncio.run( diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index a18b4ab145..46ccaab73e 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -105,8 +105,7 @@ def __init__( self._scenario_history: list[tuple[str, ScenarioResult]] = [] # Initialize PyRIT in background thread for faster startup - self._init_thread = threading.Thread( - target=self._background_init, daemon=True) + self._init_thread = threading.Thread(target=self._background_init, daemon=True) self._init_complete = threading.Event() self._init_thread.start() @@ -126,8 +125,7 @@ def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" self._ensure_initialized() try: - asyncio.run(frontend_core.print_scenarios_list_async( - context=self.context)) + asyncio.run(frontend_core.print_scenarios_list_async(context=self.context)) except Exception as e: print(f"Error listing scenarios: {e}") @@ -138,8 +136,7 @@ def do_list_initializers(self, arg: str) -> None: # Discover from scenarios directory by default (same as scan) discovery_path = frontend_core.get_default_initializer_discovery_path() asyncio.run( - frontend_core.print_initializers_list_async( - context=self.context, discovery_path=discovery_path) + frontend_core.print_initializers_list_async(context=self.context, discovery_path=discovery_path) ) except Exception as e: print(f"Error listing initializers: {e}") @@ -182,19 +179,14 @@ def do_run(self, line: str) -> None: print("\nUsage: run [options]") print("\nNote: Every scenario requires an initializer.") print("\nOptions:") - print( - f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") + print(f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") print( f" --initialization-scripts <...> {frontend_core.ARG_HELP['initialization_scripts']} (alternative to --initializers)" ) - print( - f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") - print( - f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") - print( - f" --max-retries {frontend_core.ARG_HELP['max_retries']}") - print( - f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") + print(f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") + print(f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") + print(f" --max-retries {frontend_core.ARG_HELP['max_retries']}") + print(f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") print( f" --database Override default database ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" ) @@ -202,8 +194,7 @@ def do_run(self, line: str) -> None: f" --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" ) print("\nExample:") - print( - " run foundry --initializers openai_objective_target load_default_datasets") + print(" run foundry --initializers openai_objective_target load_default_datasets") print("\nType 'help run' for more details and examples") return @@ -229,8 +220,7 @@ def do_run(self, line: str) -> None: resolved_env_files = None if args["env_files"]: try: - resolved_env_files = frontend_core.resolve_env_files( - env_file_paths=args["env_files"]) + resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) except ValueError as e: print(f"Error: {e}") return @@ -293,8 +283,7 @@ def do_scenario_history(self, arg: str) -> None: print(f"{idx}) {command}") print("=" * 80) print(f"\nTotal runs: {len(self._scenario_history)}") - print( - "\nUse 'print-scenario ' to view detailed results for a specific run.") + print("\nUse 'print-scenario ' to view detailed results for a specific run.") print("Use 'print-scenario' to view detailed results for all runs.") def do_print_scenario(self, arg: str) -> None: @@ -336,8 +325,7 @@ def do_print_scenario(self, arg: str) -> None: try: scenario_num = int(arg) if scenario_num < 1 or scenario_num > len(self._scenario_history): - print( - f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") + print(f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") return command, result = self._scenario_history[scenario_num - 1] @@ -350,8 +338,7 @@ def do_print_scenario(self, arg: str) -> None: printer = ConsoleScenarioResultPrinter() asyncio.run(printer.print_summary_async(result)) except ValueError: - print( - f"Error: Invalid scenario number '{arg}'. Must be an integer.") + print(f"Error: Invalid scenario number '{arg}'. Must be an integer.") def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" @@ -364,14 +351,12 @@ def do_help(self, arg: str) -> None: print(" --database ") print(" Default database type: InMemory, SQLite, or AzureSQL") print(" Default: SQLite") - print( - " Can be overridden per-run with 'run --database '") + print(" Can be overridden per-run with 'run --database '") print() print(" --log-level ") print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") print(" Default: WARNING") - print( - " Can be overridden per-run with 'run --log-level '") + print(" Can be overridden per-run with 'run --log-level '") print() print("=" * 70) print("Run Command Options (specified when running scenarios):") @@ -379,11 +364,9 @@ def do_help(self, arg: str) -> None: print(" --initializers [ ...] (REQUIRED)") print(f" {frontend_core.ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") - print( - " Example: run foundry --initializers openai_objective_target load_default_datasets") + print(" Example: run foundry --initializers openai_objective_target load_default_datasets") print() - print( - " --initialization-scripts [ ...] (Alternative to --initializers)") + print(" --initialization-scripts [ ...] (Alternative to --initializers)") print(f" {frontend_core.ARG_HELP['initialization_scripts']}") print(" Example: run foundry --initialization-scripts ./my_init.py") print() @@ -399,8 +382,7 @@ def do_help(self, arg: str) -> None: print() print(" --memory-labels ") print(f" {frontend_core.ARG_HELP['memory_labels']}") - print( - ' Example: run foundry --memory-labels \'{"env":"test"}\'') + print(' Example: run foundry --memory-labels \'{"env":"test"}\'') print() print("Start the shell like:") print(" pyrit_shell") @@ -479,8 +461,7 @@ def main() -> int: parser.add_argument( "--database", - choices=[frontend_core.IN_MEMORY, - frontend_core.SQLITE, frontend_core.AZURE_SQL], + choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], default=frontend_core.SQLITE, help=f"Default database type to use ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE}, can be overridden per-run)", ) @@ -506,8 +487,7 @@ def main() -> int: env_files = None if args.env_files: try: - env_files = frontend_core.resolve_env_files( - env_file_paths=args.env_files) + env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files) except ValueError as e: print(f"Error: {e}") return 1 diff --git a/pyrit/common/path.py b/pyrit/common/path.py index b61eb09d91..fcdc4c92ec 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -58,30 +58,22 @@ def in_git_repo() -> bool: DATASETS_PATH = pathlib.Path(PYRIT_PATH, "datasets").resolve() EXECUTOR_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "executors").resolve() -EXECUTOR_RED_TEAM_PATH = pathlib.Path( - EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() -EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path( - EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() -CONVERTER_SEED_PROMPT_PATH = pathlib.Path( - DATASETS_PATH, "prompt_converters").resolve() +EXECUTOR_RED_TEAM_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() +EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() +CONVERTER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "prompt_converters").resolve() SCORER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "score").resolve() -SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path( - SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() +SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() SCORER_LIKERT_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "likert").resolve() SCORER_SCALES_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "scales").resolve() HARM_DEFINITION_PATH = pathlib.Path(DATASETS_PATH, "harm_definition").resolve() -JAILBREAK_TEMPLATES_PATH = pathlib.Path( - DATASETS_PATH, "jailbreak", "templates").resolve() +JAILBREAK_TEMPLATES_PATH = pathlib.Path(DATASETS_PATH, "jailbreak", "templates").resolve() SCORER_EVALS_PATH = pathlib.Path(DATASETS_PATH, "scorer_evals").resolve() SCORER_EVALS_HARM_PATH = pathlib.Path(SCORER_EVALS_PATH, "harm").resolve() -SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path( - SCORER_EVALS_PATH, "objective").resolve() -SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path( - SCORER_EVALS_PATH, "refusal_scorer").resolve() -SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path( - SCORER_EVALS_PATH, "true_false").resolve() +SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path(SCORER_EVALS_PATH, "objective").resolve() +SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path(SCORER_EVALS_PATH, "refusal_scorer").resolve() +SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path(SCORER_EVALS_PATH, "true_false").resolve() SCORER_EVALS_LIKERT_PATH = pathlib.Path(SCORER_EVALS_PATH, "likert").resolve() diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 2da6eda6bc..0f08ee526d 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -88,8 +88,7 @@ class ConfigurationLoader(YamlLoadable): """ memory_db_type: str = "sqlite" - initializers: List[Union[str, Dict[str, Any]] - ] = field(default_factory=list) + initializers: List[Union[str, Dict[str, Any]]] = field(default_factory=list) initialization_scripts: List[str] = field(default_factory=list) env_files: List[str] = field(default_factory=list) silent: bool = False @@ -106,6 +105,9 @@ def _normalize_memory_db_type(self) -> None: Converts the input to lowercase snake_case and validates against known types. Stores the normalized snake_case value for config consistency, but maps to internal constants when initializing. + + Raises: + ValueError: If the memory_db_type is not a valid database type. """ # Normalize to lowercase normalized = self.memory_db_type.lower().replace("-", "_") @@ -118,8 +120,7 @@ def _normalize_memory_db_type(self) -> None: if normalized not in _MEMORY_DB_TYPE_MAP: valid_types = list(_MEMORY_DB_TYPE_MAP.keys()) raise ValueError( - f"Invalid memory_db_type '{self.memory_db_type}'. " - f"Must be one of: {', '.join(valid_types)}" + f"Invalid memory_db_type '{self.memory_db_type}'. Must be one of: {', '.join(valid_types)}" ) # Store normalized snake_case value @@ -130,6 +131,9 @@ def _normalize_initializers(self) -> None: Normalize initializer entries to InitializerConfig objects. Converts initializer names to snake_case for consistent registry lookup. + + Raises: + ValueError: If an initializer entry is missing a 'name' field or has an invalid type. """ normalized: List[InitializerConfig] = [] for entry in self.initializers: @@ -140,9 +144,7 @@ def _normalize_initializers(self) -> None: elif isinstance(entry, dict): # Dict entry: name and optional args if "name" not in entry: - raise ValueError( - f"Initializer configuration must have a 'name' field. Got: {entry}" - ) + raise ValueError(f"Initializer configuration must have a 'name' field. Got: {entry}") name = class_name_to_snake_case(entry["name"]) normalized.append( InitializerConfig( @@ -151,9 +153,7 @@ def _normalize_initializers(self) -> None: ) ) else: - raise ValueError( - f"Initializer entry must be a string or dict, got: {type(entry).__name__}" - ) + raise ValueError(f"Initializer entry must be a string or dict, got: {type(entry).__name__}") self._initializer_configs = normalized @classmethod @@ -168,10 +168,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": A new ConfigurationLoader instance. """ # Filter out None values and empty lists to use defaults - filtered_data = { - k: v for k, v in data.items() - if v is not None and v != [] - } + filtered_data = {k: v for k, v in data.items() if v is not None and v != []} return cls(**filtered_data) @classmethod @@ -211,8 +208,7 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: if initializer_class is None: available = ", ".join(sorted(registry.get_names())) raise ValueError( - f"Initializer '{config.name}' not found in registry.\n" - f"Available initializers: {available}" + f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}" ) # Instantiate with args if provided diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index d652969e36..eb6cf94b0a 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -89,14 +89,10 @@ def test_initializer_as_dict_with_name_only(self): def test_initializer_as_dict_with_args(self): """Test initializers specified as dicts with name and args.""" - config = ConfigurationLoader( - initializers=[{"name": "custom", "args": { - "param1": "value1", "param2": 42}}] - ) + config = ConfigurationLoader(initializers=[{"name": "custom", "args": {"param1": "value1", "param2": 42}}]) assert len(config._initializer_configs) == 1 assert config._initializer_configs[0].name == "custom" - assert config._initializer_configs[0].args == { - "param1": "value1", "param2": 42} + assert config._initializer_configs[0].args == {"param1": "value1", "param2": 42} def test_mixed_initializer_formats(self): """Test initializers with mixed string and dict formats.""" @@ -115,22 +111,19 @@ def test_mixed_initializer_formats(self): def test_initializer_name_normalization_from_pascal_case(self): """Test that PascalCase initializer names are normalized to snake_case.""" - config = ConfigurationLoader( - initializers=["SimpleInitializer", "AIRTInitializer"]) + config = ConfigurationLoader(initializers=["SimpleInitializer", "AIRTInitializer"]) assert config._initializer_configs[0].name == "simple_initializer" assert config._initializer_configs[1].name == "airt_initializer" def test_initializer_name_normalization_preserves_snake_case(self): """Test that snake_case names are preserved.""" - config = ConfigurationLoader( - initializers=["simple_initializer", "airt_init"]) + config = ConfigurationLoader(initializers=["simple_initializer", "airt_init"]) assert config._initializer_configs[0].name == "simple_initializer" assert config._initializer_configs[1].name == "airt_init" def test_initializer_name_already_snake_case(self): """Test that snake_case names remain unchanged.""" - config = ConfigurationLoader( - initializers=["load_default_datasets", "objective_list"]) + config = ConfigurationLoader(initializers=["load_default_datasets", "objective_list"]) assert config._initializer_configs[0].name == "load_default_datasets" assert config._initializer_configs[1].name == "objective_list" @@ -214,8 +207,7 @@ def test_resolve_initialization_scripts_empty(self): def test_resolve_initialization_scripts_absolute_path(self): """Test resolving absolute script paths.""" - config = ConfigurationLoader(initialization_scripts=[ - "/absolute/path/script.py"]) + config = ConfigurationLoader(initialization_scripts=["/absolute/path/script.py"]) resolved = config._resolve_initialization_scripts() assert resolved is not None assert len(resolved) == 1 @@ -223,8 +215,7 @@ def test_resolve_initialization_scripts_absolute_path(self): def test_resolve_initialization_scripts_relative_path(self): """Test resolving relative script paths (converted to absolute).""" - config = ConfigurationLoader( - initialization_scripts=["relative/script.py"]) + config = ConfigurationLoader(initialization_scripts=["relative/script.py"]) resolved = config._resolve_initialization_scripts() assert resolved is not None assert len(resolved) == 1 @@ -327,8 +318,7 @@ async def test_initialize_from_config_with_path(self, mock_init, mock_from_yaml) result = await initialize_from_config_async("/path/to/config.yaml") - mock_from_yaml.assert_called_once_with( - pathlib.Path("/path/to/config.yaml")) + mock_from_yaml.assert_called_once_with(pathlib.Path("/path/to/config.yaml")) mock_init.assert_called_once() assert result is mock_config @@ -354,11 +344,9 @@ async def test_initialize_from_config_default_path(self, mock_init, mock_from_ya """Test initialize_from_config_async uses default path when none specified.""" mock_config = ConfigurationLoader() mock_from_yaml.return_value = mock_config - mock_default_path.return_value = pathlib.Path( - "/default/path/.pyrit_conf") + mock_default_path.return_value = pathlib.Path("/default/path/.pyrit_conf") await initialize_from_config_async() mock_default_path.assert_called_once() - mock_from_yaml.assert_called_once_with( - pathlib.Path("/default/path/.pyrit_conf")) + mock_from_yaml.assert_called_once_with(pathlib.Path("/default/path/.pyrit_conf")) From 930f7ecc5bae241e1787d8ff0618535d7c2cc26a Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 5 Feb 2026 21:49:47 +0000 Subject: [PATCH 06/15] unit tests precommit --- pyrit/cli/frontend_core.py | 4 ++-- pyrit/cli/pyrit_shell.py | 11 ++++++----- tests/unit/cli/test_frontend_core.py | 5 ++++- tests/unit/setup/test_configuration_loader.py | 9 ++++++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index ef77818870..c2d0cade83 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -19,7 +19,7 @@ import logging import sys from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence try: import termcolor @@ -162,7 +162,7 @@ def _load_and_merge_config( from pyrit.setup import ConfigurationLoader # Start with defaults - config_data: dict = { + config_data: Dict[str, Any] = { "memory_db_type": "sqlite", "initializers": [], "initialization_scripts": [], diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 46ccaab73e..aadaa01ab7 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -14,7 +14,8 @@ import cmd import sys import threading -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from pyrit.models.scenario_result import ScenarioResult @@ -217,16 +218,16 @@ def do_run(self, line: str) -> None: return # Resolve env files if provided - resolved_env_files = None + resolved_env_files: Optional[list[Path]] = None if args["env_files"]: try: - resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) + resolved_env_files = list(frontend_core.resolve_env_files(env_file_paths=args["env_files"])) except ValueError as e: print(f"Error: {e}") return else: # Use default env files from shell startup - resolved_env_files = self.default_env_files + resolved_env_files = list(self.default_env_files) if self.default_env_files else None # Create a context for this run with overrides run_context = frontend_core.FrontendCore( @@ -455,7 +456,7 @@ def main() -> int: parser.add_argument( "--config-file", - type=frontend_core.Path, + type=Path, help=frontend_core.ARG_HELP["config_file"], ) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index fca9dd6439..3ed4e9e81c 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -40,7 +40,10 @@ def test_init_with_all_parameters(self): ) assert context._database == frontend_core.IN_MEMORY - assert context._initialization_scripts == scripts + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert context._initialization_scripts is not None + assert len(context._initialization_scripts) == 1 + assert context._initialization_scripts[0].parts[-2:] == ("test", "script.py") assert context._initializer_names == initializers assert context._log_level == "DEBUG" diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index eb6cf94b0a..a55dab6279 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -211,7 +211,8 @@ def test_resolve_initialization_scripts_absolute_path(self): resolved = config._resolve_initialization_scripts() assert resolved is not None assert len(resolved) == 1 - assert resolved[0] == pathlib.Path("/absolute/path/script.py") + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert resolved[0].parts[-3:] == ("absolute", "path", "script.py") def test_resolve_initialization_scripts_relative_path(self): """Test resolving relative script paths (converted to absolute).""" @@ -220,7 +221,8 @@ def test_resolve_initialization_scripts_relative_path(self): assert resolved is not None assert len(resolved) == 1 assert resolved[0].is_absolute() - assert str(resolved[0]).endswith("relative/script.py") + # Check path ends with expected components (works on both Unix and Windows) + assert resolved[0].parts[-2:] == ("relative", "script.py") def test_resolve_env_files_empty(self): """Test that empty env files returns None.""" @@ -233,7 +235,8 @@ def test_resolve_env_files_absolute_path(self): resolved = config._resolve_env_files() assert resolved is not None assert len(resolved) == 1 - assert resolved[0] == pathlib.Path("/path/to/.env") + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert resolved[0].parts[-3:] == ("path", "to", ".env") @pytest.mark.usefixtures("patch_central_database") From 2fdd416187cf3a9c3974f6affab74c3fd6b0e2db Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 5 Feb 2026 23:51:40 +0000 Subject: [PATCH 07/15] fixes --- pyrit/cli/frontend_core.py | 146 ++++++---------------------- pyrit/setup/configuration_loader.py | 96 ++++++++++++++++++ 2 files changed, 124 insertions(+), 118 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index c2d0cade83..69e9d0239a 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -19,7 +19,7 @@ import logging import sys from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence try: import termcolor @@ -46,7 +46,6 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i ScenarioMetadata, ScenarioRegistry, ) - from pyrit.setup import ConfigurationLoader logger = logging.getLogger(__name__) @@ -83,7 +82,7 @@ def __init__( 3. Individual CLI arguments (database, initializers, etc.) Args: - config_file: Optional path to a YAML configuration file. + config_file: Optional path to a YAML-formatted configuration file.\n The file uses .pyrit_conf extension but is YAML format. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. @@ -94,32 +93,43 @@ def __init__( ValueError: If database or log_level are invalid, or if config file is invalid. FileNotFoundError: If an explicitly specified config_file does not exist. """ - # Load configuration from files and merge with CLI arguments - config = self._load_and_merge_config( - config_file=config_file, - database=database, - initialization_scripts=initialization_scripts, - initializer_names=initializer_names, - env_files=env_files, - ) + from pyrit.setup import ConfigurationLoader + from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP + + # Validate log level early + effective_log_level = log_level if log_level is not None else "WARNING" + self._log_level = validate_log_level(log_level=effective_log_level) + + # Load configuration using ConfigurationLoader.load_with_overrides + try: + config = ConfigurationLoader.load_with_overrides( + config_file=config_file, + memory_db_type=database, + initializers=initializer_names, + initialization_scripts=[str(p) for p in initialization_scripts] if initialization_scripts else None, + env_files=[str(p) for p in env_files] if env_files else None, + ) + except ValueError as e: + # Re-raise with user-friendly message for CLI users + error_msg = str(e) + if "memory_db_type" in error_msg: + raise ValueError( + f"Invalid database type '{database}'. Must be one of: InMemory, SQLite, AzureSQL" + ) from e + raise # Store the merged configuration self._config = config # Extract values from config for internal use - # Map snake_case db type back to PascalCase for backward compatibility - db_type_map = {"in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL} - self._database = db_type_map[config.memory_db_type] + # Use canonical mapping from configuration_loader + self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] self._initialization_scripts = config._resolve_initialization_scripts() self._initializer_names = ( [ic.name for ic in config._initializer_configs] if config._initializer_configs else None ) self._env_files = config._resolve_env_files() - # Log level comes from CLI arg (not in config file), default to WARNING - effective_log_level = log_level if log_level is not None else "WARNING" - self._log_level = validate_log_level(log_level=effective_log_level) - # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None self._initializer_registry: Optional[InitializerRegistry] = None @@ -128,106 +138,6 @@ def __init__( # Configure logging logging.basicConfig(level=getattr(logging, self._log_level)) - def _load_and_merge_config( - self, - *, - config_file: Optional[Path], - database: Optional[str], - initialization_scripts: Optional[list[Path]], - initializer_names: Optional[list[str]], - env_files: Optional[list[Path]], - ) -> "ConfigurationLoader": - """ - Load configuration from files and merge with CLI arguments. - - Precedence (later overrides earlier): - 1. Default config file (~/.pyrit/.pyrit_conf) if it exists - 2. Explicit config_file argument if provided - 3. Individual CLI arguments - - Args: - config_file: Optional explicit config file path. - database: Optional database type from CLI. - initialization_scripts: Optional scripts from CLI. - initializer_names: Optional initializer names from CLI. - env_files: Optional env files from CLI. - - Returns: - Merged ConfigurationLoader instance. - - Raises: - FileNotFoundError: If an explicitly specified config_file does not exist. - ValueError: If the database type is invalid. - """ - from pyrit.setup import ConfigurationLoader - - # Start with defaults - config_data: Dict[str, Any] = { - "memory_db_type": "sqlite", - "initializers": [], - "initialization_scripts": [], - "env_files": [], - } - - # 1. Try loading default config file if it exists - default_config_path = ConfigurationLoader.get_default_config_path() - if default_config_path.exists(): - try: - default_config = ConfigurationLoader.from_yaml_file(default_config_path) - config_data["memory_db_type"] = default_config.memory_db_type - config_data["initializers"] = [ - {"name": ic.name, "args": ic.args} if ic.args else ic.name - for ic in default_config._initializer_configs - ] - config_data["initialization_scripts"] = default_config.initialization_scripts - config_data["env_files"] = default_config.env_files - except Exception as e: - logger.warning(f"Failed to load default config file {default_config_path}: {e}") - - # 2. Load explicit config file if provided (overrides default) - if config_file is not None: - if not config_file.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_file}") - explicit_config = ConfigurationLoader.from_yaml_file(config_file) - config_data["memory_db_type"] = explicit_config.memory_db_type - config_data["initializers"] = [ - {"name": ic.name, "args": ic.args} if ic.args else ic.name - for ic in explicit_config._initializer_configs - ] - config_data["initialization_scripts"] = explicit_config.initialization_scripts - config_data["env_files"] = explicit_config.env_files - - # 3. Apply CLI overrides (non-None values take precedence) - if database is not None: - # Normalize to snake_case for ConfigurationLoader - normalized_db = database.lower().replace("-", "_") - # Handle PascalCase inputs - if normalized_db == "inmemory": - normalized_db = "in_memory" - elif normalized_db == "azuresql": - normalized_db = "azure_sql" - config_data["memory_db_type"] = normalized_db - - if initialization_scripts is not None: - config_data["initialization_scripts"] = [str(p) for p in initialization_scripts] - - if initializer_names is not None: - config_data["initializers"] = initializer_names - - if env_files is not None: - config_data["env_files"] = [str(p) for p in env_files] - - try: - return ConfigurationLoader.from_dict(config_data) - except ValueError as e: - # Re-raise with user-friendly message for CLI users - error_msg = str(e) - if "memory_db_type" in error_msg: - raise ValueError( - f"Invalid database type '{database}'. Must be one of: InMemory, SQLite, AzureSQL" - ) from e - raise - async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" if self._initialized: diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 0f08ee526d..43781482a6 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -171,6 +171,102 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": filtered_data = {k: v for k, v in data.items() if v is not None and v != []} return cls(**filtered_data) + @staticmethod + def load_with_overrides( + config_file: Optional[pathlib.Path] = None, + *, + memory_db_type: Optional[str] = None, + initializers: Optional[Sequence[Union[str, Dict[str, Any]]]] = None, + initialization_scripts: Optional[Sequence[str]] = None, + env_files: Optional[Sequence[str]] = None, + ) -> "ConfigurationLoader": + """ + Load configuration with optional overrides. + + This factory method implements a 3-layer configuration precedence: + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual override arguments (non-None values take precedence) + + This is a staticmethod (not classmethod) because it's a pure factory function + that doesn't need access to class state and can be reused by multiple interfaces + (CLI, shell, programmatic API). + + Args: + config_file: Optional path to a YAML-formatted configuration file. + memory_db_type: Override for database type (in_memory, sqlite, azure_sql). + initializers: Override for initializer list. + initialization_scripts: Override for initialization script paths. + env_files: Override for environment file paths. + + Returns: + A merged ConfigurationLoader instance. + + Raises: + FileNotFoundError: If an explicitly specified config_file does not exist. + ValueError: If the configuration is invalid. + """ + import logging + + logger = logging.getLogger(__name__) + + # Start with defaults + config_data: Dict[str, Any] = { + "memory_db_type": "sqlite", + "initializers": [], + "initialization_scripts": [], + "env_files": [], + } + + # 1. Try loading default config file if it exists + default_config_path = DEFAULT_CONFIG_PATH + if default_config_path.exists(): + try: + default_config = ConfigurationLoader.from_yaml_file(default_config_path) + config_data["memory_db_type"] = default_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in default_config._initializer_configs + ] + config_data["initialization_scripts"] = default_config.initialization_scripts + config_data["env_files"] = default_config.env_files + except Exception as e: + logger.warning(f"Failed to load default config file {default_config_path}: {e}") + + # 2. Load explicit config file if provided (overrides default) + if config_file is not None: + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + explicit_config = ConfigurationLoader.from_yaml_file(config_file) + config_data["memory_db_type"] = explicit_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in explicit_config._initializer_configs + ] + config_data["initialization_scripts"] = explicit_config.initialization_scripts + config_data["env_files"] = explicit_config.env_files + + # 3. Apply overrides (non-None values take precedence) + if memory_db_type is not None: + # Normalize to snake_case + normalized_db = memory_db_type.lower().replace("-", "_") + if normalized_db == "inmemory": + normalized_db = "in_memory" + elif normalized_db == "azuresql": + normalized_db = "azure_sql" + config_data["memory_db_type"] = normalized_db + + if initializers is not None: + config_data["initializers"] = initializers + + if initialization_scripts is not None: + config_data["initialization_scripts"] = initialization_scripts + + if env_files is not None: + config_data["env_files"] = env_files + + return ConfigurationLoader.from_dict(config_data) + @classmethod def get_default_config_path(cls) -> pathlib.Path: """ From 761e94ea844ee9c7dfd7a7917909092a9101c643 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 00:45:49 +0000 Subject: [PATCH 08/15] devcontainer and conf --- .devcontainer/devcontainer_setup.sh | 15 +++++++++++++++ .pyrit_conf_example | 9 ++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/.devcontainer/devcontainer_setup.sh b/.devcontainer/devcontainer_setup.sh index 822b850f4a..d5107b3d19 100644 --- a/.devcontainer/devcontainer_setup.sh +++ b/.devcontainer/devcontainer_setup.sh @@ -3,6 +3,21 @@ set -e MYPY_CACHE="/workspace/.mypy_cache" VIRTUAL_ENV="/opt/venv" +PYRIT_CONFIG_DIR="/home/vscode/.pyrit" +PYRIT_CONFIG_FILE="$PYRIT_CONFIG_DIR/.pyrit_conf" + +# Create the .pyrit config directory and copy example config if not exists +if [ ! -d "$PYRIT_CONFIG_DIR" ]; then + echo "Creating PyRIT config directory..." + mkdir -p "$PYRIT_CONFIG_DIR" +fi + +if [ ! -f "$PYRIT_CONFIG_FILE" ] && [ -f "/workspace/.pyrit_conf_example" ]; then + echo "Copying example PyRIT config file..." + cp /workspace/.pyrit_conf_example "$PYRIT_CONFIG_FILE" + echo "✅ Created $PYRIT_CONFIG_FILE from example. Edit as needed." +fi + # Create the mypy cache directory if it doesn't exist if [ ! -d "$MYPY_CACHE" ]; then echo "Creating mypy cache directory..." diff --git a/.pyrit_conf_example b/.pyrit_conf_example index 70c60f74fd..bf3af0f5c4 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -1,6 +1,7 @@ # PyRIT Configuration File Example # ================================ -# Copy this file to ~/.pyrit/.pyrit_conf or specify a custom path when loading. +# This is a YAML-formatted configuration file. Copy to ~/.pyrit/.pyrit_conf +# or specify a custom path when loading via --config-file. # # For documentation on configuration options, see: # https://github.com/Azure/PyRIT/blob/main/doc/setup/configuration.md @@ -38,12 +39,13 @@ memory_db_type: sqlite # args: # some_param: value initializers: - - simple + - simple # Initialization Scripts # ---------------------- # List of paths to custom Python scripts containing PyRITInitializer subclasses. # Paths can be absolute or relative to the current working directory. +# Set to [] to explicitly load no scripts. Comment out to use defaults. # # Example: # initialization_scripts: @@ -55,7 +57,8 @@ initialization_scripts: [] # ----------------- # List of .env file paths to load during initialization. # Later files override values from earlier files. -# If not specified, PyRIT loads ~/.pyrit/.env and ~/.pyrit/.env.local by default. +# Set to [] to explicitly load no environment files. +# Comment out this field to use default behavior. # # Example: # env_files: From 3e0597424e3c97d985f809fd1f498863cf3c55aa Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 19:29:01 +0000 Subject: [PATCH 09/15] empty list behavior --- .pyrit_conf_example | 15 +++-- pyrit/cli/frontend_core.py | 57 ++++++++++++------- pyrit/setup/configuration_loader.py | 38 +++++++++---- tests/unit/setup/test_configuration_loader.py | 26 +++++++-- 4 files changed, 95 insertions(+), 41 deletions(-) diff --git a/.pyrit_conf_example b/.pyrit_conf_example index bf3af0f5c4..46014434f8 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -45,26 +45,31 @@ initializers: # ---------------------- # List of paths to custom Python scripts containing PyRITInitializer subclasses. # Paths can be absolute or relative to the current working directory. -# Set to [] to explicitly load no scripts. Comment out to use defaults. +# +# Behavior: +# - Omit this field (or set to null): No custom scripts loaded (default) +# - Set to []: Explicitly load no scripts (same as omitting) +# - Set to list of paths: Load the specified scripts # # Example: # initialization_scripts: # - /path/to/my_custom_initializer.py # - ./local_initializer.py -initialization_scripts: [] # Environment Files # ----------------- # List of .env file paths to load during initialization. # Later files override values from earlier files. -# Set to [] to explicitly load no environment files. -# Comment out this field to use default behavior. +# +# Behavior: +# - Omit this field (or set to null): Load default .env and .env.local from ~/.pyrit/ if they exist +# - Set to []: Explicitly load NO environment files +# - Set to list of paths: Load only the specified files # # Example: # env_files: # - /path/to/.env # - /path/to/.env.local -env_files: [] # Silent Mode # ----------- diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 69e9d0239a..4122a40f32 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -82,7 +82,8 @@ def __init__( 3. Individual CLI arguments (database, initializers, etc.) Args: - config_file: Optional path to a YAML-formatted configuration file.\n The file uses .pyrit_conf extension but is YAML format. + config_file: Optional path to a YAML-formatted configuration file. + The file uses .pyrit_conf extension but is YAML format. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. @@ -106,7 +107,8 @@ def __init__( config_file=config_file, memory_db_type=database, initializers=initializer_names, - initialization_scripts=[str(p) for p in initialization_scripts] if initialization_scripts else None, + initialization_scripts=[ + str(p) for p in initialization_scripts] if initialization_scripts else None, env_files=[str(p) for p in env_files] if env_files else None, ) except ValueError as e: @@ -308,7 +310,8 @@ async def run_scenario_async( if scenario_class is None: available = ", ".join(context.scenario_registry.get_names()) - raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") + raise ValueError( + f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") # Build initialization kwargs (these go to initialize_async, not __init__) init_kwargs: dict[str, Any] = {} @@ -433,13 +436,15 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if scenario_metadata.aggregate_strategies: agg_strategies = scenario_metadata.aggregate_strategies print(" Aggregate Strategies:") - formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") + formatted = _format_wrapped_text( + text=", ".join(agg_strategies), indent=" - ") print(formatted) if scenario_metadata.all_strategies: strategies = scenario_metadata.all_strategies print(f" Available Strategies ({len(strategies)}):") - formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") + formatted = _format_wrapped_text( + text=", ".join(strategies), indent=" ") print(formatted) if scenario_metadata.default_strategy: @@ -451,7 +456,8 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if datasets: size_suffix = f", max {max_size} per dataset" if max_size else "" print(f" Default Datasets ({len(datasets)}{size_suffix}):") - formatted = _format_wrapped_text(text=", ".join(datasets), indent=" ") + formatted = _format_wrapped_text( + text=", ".join(datasets), indent=" ") print(formatted) else: print(" Default Datasets: None") @@ -478,7 +484,8 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) + print(_format_wrapped_text( + text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: @@ -496,7 +503,8 @@ def validate_database(*, database: str) -> str: """ valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] if database not in valid_databases: - raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") + raise ValueError( + f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") return database @@ -516,7 +524,8 @@ def validate_log_level(*, log_level: str) -> str: valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] level_upper = log_level.upper() if level_upper not in valid_levels: - raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + raise ValueError( + f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") return level_upper @@ -541,11 +550,13 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int """ # Reject boolean types explicitly (int(True) == 1, int(False) == 0) if isinstance(value, bool): - raise ValueError(f"{name} must be an integer string, got boolean: {value}") + raise ValueError( + f"{name} must be an integer string, got boolean: {value}") # Ensure value is a string if not isinstance(value, str): - raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") + raise ValueError( + f"{name} must be a string, got {type(value).__name__}: {value}") # Strip whitespace and validate it looks like an integer value = value.strip() @@ -558,7 +569,8 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int raise ValueError(f"{name} must be an integer, got: {value}") from e if min_value is not None and int_value < min_value: - raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") + raise ValueError( + f"{name} must be at least {min_value}, got: {int_value}") return int_value @@ -602,7 +614,8 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) if not params: - raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") + raise ValueError( + f"Validator function {validator_func.__name__} must have at least one parameter") first_param = params[0] def wrapper(value: Any) -> Any: @@ -615,7 +628,8 @@ def wrapper(value: Any) -> Any: raise ap.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging - wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") + wrapper.__name__ = getattr( + validator_func, "__name__", "argparse_validator") wrapper.__doc__ = getattr(validator_func, "__doc__", None) return wrapper @@ -675,7 +689,8 @@ def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: validate_database_argparse = _argparse_validator(validate_database) validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) -non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) +non_negative_int = _argparse_validator( + lambda v: validate_integer(v, min_value=0)) resolve_env_files_argparse = _argparse_validator(resolve_env_files) @@ -703,7 +718,8 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: # Validate all keys and values are strings for key, value in labels.items(): if not isinstance(key, str) or not isinstance(value, str): - raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") + raise ValueError( + f"All label keys and values must be strings. Got: {key}={value}") return labels @@ -872,13 +888,15 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) + result["max_concurrency"] = validate_integer( + parts[i], name="--max-concurrency", min_value=1) i += 1 elif parts[i] == "--max-retries": i += 1 if i >= len(parts): raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) + result["max_retries"] = validate_integer( + parts[i], name="--max-retries", min_value=0) i += 1 elif parts[i] == "--memory-labels": i += 1 @@ -909,7 +927,8 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) + result["max_dataset_size"] = validate_integer( + parts[i], name="--max-dataset-size", min_value=1) i += 1 else: logger.warning(f"Unknown argument: {parts[i]}") diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 43781482a6..c4de19e1ff 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -65,7 +65,9 @@ class ConfigurationLoader(YamlLoadable): memory_db_type: The type of memory database (in_memory, sqlite, azure_sql). initializers: List of initializer configurations (name + optional args). initialization_scripts: List of paths to custom initialization scripts. + None means "use defaults", [] means "load nothing". env_files: List of environment file paths to load. + None means "use defaults (.env, .env.local)", [] means "load nothing". silent: Whether to suppress initialization messages. Example YAML configuration: @@ -89,8 +91,8 @@ class ConfigurationLoader(YamlLoadable): memory_db_type: str = "sqlite" initializers: List[Union[str, Dict[str, Any]]] = field(default_factory=list) - initialization_scripts: List[str] = field(default_factory=list) - env_files: List[str] = field(default_factory=list) + initialization_scripts: Optional[List[str]] = None + env_files: Optional[List[str]] = None silent: bool = False def __post_init__(self) -> None: @@ -167,8 +169,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": Returns: A new ConfigurationLoader instance. """ - # Filter out None values and empty lists to use defaults - filtered_data = {k: v for k, v in data.items() if v is not None and v != []} + # Filter out None values only - empty lists are meaningful ("load nothing") + filtered_data = {k: v for k, v in data.items() if v is not None} return cls(**filtered_data) @staticmethod @@ -210,12 +212,12 @@ def load_with_overrides( logger = logging.getLogger(__name__) - # Start with defaults + # Start with defaults - None means "use defaults", [] means "load nothing" config_data: Dict[str, Any] = { "memory_db_type": "sqlite", "initializers": [], - "initialization_scripts": [], - "env_files": [], + "initialization_scripts": None, # None = use defaults + "env_files": None, # None = use defaults } # 1. Try loading default config file if it exists @@ -228,6 +230,7 @@ def load_with_overrides( {"name": ic.name, "args": ic.args} if ic.args else ic.name for ic in default_config._initializer_configs ] + # Preserve None vs [] distinction from config file config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files except Exception as e: @@ -243,6 +246,7 @@ def load_with_overrides( {"name": ic.name, "args": ic.args} if ic.args else ic.name for ic in explicit_config._initializer_configs ] + # Preserve None vs [] distinction from config file config_data["initialization_scripts"] = explicit_config.initialization_scripts config_data["env_files"] = explicit_config.env_files @@ -322,11 +326,17 @@ def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: Resolve initialization script paths. Returns: - Sequence of Path objects, or None if no scripts configured. + None if field is None (use defaults), empty list if field is [], + or Sequence of resolved Path objects if paths are specified. """ - if not self.initialization_scripts: + # None means "use defaults" - return None to signal this + if self.initialization_scripts is None: return None + # Empty list means "load nothing" - return empty list + if len(self.initialization_scripts) == 0: + return [] + resolved: List[pathlib.Path] = [] for script_str in self.initialization_scripts: script_path = pathlib.Path(script_str) @@ -341,11 +351,17 @@ def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: Resolve environment file paths. Returns: - Sequence of Path objects, or None if no env files configured. + None if field is None (use defaults), empty list if field is [], + or Sequence of resolved Path objects if paths are specified. """ - if not self.env_files: + # None means "use defaults" - return None to signal this + if self.env_files is None: return None + # Empty list means "load nothing" - return empty list + if len(self.env_files) == 0: + return [] + resolved: List[pathlib.Path] = [] for env_str in self.env_files: env_path = pathlib.Path(env_str) diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index a55dab6279..0524084ffb 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -38,8 +38,8 @@ def test_default_values(self): config = ConfigurationLoader() assert config.memory_db_type == "sqlite" assert config.initializers == [] - assert config.initialization_scripts == [] - assert config.env_files == [] + assert config.initialization_scripts is None # None means "use defaults" + assert config.env_files is None # None means "use defaults" assert config.silent is False def test_valid_memory_db_types_snake_case(self): @@ -200,11 +200,18 @@ def test_get_default_config_path(self): class TestConfigurationLoaderResolvers: """Tests for ConfigurationLoader path resolution methods.""" - def test_resolve_initialization_scripts_empty(self): - """Test that empty scripts returns None.""" + def test_resolve_initialization_scripts_none_returns_none(self): + """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() assert config._resolve_initialization_scripts() is None + def test_resolve_initialization_scripts_empty_list_returns_empty_list(self): + """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" + config = ConfigurationLoader(initialization_scripts=[]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert resolved == [] + def test_resolve_initialization_scripts_absolute_path(self): """Test resolving absolute script paths.""" config = ConfigurationLoader(initialization_scripts=["/absolute/path/script.py"]) @@ -224,11 +231,18 @@ def test_resolve_initialization_scripts_relative_path(self): # Check path ends with expected components (works on both Unix and Windows) assert resolved[0].parts[-2:] == ("relative", "script.py") - def test_resolve_env_files_empty(self): - """Test that empty env files returns None.""" + def test_resolve_env_files_none_returns_none(self): + """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() assert config._resolve_env_files() is None + def test_resolve_env_files_empty_list_returns_empty_list(self): + """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" + config = ConfigurationLoader(env_files=[]) + resolved = config._resolve_env_files() + assert resolved is not None + assert resolved == [] + def test_resolve_env_files_absolute_path(self): """Test resolving absolute env file paths.""" config = ConfigurationLoader(env_files=["/path/to/.env"]) From dbb2e3fd1ff03d5563a3b8ae89c3044925b760c5 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 19:31:34 +0000 Subject: [PATCH 10/15] precommit --- pyrit/cli/frontend_core.py | 54 +++++++++++++------------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 4122a40f32..3d913a7be8 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -107,8 +107,7 @@ def __init__( config_file=config_file, memory_db_type=database, initializers=initializer_names, - initialization_scripts=[ - str(p) for p in initialization_scripts] if initialization_scripts else None, + initialization_scripts=[str(p) for p in initialization_scripts] if initialization_scripts else None, env_files=[str(p) for p in env_files] if env_files else None, ) except ValueError as e: @@ -310,8 +309,7 @@ async def run_scenario_async( if scenario_class is None: available = ", ".join(context.scenario_registry.get_names()) - raise ValueError( - f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") + raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") # Build initialization kwargs (these go to initialize_async, not __init__) init_kwargs: dict[str, Any] = {} @@ -436,15 +434,13 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if scenario_metadata.aggregate_strategies: agg_strategies = scenario_metadata.aggregate_strategies print(" Aggregate Strategies:") - formatted = _format_wrapped_text( - text=", ".join(agg_strategies), indent=" - ") + formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") print(formatted) if scenario_metadata.all_strategies: strategies = scenario_metadata.all_strategies print(f" Available Strategies ({len(strategies)}):") - formatted = _format_wrapped_text( - text=", ".join(strategies), indent=" ") + formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") print(formatted) if scenario_metadata.default_strategy: @@ -456,8 +452,7 @@ def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: if datasets: size_suffix = f", max {max_size} per dataset" if max_size else "" print(f" Default Datasets ({len(datasets)}{size_suffix}):") - formatted = _format_wrapped_text( - text=", ".join(datasets), indent=" ") + formatted = _format_wrapped_text(text=", ".join(datasets), indent=" ") print(formatted) else: print(" Default Datasets: None") @@ -484,8 +479,7 @@ def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") if initializer_metadata.class_description: print(" Description:") - print(_format_wrapped_text( - text=initializer_metadata.class_description, indent=" ")) + print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) def validate_database(*, database: str) -> str: @@ -503,8 +497,7 @@ def validate_database(*, database: str) -> str: """ valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] if database not in valid_databases: - raise ValueError( - f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") + raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") return database @@ -524,8 +517,7 @@ def validate_log_level(*, log_level: str) -> str: valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] level_upper = log_level.upper() if level_upper not in valid_levels: - raise ValueError( - f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") return level_upper @@ -550,13 +542,11 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int """ # Reject boolean types explicitly (int(True) == 1, int(False) == 0) if isinstance(value, bool): - raise ValueError( - f"{name} must be an integer string, got boolean: {value}") + raise ValueError(f"{name} must be an integer string, got boolean: {value}") # Ensure value is a string if not isinstance(value, str): - raise ValueError( - f"{name} must be a string, got {type(value).__name__}: {value}") + raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") # Strip whitespace and validate it looks like an integer value = value.strip() @@ -569,8 +559,7 @@ def validate_integer(value: str, *, name: str = "value", min_value: Optional[int raise ValueError(f"{name} must be an integer, got: {value}") from e if min_value is not None and int_value < min_value: - raise ValueError( - f"{name} must be at least {min_value}, got: {int_value}") + raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") return int_value @@ -614,8 +603,7 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) if not params: - raise ValueError( - f"Validator function {validator_func.__name__} must have at least one parameter") + raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") first_param = params[0] def wrapper(value: Any) -> Any: @@ -628,8 +616,7 @@ def wrapper(value: Any) -> Any: raise ap.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging - wrapper.__name__ = getattr( - validator_func, "__name__", "argparse_validator") + wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") wrapper.__doc__ = getattr(validator_func, "__doc__", None) return wrapper @@ -689,8 +676,7 @@ def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: validate_database_argparse = _argparse_validator(validate_database) validate_log_level_argparse = _argparse_validator(validate_log_level) positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) -non_negative_int = _argparse_validator( - lambda v: validate_integer(v, min_value=0)) +non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) resolve_env_files_argparse = _argparse_validator(resolve_env_files) @@ -718,8 +704,7 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: # Validate all keys and values are strings for key, value in labels.items(): if not isinstance(key, str) or not isinstance(value, str): - raise ValueError( - f"All label keys and values must be strings. Got: {key}={value}") + raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") return labels @@ -888,15 +873,13 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer( - parts[i], name="--max-concurrency", min_value=1) + result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) i += 1 elif parts[i] == "--max-retries": i += 1 if i >= len(parts): raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer( - parts[i], name="--max-retries", min_value=0) + result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) i += 1 elif parts[i] == "--memory-labels": i += 1 @@ -927,8 +910,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i += 1 if i >= len(parts): raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer( - parts[i], name="--max-dataset-size", min_value=1) + result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) i += 1 else: logger.warning(f"Unknown argument: {parts[i]}") From 3284e194fb8e87b0dc5321262f87f130f85eaa8c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 19:39:42 +0000 Subject: [PATCH 11/15] logging type --- pyrit/cli/frontend_core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 3d913a7be8..bdf920ccd5 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -97,9 +97,11 @@ def __init__( from pyrit.setup import ConfigurationLoader from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP - # Validate log level early - effective_log_level = log_level if log_level is not None else "WARNING" - self._log_level = validate_log_level(log_level=effective_log_level) + # Validate and convert log level to actual logging level constant + if log_level is None: + self._log_level = logging.WARNING + else: + self._log_level = getattr(logging, validate_log_level(log_level=log_level)) # Load configuration using ConfigurationLoader.load_with_overrides try: @@ -137,7 +139,7 @@ def __init__( self._initialized = False # Configure logging - logging.basicConfig(level=getattr(logging, self._log_level)) + logging.basicConfig(level=self._log_level) async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" From 5a48b782cc191975d80d5e9be01bc91fddaff625 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 23:23:54 +0000 Subject: [PATCH 12/15] precommit --- pyrit/cli/frontend_core.py | 6 +++--- tests/unit/cli/test_frontend_core.py | 10 +++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index bdf920ccd5..b0fddceac0 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -21,6 +21,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from pyrit.setup import ConfigurationLoader +from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP + try: import termcolor @@ -94,9 +97,6 @@ def __init__( ValueError: If database or log_level are invalid, or if config file is invalid. FileNotFoundError: If an explicitly specified config_file does not exist. """ - from pyrit.setup import ConfigurationLoader - from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP - # Validate and convert log level to actual logging level constant if log_level is None: self._log_level = logging.WARNING diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 3ed4e9e81c..c0689e63e0 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -5,6 +5,7 @@ Unit tests for the frontend_core module. """ +import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -24,7 +25,7 @@ def test_init_with_defaults(self): assert context._database == frontend_core.SQLITE assert context._initialization_scripts is None assert context._initializer_names is None - assert context._log_level == "WARNING" + assert context._log_level == logging.WARNING assert context._initialized is False def test_init_with_all_parameters(self): @@ -45,18 +46,13 @@ def test_init_with_all_parameters(self): assert len(context._initialization_scripts) == 1 assert context._initialization_scripts[0].parts[-2:] == ("test", "script.py") assert context._initializer_names == initializers - assert context._log_level == "DEBUG" + assert context._log_level == logging.DEBUG def test_init_with_invalid_database(self): """Test initialization with invalid database raises ValueError.""" with pytest.raises(ValueError, match="Invalid database type"): frontend_core.FrontendCore(database="InvalidDB") - def test_init_with_invalid_log_level(self): - """Test initialization with invalid log level raises ValueError.""" - with pytest.raises(ValueError, match="Invalid log level"): - frontend_core.FrontendCore(log_level="INVALID") - @patch("pyrit.registry.ScenarioRegistry") @patch("pyrit.registry.InitializerRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) From ecc4ab0b36a7418abe9c51ff71317f744adeeae0 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 6 Feb 2026 23:33:40 +0000 Subject: [PATCH 13/15] precommit mypy --- pyrit/cli/pyrit_shell.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index aadaa01ab7..d23ed76aa7 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -99,7 +99,7 @@ def __init__( super().__init__() self.context = context self.default_database = context._database - self.default_log_level = context._log_level + self.default_log_level: Optional[str] = str(context._log_level) if context._log_level is not None else None self.default_env_files = context._env_files # Track scenario execution history: list of (command_string, ScenarioResult) tuples @@ -235,7 +235,7 @@ def do_run(self, line: str) -> None: initialization_scripts=resolved_scripts, initializer_names=args["initializers"], env_files=resolved_env_files, - log_level=args["log_level"] or self.default_log_level, + log_level=str(args["log_level"]) if args["log_level"] else self.default_log_level, ) # Use the existing registries (don't reinitialize) run_context._scenario_registry = self.context._scenario_registry From c183e34ac26de765010069b0a9033f585888bbdc Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 7 Feb 2026 00:30:28 +0000 Subject: [PATCH 14/15] logging type --- pyrit/cli/frontend_core.py | 24 +++++++++++------------- pyrit/cli/pyrit_shell.py | 4 ++-- tests/unit/cli/test_frontend_core.py | 16 ++++++++-------- tests/unit/cli/test_pyrit_scan.py | 11 ++++++----- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b0fddceac0..717e23a624 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -74,7 +74,7 @@ def __init__( initialization_scripts: Optional[list[Path]] = None, initializer_names: Optional[list[str]] = None, env_files: Optional[list[Path]] = None, - log_level: Optional[str] = None, + log_level: Optional[int] = None, ): """ Initialize PyRIT context. @@ -91,17 +91,14 @@ def __init__( initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. env_files: Optional list of environment file paths to load in order. - log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. + log_level: Logging level constant (e.g., logging.WARNING). Defaults to logging.WARNING. Raises: - ValueError: If database or log_level are invalid, or if config file is invalid. + ValueError: If database is invalid, or if config file is invalid. FileNotFoundError: If an explicitly specified config_file does not exist. """ - # Validate and convert log level to actual logging level constant - if log_level is None: - self._log_level = logging.WARNING - else: - self._log_level = getattr(logging, validate_log_level(log_level=log_level)) + # Use provided log level or default to WARNING + self._log_level = log_level if log_level is not None else logging.WARNING # Load configuration using ConfigurationLoader.load_with_overrides try: @@ -503,15 +500,15 @@ def validate_database(*, database: str) -> str: return database -def validate_log_level(*, log_level: str) -> str: +def validate_log_level(*, log_level: str) -> int: """ - Validate log level. + Validate log level and convert to logging constant. Args: log_level: Log level string (case-insensitive). Returns: - Validated log level in uppercase. + Validated log level as logging constant (e.g., logging.WARNING). Raises: ValueError: If log level is invalid. @@ -520,7 +517,8 @@ def validate_log_level(*, log_level: str) -> str: level_upper = log_level.upper() if level_upper not in valid_levels: raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") - return level_upper + level_value: int = getattr(logging, level_upper) + return level_value def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: @@ -814,7 +812,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: - max_retries: Optional[int] - memory_labels: Optional[dict[str, str]] - database: Optional[str] - - log_level: Optional[str] + - log_level: Optional[int] - dataset_names: Optional[list[str]] - max_dataset_size: Optional[int] diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index d23ed76aa7..2c218f237b 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -99,7 +99,7 @@ def __init__( super().__init__() self.context = context self.default_database = context._database - self.default_log_level: Optional[str] = str(context._log_level) if context._log_level is not None else None + self.default_log_level: Optional[int] = context._log_level self.default_env_files = context._env_files # Track scenario execution history: list of (command_string, ScenarioResult) tuples @@ -235,7 +235,7 @@ def do_run(self, line: str) -> None: initialization_scripts=resolved_scripts, initializer_names=args["initializers"], env_files=resolved_env_files, - log_level=str(args["log_level"]) if args["log_level"] else self.default_log_level, + log_level=args["log_level"] if args["log_level"] else self.default_log_level, ) # Use the existing registries (don't reinitialize) run_context._scenario_registry = self.context._scenario_registry diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index c0689e63e0..5e7ce37e19 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -37,7 +37,7 @@ def test_init_with_all_parameters(self): database=frontend_core.IN_MEMORY, initialization_scripts=scripts, initializer_names=initializers, - log_level="DEBUG", + log_level=logging.DEBUG, ) assert context._database == frontend_core.IN_MEMORY @@ -128,11 +128,11 @@ def test_validate_database_invalid_value(self): def test_validate_log_level_valid_values(self): """Test validate_log_level with valid values.""" - assert frontend_core.validate_log_level(log_level="DEBUG") == "DEBUG" - assert frontend_core.validate_log_level(log_level="INFO") == "INFO" - assert frontend_core.validate_log_level(log_level="warning") == "WARNING" # Case-insensitive - assert frontend_core.validate_log_level(log_level="error") == "ERROR" - assert frontend_core.validate_log_level(log_level="CRITICAL") == "CRITICAL" + assert frontend_core.validate_log_level(log_level="DEBUG") == logging.DEBUG + assert frontend_core.validate_log_level(log_level="INFO") == logging.INFO + assert frontend_core.validate_log_level(log_level="warning") == logging.WARNING # Case-insensitive + assert frontend_core.validate_log_level(log_level="error") == logging.ERROR + assert frontend_core.validate_log_level(log_level="CRITICAL") == logging.CRITICAL def test_validate_log_level_invalid_value(self): """Test validate_log_level with invalid value.""" @@ -207,7 +207,7 @@ def test_validate_database_argparse(self): def test_validate_log_level_argparse(self): """Test validate_log_level_argparse wrapper.""" - assert frontend_core.validate_log_level_argparse("DEBUG") == "DEBUG" + assert frontend_core.validate_log_level_argparse("DEBUG") == logging.DEBUG import argparse @@ -582,7 +582,7 @@ def test_parse_run_arguments_with_log_level(self): """Test parsing with log-level override.""" result = frontend_core.parse_run_arguments(args_string="test_scenario --log-level DEBUG") - assert result["log_level"] == "DEBUG" + assert result["log_level"] == logging.DEBUG def test_parse_run_arguments_with_initialization_scripts(self): """Test parsing with initialization-scripts.""" diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 58517c5235..7b477cb3ec 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -5,6 +5,7 @@ Unit tests for the pyrit_scan CLI module. """ +import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -36,7 +37,7 @@ def test_parse_args_scenario_name_only(self): assert args.scenario_name == "test_scenario" assert args.database == "SQLite" - assert args.log_level == "WARNING" + assert args.log_level == logging.WARNING def test_parse_args_with_database(self): """Test parsing with database option.""" @@ -48,7 +49,7 @@ def test_parse_args_with_log_level(self): """Test parsing with log-level option.""" args = pyrit_scan.parse_args(["test_scenario", "--log-level", "DEBUG"]) - assert args.log_level == "DEBUG" + assert args.log_level == logging.DEBUG def test_parse_args_with_initializers(self): """Test parsing with initializers.""" @@ -117,7 +118,7 @@ def test_parse_args_complex_command(self): assert args.scenario_name == "encoding_scenario" assert args.database == "InMemory" - assert args.log_level == "INFO" + assert args.log_level == logging.INFO assert args.initializers == ["openai_target"] assert args.scenario_strategies == ["base64", "rot13"] assert args.max_concurrency == 10 @@ -304,7 +305,7 @@ def test_main_run_scenario_with_all_options( # Verify FrontendCore was called with correct args call_kwargs = mock_frontend_core.call_args[1] assert call_kwargs["database"] == "InMemory" - assert call_kwargs["log_level"] == "DEBUG" + assert call_kwargs["log_level"] == logging.DEBUG assert call_kwargs["initializer_names"] == ["init1", "init2"] @patch("pyrit.cli.pyrit_scan.asyncio.run") @@ -354,7 +355,7 @@ def test_main_log_level_defaults_to_warning(self, mock_frontend_core: MagicMock) pyrit_scan.main(["--list-scenarios"]) call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["log_level"] == "WARNING" + assert call_kwargs["log_level"] == logging.WARNING def test_main_with_invalid_args(self): """Test main with invalid arguments.""" From 0f638bdaef0aa4c0b8303bcd781d9fbc0873c1bb Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 7 Feb 2026 00:38:21 +0000 Subject: [PATCH 15/15] logging --- pyrit/cli/pyrit_scan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index ba025c61d7..d73992d7bc 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -8,6 +8,7 @@ """ import asyncio +import logging import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path @@ -58,7 +59,7 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--log-level", type=frontend_core.validate_log_level_argparse, - default="WARNING", + default=logging.WARNING, help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", )