diff --git a/.devcontainer/devcontainer_setup.sh b/.devcontainer/devcontainer_setup.sh index 3657b804ff..095ec9ca29 100644 --- a/.devcontainer/devcontainer_setup.sh +++ b/.devcontainer/devcontainer_setup.sh @@ -3,6 +3,21 @@ set -e MYPY_CACHE="/workspace/.mypy_cache" VIRTUAL_ENV="/opt/venv" +PYRIT_CONFIG_DIR="/home/vscode/.pyrit" +PYRIT_CONFIG_FILE="$PYRIT_CONFIG_DIR/.pyrit_conf" + +# Create the .pyrit config directory and copy example config if not exists +if [ ! -d "$PYRIT_CONFIG_DIR" ]; then + echo "Creating PyRIT config directory..." + mkdir -p "$PYRIT_CONFIG_DIR" +fi + +if [ ! -f "$PYRIT_CONFIG_FILE" ] && [ -f "/workspace/.pyrit_conf_example" ]; then + echo "Copying example PyRIT config file..." + cp /workspace/.pyrit_conf_example "$PYRIT_CONFIG_FILE" + echo "✅ Created $PYRIT_CONFIG_FILE from example. Edit as needed." +fi + # Create the mypy cache directory if it doesn't exist if [ ! -d "$MYPY_CACHE" ]; then echo "Creating mypy cache directory..." diff --git a/.pyrit_conf_example b/.pyrit_conf_example new file mode 100644 index 0000000000..46014434f8 --- /dev/null +++ b/.pyrit_conf_example @@ -0,0 +1,78 @@ +# PyRIT Configuration File Example +# ================================ +# This is a YAML-formatted configuration file. Copy to ~/.pyrit/.pyrit_conf +# or specify a custom path when loading via --config-file. +# +# For documentation on configuration options, see: +# https://github.com/Azure/PyRIT/blob/main/doc/setup/configuration.md + +# Memory Database Type +# -------------------- +# Specifies which database backend to use for storing prompts and results. +# Options: in_memory, sqlite, azure_sql (case-insensitive) +# - in_memory: Temporary in-memory database (data lost on exit) +# - sqlite: Persistent local SQLite database (default) +# - azure_sql: Azure SQL database (requires connection string in env vars) +memory_db_type: sqlite + +# Initializers +# ------------ +# List of built-in initializers to run during PyRIT initialization. +# Initializers configure default values for converters, scorers, and targets. +# Names are normalized to snake_case (e.g., "SimpleInitializer" -> "simple"). +# +# Available initializers: +# - simple: Basic OpenAI configuration (requires OPENAI_CHAT_* env vars) +# - airt: AI Red Team setup with Azure OpenAI (requires AZURE_OPENAI_* env vars) +# - load_default_datasets: Loads default datasets for all registered scenarios +# - objective_list: Sets default objectives for scenarios +# - openai_objective_target: Sets up OpenAI target for scenarios +# +# Each initializer can be specified as: +# - A simple string (name only) +# - A dictionary with 'name' and optional 'args' for constructor arguments +# +# Example: +# initializers: +# - simple +# - name: airt +# args: +# some_param: value +initializers: + - simple + +# Initialization Scripts +# ---------------------- +# List of paths to custom Python scripts containing PyRITInitializer subclasses. +# Paths can be absolute or relative to the current working directory. +# +# Behavior: +# - Omit this field (or set to null): No custom scripts loaded (default) +# - Set to []: Explicitly load no scripts (same as omitting) +# - Set to list of paths: Load the specified scripts +# +# Example: +# initialization_scripts: +# - /path/to/my_custom_initializer.py +# - ./local_initializer.py + +# Environment Files +# ----------------- +# List of .env file paths to load during initialization. +# Later files override values from earlier files. +# +# Behavior: +# - Omit this field (or set to null): Load default .env and .env.local from ~/.pyrit/ if they exist +# - Set to []: Explicitly load NO environment files +# - Set to list of paths: Load only the specified files +# +# Example: +# env_files: +# - /path/to/.env +# - /path/to/.env.local + +# Silent Mode +# ----------- +# If true, suppresses print statements during initialization. +# Useful for non-interactive environments or when embedding PyRIT in other tools. +silent: false diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 6cfc8b5a45..717e23a624 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -21,6 +21,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from pyrit.setup import ConfigurationLoader +from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP + try: import termcolor @@ -66,31 +69,66 @@ class FrontendCore: def __init__( self, *, - database: str = SQLITE, + config_file: Optional[Path] = None, + database: Optional[str] = None, initialization_scripts: Optional[list[Path]] = None, initializer_names: Optional[list[str]] = None, env_files: Optional[list[Path]] = None, - log_level: str = "WARNING", + log_level: Optional[int] = None, ): """ Initialize PyRIT context. + Configuration is loaded in the following order (later values override earlier): + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual CLI arguments (database, initializers, etc.) + Args: + config_file: Optional path to a YAML-formatted configuration file. + The file uses .pyrit_conf extension but is YAML format. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. env_files: Optional list of environment file paths to load in order. - log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. + log_level: Logging level constant (e.g., logging.WARNING). Defaults to logging.WARNING. Raises: - ValueError: If database or log_level are invalid. + ValueError: If database is invalid, or if config file is invalid. + FileNotFoundError: If an explicitly specified config_file does not exist. """ - # Validate inputs - self._database = validate_database(database=database) - self._initialization_scripts = initialization_scripts - self._initializer_names = initializer_names - self._env_files = env_files - self._log_level = validate_log_level(log_level=log_level) + # Use provided log level or default to WARNING + self._log_level = log_level if log_level is not None else logging.WARNING + + # Load configuration using ConfigurationLoader.load_with_overrides + try: + config = ConfigurationLoader.load_with_overrides( + config_file=config_file, + memory_db_type=database, + initializers=initializer_names, + initialization_scripts=[str(p) for p in initialization_scripts] if initialization_scripts else None, + env_files=[str(p) for p in env_files] if env_files else None, + ) + except ValueError as e: + # Re-raise with user-friendly message for CLI users + error_msg = str(e) + if "memory_db_type" in error_msg: + raise ValueError( + f"Invalid database type '{database}'. Must be one of: InMemory, SQLite, AzureSQL" + ) from e + raise + + # Store the merged configuration + self._config = config + + # Extract values from config for internal use + # Use canonical mapping from configuration_loader + self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] + self._initialization_scripts = config._resolve_initialization_scripts() + self._initializer_names = ( + [ic.name for ic in config._initializer_configs] if config._initializer_configs else None + ) + self._env_files = config._resolve_env_files() # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -98,7 +136,7 @@ def __init__( self._initialized = False # Configure logging - logging.basicConfig(level=getattr(logging, self._log_level)) + logging.basicConfig(level=self._log_level) async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" @@ -462,15 +500,15 @@ def validate_database(*, database: str) -> str: return database -def validate_log_level(*, log_level: str) -> str: +def validate_log_level(*, log_level: str) -> int: """ - Validate log level. + Validate log level and convert to logging constant. Args: log_level: Log level string (case-insensitive). Returns: - Validated log level in uppercase. + Validated log level as logging constant (e.g., logging.WARNING). Raises: ValueError: If log level is invalid. @@ -479,7 +517,8 @@ def validate_log_level(*, log_level: str) -> str: level_upper = log_level.upper() if level_upper not in valid_levels: raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") - return level_upper + level_value: int = getattr(logging, level_upper) + return level_value def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: @@ -734,6 +773,11 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path # Shared argument help text ARG_HELP = { + "config_file": ( + "Path to a YAML configuration file. Allows specifying database, initializers (with args), " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." + ), "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " @@ -768,7 +812,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: - max_retries: Optional[int] - memory_labels: Optional[dict[str, str]] - database: Optional[str] - - log_level: Optional[str] + - log_level: Optional[int] - dataset_names: Optional[list[str]] - max_dataset_size: Optional[int] diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index df342ce0cd..d73992d7bc 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -8,8 +8,10 @@ """ import asyncio +import logging import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path from typing import Optional from pyrit.cli import frontend_core @@ -34,6 +36,9 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: # Run a scenario with built-in initializers pyrit_scan foundry --initializers openai_objective_target load_default_datasets + # Run with a configuration file (recommended for complex setups) + pyrit_scan foundry --config-file ./my_config.yaml + # Run with custom initialization scripts pyrit_scan garak.encoding --initialization-scripts ./my_config.py @@ -45,10 +50,16 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: formatter_class=RawDescriptionHelpFormatter, ) + parser.add_argument( + "--config-file", + type=Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--log-level", type=frontend_core.validate_log_level_argparse, - default="WARNING", + default=logging.WARNING, help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", ) @@ -182,6 +193,7 @@ def main(args: Optional[list[str]] = None) -> int: return 1 context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, env_files=env_files, @@ -194,7 +206,10 @@ def main(args: Optional[list[str]] = None) -> int: # Discover from scenarios directory scenarios_path = frontend_core.get_default_initializer_discovery_path() - context = frontend_core.FrontendCore(log_level=parsed_args.log_level) + context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, + log_level=parsed_args.log_level, + ) return asyncio.run(frontend_core.print_initializers_list_async(context=context, discovery_path=scenarios_path)) # Verify scenario was provided @@ -218,6 +233,7 @@ def main(args: Optional[list[str]] = None) -> int: # Create context with initializers context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index bcb0743425..2c218f237b 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -14,7 +14,8 @@ import cmd import sys import threading -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from pyrit.models.scenario_result import ScenarioResult @@ -98,7 +99,7 @@ def __init__( super().__init__() self.context = context self.default_database = context._database - self.default_log_level = context._log_level + self.default_log_level: Optional[int] = context._log_level self.default_env_files = context._env_files # Track scenario execution history: list of (command_string, ScenarioResult) tuples @@ -217,16 +218,16 @@ def do_run(self, line: str) -> None: return # Resolve env files if provided - resolved_env_files = None + resolved_env_files: Optional[list[Path]] = None if args["env_files"]: try: - resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) + resolved_env_files = list(frontend_core.resolve_env_files(env_file_paths=args["env_files"])) except ValueError as e: print(f"Error: {e}") return else: # Use default env files from shell startup - resolved_env_files = self.default_env_files + resolved_env_files = list(self.default_env_files) if self.default_env_files else None # Create a context for this run with overrides run_context = frontend_core.FrontendCore( @@ -234,7 +235,7 @@ def do_run(self, line: str) -> None: initialization_scripts=resolved_scripts, initializer_names=args["initializers"], env_files=resolved_env_files, - log_level=args["log_level"] or self.default_log_level, + log_level=args["log_level"] if args["log_level"] else self.default_log_level, ) # Use the existing registries (don't reinitialize) run_context._scenario_registry = self.context._scenario_registry @@ -453,6 +454,12 @@ def main() -> int: description="PyRIT Interactive Shell - Load modules once, run commands instantly", ) + parser.add_argument( + "--config-file", + type=Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--database", choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], @@ -488,6 +495,7 @@ def main() -> int: # Create context (initializers are specified per-run, not at startup) context = frontend_core.FrontendCore( + config_file=args.config_file, database=args.database, initialization_scripts=None, initializer_names=None, diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 4094ba8a4b..fcdc4c92ec 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -33,6 +33,10 @@ def in_git_repo() -> bool: CONFIGURATION_DIRECTORY_PATH = pathlib.Path.home() / ".pyrit" +# Default configuration file name and path +DEFAULT_CONFIG_FILENAME = ".pyrit_conf" +DEFAULT_CONFIG_PATH = CONFIGURATION_DIRECTORY_PATH / DEFAULT_CONFIG_FILENAME + # Points to the root of the project HOME_PATH = pathlib.Path(PYRIT_PATH, "..").resolve() diff --git a/pyrit/setup/__init__.py b/pyrit/setup/__init__.py index 4ecdbd9d43..2929a59ea3 100644 --- a/pyrit/setup/__init__.py +++ b/pyrit/setup/__init__.py @@ -3,6 +3,7 @@ """Module containing initialization PyRIT.""" +from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async __all__ = [ @@ -10,5 +11,7 @@ "SQLITE", "IN_MEMORY", "initialize_pyrit_async", + "initialize_from_config_async", "MemoryDatabaseType", + "ConfigurationLoader", ] diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py new file mode 100644 index 0000000000..c4de19e1ff --- /dev/null +++ b/pyrit/setup/configuration_loader.py @@ -0,0 +1,427 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Configuration loader for PyRIT initialization. + +This module provides the ConfigurationLoader class that loads PyRIT configuration +from YAML files and initializes PyRIT accordingly. +""" + +import pathlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union + +from pyrit.common.path import DEFAULT_CONFIG_PATH +from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.setup.initialization import ( + AZURE_SQL, + IN_MEMORY, + SQLITE, + initialize_pyrit_async, +) + +if TYPE_CHECKING: + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + +# Type alias for YAML-serializable values that can be passed as initializer args +# This matches what YAML can represent: primitives, lists, and nested dicts +YamlPrimitive = Union[str, int, float, bool, None] +YamlValue = Union[YamlPrimitive, List["YamlValue"], Dict[str, "YamlValue"]] + +# Mapping from snake_case config values to internal constants +_MEMORY_DB_TYPE_MAP: Dict[str, str] = { + "in_memory": IN_MEMORY, + "sqlite": SQLITE, + "azure_sql": AZURE_SQL, +} + + +@dataclass +class InitializerConfig: + """ + Configuration for a single initializer. + + Attributes: + name: The name of the initializer (must be registered in InitializerRegistry). + args: Optional dictionary of YAML-serializable arguments to pass to the initializer constructor. + """ + + name: str + args: Optional[Dict[str, YamlValue]] = None + + +@dataclass +class ConfigurationLoader(YamlLoadable): + """ + Loader for PyRIT configuration from YAML files. + + This class loads configuration from a YAML file and provides methods to + initialize PyRIT with the loaded configuration. + + Attributes: + memory_db_type: The type of memory database (in_memory, sqlite, azure_sql). + initializers: List of initializer configurations (name + optional args). + initialization_scripts: List of paths to custom initialization scripts. + None means "use defaults", [] means "load nothing". + env_files: List of environment file paths to load. + None means "use defaults (.env, .env.local)", [] means "load nothing". + silent: Whether to suppress initialization messages. + + Example YAML configuration: + memory_db_type: sqlite + + initializers: + - simple + - name: airt + args: + some_param: value + + initialization_scripts: + - /path/to/custom_initializer.py + + env_files: + - /path/to/.env + - /path/to/.env.local + + silent: false + """ + + memory_db_type: str = "sqlite" + initializers: List[Union[str, Dict[str, Any]]] = field(default_factory=list) + initialization_scripts: Optional[List[str]] = None + env_files: Optional[List[str]] = None + silent: bool = False + + def __post_init__(self) -> None: + """Validate and normalize the configuration after loading.""" + self._normalize_memory_db_type() + self._normalize_initializers() + + def _normalize_memory_db_type(self) -> None: + """ + Normalize and validate memory_db_type. + + Converts the input to lowercase snake_case and validates against known types. + Stores the normalized snake_case value for config consistency, but maps + to internal constants when initializing. + + Raises: + ValueError: If the memory_db_type is not a valid database type. + """ + # Normalize to lowercase + normalized = self.memory_db_type.lower().replace("-", "_") + + # Also handle PascalCase inputs (e.g., "InMemory" -> "in_memory") + if normalized not in _MEMORY_DB_TYPE_MAP: + # Try converting from PascalCase + normalized = class_name_to_snake_case(self.memory_db_type) + + if normalized not in _MEMORY_DB_TYPE_MAP: + valid_types = list(_MEMORY_DB_TYPE_MAP.keys()) + raise ValueError( + f"Invalid memory_db_type '{self.memory_db_type}'. Must be one of: {', '.join(valid_types)}" + ) + + # Store normalized snake_case value + self.memory_db_type = normalized + + def _normalize_initializers(self) -> None: + """ + Normalize initializer entries to InitializerConfig objects. + + Converts initializer names to snake_case for consistent registry lookup. + + Raises: + ValueError: If an initializer entry is missing a 'name' field or has an invalid type. + """ + normalized: List[InitializerConfig] = [] + for entry in self.initializers: + if isinstance(entry, str): + # Simple string entry: normalize name to snake_case + name = class_name_to_snake_case(entry) + normalized.append(InitializerConfig(name=name)) + elif isinstance(entry, dict): + # Dict entry: name and optional args + if "name" not in entry: + raise ValueError(f"Initializer configuration must have a 'name' field. Got: {entry}") + name = class_name_to_snake_case(entry["name"]) + normalized.append( + InitializerConfig( + name=name, + args=entry.get("args"), + ) + ) + else: + raise ValueError(f"Initializer entry must be a string or dict, got: {type(entry).__name__}") + self._initializer_configs = normalized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": + """ + Create a ConfigurationLoader from a dictionary. + + Args: + data: Dictionary containing configuration values. + + Returns: + A new ConfigurationLoader instance. + """ + # Filter out None values only - empty lists are meaningful ("load nothing") + filtered_data = {k: v for k, v in data.items() if v is not None} + return cls(**filtered_data) + + @staticmethod + def load_with_overrides( + config_file: Optional[pathlib.Path] = None, + *, + memory_db_type: Optional[str] = None, + initializers: Optional[Sequence[Union[str, Dict[str, Any]]]] = None, + initialization_scripts: Optional[Sequence[str]] = None, + env_files: Optional[Sequence[str]] = None, + ) -> "ConfigurationLoader": + """ + Load configuration with optional overrides. + + This factory method implements a 3-layer configuration precedence: + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual override arguments (non-None values take precedence) + + This is a staticmethod (not classmethod) because it's a pure factory function + that doesn't need access to class state and can be reused by multiple interfaces + (CLI, shell, programmatic API). + + Args: + config_file: Optional path to a YAML-formatted configuration file. + memory_db_type: Override for database type (in_memory, sqlite, azure_sql). + initializers: Override for initializer list. + initialization_scripts: Override for initialization script paths. + env_files: Override for environment file paths. + + Returns: + A merged ConfigurationLoader instance. + + Raises: + FileNotFoundError: If an explicitly specified config_file does not exist. + ValueError: If the configuration is invalid. + """ + import logging + + logger = logging.getLogger(__name__) + + # Start with defaults - None means "use defaults", [] means "load nothing" + config_data: Dict[str, Any] = { + "memory_db_type": "sqlite", + "initializers": [], + "initialization_scripts": None, # None = use defaults + "env_files": None, # None = use defaults + } + + # 1. Try loading default config file if it exists + default_config_path = DEFAULT_CONFIG_PATH + if default_config_path.exists(): + try: + default_config = ConfigurationLoader.from_yaml_file(default_config_path) + config_data["memory_db_type"] = default_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in default_config._initializer_configs + ] + # Preserve None vs [] distinction from config file + config_data["initialization_scripts"] = default_config.initialization_scripts + config_data["env_files"] = default_config.env_files + except Exception as e: + logger.warning(f"Failed to load default config file {default_config_path}: {e}") + + # 2. Load explicit config file if provided (overrides default) + if config_file is not None: + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + explicit_config = ConfigurationLoader.from_yaml_file(config_file) + config_data["memory_db_type"] = explicit_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in explicit_config._initializer_configs + ] + # Preserve None vs [] distinction from config file + config_data["initialization_scripts"] = explicit_config.initialization_scripts + config_data["env_files"] = explicit_config.env_files + + # 3. Apply overrides (non-None values take precedence) + if memory_db_type is not None: + # Normalize to snake_case + normalized_db = memory_db_type.lower().replace("-", "_") + if normalized_db == "inmemory": + normalized_db = "in_memory" + elif normalized_db == "azuresql": + normalized_db = "azure_sql" + config_data["memory_db_type"] = normalized_db + + if initializers is not None: + config_data["initializers"] = initializers + + if initialization_scripts is not None: + config_data["initialization_scripts"] = initialization_scripts + + if env_files is not None: + config_data["env_files"] = env_files + + return ConfigurationLoader.from_dict(config_data) + + @classmethod + def get_default_config_path(cls) -> pathlib.Path: + """ + Get the default configuration file path. + + Returns: + Path to the default config file in ~/.pyrit/.pyrit_conf + """ + return DEFAULT_CONFIG_PATH + + def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: + """ + Resolve initializer names to PyRITInitializer instances. + + Uses the InitializerRegistry to look up initializer classes by name + and instantiate them with optional arguments. + + Returns: + Sequence of PyRITInitializer instances. + + Raises: + ValueError: If an initializer name is not found in the registry. + """ + from pyrit.registry import InitializerRegistry + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + if not self._initializer_configs: + return [] + + registry = InitializerRegistry() + resolved: List[PyRITInitializer] = [] + + for config in self._initializer_configs: + initializer_class = registry.get_class(config.name) + if initializer_class is None: + available = ", ".join(sorted(registry.get_names())) + raise ValueError( + f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}" + ) + + # Instantiate with args if provided + if config.args: + instance = initializer_class(**config.args) + else: + instance = initializer_class() + + resolved.append(instance) + + return resolved + + def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve initialization script paths. + + Returns: + None if field is None (use defaults), empty list if field is [], + or Sequence of resolved Path objects if paths are specified. + """ + # None means "use defaults" - return None to signal this + if self.initialization_scripts is None: + return None + + # Empty list means "load nothing" - return empty list + if len(self.initialization_scripts) == 0: + return [] + + resolved: List[pathlib.Path] = [] + for script_str in self.initialization_scripts: + script_path = pathlib.Path(script_str) + if not script_path.is_absolute(): + script_path = pathlib.Path.cwd() / script_path + resolved.append(script_path) + + return resolved + + def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve environment file paths. + + Returns: + None if field is None (use defaults), empty list if field is [], + or Sequence of resolved Path objects if paths are specified. + """ + # None means "use defaults" - return None to signal this + if self.env_files is None: + return None + + # Empty list means "load nothing" - return empty list + if len(self.env_files) == 0: + return [] + + resolved: List[pathlib.Path] = [] + for env_str in self.env_files: + env_path = pathlib.Path(env_str) + if not env_path.is_absolute(): + env_path = pathlib.Path.cwd() / env_path + resolved.append(env_path) + + return resolved + + async def initialize_pyrit_async(self) -> None: + """ + Initialize PyRIT with the loaded configuration. + + This method resolves all initializer names to instances and calls + the core initialize_pyrit_async function. + + Raises: + ValueError: If configuration is invalid or initializers cannot be resolved. + """ + resolved_initializers = self._resolve_initializers() + resolved_scripts = self._resolve_initialization_scripts() + resolved_env_files = self._resolve_env_files() + + # Map snake_case memory_db_type to internal constant + internal_memory_db_type = _MEMORY_DB_TYPE_MAP[self.memory_db_type] + + await initialize_pyrit_async( + memory_db_type=internal_memory_db_type, + initialization_scripts=resolved_scripts, + initializers=resolved_initializers if resolved_initializers else None, + env_files=resolved_env_files, + silent=self.silent, + ) + + +async def initialize_from_config_async( + config_path: Optional[Union[str, pathlib.Path]] = None, +) -> ConfigurationLoader: + """ + Initialize PyRIT from a configuration file. + + This is a convenience function that loads a ConfigurationLoader from + a YAML file and initializes PyRIT. + + Args: + config_path: Path to the configuration file. If None, uses the default + path (~/.pyrit/.pyrit_conf). Can be a string or pathlib.Path. + + Returns: + The loaded ConfigurationLoader instance. + + Raises: + FileNotFoundError: If the configuration file does not exist. + ValueError: If the configuration is invalid. + """ + if config_path is None: + config_path = ConfigurationLoader.get_default_config_path() + elif isinstance(config_path, str): + config_path = pathlib.Path(config_path) + + config = ConfigurationLoader.from_yaml_file(config_path) + await config.initialize_pyrit_async() + return config diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index fca9dd6439..5e7ce37e19 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -5,6 +5,7 @@ Unit tests for the frontend_core module. """ +import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -24,7 +25,7 @@ def test_init_with_defaults(self): assert context._database == frontend_core.SQLITE assert context._initialization_scripts is None assert context._initializer_names is None - assert context._log_level == "WARNING" + assert context._log_level == logging.WARNING assert context._initialized is False def test_init_with_all_parameters(self): @@ -36,24 +37,22 @@ def test_init_with_all_parameters(self): database=frontend_core.IN_MEMORY, initialization_scripts=scripts, initializer_names=initializers, - log_level="DEBUG", + log_level=logging.DEBUG, ) assert context._database == frontend_core.IN_MEMORY - assert context._initialization_scripts == scripts + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert context._initialization_scripts is not None + assert len(context._initialization_scripts) == 1 + assert context._initialization_scripts[0].parts[-2:] == ("test", "script.py") assert context._initializer_names == initializers - assert context._log_level == "DEBUG" + assert context._log_level == logging.DEBUG def test_init_with_invalid_database(self): """Test initialization with invalid database raises ValueError.""" with pytest.raises(ValueError, match="Invalid database type"): frontend_core.FrontendCore(database="InvalidDB") - def test_init_with_invalid_log_level(self): - """Test initialization with invalid log level raises ValueError.""" - with pytest.raises(ValueError, match="Invalid log level"): - frontend_core.FrontendCore(log_level="INVALID") - @patch("pyrit.registry.ScenarioRegistry") @patch("pyrit.registry.InitializerRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) @@ -129,11 +128,11 @@ def test_validate_database_invalid_value(self): def test_validate_log_level_valid_values(self): """Test validate_log_level with valid values.""" - assert frontend_core.validate_log_level(log_level="DEBUG") == "DEBUG" - assert frontend_core.validate_log_level(log_level="INFO") == "INFO" - assert frontend_core.validate_log_level(log_level="warning") == "WARNING" # Case-insensitive - assert frontend_core.validate_log_level(log_level="error") == "ERROR" - assert frontend_core.validate_log_level(log_level="CRITICAL") == "CRITICAL" + assert frontend_core.validate_log_level(log_level="DEBUG") == logging.DEBUG + assert frontend_core.validate_log_level(log_level="INFO") == logging.INFO + assert frontend_core.validate_log_level(log_level="warning") == logging.WARNING # Case-insensitive + assert frontend_core.validate_log_level(log_level="error") == logging.ERROR + assert frontend_core.validate_log_level(log_level="CRITICAL") == logging.CRITICAL def test_validate_log_level_invalid_value(self): """Test validate_log_level with invalid value.""" @@ -208,7 +207,7 @@ def test_validate_database_argparse(self): def test_validate_log_level_argparse(self): """Test validate_log_level_argparse wrapper.""" - assert frontend_core.validate_log_level_argparse("DEBUG") == "DEBUG" + assert frontend_core.validate_log_level_argparse("DEBUG") == logging.DEBUG import argparse @@ -583,7 +582,7 @@ def test_parse_run_arguments_with_log_level(self): """Test parsing with log-level override.""" result = frontend_core.parse_run_arguments(args_string="test_scenario --log-level DEBUG") - assert result["log_level"] == "DEBUG" + assert result["log_level"] == logging.DEBUG def test_parse_run_arguments_with_initialization_scripts(self): """Test parsing with initialization-scripts.""" diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 58517c5235..7b477cb3ec 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -5,6 +5,7 @@ Unit tests for the pyrit_scan CLI module. """ +import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -36,7 +37,7 @@ def test_parse_args_scenario_name_only(self): assert args.scenario_name == "test_scenario" assert args.database == "SQLite" - assert args.log_level == "WARNING" + assert args.log_level == logging.WARNING def test_parse_args_with_database(self): """Test parsing with database option.""" @@ -48,7 +49,7 @@ def test_parse_args_with_log_level(self): """Test parsing with log-level option.""" args = pyrit_scan.parse_args(["test_scenario", "--log-level", "DEBUG"]) - assert args.log_level == "DEBUG" + assert args.log_level == logging.DEBUG def test_parse_args_with_initializers(self): """Test parsing with initializers.""" @@ -117,7 +118,7 @@ def test_parse_args_complex_command(self): assert args.scenario_name == "encoding_scenario" assert args.database == "InMemory" - assert args.log_level == "INFO" + assert args.log_level == logging.INFO assert args.initializers == ["openai_target"] assert args.scenario_strategies == ["base64", "rot13"] assert args.max_concurrency == 10 @@ -304,7 +305,7 @@ def test_main_run_scenario_with_all_options( # Verify FrontendCore was called with correct args call_kwargs = mock_frontend_core.call_args[1] assert call_kwargs["database"] == "InMemory" - assert call_kwargs["log_level"] == "DEBUG" + assert call_kwargs["log_level"] == logging.DEBUG assert call_kwargs["initializer_names"] == ["init1", "init2"] @patch("pyrit.cli.pyrit_scan.asyncio.run") @@ -354,7 +355,7 @@ def test_main_log_level_defaults_to_warning(self, mock_frontend_core: MagicMock) pyrit_scan.main(["--list-scenarios"]) call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["log_level"] == "WARNING" + assert call_kwargs["log_level"] == logging.WARNING def test_main_with_invalid_args(self): """Test main with invalid arguments.""" diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py new file mode 100644 index 0000000000..0524084ffb --- /dev/null +++ b/tests/unit/setup/test_configuration_loader.py @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pathlib +import tempfile +from unittest import mock + +import pytest + +from pyrit.setup.configuration_loader import ( + ConfigurationLoader, + InitializerConfig, + initialize_from_config_async, +) + + +class TestInitializerConfig: + """Tests for InitializerConfig dataclass.""" + + def test_initializer_config_with_name_only(self): + """Test creating InitializerConfig with just a name.""" + config = InitializerConfig(name="simple") + assert config.name == "simple" + assert config.args is None + + def test_initializer_config_with_args(self): + """Test creating InitializerConfig with name and args.""" + config = InitializerConfig(name="custom", args={"param1": "value1"}) + assert config.name == "custom" + assert config.args == {"param1": "value1"} + + +class TestConfigurationLoader: + """Tests for ConfigurationLoader class.""" + + def test_default_values(self): + """Test default configuration values.""" + config = ConfigurationLoader() + assert config.memory_db_type == "sqlite" + assert config.initializers == [] + assert config.initialization_scripts is None # None means "use defaults" + assert config.env_files is None # None means "use defaults" + assert config.silent is False + + def test_valid_memory_db_types_snake_case(self): + """Test all valid memory database types in snake_case.""" + for db_type in ["in_memory", "sqlite", "azure_sql"]: + config = ConfigurationLoader(memory_db_type=db_type) + assert config.memory_db_type == db_type + + def test_memory_db_type_normalization_from_pascal_case(self): + """Test that PascalCase memory_db_type is normalized to snake_case.""" + config = ConfigurationLoader(memory_db_type="InMemory") + assert config.memory_db_type == "in_memory" + + config = ConfigurationLoader(memory_db_type="SQLite") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="AzureSQL") + assert config.memory_db_type == "azure_sql" + + def test_memory_db_type_normalization_case_insensitive(self): + """Test that memory_db_type normalization is case-insensitive.""" + config = ConfigurationLoader(memory_db_type="SQLITE") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="In_Memory") + assert config.memory_db_type == "in_memory" + + def test_invalid_memory_db_type_raises_error(self): + """Test that invalid memory_db_type raises ValueError.""" + with pytest.raises(ValueError, match="Invalid memory_db_type"): + ConfigurationLoader(memory_db_type="InvalidType") + + def test_initializer_as_string(self): + """Test initializers specified as simple strings.""" + config = ConfigurationLoader(initializers=["simple", "airt"]) + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + assert config._initializer_configs[1].name == "airt" + + def test_initializer_as_dict_with_name_only(self): + """Test initializers specified as dicts with only name.""" + config = ConfigurationLoader(initializers=[{"name": "simple"}]) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + + def test_initializer_as_dict_with_args(self): + """Test initializers specified as dicts with name and args.""" + config = ConfigurationLoader(initializers=[{"name": "custom", "args": {"param1": "value1", "param2": 42}}]) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "custom" + assert config._initializer_configs[0].args == {"param1": "value1", "param2": 42} + + def test_mixed_initializer_formats(self): + """Test initializers with mixed string and dict formats.""" + config = ConfigurationLoader( + initializers=[ + "simple", + {"name": "airt"}, + {"name": "custom", "args": {"key": "value"}}, + ] + ) + assert len(config._initializer_configs) == 3 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[2].name == "custom" + assert config._initializer_configs[2].args == {"key": "value"} + + def test_initializer_name_normalization_from_pascal_case(self): + """Test that PascalCase initializer names are normalized to snake_case.""" + config = ConfigurationLoader(initializers=["SimpleInitializer", "AIRTInitializer"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_initializer" + + def test_initializer_name_normalization_preserves_snake_case(self): + """Test that snake_case names are preserved.""" + config = ConfigurationLoader(initializers=["simple_initializer", "airt_init"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_init" + + def test_initializer_name_already_snake_case(self): + """Test that snake_case names remain unchanged.""" + config = ConfigurationLoader(initializers=["load_default_datasets", "objective_list"]) + assert config._initializer_configs[0].name == "load_default_datasets" + assert config._initializer_configs[1].name == "objective_list" + + def test_initializer_dict_without_name_raises_error(self): + """Test that dict initializer without 'name' raises ValueError.""" + with pytest.raises(ValueError, match="must have a 'name' field"): + ConfigurationLoader(initializers=[{"args": {"key": "value"}}]) + + def test_initializer_invalid_type_raises_error(self): + """Test that invalid initializer type raises ValueError.""" + with pytest.raises(ValueError, match="must be a string or dict"): + ConfigurationLoader(initializers=[123]) # type: ignore + + def test_from_dict_with_all_fields(self): + """Test from_dict with all configuration fields.""" + data = { + "memory_db_type": "InMemory", + "initializers": ["simple"], + "initialization_scripts": ["/path/to/script.py"], + "env_files": ["/path/to/.env"], + "silent": True, + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "in_memory" # Normalized to snake_case + assert config.initializers == ["simple"] + assert config.initialization_scripts == ["/path/to/script.py"] + assert config.env_files == ["/path/to/.env"] + assert config.silent is True + + def test_from_dict_filters_none_values(self): + """Test that from_dict filters out None values.""" + data = { + "memory_db_type": "SQLite", + "initializers": None, + "env_files": [], + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "sqlite" # Normalized to snake_case + assert config.initializers == [] # Uses default, not None + + def test_from_yaml_file(self): + """Test loading configuration from a YAML file.""" + yaml_content = """ +memory_db_type: in_memory +initializers: + - simple + - name: airt + args: + key: value +silent: true +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + config = ConfigurationLoader.from_yaml_file(yaml_path) + assert config.memory_db_type == "in_memory" + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[1].args == {"key": "value"} + assert config.silent is True + finally: + pathlib.Path(yaml_path).unlink() + + def test_get_default_config_path(self): + """Test get_default_config_path returns expected path.""" + default_path = ConfigurationLoader.get_default_config_path() + assert default_path.name == ".pyrit_conf" + assert ".pyrit" in str(default_path) + + +class TestConfigurationLoaderResolvers: + """Tests for ConfigurationLoader path resolution methods.""" + + def test_resolve_initialization_scripts_none_returns_none(self): + """Test that None (default) returns None to signal 'use defaults'.""" + config = ConfigurationLoader() + assert config._resolve_initialization_scripts() is None + + def test_resolve_initialization_scripts_empty_list_returns_empty_list(self): + """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" + config = ConfigurationLoader(initialization_scripts=[]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert resolved == [] + + def test_resolve_initialization_scripts_absolute_path(self): + """Test resolving absolute script paths.""" + config = ConfigurationLoader(initialization_scripts=["/absolute/path/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert resolved[0].parts[-3:] == ("absolute", "path", "script.py") + + def test_resolve_initialization_scripts_relative_path(self): + """Test resolving relative script paths (converted to absolute).""" + config = ConfigurationLoader(initialization_scripts=["relative/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].is_absolute() + # Check path ends with expected components (works on both Unix and Windows) + assert resolved[0].parts[-2:] == ("relative", "script.py") + + def test_resolve_env_files_none_returns_none(self): + """Test that None (default) returns None to signal 'use defaults'.""" + config = ConfigurationLoader() + assert config._resolve_env_files() is None + + def test_resolve_env_files_empty_list_returns_empty_list(self): + """Test that explicit empty list [] returns empty list to signal 'load nothing'.""" + config = ConfigurationLoader(env_files=[]) + resolved = config._resolve_env_files() + assert resolved is not None + assert resolved == [] + + def test_resolve_env_files_absolute_path(self): + """Test resolving absolute env file paths.""" + config = ConfigurationLoader(env_files=["/path/to/.env"]) + resolved = config._resolve_env_files() + assert resolved is not None + assert len(resolved) == 1 + # Check path ends with expected components (Windows adds drive letter to Unix-style paths) + assert resolved[0].parts[-3:] == ("path", "to", ".env") + + +@pytest.mark.usefixtures("patch_central_database") +class TestConfigurationLoaderInitialization: + """Tests for ConfigurationLoader.initialize_pyrit_async method.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + async def test_initialize_pyrit_async_basic(self, mock_init): + """Test basic initialization with minimal configuration.""" + config = ConfigurationLoader(memory_db_type="in_memory") + await config.initialize_pyrit_async() + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + # Should map snake_case to internal constant + assert call_kwargs["memory_db_type"] == "InMemory" + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["initializers"] is None + assert call_kwargs["env_files"] is None + assert call_kwargs["silent"] is False + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, mock_init): + """Test initialization with initializers resolved from registry.""" + # Setup mock registry + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + + # Mock an initializer class + mock_initializer_class = mock.MagicMock() + mock_initializer_instance = mock.MagicMock() + mock_initializer_class.return_value = mock_initializer_instance + mock_registry.get_class.return_value = mock_initializer_class + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["simple"], + ) + await config.initialize_pyrit_async() + + # Verify registry was used to resolve initializer + mock_registry.get_class.assert_called_once_with("simple") + mock_initializer_class.assert_called_once_with() + + # Verify initialize was called with resolved initializers + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + assert call_kwargs["initializers"] == [mock_initializer_instance] + + @pytest.mark.asyncio + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_unknown_initializer_raises_error(self, mock_registry_cls): + """Test that unknown initializer name raises ValueError.""" + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + mock_registry.get_class.return_value = None + mock_registry.get_names.return_value = ["simple", "airt"] + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["unknown_initializer"], + ) + + with pytest.raises(ValueError, match="not found in registry"): + await config.initialize_pyrit_async() + + +@pytest.mark.usefixtures("patch_central_database") +class TestInitializeFromConfigAsync: + """Tests for initialize_from_config_async function.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with explicit path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + mock_from_yaml.assert_called_once_with(pathlib.Path("/path/to/config.yaml")) + mock_init.assert_called_once() + assert result is mock_config + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_string_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with string path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + # Should convert string to Path + call_args = mock_from_yaml.call_args[0][0] + assert isinstance(call_args, pathlib.Path) + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.get_default_config_path") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_default_path(self, mock_init, mock_from_yaml, mock_default_path): + """Test initialize_from_config_async uses default path when none specified.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + mock_default_path.return_value = pathlib.Path("/default/path/.pyrit_conf") + + await initialize_from_config_async() + + mock_default_path.assert_called_once() + mock_from_yaml.assert_called_once_with(pathlib.Path("/default/path/.pyrit_conf"))