From 3bacaaee977e1fcf79ef8b650fac1e94345d1587 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Thu, 22 Jan 2026 10:35:06 -0800 Subject: [PATCH 1/5] first draft --- .env_example | 10 + doc/api.rst | 1 + pyrit/prompt_target/common/prompt_target.py | 2 - pyrit/registry/__init__.py | 4 + .../registry/instance_registries/__init__.py | 6 + .../instance_registries/target_registry.py | 145 +++++++++ pyrit/setup/initializers/__init__.py | 2 + pyrit/setup/initializers/airt_targets.py | 223 ++++++++++++++ tests/unit/registry/test_target_registry.py | 291 ++++++++++++++++++ 9 files changed, 682 insertions(+), 2 deletions(-) create mode 100644 pyrit/registry/instance_registries/target_registry.py create mode 100644 pyrit/setup/initializers/airt_targets.py create mode 100644 tests/unit/registry/test_target_registry.py diff --git a/.env_example b/.env_example index 2d63d66913..d77940bcde 100644 --- a/.env_example +++ b/.env_example @@ -35,6 +35,16 @@ AZURE_OPENAI_GPT4_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT4_CHAT_MODEL="deployment-name" +# Endpoints that host models with fewer safety mechanisms (e.g. via adversarial fine tuning +# or content filters turned off) can be defined below and used in adversarial attack testing scenarios. +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL="deployment-name" + +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2="deployment-name" + AZURE_FOUNDRY_DEEPSEEK_ENDPOINT="https://xxxxx.eastus2.models.ai.azure.com" AZURE_FOUNDRY_DEEPSEEK_KEY="xxxxx" diff --git a/doc/api.rst b/doc/api.rst index c475c19231..bf94fd6d3a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -681,6 +681,7 @@ API Reference PyRITInitializer AIRTInitializer + AIRTTargetInitializer SimpleInitializer LoadDefaultDatasets ScenarioObjectiveListInitializer diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 78b466e042..0a140d8c92 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -95,8 +95,6 @@ def get_identifier(self) -> Dict[str, Any]: Get an identifier dictionary for this prompt target. This includes essential attributes needed for scorer evaluation and registry tracking. - Subclasses should override this method to include additional relevant attributes - (e.g., temperature, top_p) when available. Returns: Dict[str, Any]: A dictionary containing identification attributes. diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index ba05660206..4111d2cc24 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -21,6 +21,8 @@ BaseInstanceRegistry, ScorerMetadata, ScorerRegistry, + TargetMetadata, + TargetRegistry, ) from pyrit.registry.name_utils import class_name_to_registry_name, registry_name_to_class_name @@ -41,4 +43,6 @@ "ScenarioRegistry", "ScorerMetadata", "ScorerRegistry", + "TargetMetadata", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index b2b1fad0f6..00d62cfe91 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -18,6 +18,10 @@ ScorerMetadata, ScorerRegistry, ) +from pyrit.registry.instance_registries.target_registry import ( + TargetMetadata, + TargetRegistry, +) __all__ = [ # Base class @@ -25,4 +29,6 @@ # Concrete registries "ScorerRegistry", "ScorerMetadata", + "TargetRegistry", + "TargetMetadata", ] diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py new file mode 100644 index 0000000000..1523801315 --- /dev/null +++ b/pyrit/registry/instance_registries/target_registry.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target registry for discovering and managing PyRIT prompt targets. + +Targets are registered explicitly via initializers as pre-configured instances. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional + +from pyrit.registry.base import RegistryItemMetadata +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) +from pyrit.registry.name_utils import class_name_to_registry_name + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class TargetMetadata(RegistryItemMetadata): + """ + Metadata describing a registered target instance. + + Unlike ScenarioMetadata/InitializerMetadata which describe classes, + TargetMetadata describes an already-instantiated prompt target. + + Use get() to retrieve the actual target instance. + """ + + target_identifier: Dict[str, Any] + + +class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetMetadata]): + """ + Registry for managing available prompt target instances. + + This registry stores pre-configured PromptTarget instances (not classes). + Targets are registered explicitly via initializers after being instantiated + with their required parameters (e.g., endpoint, API keys). + + Targets are identified by their snake_case name derived from the class name, + or a custom name provided during registration. + """ + + @classmethod + def get_registry_singleton(cls) -> "TargetRegistry": + """ + Get the singleton instance of the TargetRegistry. + + Returns: + The singleton TargetRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + target: "PromptTarget", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a target instance. + + Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, + TargetRegistry registers pre-configured instances. + + Args: + target: The pre-configured target instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name with identifier hash appended + (e.g., OpenAIChatTarget -> openai_chat_abc123). + """ + if name is None: + base_name = class_name_to_registry_name(target.__class__.__name__, suffix="Target") + # Append identifier hash for uniqueness + identifier_hash = self._compute_identifier_hash(target)[:8] + name = f"{base_name}_{identifier_hash}" + + self.register(target, name=name) + logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + """ + Get a registered target instance by name. + + Note: This returns an already-instantiated target, not a class. + + Args: + name: The registry name of the target. + + Returns: + The target instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetMetadata: + """ + Build metadata for a target instance. + + Args: + name: The registry name of the target. + instance: The target instance. + + Returns: + TargetMetadata describing the target. + """ + # Get description from docstring + doc = instance.__class__.__doc__ or "" + description = " ".join(doc.split()) if doc else "No description available" + + # Get identifier from the target + target_identifier = instance.get_identifier() + + return TargetMetadata( + name=name, + class_name=instance.__class__.__name__, + description=description, + target_identifier=target_identifier, + ) + + @staticmethod + def _compute_identifier_hash(target: "PromptTarget") -> str: + """ + Compute a hash from the target's identifier for unique naming. + + Args: + target: The target instance. + + Returns: + A hex string hash of the identifier. + """ + identifier = target.get_identifier() + identifier_str = json.dumps(identifier, sort_keys=True) + return hashlib.sha256(identifier_str.encode()).hexdigest() diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 1c0cbd4683..6b1c63c484 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -4,6 +4,7 @@ """PyRIT initializers package.""" from pyrit.setup.initializers.airt import AIRTInitializer +from pyrit.setup.initializers.airt_targets import AIRTTargetInitializer from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets from pyrit.setup.initializers.scenarios.objective_list import ScenarioObjectiveListInitializer @@ -13,6 +14,7 @@ __all__ = [ "PyRITInitializer", "AIRTInitializer", + "AIRTTargetInitializer", "SimpleInitializer", "LoadDefaultDatasets", "ScenarioObjectiveListInitializer", diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py new file mode 100644 index 0000000000..7029484f45 --- /dev/null +++ b/pyrit/setup/initializers/airt_targets.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AIRT Target Initializer for registering pre-configured targets from environment variables. + +This module provides the AIRTTargetInitializer class that registers available +targets into the TargetRegistry based on environment variable configuration. +""" + +import logging +import os +from dataclasses import dataclass +from typing import Any, List, Optional, Type + +from pyrit.prompt_target import ( + OpenAIChatTarget, + OpenAIImageTarget, + OpenAIResponseTarget, + OpenAITTSTarget, + OpenAIVideoTarget, + PromptShieldTarget, + PromptTarget, + RealtimeTarget, +) +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +@dataclass +class TargetConfig: + """Configuration for a target to be registered.""" + + registry_name: str + target_class: Type[PromptTarget] + endpoint_var: str + key_var: str + model_var: Optional[str] = None + underlying_model_var: Optional[str] = None + + +# Define all supported target configurations +TARGET_CONFIGS: List[TargetConfig] = [ + TargetConfig( + registry_name="default_openai_frontend", + target_class=OpenAIChatTarget, + endpoint_var="DEFAULT_OPENAI_FRONTEND_ENDPOINT", + key_var="DEFAULT_OPENAI_FRONTEND_KEY", + model_var="DEFAULT_OPENAI_FRONTEND_MODEL", + underlying_model_var="DEFAULT_OPENAI_FRONTEND_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_chat", + target_class=OpenAIChatTarget, + endpoint_var="OPENAI_CHAT_ENDPOINT", + key_var="OPENAI_CHAT_KEY", + model_var="OPENAI_CHAT_MODEL", + underlying_model_var="OPENAI_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="OPENAI_RESPONSES_ENDPOINT", + key_var="OPENAI_RESPONSES_KEY", + model_var="OPENAI_RESPONSES_MODEL", + underlying_model_var="OPENAI_RESPONSES_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat2", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2", + ), + TargetConfig( + registry_name="openai_realtime", + target_class=RealtimeTarget, + endpoint_var="OPENAI_REALTIME_ENDPOINT", + key_var="OPENAI_REALTIME_API_KEY", + model_var="OPENAI_REALTIME_MODEL", + underlying_model_var="OPENAI_REALTIME_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_image", + target_class=OpenAIImageTarget, + endpoint_var="OPENAI_IMAGE_ENDPOINT", + key_var="OPENAI_IMAGE_API_KEY", + model_var="OPENAI_IMAGE_MODEL", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_tts", + target_class=OpenAITTSTarget, + endpoint_var="OPENAI_TTS_ENDPOINT", + key_var="OPENAI_TTS_KEY", + model_var="OPENAI_TTS_MODEL", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="openai_video", + target_class=OpenAIVideoTarget, + endpoint_var="OPENAI_VIDEO_ENDPOINT", + key_var="OPENAI_VIDEO_KEY", + model_var="OPENAI_VIDEO_MODEL", + underlying_model_var="OPENAI_VIDEO_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_content_safety", + target_class=PromptShieldTarget, + endpoint_var="AZURE_CONTENT_SAFETY_API_ENDPOINT", + key_var="AZURE_CONTENT_SAFETY_API_KEY", + ), +] + + +class AIRTTargetInitializer(PyRITInitializer): + """ + AIRT Target Initializer for registering pre-configured targets. + + This initializer scans for known endpoint environment variables and registers + the corresponding targets into the TargetRegistry. Unlike AIRTInitializer, + this initializer does not require any environment variables - it simply + registers whatever endpoints are available. + + Supported Endpoints: + - DEFAULT_OPENAI_FRONTEND_ENDPOINT: Default OpenAI frontend (OpenAIChatTarget) + - OPENAI_CHAT_ENDPOINT: OpenAI Chat API (OpenAIChatTarget) + - OPENAI_RESPONSES_ENDPOINT: OpenAI Responses API (OpenAIResponseTarget) + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI GPT-4o unsafe (OpenAIChatTarget) + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI GPT-4o unsafe secondary (OpenAIChatTarget) + - OPENAI_REALTIME_ENDPOINT: OpenAI Realtime API (RealtimeTarget) + - OPENAI_IMAGE_ENDPOINT: OpenAI Image Generation (OpenAIImageTarget) + - OPENAI_TTS_ENDPOINT: OpenAI Text-to-Speech (OpenAITTSTarget) + - OPENAI_VIDEO_ENDPOINT: OpenAI Video Generation (OpenAIVideoTarget) + - AZURE_CONTENT_SAFETY_API_ENDPOINT: Azure Content Safety (PromptShieldTarget) + + Example: + initializer = AIRTTargetInitializer() + await initializer.initialize_async() + """ + + def __init__(self) -> None: + """Initialize the AIRT Target Initializer.""" + super().__init__() + + @property + def name(self) -> str: + """Get the name of this initializer.""" + return "AIRT Target Initializer" + + @property + def description(self) -> str: + """Get the description of this initializer.""" + return ( + "Instantiates a collection of (AI Red Team suggested) targets from " + "available environment variables and adds them to the TargetRegistry" + ) + + @property + def required_env_vars(self) -> List[str]: + """ + Get list of required environment variables. + + Returns empty list since this initializer is optional - it registers + whatever endpoints are available without requiring any. + """ + return [] + + async def initialize_async(self) -> None: + """ + Register available targets based on environment variables. + + Scans for known endpoint environment variables and registers the + corresponding targets into the TargetRegistry. + """ + for config in TARGET_CONFIGS: + self._register_target(config) + + def _register_target(self, config: TargetConfig) -> None: + """ + Register a target if its required environment variables are set. + + Args: + config: The target configuration specifying env vars and target class. + """ + endpoint = os.getenv(config.endpoint_var) + api_key = os.getenv(config.key_var) + + if not endpoint or not api_key: + return + + model_name = os.getenv(config.model_var) if config.model_var else None + underlying_model = os.getenv(config.underlying_model_var) if config.underlying_model_var else None + + # Build kwargs for the target constructor + kwargs: dict[str, Any] = { + "endpoint": endpoint, + "api_key": api_key, + } + + # Only add model_name if the target supports it (PromptShieldTarget doesn't) + if model_name: + kwargs["model_name"] = model_name + + # Add underlying_model if specified (for Azure deployments where name differs from model) + if underlying_model: + kwargs["underlying_model"] = underlying_model + + target = config.target_class(**kwargs) + registry = TargetRegistry.get_registry_singleton() + registry.register_instance(target, name=config.registry_name) + logger.info(f"Registered target: {config.registry_name}") diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py new file mode 100644 index 0000000000..da313be975 --- /dev/null +++ b/tests/unit/registry/test_target_registry.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import pytest + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import PromptTarget +from pyrit.registry.instance_registries.target_registry import TargetRegistry + + +class MockPromptTarget(PromptTarget): + """Mock PromptTarget for testing.""" + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="mock response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + async def dispose_async(self) -> None: + pass + + +class MockChatTarget(PromptTarget): + """Mock chat target for testing different target types.""" + + def __init__(self, *, endpoint: str = "http://test") -> None: + super().__init__(endpoint=endpoint) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="chat response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + async def dispose_async(self) -> None: + pass + + +class TestTargetRegistrySingleton: + """Tests for the singleton pattern in TargetRegistry.""" + + def setup_method(self): + """Reset the singleton before each test.""" + TargetRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + """Test that get_registry_singleton returns the same singleton each time.""" + instance1 = TargetRegistry.get_registry_singleton() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_target_registry_type(self): + """Test that get_registry_singleton returns a TargetRegistry instance.""" + instance = TargetRegistry.get_registry_singleton() + assert isinstance(instance, TargetRegistry) + + def test_reset_instance_clears_singleton(self): + """Test that reset_instance clears the singleton.""" + instance1 = TargetRegistry.get_registry_singleton() + TargetRegistry.reset_instance() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryRegisterInstance: + """Tests for register_instance functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_register_instance_with_custom_name(self): + """Test registering a target with a custom name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="custom_target") + + assert "custom_target" in self.registry + assert self.registry.get("custom_target") is target + + def test_register_instance_generates_name_from_class(self): + """Test that register_instance generates a name from class name when not provided.""" + target = MockPromptTarget() + self.registry.register_instance(target) + + # Name should be derived from class name with hash suffix + names = self.registry.get_names() + assert len(names) == 1 + assert names[0].startswith("mock_prompt_") + + def test_register_instance_multiple_targets_unique_names(self): + """Test registering multiple targets generates unique names.""" + target1 = MockPromptTarget() + target2 = MockChatTarget() + + self.registry.register_instance(target1) + self.registry.register_instance(target2) + + assert len(self.registry) == 2 + + def test_register_instance_same_target_type_different_config(self): + """Test that same target class with different configs can be registered.""" + target1 = MockPromptTarget(model_name="model_a") + target2 = MockPromptTarget(model_name="model_b") + + # Register with explicit names + self.registry.register_instance(target1, name="target_1") + self.registry.register_instance(target2, name="target_2") + + assert len(self.registry) == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryGetInstanceByName: + """Tests for get_instance_by_name functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + self.target = MockPromptTarget() + self.registry.register_instance(self.target, name="test_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_instance_by_name_returns_target(self): + """Test getting a registered target by name.""" + result = self.registry.get_instance_by_name("test_target") + assert result is self.target + + def test_get_instance_by_name_nonexistent_returns_none(self): + """Test that getting a non-existent target returns None.""" + result = self.registry.get_instance_by_name("nonexistent") + assert result is None + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryBuildMetadata: + """Tests for _build_metadata functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_build_metadata_includes_class_name(self): + """Test that metadata includes the class name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].class_name == "MockPromptTarget" + assert metadata[0].name == "mock_target" + + def test_build_metadata_includes_target_identifier(self): + """Test that metadata includes the target_identifier.""" + target = MockPromptTarget(model_name="test_model") + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert hasattr(metadata[0], "target_identifier") + assert isinstance(metadata[0].target_identifier, dict) + assert metadata[0].target_identifier.get("model_name") == "test_model" + + def test_build_metadata_description_from_docstring(self): + """Test that description is derived from the target's docstring.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + # MockPromptTarget has a docstring + assert "Mock PromptTarget for testing" in metadata[0].description + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryListMetadata: + """Tests for list_metadata in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry with multiple targets.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + self.target1 = MockPromptTarget(model_name="model_a") + self.target2 = MockPromptTarget(model_name="model_b") + self.chat_target = MockChatTarget() + + self.registry.register_instance(self.target1, name="target_1") + self.registry.register_instance(self.target2, name="target_2") + self.registry.register_instance(self.chat_target, name="chat_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_list_metadata_returns_all_registered(self): + """Test that list_metadata returns metadata for all registered targets.""" + metadata = self.registry.list_metadata() + assert len(metadata) == 3 + + def test_list_metadata_filter_by_class_name(self): + """Test filtering metadata by class_name.""" + mock_metadata = self.registry.list_metadata(include_filters={"class_name": "MockPromptTarget"}) + + assert len(mock_metadata) == 2 + for m in mock_metadata: + assert m.class_name == "MockPromptTarget" + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryComputeIdentifierHash: + """Tests for _compute_identifier_hash functionality.""" + + def setup_method(self): + """Reset the singleton before each test.""" + TargetRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_compute_identifier_hash_deterministic(self): + """Test that identifier hash is deterministic for same config.""" + target1 = MockPromptTarget(model_name="same_model") + target2 = MockPromptTarget(model_name="same_model") + + hash1 = TargetRegistry._compute_identifier_hash(target1) + hash2 = TargetRegistry._compute_identifier_hash(target2) + + assert hash1 == hash2 + + def test_compute_identifier_hash_different_for_different_config(self): + """Test that identifier hash is different for different configs.""" + target1 = MockPromptTarget(model_name="model_a") + target2 = MockPromptTarget(model_name="model_b") + + hash1 = TargetRegistry._compute_identifier_hash(target1) + hash2 = TargetRegistry._compute_identifier_hash(target2) + + assert hash1 != hash2 + + def test_compute_identifier_hash_is_string(self): + """Test that identifier hash returns a hex string.""" + target = MockPromptTarget() + hash_value = TargetRegistry._compute_identifier_hash(target) + + assert isinstance(hash_value, str) + # Should be a valid hex string (SHA256 = 64 hex chars) + assert len(hash_value) == 64 + assert all(c in "0123456789abcdef" for c in hash_value) From a572bbc5f4d6b2266e6f7026007705fd40551dfc Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 30 Jan 2026 14:26:10 -0800 Subject: [PATCH 2/5] use target identifier --- pyrit/identifiers/target_identifier.py | 3 + pyrit/prompt_target/common/prompt_target.py | 4 + pyrit/registry/__init__.py | 2 - .../registry/instance_registries/__init__.py | 2 - .../instance_registries/target_registry.py | 62 +----- pyrit/setup/initializers/airt_targets.py | 4 +- .../identifiers/test_target_identifier.py | 57 +++++ tests/unit/registry/test_target_registry.py | 82 +++---- .../setup/test_airt_targets_initializer.py | 202 ++++++++++++++++++ 9 files changed, 309 insertions(+), 109 deletions(-) create mode 100644 tests/unit/setup/test_airt_targets_initializer.py diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index f08ad8709d..b8924fb0c0 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -34,6 +34,9 @@ class TargetIdentifier(Identifier): max_requests_per_minute: Optional[int] = None """Maximum number of requests per minute.""" + supports_conversation_history: bool = False + """Whether the target supports explicit setting of conversation history (is a PromptChatTarget).""" + target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 8cd80f47d4..653d008e65 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -122,6 +122,9 @@ def _create_identifier( elif self._model_name: model_name = self._model_name + # Late import to avoid circular dependency + from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget + return TargetIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, @@ -132,6 +135,7 @@ def _create_identifier( temperature=temperature, top_p=top_p, max_requests_per_minute=self._max_requests_per_minute, + supports_conversation_history=isinstance(self, PromptChatTarget), target_specific_params=target_specific_params, ) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index afc842b971..5f2fe7536f 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -21,7 +21,6 @@ from pyrit.registry.instance_registries import ( BaseInstanceRegistry, ScorerRegistry, - TargetMetadata, TargetRegistry, ) @@ -41,6 +40,5 @@ "ScenarioMetadata", "ScenarioRegistry", "ScorerRegistry", - "TargetMetadata", "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index 9ea512c501..2cf50693cf 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -18,7 +18,6 @@ ScorerRegistry, ) from pyrit.registry.instance_registries.target_registry import ( - TargetMetadata, TargetRegistry, ) @@ -27,6 +26,5 @@ "BaseInstanceRegistry", # Concrete registries "ScorerRegistry", - "TargetMetadata", "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py index 1523801315..3fcdbb3160 100644 --- a/pyrit/registry/instance_registries/target_registry.py +++ b/pyrit/registry/instance_registries/target_registry.py @@ -9,17 +9,13 @@ from __future__ import annotations -import hashlib -import json import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Optional -from pyrit.registry.base import RegistryItemMetadata +from pyrit.identifiers import TargetIdentifier from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) -from pyrit.registry.name_utils import class_name_to_registry_name if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget @@ -27,21 +23,7 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class TargetMetadata(RegistryItemMetadata): - """ - Metadata describing a registered target instance. - - Unlike ScenarioMetadata/InitializerMetadata which describe classes, - TargetMetadata describes an already-instantiated prompt target. - - Use get() to retrieve the actual target instance. - """ - - target_identifier: Dict[str, Any] - - -class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetMetadata]): +class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): """ Registry for managing available prompt target instances. @@ -82,10 +64,7 @@ def register_instance( (e.g., OpenAIChatTarget -> openai_chat_abc123). """ if name is None: - base_name = class_name_to_registry_name(target.__class__.__name__, suffix="Target") - # Append identifier hash for uniqueness - identifier_hash = self._compute_identifier_hash(target)[:8] - name = f"{base_name}_{identifier_hash}" + name = target.get_identifier().unique_name self.register(target, name=name) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") @@ -104,7 +83,7 @@ def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: """ return self.get(name) - def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetMetadata: + def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifier: """ Build metadata for a target instance. @@ -113,33 +92,6 @@ def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetMetadata instance: The target instance. Returns: - TargetMetadata describing the target. - """ - # Get description from docstring - doc = instance.__class__.__doc__ or "" - description = " ".join(doc.split()) if doc else "No description available" - - # Get identifier from the target - target_identifier = instance.get_identifier() - - return TargetMetadata( - name=name, - class_name=instance.__class__.__name__, - description=description, - target_identifier=target_identifier, - ) - - @staticmethod - def _compute_identifier_hash(target: "PromptTarget") -> str: - """ - Compute a hash from the target's identifier for unique naming. - - Args: - target: The target instance. - - Returns: - A hex string hash of the identifier. + TargetIdentifier describing the target. """ - identifier = target.get_identifier() - identifier_str = json.dumps(identifier, sort_keys=True) - return hashlib.sha256(identifier_str.encode()).hexdigest() + return instance.get_identifier() diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index 7029484f45..be98380ae1 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -210,11 +210,11 @@ def _register_target(self, config: TargetConfig) -> None: } # Only add model_name if the target supports it (PromptShieldTarget doesn't) - if model_name: + if model_name is not None: kwargs["model_name"] = model_name # Add underlying_model if specified (for Azure deployments where name differs from model) - if underlying_model: + if underlying_model is not None: kwargs["underlying_model"] = underlying_model target = config.target_class(**kwargs) diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py index 0541b36be5..148c60983d 100644 --- a/tests/unit/identifiers/test_target_identifier.py +++ b/tests/unit/identifiers/test_target_identifier.py @@ -500,6 +500,63 @@ def test_can_use_as_dict_key(self): assert d[identifier] == "value" +class TestTargetIdentifierSupportsConversationHistory: + """Test the supports_conversation_history field in TargetIdentifier.""" + + def test_supports_conversation_history_defaults_to_false(self): + """Test that supports_conversation_history defaults to False.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + assert identifier.supports_conversation_history is False + + def test_supports_conversation_history_included_in_hash(self): + """Test that supports_conversation_history affects the hash.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(supports_conversation_history=False, **base_args) + identifier2 = TargetIdentifier(supports_conversation_history=True, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_supports_conversation_history_in_to_dict(self): + """Test that supports_conversation_history is included in to_dict.""" + identifier = TargetIdentifier( + class_name="TestChatTarget", + class_module="pyrit.prompt_target.test_chat_target", + class_description="A test chat target", + identifier_type="instance", + supports_conversation_history=True, + ) + + result = identifier.to_dict() + + assert result["supports_conversation_history"] is True + + def test_supports_conversation_history_from_dict(self): + """Test that supports_conversation_history is restored from dict.""" + data = { + "class_name": "TestChatTarget", + "class_module": "pyrit.prompt_target.test_chat_target", + "class_description": "A test chat target", + "identifier_type": "instance", + "supports_conversation_history": True, + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.supports_conversation_history is True + + class TestTargetIdentifierNormalize: """Test the normalize class method for TargetIdentifier.""" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index da313be975..8e32411b89 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -4,8 +4,10 @@ import pytest +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.registry.instance_registries.target_registry import TargetRegistry @@ -30,15 +32,12 @@ async def send_prompt_async( def _validate_request(self, *, message: Message) -> None: pass - async def dispose_async(self) -> None: - pass - -class MockChatTarget(PromptTarget): - """Mock chat target for testing different target types.""" +class MockPromptChatTarget(PromptChatTarget): + """Mock PromptChatTarget for testing conversation history support.""" - def __init__(self, *, endpoint: str = "http://test") -> None: - super().__init__(endpoint=endpoint) + def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: + super().__init__(model_name=model_name, endpoint=endpoint) async def send_prompt_async( self, @@ -55,8 +54,8 @@ async def send_prompt_async( def _validate_request(self, *, message: Message) -> None: pass - async def dispose_async(self) -> None: - pass + def is_json_response_supported(self) -> bool: + return False class TestTargetRegistrySingleton: @@ -125,7 +124,7 @@ def test_register_instance_generates_name_from_class(self): def test_register_instance_multiple_targets_unique_names(self): """Test registering multiple targets generates unique names.""" target1 = MockPromptTarget() - target2 = MockChatTarget() + target2 = MockPromptChatTarget() self.registry.register_instance(target1) self.registry.register_instance(target2) @@ -184,33 +183,31 @@ def teardown_method(self): TargetRegistry.reset_instance() def test_build_metadata_includes_class_name(self): - """Test that metadata includes the class name.""" + """Test that metadata (TargetIdentifier) includes the class name.""" target = MockPromptTarget() self.registry.register_instance(target, name="mock_target") metadata = self.registry.list_metadata() assert len(metadata) == 1 + assert isinstance(metadata[0], TargetIdentifier) assert metadata[0].class_name == "MockPromptTarget" - assert metadata[0].name == "mock_target" - def test_build_metadata_includes_target_identifier(self): - """Test that metadata includes the target_identifier.""" + def test_build_metadata_includes_model_name(self): + """Test that metadata includes the model_name.""" target = MockPromptTarget(model_name="test_model") self.registry.register_instance(target, name="mock_target") metadata = self.registry.list_metadata() - assert hasattr(metadata[0], "target_identifier") - assert isinstance(metadata[0].target_identifier, dict) - assert metadata[0].target_identifier.get("model_name") == "test_model" + assert metadata[0].model_name == "test_model" def test_build_metadata_description_from_docstring(self): - """Test that description is derived from the target's docstring.""" + """Test that class_description is derived from the target's docstring.""" target = MockPromptTarget() self.registry.register_instance(target, name="mock_target") metadata = self.registry.list_metadata() # MockPromptTarget has a docstring - assert "Mock PromptTarget for testing" in metadata[0].description + assert "Mock PromptTarget for testing" in metadata[0].class_description @pytest.mark.usefixtures("patch_central_database") @@ -224,7 +221,7 @@ def setup_method(self): self.target1 = MockPromptTarget(model_name="model_a") self.target2 = MockPromptTarget(model_name="model_b") - self.chat_target = MockChatTarget() + self.chat_target = MockPromptChatTarget() self.registry.register_instance(self.target1, name="target_1") self.registry.register_instance(self.target2, name="target_2") @@ -249,43 +246,32 @@ def test_list_metadata_filter_by_class_name(self): @pytest.mark.usefixtures("patch_central_database") -class TestTargetRegistryComputeIdentifierHash: - """Tests for _compute_identifier_hash functionality.""" +class TestTargetRegistrySupportsConversationHistory: + """Tests for supports_conversation_history field in TargetIdentifier.""" def setup_method(self): - """Reset the singleton before each test.""" + """Reset and get a fresh registry for each test.""" TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() def teardown_method(self): """Reset the singleton after each test.""" TargetRegistry.reset_instance() - def test_compute_identifier_hash_deterministic(self): - """Test that identifier hash is deterministic for same config.""" - target1 = MockPromptTarget(model_name="same_model") - target2 = MockPromptTarget(model_name="same_model") - - hash1 = TargetRegistry._compute_identifier_hash(target1) - hash2 = TargetRegistry._compute_identifier_hash(target2) - - assert hash1 == hash2 + def test_registered_chat_target_has_supports_conversation_history_true(self): + """Test that registered chat targets have supports_conversation_history=True in metadata.""" + chat_target = MockPromptChatTarget() + self.registry.register_instance(chat_target, name="chat_target") - def test_compute_identifier_hash_different_for_different_config(self): - """Test that identifier hash is different for different configs.""" - target1 = MockPromptTarget(model_name="model_a") - target2 = MockPromptTarget(model_name="model_b") - - hash1 = TargetRegistry._compute_identifier_hash(target1) - hash2 = TargetRegistry._compute_identifier_hash(target2) - - assert hash1 != hash2 + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is True - def test_compute_identifier_hash_is_string(self): - """Test that identifier hash returns a hex string.""" + def test_registered_non_chat_target_has_supports_conversation_history_false(self): + """Test that registered non-chat targets have supports_conversation_history=False in metadata.""" target = MockPromptTarget() - hash_value = TargetRegistry._compute_identifier_hash(target) + self.registry.register_instance(target, name="prompt_target") - assert isinstance(hash_value, str) - # Should be a valid hex string (SHA256 = 64 hex chars) - assert len(hash_value) == 64 - assert all(c in "0123456789abcdef" for c in hash_value) + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is False diff --git a/tests/unit/setup/test_airt_targets_initializer.py b/tests/unit/setup/test_airt_targets_initializer.py new file mode 100644 index 0000000000..48a537313f --- /dev/null +++ b/tests/unit/setup/test_airt_targets_initializer.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +import pytest + +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers import AIRTTargetInitializer +from pyrit.setup.initializers.airt_targets import TARGET_CONFIGS + + +class TestAIRTTargetInitializerBasic: + """Tests for AIRTTargetInitializer class - basic functionality.""" + + def test_can_be_created(self): + """Test that AIRTTargetInitializer can be instantiated.""" + init = AIRTTargetInitializer() + assert init is not None + assert init.name == "AIRT Target Initializer" + assert init.execution_order == 1 + + def test_required_env_vars_is_empty(self): + """Test that no env vars are required (initializer is optional).""" + init = AIRTTargetInitializer() + assert init.required_env_vars == [] + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerInitialize: + """Tests for AIRTTargetInitializer.initialize_async method.""" + + def setup_method(self) -> None: + """Reset registry before each test.""" + TargetRegistry.reset_instance() + # Clear all target-related env vars + self._clear_env_vars() + + def teardown_method(self) -> None: + """Clean up after each test.""" + TargetRegistry.reset_instance() + self._clear_env_vars() + + def _clear_env_vars(self) -> None: + """Clear all environment variables used by TARGET_CONFIGS.""" + for config in TARGET_CONFIGS: + for var in [config.endpoint_var, config.key_var, config.model_var, config.underlying_model_var]: + if var and var in os.environ: + del os.environ[var] + + @pytest.mark.asyncio + async def test_initialize_runs_without_error_no_env_vars(self): + """Test that initialize runs without errors when no env vars are set.""" + init = AIRTTargetInitializer() + await init.initialize_async() + + # No targets should be registered + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 0 + + @pytest.mark.asyncio + async def test_registers_target_when_env_vars_set(self): + """Test that a target is registered when its env vars are set.""" + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" in registry + target = registry.get_instance_by_name("openai_chat") + assert target is not None + assert target._model_name == "gpt-4o" + + @pytest.mark.asyncio + async def test_does_not_register_target_without_endpoint(self): + """Test that target is not registered if endpoint is missing.""" + # Only set key, not endpoint + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" not in registry + + @pytest.mark.asyncio + async def test_does_not_register_target_without_api_key(self): + """Test that target is not registered if api_key is missing.""" + # Only set endpoint, not key + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "openai_chat" not in registry + + @pytest.mark.asyncio + async def test_registers_multiple_targets(self): + """Test that multiple targets are registered when their env vars are set.""" + # Set up openai_chat + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + + # Set up openai_image + os.environ["OPENAI_IMAGE_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["OPENAI_IMAGE_API_KEY"] = "test_image_key" + os.environ["OPENAI_IMAGE_MODEL"] = "dall-e-3" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 2 + assert "openai_chat" in registry + assert "openai_image" in registry + + @pytest.mark.asyncio + async def test_registers_azure_content_safety_without_model(self): + """Test that PromptShieldTarget is registered without model_name (it doesn't use one).""" + os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test.cognitiveservices.azure.com" + os.environ["AZURE_CONTENT_SAFETY_API_KEY"] = "test_safety_key" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "azure_content_safety" in registry + + @pytest.mark.asyncio + async def test_underlying_model_passed_when_set(self): + """Test that underlying_model is passed to target when env var is set.""" + os.environ["OPENAI_CHAT_ENDPOINT"] = "https://my-deployment.openai.azure.com" + os.environ["OPENAI_CHAT_KEY"] = "test_key" + os.environ["OPENAI_CHAT_MODEL"] = "my-deployment-name" + os.environ["OPENAI_CHAT_UNDERLYING_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + target = registry.get_instance_by_name("openai_chat") + assert target is not None + assert target._model_name == "my-deployment-name" + assert target._underlying_model == "gpt-4o" + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerTargetConfigs: + """Tests verifying TARGET_CONFIGS covers expected targets.""" + + def test_target_configs_not_empty(self): + """Test that TARGET_CONFIGS has configurations defined.""" + assert len(TARGET_CONFIGS) > 0 + + def test_all_configs_have_required_fields(self): + """Test that all TARGET_CONFIGS have required fields.""" + for config in TARGET_CONFIGS: + assert config.registry_name, f"Config missing registry_name" + assert config.target_class, f"Config {config.registry_name} missing target_class" + assert config.endpoint_var, f"Config {config.registry_name} missing endpoint_var" + assert config.key_var, f"Config {config.registry_name} missing key_var" + + def test_expected_targets_in_configs(self): + """Test that expected target names are in TARGET_CONFIGS.""" + registry_names = [config.registry_name for config in TARGET_CONFIGS] + + # Verify key targets are configured + assert "openai_chat" in registry_names + assert "openai_image" in registry_names + assert "openai_tts" in registry_names + assert "azure_content_safety" in registry_names + + +class TestAIRTTargetInitializerGetInfo: + """Tests for AIRTTargetInitializer.get_info_async method.""" + + @pytest.mark.asyncio + async def test_get_info_returns_expected_structure(self): + """Test that get_info_async returns expected structure.""" + info = await AIRTTargetInitializer.get_info_async() + + assert isinstance(info, dict) + assert info["name"] == "AIRT Target Initializer" + assert info["class"] == "AIRTTargetInitializer" + assert "description" in info + assert isinstance(info["description"], str) + + @pytest.mark.asyncio + async def test_get_info_required_env_vars_empty_or_not_present(self): + """Test that get_info has empty or no required_env_vars (since none are required).""" + info = await AIRTTargetInitializer.get_info_async() + + # required_env_vars may be omitted or empty since this initializer has no requirements + if "required_env_vars" in info: + assert info["required_env_vars"] == [] From 96d071aaa0d1fced66876e0aabd6153b443f7b23 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 30 Jan 2026 14:37:56 -0800 Subject: [PATCH 3/5] update registry notebook --- doc/code/registry/2_instance_registry.ipynb | 76 ++++++++++++++++----- doc/code/registry/2_instance_registry.py | 43 +++++++++--- 2 files changed, 93 insertions(+), 26 deletions(-) diff --git a/doc/code/registry/2_instance_registry.ipynb b/doc/code/registry/2_instance_registry.ipynb index 24a8b1bb68..52ce374054 100644 --- a/doc/code/registry/2_instance_registry.ipynb +++ b/doc/code/registry/2_instance_registry.ipynb @@ -35,10 +35,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Found default environment files: ['C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env', 'C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env.local']\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env.local\n", - "Registered scorers: ['self_ask_refusal_d9007ba2']\n" + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], @@ -83,7 +82,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Retrieved scorer: \n", + "Retrieved scorer: \n", "Scorer type: SelfAskRefusalScorer\n" ] } @@ -118,7 +117,7 @@ "output_type": "stream", "text": [ "\n", - "self_ask_refusal_d9007ba2:\n", + "self_ask_refusal_scorer::94a582f5:\n", " Class: SelfAskRefusalScorer\n", " Type: true_false\n", " Description: A self-ask scorer that detects refusal in AI responses. This...\n", @@ -126,7 +125,7 @@ "\u001b[1m 📊 Scorer Information\u001b[0m\n", "\u001b[37m ▸ Scorer Identifier\u001b[0m\n", "\u001b[36m • Scorer Type: SelfAskRefusalScorer\u001b[0m\n", - "\u001b[36m • Target Model: gpt-40\u001b[0m\n", + "\u001b[36m • Target Model: gpt-4o\u001b[0m\n", "\u001b[36m • Temperature: None\u001b[0m\n", "\u001b[36m • Score Aggregator: OR_\u001b[0m\n", "\n", @@ -141,12 +140,12 @@ "# Get metadata for all registered scorers\n", "metadata = registry.list_metadata()\n", "for item in metadata:\n", - " print(f\"\\n{item.name}:\")\n", + " print(f\"\\n{item.unique_name}:\")\n", " print(f\" Class: {item.class_name}\")\n", " print(f\" Type: {item.scorer_type}\")\n", - " print(f\" Description: {item.description[:60]}...\")\n", + " print(f\" Description: {item.class_description[:60]}...\")\n", "\n", - " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier)" + " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item)" ] }, { @@ -169,26 +168,69 @@ "name": "stdout", "output_type": "stream", "text": [ - "True/False scorers: ['self_ask_refusal_d9007ba2']\n", - "Refusal scorers: ['self_ask_refusal_d9007ba2']\n", - "True/False refusal scorers: ['self_ask_refusal_d9007ba2']\n" + "True/False scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "Refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "True/False refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], "source": [ "# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer)\n", "true_false_scorers = registry.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", - "print(f\"True/False scorers: {[m.name for m in true_false_scorers]}\")\n", + "print(f\"True/False scorers: {[m.unique_name for m in true_false_scorers]}\")\n", "\n", "# Filter by class_name\n", "refusal_scorers = registry.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", - "print(f\"Refusal scorers: {[m.name for m in refusal_scorers]}\")\n", + "print(f\"Refusal scorers: {[m.unique_name for m in refusal_scorers]}\")\n", "\n", "# Combine multiple filters (AND logic)\n", "specific_scorers = registry.list_metadata(\n", " include_filters={\"scorer_type\": \"true_false\", \"class_name\": \"SelfAskRefusalScorer\"}\n", ")\n", - "print(f\"True/False refusal scorers: {[m.name for m in specific_scorers]}\")" + "print(f\"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Using Target Initializer\n", + "\n", + "You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered targets after AIRT initialization: ['azure_content_safety', 'azure_gpt4o_unsafe_chat', 'azure_gpt4o_unsafe_chat2', 'default_openai_frontend', 'openai_chat', 'openai_image', 'openai_realtime', 'openai_responses', 'openai_tts', 'openai_video']\n" + ] + } + ], + "source": [ + "from pyrit.registry import TargetRegistry\n", + "from pyrit.setup import initialize_pyrit_async\n", + "from pyrit.setup.initializers import AIRTTargetInitializer\n", + "\n", + "# Using built-in initializer\n", + "await initialize_pyrit_async( # type: ignore\n", + " memory_db_type=\"InMemory\", initializers=[AIRTTargetInitializer()]\n", + ")\n", + "\n", + "# Get the registry singleton\n", + "registry = TargetRegistry.get_registry_singleton()\n", + "# List registered targets\n", + "target_names = registry.get_names()\n", + "print(f\"Registered targets after AIRT initialization: {target_names}\")" ] } ], @@ -203,7 +245,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/doc/code/registry/2_instance_registry.py b/doc/code/registry/2_instance_registry.py index c20755730c..d645529f25 100644 --- a/doc/code/registry/2_instance_registry.py +++ b/doc/code/registry/2_instance_registry.py @@ -5,11 +5,15 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.18.1 +# jupytext_version: 1.17.2 +# kernelspec: +# display_name: pyrit-dev +# language: python +# name: python3 # --- # %% [markdown] -# ## Why Instance Registries? +# # Why Instance Registries? # # Some components need configuration that can't easily be passed at instantiation time. For example, scorers often need: # - A configured `chat_target` for LLM-based scoring @@ -19,7 +23,7 @@ # Instance registries let initializers register fully-configured instances that are ready to use. # %% [markdown] -# # Listing Available Instances +# ## Listing Available Instances # # Use `get_names()` to see registered instances, or `list_metadata()` for details. @@ -67,12 +71,12 @@ # Get metadata for all registered scorers metadata = registry.list_metadata() for item in metadata: - print(f"\n{item.name}:") + print(f"\n{item.unique_name}:") print(f" Class: {item.class_name}") print(f" Type: {item.scorer_type}") - print(f" Description: {item.description[:60]}...") + print(f" Description: {item.class_description[:60]}...") - ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier) + ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item) # %% [markdown] # ## Filtering @@ -82,14 +86,35 @@ # %% # Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer) true_false_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false"}) -print(f"True/False scorers: {[m.name for m in true_false_scorers]}") +print(f"True/False scorers: {[m.unique_name for m in true_false_scorers]}") # Filter by class_name refusal_scorers = registry.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) -print(f"Refusal scorers: {[m.name for m in refusal_scorers]}") +print(f"Refusal scorers: {[m.unique_name for m in refusal_scorers]}") # Combine multiple filters (AND logic) specific_scorers = registry.list_metadata( include_filters={"scorer_type": "true_false", "class_name": "SelfAskRefusalScorer"} ) -print(f"True/False refusal scorers: {[m.name for m in specific_scorers]}") +print(f"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}") + +# %% [markdown] +# ## Using Target Initializer +# +# You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available. + +# %% +from pyrit.registry import TargetRegistry +from pyrit.setup import initialize_pyrit_async +from pyrit.setup.initializers import AIRTTargetInitializer + +# Using built-in initializer +await initialize_pyrit_async( # type: ignore + memory_db_type="InMemory", initializers=[AIRTTargetInitializer()] +) + +# Get the registry singleton +registry = TargetRegistry.get_registry_singleton() +# List registered targets +target_names = registry.get_names() +print(f"Registered targets after AIRT initialization: {target_names}") From 3a353c2e9197ab960373acb21ed878a271e7c61f Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 6 Feb 2026 10:28:41 -0800 Subject: [PATCH 4/5] add all of env_example --- .env_example | 20 +- pyrit/setup/initializers/airt_targets.py | 311 ++++++++++++++---- .../setup/test_airt_targets_initializer.py | 85 +++-- 3 files changed, 325 insertions(+), 91 deletions(-) diff --git a/.env_example b/.env_example index d77940bcde..281b3db223 100644 --- a/.env_example +++ b/.env_example @@ -19,37 +19,45 @@ PLATFORM_OPENAI_CHAT_GPT4O_MODEL="gpt-4o" AZURE_OPENAI_GPT4O_ENDPOINT="https://xxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4O_KEY="xxxxx" AZURE_OPENAI_GPT4O_MODEL="deployment-name" -# Since deployment name may be custom and differ from the actual underlying model, -# you can specify the underlying model for identifier purposes +# Since Azure deployment name may be custom and differ from the actual underlying model, +# you can specify the underlying model for identifier purposes. If not specified, +# identifiers will default to the value of the standard MODEL environment variable. AZURE_OPENAI_GPT4O_UNDERLYING_MODEL="gpt-4o" AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_INTEGRATION_TEST_KEY="xxxxx" AZURE_OPENAI_INTEGRATION_TEST_MODEL="deployment-name" +AZURE_OPENAI_INTEGRATION_TEST_UNDERLYING_MODEL="" AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT3_5_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT3_5_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT3_5_CHAT_UNDERLYING_MODEL="" AZURE_OPENAI_GPT4_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT4_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT4_CHAT_UNDERLYING_MODEL="" # Endpoints that host models with fewer safety mechanisms (e.g. via adversarial fine tuning # or content filters turned off) can be defined below and used in adversarial attack testing scenarios. AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL="" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2="xxxxx" AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2="deployment-name" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2="" AZURE_FOUNDRY_DEEPSEEK_ENDPOINT="https://xxxxx.eastus2.models.ai.azure.com" AZURE_FOUNDRY_DEEPSEEK_KEY="xxxxx" +AZURE_FOUNDRY_DEEPSEEK_MODEL="" AZURE_FOUNDRY_PHI4_ENDPOINT="https://xxxxx.models.ai.azure.com" AZURE_CHAT_PHI4_KEY="xxxxx" +AZURE_FOUNDRY_PHI4_MODEL="" AZURE_FOUNDRY_MISTRAL_LARGE_ENDPOINT="https://xxxxx.services.ai.azure.com/openai/v1/" AZURE_FOUNDRY_MISTRAL_LARGE_KEY="xxxxx" @@ -85,6 +93,7 @@ AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT="https://xxxxxxxxx.azure.com/openai/v1" AZURE_OPENAI_GPT5_COMPLETION_ENDPOINT="https://xxxxxxxxx.azure.com/openai/v1" AZURE_OPENAI_GPT5_KEY="xxxxxxx" AZURE_OPENAI_GPT5_MODEL="gpt-5" +AZURE_OPENAI_GPT5_UNDERLYING_MODEL="gpt-5" PLATFORM_OPENAI_RESPONSES_ENDPOINT="https://api.openai.com/v1" PLATFORM_OPENAI_RESPONSES_KEY="sk-xxxxx" @@ -93,6 +102,7 @@ PLATFORM_OPENAI_RESPONSES_MODEL="o4-mini" AZURE_OPENAI_RESPONSES_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_RESPONSES_KEY="xxxxx" AZURE_OPENAI_RESPONSES_MODEL="o4-mini" +AZURE_OPENAI_RESPONSES_UNDERLYING_MODEL="o4-mini" OPENAI_RESPONSES_ENDPOINT=${PLATFORM_OPENAI_RESPONSES_ENDPOINT} OPENAI_RESPONSES_KEY=${PLATFORM_OPENAI_RESPONSES_KEY} @@ -113,6 +123,7 @@ PLATFORM_OPENAI_REALTIME_MODEL="gpt-4o-realtime-preview" AZURE_OPENAI_REALTIME_ENDPOINT = "wss://xxxx.openai.azure.com/openai/v1" AZURE_OPENAI_REALTIME_API_KEY = "xxxxx" AZURE_OPENAI_REALTIME_MODEL = "gpt-4o-realtime-preview" +AZURE_OPENAI_REALTIME_UNDERLYING_MODEL = "gpt-4o-realtime-preview" OPENAI_REALTIME_ENDPOINT = ${PLATFORM_OPENAI_REALTIME_ENDPOINT} OPENAI_REALTIME_API_KEY = ${PLATFORM_OPENAI_REALTIME_API_KEY} @@ -129,10 +140,12 @@ OPENAI_REALTIME_UNDERLYING_MODEL = "" OPENAI_IMAGE_ENDPOINT1 = "https://xxxxx.openai.azure.com/openai/v1" OPENAI_IMAGE_API_KEY1 = "xxxxxx" OPENAI_IMAGE_MODEL1 = "deployment-name" +OPENAI_IMAGE_UNDERLYING_MODEL1 = "dall-e-3" OPENAI_IMAGE_ENDPOINT2 = "https://api.openai.com/v1" OPENAI_IMAGE_API_KEY2 = "sk-xxxxx" OPENAI_IMAGE_MODEL2 = "dall-e-3" +OPENAI_IMAGE_UNDERLYING_MODEL2 = "dall-e-3" OPENAI_IMAGE_ENDPOINT = ${OPENAI_IMAGE_ENDPOINT2} OPENAI_IMAGE_API_KEY = ${OPENAI_IMAGE_API_KEY2} @@ -150,10 +163,12 @@ OPENAI_IMAGE_UNDERLYING_MODEL = "" OPENAI_TTS_ENDPOINT1 = "https://xxxxx.openai.azure.com/openai/v1" OPENAI_TTS_KEY1 = "xxxxxxx" OPENAI_TTS_MODEL1 = "tts" +OPENAI_TTS_UNDERLYING_MODEL1 = "tts" OPENAI_TTS_ENDPOINT2 = "https://api.openai.com/v1" OPENAI_TTS_KEY2 = "xxxxxx" OPENAI_TTS_MODEL2 = "tts-1" +OPENAI_TTS_UNDERLYING_MODEL2 = "tts-1" OPENAI_TTS_ENDPOINT = ${OPENAI_TTS_ENDPOINT2} OPENAI_TTS_KEY = ${OPENAI_TTS_KEY2} @@ -171,6 +186,7 @@ OPENAI_TTS_UNDERLYING_MODEL = "" AZURE_OPENAI_VIDEO_ENDPOINT="https://xxxxx.cognitiveservices.azure.com/openai/v1" AZURE_OPENAI_VIDEO_KEY="xxxxxxx" AZURE_OPENAI_VIDEO_MODEL="sora-2" +AZURE_OPENAI_VIDEO_UNDERLYING_MODEL="sora-2" OPENAI_VIDEO_ENDPOINT = ${AZURE_OPENAI_VIDEO_ENDPOINT} OPENAI_VIDEO_KEY = ${AZURE_OPENAI_VIDEO_KEY} diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index be98380ae1..b919d79d3f 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -6,6 +6,10 @@ This module provides the AIRTTargetInitializer class that registers available targets into the TargetRegistry based on environment variable configuration. + +Note: This module only includes PRIMARY endpoint configurations from .env_example. + Alias configurations (those using ${...} syntax) are excluded since they + reference other primary configurations. """ import logging @@ -14,7 +18,9 @@ from typing import Any, List, Optional, Type from pyrit.prompt_target import ( + AzureMLChatTarget, OpenAIChatTarget, + OpenAICompletionTarget, OpenAIImageTarget, OpenAIResponseTarget, OpenAITTSTarget, @@ -36,36 +42,56 @@ class TargetConfig: registry_name: str target_class: Type[PromptTarget] endpoint_var: str - key_var: str + key_var: str = "" # Empty string means no auth required model_var: Optional[str] = None underlying_model_var: Optional[str] = None -# Define all supported target configurations +# Define all supported target configurations. +# Only PRIMARY configurations are included here - alias configurations that use ${...} +# syntax in .env_example are excluded since they reference other primary configurations. TARGET_CONFIGS: List[TargetConfig] = [ + # ============================================ + # OpenAI Chat Targets (OpenAIChatTarget) + # ============================================ TargetConfig( - registry_name="default_openai_frontend", + registry_name="platform_openai_chat", target_class=OpenAIChatTarget, - endpoint_var="DEFAULT_OPENAI_FRONTEND_ENDPOINT", - key_var="DEFAULT_OPENAI_FRONTEND_KEY", - model_var="DEFAULT_OPENAI_FRONTEND_MODEL", - underlying_model_var="DEFAULT_OPENAI_FRONTEND_UNDERLYING_MODEL", + endpoint_var="PLATFORM_OPENAI_CHAT_ENDPOINT", + key_var="PLATFORM_OPENAI_CHAT_API_KEY", + model_var="PLATFORM_OPENAI_CHAT_GPT4O_MODEL", ), TargetConfig( - registry_name="openai_chat", + registry_name="azure_openai_gpt4o", target_class=OpenAIChatTarget, - endpoint_var="OPENAI_CHAT_ENDPOINT", - key_var="OPENAI_CHAT_KEY", - model_var="OPENAI_CHAT_MODEL", - underlying_model_var="OPENAI_CHAT_UNDERLYING_MODEL", + endpoint_var="AZURE_OPENAI_GPT4O_ENDPOINT", + key_var="AZURE_OPENAI_GPT4O_KEY", + model_var="AZURE_OPENAI_GPT4O_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4O_UNDERLYING_MODEL", ), TargetConfig( - registry_name="openai_responses", - target_class=OpenAIResponseTarget, - endpoint_var="OPENAI_RESPONSES_ENDPOINT", - key_var="OPENAI_RESPONSES_KEY", - model_var="OPENAI_RESPONSES_MODEL", - underlying_model_var="OPENAI_RESPONSES_UNDERLYING_MODEL", + registry_name="azure_openai_integration_test", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT", + key_var="AZURE_OPENAI_INTEGRATION_TEST_KEY", + model_var="AZURE_OPENAI_INTEGRATION_TEST_MODEL", + underlying_model_var="AZURE_OPENAI_INTEGRATION_TEST_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_openai_gpt35_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT3_5_CHAT_KEY", + model_var="AZURE_OPENAI_GPT3_5_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT3_5_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_openai_gpt4_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT4_CHAT_KEY", + model_var="AZURE_OPENAI_GPT4_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4_CHAT_UNDERLYING_MODEL", ), TargetConfig( registry_name="azure_gpt4o_unsafe_chat", @@ -84,37 +110,168 @@ class TargetConfig: underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2", ), TargetConfig( - registry_name="openai_realtime", + registry_name="azure_foundry_deepseek", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_DEEPSEEK_ENDPOINT", + key_var="AZURE_FOUNDRY_DEEPSEEK_KEY", + model_var="AZURE_FOUNDRY_DEEPSEEK_MODEL", + ), + TargetConfig( + registry_name="azure_foundry_phi4", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_PHI4_ENDPOINT", + key_var="AZURE_CHAT_PHI4_KEY", + model_var="AZURE_FOUNDRY_PHI4_MODEL", + ), + TargetConfig( + registry_name="azure_foundry_mistral_large", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_MISTRAL_LARGE_ENDPOINT", + key_var="AZURE_FOUNDRY_MISTRAL_LARGE_KEY", + model_var="AZURE_FOUNDRY_MISTRAL_LARGE_MODEL", + ), + TargetConfig( + registry_name="groq", + target_class=OpenAIChatTarget, + endpoint_var="GROQ_ENDPOINT", + key_var="GROQ_KEY", + model_var="GROQ_LLAMA_MODEL", + ), + TargetConfig( + registry_name="open_router", + target_class=OpenAIChatTarget, + endpoint_var="OPEN_ROUTER_ENDPOINT", + key_var="OPEN_ROUTER_KEY", + model_var="OPEN_ROUTER_CLAUDE_MODEL", + ), + TargetConfig( + registry_name="ollama", + target_class=OpenAIChatTarget, + endpoint_var="OLLAMA_CHAT_ENDPOINT", + model_var="OLLAMA_MODEL", + ), + TargetConfig( + registry_name="google_gemini", + target_class=OpenAIChatTarget, + endpoint_var="GOOGLE_GEMINI_ENDPOINT", + key_var="GOOGLE_GEMINI_API_KEY", + model_var="GOOGLE_GEMINI_MODEL", + ), + # ============================================ + # OpenAI Responses Targets (OpenAIResponseTarget) + # ============================================ + TargetConfig( + registry_name="azure_openai_gpt5_responses", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_GPT5_KEY", + model_var="AZURE_OPENAI_GPT5_MODEL", + underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="platform_openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="PLATFORM_OPENAI_RESPONSES_ENDPOINT", + key_var="PLATFORM_OPENAI_RESPONSES_KEY", + model_var="PLATFORM_OPENAI_RESPONSES_MODEL", + ), + TargetConfig( + registry_name="azure_openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_RESPONSES_KEY", + model_var="AZURE_OPENAI_RESPONSES_MODEL", + underlying_model_var="AZURE_OPENAI_RESPONSES_UNDERLYING_MODEL", + ), + # ============================================ + # Realtime Targets (RealtimeTarget) + # ============================================ + TargetConfig( + registry_name="platform_openai_realtime", + target_class=RealtimeTarget, + endpoint_var="PLATFORM_OPENAI_REALTIME_ENDPOINT", + key_var="PLATFORM_OPENAI_REALTIME_API_KEY", + model_var="PLATFORM_OPENAI_REALTIME_MODEL", + ), + TargetConfig( + registry_name="azure_openai_realtime", target_class=RealtimeTarget, - endpoint_var="OPENAI_REALTIME_ENDPOINT", - key_var="OPENAI_REALTIME_API_KEY", - model_var="OPENAI_REALTIME_MODEL", - underlying_model_var="OPENAI_REALTIME_UNDERLYING_MODEL", + endpoint_var="AZURE_OPENAI_REALTIME_ENDPOINT", + key_var="AZURE_OPENAI_REALTIME_API_KEY", + model_var="AZURE_OPENAI_REALTIME_MODEL", + underlying_model_var="AZURE_OPENAI_REALTIME_UNDERLYING_MODEL", + ), + # ============================================ + # Image Targets (OpenAIImageTarget) + # ============================================ + TargetConfig( + registry_name="openai_image_azure", + target_class=OpenAIImageTarget, + endpoint_var="OPENAI_IMAGE_ENDPOINT1", + key_var="OPENAI_IMAGE_API_KEY1", + model_var="OPENAI_IMAGE_MODEL1", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL1", ), TargetConfig( - registry_name="openai_image", + registry_name="openai_image_platform", target_class=OpenAIImageTarget, - endpoint_var="OPENAI_IMAGE_ENDPOINT", - key_var="OPENAI_IMAGE_API_KEY", - model_var="OPENAI_IMAGE_MODEL", - underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL", + endpoint_var="OPENAI_IMAGE_ENDPOINT2", + key_var="OPENAI_IMAGE_API_KEY2", + model_var="OPENAI_IMAGE_MODEL2", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL2", + ), + # ============================================ + # TTS Targets (OpenAITTSTarget) + # ============================================ + TargetConfig( + registry_name="openai_tts_azure", + target_class=OpenAITTSTarget, + endpoint_var="OPENAI_TTS_ENDPOINT1", + key_var="OPENAI_TTS_KEY1", + model_var="OPENAI_TTS_MODEL1", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL1", ), TargetConfig( - registry_name="openai_tts", + registry_name="openai_tts_platform", target_class=OpenAITTSTarget, - endpoint_var="OPENAI_TTS_ENDPOINT", - key_var="OPENAI_TTS_KEY", - model_var="OPENAI_TTS_MODEL", - underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL", + endpoint_var="OPENAI_TTS_ENDPOINT2", + key_var="OPENAI_TTS_KEY2", + model_var="OPENAI_TTS_MODEL2", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL2", ), + # ============================================ + # Video Targets (OpenAIVideoTarget) + # ============================================ TargetConfig( - registry_name="openai_video", + registry_name="azure_openai_video", target_class=OpenAIVideoTarget, - endpoint_var="OPENAI_VIDEO_ENDPOINT", - key_var="OPENAI_VIDEO_KEY", - model_var="OPENAI_VIDEO_MODEL", - underlying_model_var="OPENAI_VIDEO_UNDERLYING_MODEL", + endpoint_var="AZURE_OPENAI_VIDEO_ENDPOINT", + key_var="AZURE_OPENAI_VIDEO_KEY", + model_var="AZURE_OPENAI_VIDEO_MODEL", + underlying_model_var="AZURE_OPENAI_VIDEO_UNDERLYING_MODEL", ), + # ============================================ + # Completion Targets (OpenAICompletionTarget) + # ============================================ + TargetConfig( + registry_name="openai_completion", + target_class=OpenAICompletionTarget, + endpoint_var="OPENAI_COMPLETION_ENDPOINT", + key_var="OPENAI_COMPLETION_API_KEY", + model_var="OPENAI_COMPLETION_MODEL", + ), + # ============================================ + # Azure ML Targets (AzureMLChatTarget) + # ============================================ + TargetConfig( + registry_name="azure_ml_phi", + target_class=AzureMLChatTarget, + endpoint_var="AZURE_ML_PHI_ENDPOINT", + key_var="AZURE_ML_PHI_KEY", + ), + # ============================================ + # Safety Targets (PromptShieldTarget) + # ============================================ TargetConfig( registry_name="azure_content_safety", target_class=PromptShieldTarget, @@ -129,21 +286,56 @@ class AIRTTargetInitializer(PyRITInitializer): AIRT Target Initializer for registering pre-configured targets. This initializer scans for known endpoint environment variables and registers - the corresponding targets into the TargetRegistry. Unlike AIRTInitializer, - this initializer does not require any environment variables - it simply - registers whatever endpoints are available. - - Supported Endpoints: - - DEFAULT_OPENAI_FRONTEND_ENDPOINT: Default OpenAI frontend (OpenAIChatTarget) - - OPENAI_CHAT_ENDPOINT: OpenAI Chat API (OpenAIChatTarget) - - OPENAI_RESPONSES_ENDPOINT: OpenAI Responses API (OpenAIResponseTarget) - - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI GPT-4o unsafe (OpenAIChatTarget) - - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI GPT-4o unsafe secondary (OpenAIChatTarget) - - OPENAI_REALTIME_ENDPOINT: OpenAI Realtime API (RealtimeTarget) - - OPENAI_IMAGE_ENDPOINT: OpenAI Image Generation (OpenAIImageTarget) - - OPENAI_TTS_ENDPOINT: OpenAI Text-to-Speech (OpenAITTSTarget) - - OPENAI_VIDEO_ENDPOINT: OpenAI Video Generation (OpenAIVideoTarget) - - AZURE_CONTENT_SAFETY_API_ENDPOINT: Azure Content Safety (PromptShieldTarget) + the corresponding targets into the TargetRegistry. It only includes PRIMARY + endpoint configurations - alias configurations (those using ${...} syntax in + .env_example) are excluded since they reference other primary configurations. + + Supported Endpoints by Category: + + **OpenAI Chat Targets (OpenAIChatTarget):** + - PLATFORM_OPENAI_CHAT_* - Platform OpenAI Chat API + - AZURE_OPENAI_GPT4O_* - Azure OpenAI GPT-4o + - AZURE_OPENAI_INTEGRATION_TEST_* - Integration test endpoint + - AZURE_OPENAI_GPT3_5_CHAT_* - Azure OpenAI GPT-3.5 + - AZURE_OPENAI_GPT4_CHAT_* - Azure OpenAI GPT-4 + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_* - Azure OpenAI GPT-4o unsafe + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_*2 - Azure OpenAI GPT-4o unsafe secondary + - AZURE_FOUNDRY_DEEPSEEK_* - Azure AI Foundry DeepSeek + - AZURE_FOUNDRY_PHI4_* - Azure AI Foundry Phi-4 + - AZURE_FOUNDRY_MISTRAL_LARGE_* - Azure AI Foundry Mistral Large + - GROQ_* - Groq API + - OPEN_ROUTER_* - OpenRouter API + - OLLAMA_* - Ollama local + - GOOGLE_GEMINI_* - Google Gemini (OpenAI-compatible) + + **OpenAI Responses Targets (OpenAIResponseTarget):** + - AZURE_OPENAI_GPT5_RESPONSES_* - Azure OpenAI GPT-5 Responses + - PLATFORM_OPENAI_RESPONSES_* - Platform OpenAI Responses + - AZURE_OPENAI_RESPONSES_* - Azure OpenAI Responses + + **Realtime Targets (RealtimeTarget):** + - PLATFORM_OPENAI_REALTIME_* - Platform OpenAI Realtime + - AZURE_OPENAI_REALTIME_* - Azure OpenAI Realtime + + **Image Targets (OpenAIImageTarget):** + - OPENAI_IMAGE_*1 - Azure OpenAI Image + - OPENAI_IMAGE_*2 - Platform OpenAI Image + + **TTS Targets (OpenAITTSTarget):** + - OPENAI_TTS_*1 - Azure OpenAI TTS + - OPENAI_TTS_*2 - Platform OpenAI TTS + + **Video Targets (OpenAIVideoTarget):** + - AZURE_OPENAI_VIDEO_* - Azure OpenAI Video + + **Completion Targets (OpenAICompletionTarget):** + - OPENAI_COMPLETION_* - OpenAI Completion + + **Azure ML Targets (AzureMLChatTarget):** + - AZURE_ML_PHI_* - Azure ML Phi + + **Safety Targets (PromptShieldTarget):** + - AZURE_CONTENT_SAFETY_* - Azure Content Safety Example: initializer = AIRTTargetInitializer() @@ -195,11 +387,18 @@ def _register_target(self, config: TargetConfig) -> None: config: The target configuration specifying env vars and target class. """ endpoint = os.getenv(config.endpoint_var) - api_key = os.getenv(config.key_var) - - if not endpoint or not api_key: + if not endpoint: return + # If key_var is empty, use placeholder (for targets like Ollama that don't require auth) + # If key_var is set, look up the env var and skip registration if not found + if config.key_var: + api_key = os.getenv(config.key_var) + if not api_key: + return + else: + api_key = "" + model_name = os.getenv(config.model_var) if config.model_var else None underlying_model = os.getenv(config.underlying_model_var) if config.underlying_model_var else None diff --git a/tests/unit/setup/test_airt_targets_initializer.py b/tests/unit/setup/test_airt_targets_initializer.py index 48a537313f..356a6388d5 100644 --- a/tests/unit/setup/test_airt_targets_initializer.py +++ b/tests/unit/setup/test_airt_targets_initializer.py @@ -61,16 +61,16 @@ async def test_initialize_runs_without_error_no_env_vars(self): @pytest.mark.asyncio async def test_registers_target_when_env_vars_set(self): """Test that a target is registered when its env vars are set.""" - os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" - os.environ["OPENAI_CHAT_KEY"] = "test_key" - os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" init = AIRTTargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "openai_chat" in registry - target = registry.get_instance_by_name("openai_chat") + assert "platform_openai_chat" in registry + target = registry.get_instance_by_name("platform_openai_chat") assert target is not None assert target._model_name == "gpt-4o" @@ -78,48 +78,48 @@ async def test_registers_target_when_env_vars_set(self): async def test_does_not_register_target_without_endpoint(self): """Test that target is not registered if endpoint is missing.""" # Only set key, not endpoint - os.environ["OPENAI_CHAT_KEY"] = "test_key" - os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" init = AIRTTargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "openai_chat" not in registry + assert "platform_openai_chat" not in registry @pytest.mark.asyncio async def test_does_not_register_target_without_api_key(self): - """Test that target is not registered if api_key is missing.""" + """Test that target is not registered if api_key env var is missing.""" # Only set endpoint, not key - os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" - os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" init = AIRTTargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - assert "openai_chat" not in registry + assert "platform_openai_chat" not in registry @pytest.mark.asyncio async def test_registers_multiple_targets(self): """Test that multiple targets are registered when their env vars are set.""" - # Set up openai_chat - os.environ["OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" - os.environ["OPENAI_CHAT_KEY"] = "test_key" - os.environ["OPENAI_CHAT_MODEL"] = "gpt-4o" + # Set up platform_openai_chat + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" - # Set up openai_image - os.environ["OPENAI_IMAGE_ENDPOINT"] = "https://api.openai.com/v1" - os.environ["OPENAI_IMAGE_API_KEY"] = "test_image_key" - os.environ["OPENAI_IMAGE_MODEL"] = "dall-e-3" + # Set up openai_image_platform (uses ENDPOINT2/KEY2/MODEL2) + os.environ["OPENAI_IMAGE_ENDPOINT2"] = "https://api.openai.com/v1" + os.environ["OPENAI_IMAGE_API_KEY2"] = "test_image_key" + os.environ["OPENAI_IMAGE_MODEL2"] = "dall-e-3" init = AIRTTargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() assert len(registry) == 2 - assert "openai_chat" in registry - assert "openai_image" in registry + assert "platform_openai_chat" in registry + assert "openai_image_platform" in registry @pytest.mark.asyncio async def test_registers_azure_content_safety_without_model(self): @@ -136,20 +136,35 @@ async def test_registers_azure_content_safety_without_model(self): @pytest.mark.asyncio async def test_underlying_model_passed_when_set(self): """Test that underlying_model is passed to target when env var is set.""" - os.environ["OPENAI_CHAT_ENDPOINT"] = "https://my-deployment.openai.azure.com" - os.environ["OPENAI_CHAT_KEY"] = "test_key" - os.environ["OPENAI_CHAT_MODEL"] = "my-deployment-name" - os.environ["OPENAI_CHAT_UNDERLYING_MODEL"] = "gpt-4o" + os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] = "https://my-deployment.openai.azure.com" + os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" + os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "my-deployment-name" + os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"] = "gpt-4o" init = AIRTTargetInitializer() await init.initialize_async() registry = TargetRegistry.get_registry_singleton() - target = registry.get_instance_by_name("openai_chat") + target = registry.get_instance_by_name("azure_openai_gpt4o") assert target is not None assert target._model_name == "my-deployment-name" assert target._underlying_model == "gpt-4o" + @pytest.mark.asyncio + async def test_registers_ollama_without_api_key(self): + """Test that Ollama target is registered without requiring an API key.""" + os.environ["OLLAMA_CHAT_ENDPOINT"] = "http://127.0.0.1:11434/v1" + os.environ["OLLAMA_MODEL"] = "llama2" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "ollama" in registry + target = registry.get_instance_by_name("ollama") + assert target is not None + assert target._model_name == "llama2" + @pytest.mark.usefixtures("patch_central_database") class TestAIRTTargetInitializerTargetConfigs: @@ -160,22 +175,26 @@ def test_target_configs_not_empty(self): assert len(TARGET_CONFIGS) > 0 def test_all_configs_have_required_fields(self): - """Test that all TARGET_CONFIGS have required fields.""" + """Test that all TARGET_CONFIGS have required fields (key_var is optional for some).""" for config in TARGET_CONFIGS: assert config.registry_name, f"Config missing registry_name" assert config.target_class, f"Config {config.registry_name} missing target_class" assert config.endpoint_var, f"Config {config.registry_name} missing endpoint_var" - assert config.key_var, f"Config {config.registry_name} missing key_var" + # key_var is optional for targets like Ollama that don't require auth def test_expected_targets_in_configs(self): """Test that expected target names are in TARGET_CONFIGS.""" registry_names = [config.registry_name for config in TARGET_CONFIGS] - # Verify key targets are configured - assert "openai_chat" in registry_names - assert "openai_image" in registry_names - assert "openai_tts" in registry_names + # Verify key targets are configured (using new primary config names) + assert "platform_openai_chat" in registry_names + assert "azure_openai_gpt4o" in registry_names + assert "openai_image_platform" in registry_names + assert "openai_tts_platform" in registry_names assert "azure_content_safety" in registry_names + assert "ollama" in registry_names + assert "groq" in registry_names + assert "google_gemini" in registry_names class TestAIRTTargetInitializerGetInfo: From 47d0dd0f7dc29cc87750762c3d104d482f2076f1 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 6 Feb 2026 10:52:07 -0800 Subject: [PATCH 5/5] update ollama key --- pyrit/setup/initializers/airt_targets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index b919d79d3f..f421c53c6e 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -397,7 +397,7 @@ def _register_target(self, config: TargetConfig) -> None: if not api_key: return else: - api_key = "" + api_key = "not-needed" model_name = os.getenv(config.model_var) if config.model_var else None underlying_model = os.getenv(config.underlying_model_var) if config.underlying_model_var else None