diff --git a/models/src/agent_control_models/__init__.py b/models/src/agent_control_models/__init__.py index 45a455d8..e4d9b726 100644 --- a/models/src/agent_control_models/__init__.py +++ b/models/src/agent_control_models/__init__.py @@ -12,6 +12,8 @@ expand_action_filter, normalize_action, normalize_action_list, + validate_action, + validate_action_list, ) from .agent import ( BUILTIN_STEP_TYPES, @@ -116,6 +118,8 @@ "SteeringContext", "normalize_action", "normalize_action_list", + "validate_action", + "validate_action_list", "expand_action_filter", # Error models "ProblemDetail", diff --git a/models/src/agent_control_models/actions.py b/models/src/agent_control_models/actions.py index a890c6f9..972f5008 100644 --- a/models/src/agent_control_models/actions.py +++ b/models/src/agent_control_models/actions.py @@ -7,6 +7,7 @@ type ActionDecision = Literal["deny", "steer", "observe"] +_CANONICAL_ACTIONS = frozenset({"deny", "steer", "observe"}) _OBSERVE_ACTION_ALIASES = frozenset({"allow", "observe", "warn", "log"}) _ACTION_QUERY_EXPANSION: dict[ActionDecision, tuple[str, ...]] = { "deny": ("deny",), @@ -15,15 +16,44 @@ } +def validate_action(action: str) -> ActionDecision: + """Validate that *action* is one of the canonical action values. + + Use this on public API boundaries (control create/update, query filters) + where legacy values should be rejected. + """ + if action in _CANONICAL_ACTIONS: + return cast(ActionDecision, action) + raise ValueError( + f"Invalid action {action!r}. Must be one of: deny, steer, observe." + ) + + +def validate_action_list(actions: Sequence[str]) -> list[ActionDecision]: + """Validate a list of actions, preserving order and removing duplicates.""" + validated: list[ActionDecision] = [] + seen: set[ActionDecision] = set() + for action in actions: + canonical = validate_action(action) + if canonical in seen: + continue + seen.add(canonical) + validated.append(canonical) + return validated + + def normalize_action(action: str) -> ActionDecision: - """Normalize a public or legacy action name to the canonical action.""" + """Normalize a stored or legacy action name to the canonical action. + + Use this on internal read paths (deserializing DB rows, server responses) + where historical data may contain legacy values. + """ if action in _OBSERVE_ACTION_ALIASES: return "observe" if action in ("deny", "steer"): return cast(ActionDecision, action) raise ValueError( - "Invalid action. Expected one of: deny, steer, observe " - "(legacy aliases allow/warn/log are also accepted temporarily)." + f"Invalid action {action!r}. Expected one of: deny, steer, observe." ) diff --git a/models/src/agent_control_models/controls.py b/models/src/agent_control_models/controls.py index 0ed11b1b..928db418 100644 --- a/models/src/agent_control_models/controls.py +++ b/models/src/agent_control_models/controls.py @@ -10,7 +10,7 @@ import re2 from pydantic import ConfigDict, Field, ValidationInfo, field_validator, model_validator -from .actions import ActionDecision, normalize_action +from .actions import ActionDecision, normalize_action, validate_action from .base import BaseModel @@ -280,8 +280,8 @@ class ControlAction(BaseModel): @field_validator("decision", mode="before") @classmethod - def normalize_decision(cls, value: str) -> ActionDecision: - return normalize_action(value) + def validate_decision(cls, value: str) -> ActionDecision: + return validate_action(value) MAX_CONDITION_DEPTH = 6 diff --git a/models/src/agent_control_models/observability.py b/models/src/agent_control_models/observability.py index a78e4b1e..dbd11fac 100644 --- a/models/src/agent_control_models/observability.py +++ b/models/src/agent_control_models/observability.py @@ -14,7 +14,11 @@ from pydantic import Field, field_validator -from .actions import ActionDecision, normalize_action, normalize_action_list +from .actions import ( + ActionDecision, + normalize_action, + validate_action_list, +) from .agent import AGENT_NAME_MIN_LENGTH, AGENT_NAME_PATTERN, normalize_agent_name from .base import BaseModel @@ -343,12 +347,12 @@ def validate_and_normalize_agent_name( @field_validator("actions", mode="before") @classmethod - def normalize_actions_filter( + def validate_actions_filter( cls, value: list[str] | None ) -> list[ActionDecision] | None: if value is None: return None - return normalize_action_list(value) + return validate_action_list(value) class EventQueryResponse(BaseModel): diff --git a/models/tests/test_actions.py b/models/tests/test_actions.py index c9a8c0cf..2b0c5775 100644 --- a/models/tests/test_actions.py +++ b/models/tests/test_actions.py @@ -1,4 +1,4 @@ -"""Tests for shared control-action compatibility behavior.""" +"""Tests for shared control-action types, validation, and normalization.""" from __future__ import annotations @@ -11,50 +11,240 @@ EvaluatorResult, expand_action_filter, ) +from agent_control_models.actions import normalize_action, validate_action from pydantic import ValidationError -def test_event_query_actions_normalize_and_expand_for_legacy_observability() -> None: - # Given: a query that mixes canonical and legacy advisory action names - query = EventQueryRequest( - actions=["warn", "observe", "deny", "log", "deny", "steer", "allow", "steer"] - ) +# --------------------------------------------------------------------------- +# validate_action (strict, for API boundaries) +# --------------------------------------------------------------------------- - # When: expanding the normalized public action filter for stored event rows - expanded = expand_action_filter(query.actions or []) - # Then: the public filter is canonicalized, deduped, and expanded for legacy rows - assert query.actions == ["observe", "deny", "steer"] - assert expanded == ["observe", "allow", "warn", "log", "deny", "steer"] +class TestValidateAction: + """Tests for the strict validate_action used on public API boundaries.""" + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_accepts_canonical_actions(self, action: str) -> None: + # Given: a canonical action name + # When: validating the action + result = validate_action(action) -def test_invalid_action_is_rejected_across_public_model_boundaries() -> None: - # Given: the same invalid action at each public model boundary - invalid_action = "block" - invalid_builders = [ - lambda: ControlAction.model_validate({"decision": invalid_action}), - lambda: ControlMatch( - control_id=123, - control_name="pii-check", - action=invalid_action, + # Then: the same canonical value is returned + assert result == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a legacy action name that is no longer accepted at API boundaries + # When / Then: validation raises ValueError + with pytest.raises(ValueError, match="Invalid action"): + validate_action(legacy) + + def test_rejects_unknown_action(self) -> None: + # Given: a completely unknown action name + # When / Then: validation raises ValueError + with pytest.raises(ValueError, match="Invalid action"): + validate_action("block") + + +# --------------------------------------------------------------------------- +# normalize_action (lenient, for internal read paths) +# --------------------------------------------------------------------------- + + +class TestNormalizeAction: + """Tests for the lenient normalize_action used on read paths.""" + + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_passes_canonical_actions(self, action: str) -> None: + # Given: a canonical action name + # When: normalizing + result = normalize_action(action) + + # Then: the same value is returned unchanged + assert result == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_normalizes_legacy_to_observe(self, legacy: str) -> None: + # Given: a legacy advisory action stored in a historical DB row + # When: normalizing on the read path + result = normalize_action(legacy) + + # Then: it maps to the canonical "observe" action + assert result == "observe" + + def test_rejects_unknown_action(self) -> None: + # Given: a completely unknown action name + # When / Then: normalization raises ValueError even on the lenient path + with pytest.raises(ValueError, match="Invalid action"): + normalize_action("block") + + +# --------------------------------------------------------------------------- +# ControlAction (API input boundary — strict) +# --------------------------------------------------------------------------- + + +class TestControlActionValidation: + """ControlAction.decision uses strict validation (rejects legacy values).""" + + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_accepts_canonical_actions(self, action: str) -> None: + # Given: a control action payload with a canonical decision + # When: validating via Pydantic + ca = ControlAction.model_validate({"decision": action}) + + # Then: the decision is accepted as-is + assert ca.decision == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a control action payload using a legacy decision value + # When / Then: Pydantic validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + ControlAction.model_validate({"decision": legacy}) + + def test_rejects_unknown_action(self) -> None: + # Given: a control action payload with an unknown decision + # When / Then: Pydantic validation rejects it + with pytest.raises(ValidationError, match="Invalid action"): + ControlAction.model_validate({"decision": "block"}) + + +# --------------------------------------------------------------------------- +# EventQueryRequest.actions (API input boundary — strict) +# --------------------------------------------------------------------------- + + +class TestEventQueryRequestValidation: + """EventQueryRequest.actions uses strict validation.""" + + def test_accepts_canonical_actions(self) -> None: + # Given: a query filter with all three canonical action values + # When: constructing the query request + query = EventQueryRequest(actions=["deny", "steer", "observe"]) + + # Then: all actions are accepted + assert query.actions == ["deny", "steer", "observe"] + + def test_deduplicates_actions(self) -> None: + # Given: a query filter with duplicate action values + # When: constructing the query request + query = EventQueryRequest(actions=["deny", "deny", "observe"]) + + # Then: duplicates are removed while preserving order + assert query.actions == ["deny", "observe"] + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a query filter using a legacy action value + # When / Then: Pydantic validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=[legacy]) + + def test_rejects_unknown_action(self) -> None: + # Given: a query filter with an unknown action + # When / Then: Pydantic validation rejects it + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=["block"]) + + +# --------------------------------------------------------------------------- +# ControlMatch / ControlExecutionEvent (read path — lenient normalization) +# --------------------------------------------------------------------------- + + +class TestReadPathNormalization: + """Internal read-path models normalize legacy values from DB rows.""" + + @pytest.mark.parametrize("legacy,expected", [ + ("allow", "observe"), + ("warn", "observe"), + ("log", "observe"), + ("observe", "observe"), + ("deny", "deny"), + ("steer", "steer"), + ]) + def test_control_match_normalizes_legacy(self, legacy: str, expected: str) -> None: + # Given: a ControlMatch deserialized from a DB row with a legacy action + # When: constructing the model + match = ControlMatch( + control_id=1, + control_name="test", + action=legacy, result=EvaluatorResult(matched=True, confidence=0.9), - ), - lambda: ControlExecutionEvent( - trace_id="trace-123", - span_id="span-123", + ) + + # Then: the action is normalized to the canonical value + assert match.action == expected + + @pytest.mark.parametrize("legacy,expected", [ + ("allow", "observe"), + ("warn", "observe"), + ("log", "observe"), + ("observe", "observe"), + ("deny", "deny"), + ("steer", "steer"), + ]) + def test_control_execution_event_normalizes_legacy( + self, legacy: str, expected: str + ) -> None: + # Given: a ControlExecutionEvent deserialized from a historical event row + # When: constructing the model + event = ControlExecutionEvent( + trace_id="4bf92f3577b34da6a3ce929d0e0e4736", + span_id="00f067aa0ba902b7", agent_name="test-agent", - control_id=123, - control_name="pii-check", + control_id=1, + control_name="test", check_stage="pre", applies_to="llm_call", - action=invalid_action, + action=legacy, matched=True, confidence=0.9, - ), - lambda: EventQueryRequest(actions=[invalid_action]), - ] + ) + + # Then: the action is normalized to the canonical value + assert event.action == expected - for build_invalid_model in invalid_builders: - # When / Then: validation fails before the invalid action can enter the system + def test_control_match_rejects_unknown(self) -> None: + # Given: a ControlMatch with a completely unknown action + # When / Then: validation rejects it even on the lenient read path with pytest.raises(ValidationError, match="Invalid action"): - build_invalid_model() + ControlMatch( + control_id=1, + control_name="test", + action="block", + result=EvaluatorResult(matched=True, confidence=0.9), + ) + + +# --------------------------------------------------------------------------- +# expand_action_filter (internal query expansion) +# --------------------------------------------------------------------------- + + +class TestExpandActionFilter: + """expand_action_filter expands canonical actions for SQL queries against historical data.""" + + def test_observe_expands_to_include_legacy(self) -> None: + # Given: a canonical "observe" filter + # When: expanding for SQL WHERE clause against historical events + expanded = expand_action_filter(["observe"]) + + # Then: it includes all legacy advisory action values stored in old rows + assert expanded == ["observe", "allow", "warn", "log"] + + def test_deny_and_steer_do_not_expand(self) -> None: + # Given: deny and steer filters (no legacy aliases) + # When: expanding + # Then: they map only to themselves + assert expand_action_filter(["deny"]) == ["deny"] + assert expand_action_filter(["steer"]) == ["steer"] + + def test_full_expansion(self) -> None: + # Given: all three canonical actions + # When: expanding + expanded = expand_action_filter(["deny", "steer", "observe"]) + + # Then: deny and steer are unchanged, observe expands to include legacy + assert expanded == ["deny", "steer", "observe", "allow", "warn", "log"] diff --git a/server/tests/test_observability_models.py b/server/tests/test_observability_models.py index deb30fa1..4a7d9252 100644 --- a/server/tests/test_observability_models.py +++ b/server/tests/test_observability_models.py @@ -307,10 +307,20 @@ def test_filter_by_trace_id(self): assert query.trace_id == "4bf92f3577b34da6a3ce929d0e0e4736" def test_filter_by_actions(self): - """Test filtering by actions.""" - query = EventQueryRequest(actions=["deny", "warn"]) + """Test filtering by canonical actions.""" + # Given: a query with canonical action filter values + query = EventQueryRequest(actions=["deny", "observe"]) + + # Then: the actions are accepted as-is assert query.actions == ["deny", "observe"] + def test_filter_by_actions_rejects_legacy(self): + """Test that legacy action values are rejected in query filters.""" + # Given: a query filter that includes the legacy "warn" value + # When / Then: validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=["deny", "warn"]) + def test_limit_bounds(self): """Test limit bounds.""" with pytest.raises(ValidationError):