diff --git a/models/src/agent_control_models/evaluation.py b/models/src/agent_control_models/evaluation.py index 07ab4810..458c91a5 100644 --- a/models/src/agent_control_models/evaluation.py +++ b/models/src/agent_control_models/evaluation.py @@ -127,8 +127,6 @@ class EvaluationResponse(BaseModel): default=None, description="List of controls that were evaluated but did not match (if any)", ) - - class EvaluationResult(EvaluationResponse): """ Client-side result model for evaluation analysis. diff --git a/sdks/python/src/agent_control/agents.py b/sdks/python/src/agent_control/agents.py index 3d63cc5e..f6ca8d58 100644 --- a/sdks/python/src/agent_control/agents.py +++ b/sdks/python/src/agent_control/agents.py @@ -16,7 +16,9 @@ async def register_agent( steps: list[dict[str, Any]] | None = None, conflict_mode: Literal["strict", "overwrite"] = "overwrite", ) -> dict[str, Any]: - """Register an agent with the server via /initAgent endpoint.""" + """Register an agent with the server via /initAgent endpoint. + + """ ensure_evaluators_discovered() agent_dict = agent.to_dict() @@ -27,7 +29,12 @@ async def register_agent( "conflict_mode": conflict_mode, } - response = await client.http_client.post("/api/v1/agents/initAgent", json=payload) + headers = None + response = await client.http_client.post( + "/api/v1/agents/initAgent", + json=payload, + headers=headers, + ) response.raise_for_status() return cast(dict[str, Any], response.json()) diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index eb8ff349..4237baf9 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -1,14 +1,12 @@ """Evaluation check operations for Agent Control SDK.""" from dataclasses import dataclass -from datetime import UTC, datetime from typing import Any, Literal, cast from agent_control_engine import list_evaluators from agent_control_engine.core import ControlEngine from agent_control_models import ( ControlDefinition, - ControlExecutionEvent, ControlMatch, EvaluationRequest, EvaluationResponse, @@ -19,139 +17,11 @@ from ._state import state from .client import AgentControlClient -from .observability import add_event, get_logger, is_observability_enabled +from .evaluation_events import build_control_execution_events, enqueue_observability_events +from .observability import is_observability_enabled from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name -_logger = get_logger(__name__) - -# Fallback IDs used when trace context is missing. -# All-zero values are invalid trace/span IDs per OpenTelemetry. -_FALLBACK_TRACE_ID = "0" * 32 -_FALLBACK_SPAN_ID = "0" * 16 -_trace_warning_logged = False - - -def _observability_metadata( - control_def: ControlDefinition, -) -> tuple[str | None, str | None, dict[str, object]]: - """Return representative event fields plus full composite context.""" - identity = control_def.observability_identity() - return ( - identity.selector_path, - identity.evaluator_name, - { - "primary_evaluator": identity.evaluator_name, - "primary_selector_path": identity.selector_path, - "leaf_count": identity.leaf_count, - "all_evaluators": identity.all_evaluators, - "all_selector_paths": identity.all_selector_paths, - }, - ) - - -def _map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: - return "tool_call" if step_type == "tool" else "llm_call" - - -def _emit_local_events( - local_result: "EvaluationResponse", - request: "EvaluationRequest", - local_controls: list["_ControlAdapter"], - trace_id: str | None, - span_id: str | None, - agent_name: str | None, -) -> None: - """Emit observability events for locally-evaluated controls. - - Mirrors the server's _emit_observability_events() so that SDK-evaluated - controls are visible in the observability pipeline. - - When trace_id/span_id are missing, fallback all-zero IDs are used so events - are still recorded (but clearly marked as uncorrelated). - - Only runs when observability is enabled. - """ - if not is_observability_enabled(): - return - - global _trace_warning_logged # noqa: PLW0603 - if not trace_id or not span_id: - if not _trace_warning_logged: - _logger.warning( - "Emitting local control events without trace context; " - "events will use fallback IDs and cannot be correlated with traces. " - "Pass trace_id/span_id for full observability." - ) - _trace_warning_logged = True - trace_id = trace_id or _FALLBACK_TRACE_ID - span_id = span_id or _FALLBACK_SPAN_ID - - applies_to = _map_applies_to(request.step.type) - control_lookup = {c.id: c for c in local_controls} - now = datetime.now(UTC) - resolved_agent_name = agent_name or request.agent_name - - def _emit_matches(matches: list[ControlMatch] | None, matched: bool) -> None: - if not matches: - return - for match in matches: - ctrl = control_lookup.get(match.control_id) - event_metadata = dict(match.result.metadata or {}) - selector_path = None - evaluator_name = None - if ctrl: - selector_path, evaluator_name, identity_metadata = _observability_metadata( - ctrl.control - ) - event_metadata.update(identity_metadata) - add_event( - ControlExecutionEvent( - control_execution_id=match.control_execution_id, - trace_id=trace_id, - span_id=span_id, - agent_name=resolved_agent_name, - control_id=match.control_id, - control_name=match.control_name, - check_stage=request.stage, - applies_to=applies_to, - action=match.action, - matched=matched, - confidence=match.result.confidence, - timestamp=now, - evaluator_name=evaluator_name, - selector_path=selector_path, - error_message=match.result.error if not matched else None, - metadata=event_metadata, - ) - ) - - _emit_matches(local_result.matches, matched=True) - _emit_matches(local_result.errors, matched=False) - _emit_matches(local_result.non_matches, matched=False) - - -async def check_evaluation( - client: AgentControlClient, - agent_name: str, - step: "Step", - stage: Literal["pre", "post"], -) -> EvaluationResult: - """Check if agent interaction is safe.""" - normalized_name = ensure_agent_name(agent_name) - - request = EvaluationRequest( - agent_name=normalized_name, - step=step, - stage=stage, - ) - request_payload = request.model_dump(mode="json") - - response = await client.http_client.post("/api/v1/evaluation", json=request_payload) - response.raise_for_status() - - return cast(EvaluationResult, EvaluationResult.from_dict(response.json())) - @dataclass class _ControlAdapter: @@ -159,7 +29,7 @@ class _ControlAdapter: id: int name: str - control: "ControlDefinition" + control: ControlDefinition def _get_applicable_controls( @@ -176,19 +46,26 @@ def _get_applicable_controls( return cast(list[_ControlAdapter], applicable_controls) +def _build_server_control_lookup( + server_control_payloads: list[dict[str, Any]], +) -> dict[int, ControlDefinition]: + """Build a best-effort lookup of server control definitions.""" + control_lookup: dict[int, ControlDefinition] = {} + + for control in server_control_payloads: + try: + control_lookup[control["id"]] = ControlDefinition.model_validate(control["control"]) + except Exception: + continue + + return control_lookup + + def _has_applicable_prefiltered_server_controls( server_control_payloads: list[dict[str, Any]], request: EvaluationRequest, ) -> bool: - """Return whether any partitioned server control applies to this request. - - The caller is responsible for partitioning raw control payloads by - ``execution`` before calling this helper. This function only inspects the - server-control subset and does not re-check ``execution`` itself. - - If any server control payload cannot be parsed locally, this returns True so - the SDK still defers to the server for authoritative handling. - """ + """Return whether any partitioned server control applies to this request.""" parsed_server_controls: list[_ControlAdapter] = [] for control in server_control_payloads: @@ -202,7 +79,6 @@ def _has_applicable_prefiltered_server_controls( ) ) except Exception: - # Preserve existing fail-open behavior for malformed server controls. return True if not parsed_server_controls: @@ -218,10 +94,10 @@ def _has_applicable_prefiltered_server_controls( def _merge_results( - local_result: "EvaluationResponse", - server_result: "EvaluationResponse", -) -> "EvaluationResult": - """Merge local and server evaluation results.""" + local_result: EvaluationResponse, + server_result: EvaluationResponse, +) -> EvaluationResult: + """Merge local and server evaluation results into one SDK-facing result.""" is_safe = local_result.is_safe and server_result.is_safe confidence = min(local_result.confidence, server_result.confidence) @@ -255,42 +131,77 @@ def _merge_results( ) +def _cached_server_control_lookup( + agent_name: str, + client: AgentControlClient, +) -> dict[int, ControlDefinition]: + """Return cached server controls for the active session when they are trustworthy.""" + current_agent = state.current_agent + if current_agent is None or current_agent.agent_name != agent_name: + return {} + if state.server_controls is None: + return {} + if state.server_url is not None: + if client.base_url.rstrip("/") != state.server_url.rstrip("/"): + return {} + return _build_server_control_lookup(state.server_controls) + + +async def check_evaluation( + client: AgentControlClient, + agent_name: str, + step: Step, + stage: Literal["pre", "post"], +) -> EvaluationResult: + """Check if agent interaction is safe through the public SDK helper. + + The server returns only evaluation semantics. When SDK observability is + enabled, this helper reconstructs server-side control-execution events + from the response and enqueues them through the built-in SDK batcher. + """ + normalized_name = ensure_agent_name(agent_name) + resolved_trace_id, resolved_span_id = get_trace_and_span_ids() + request = EvaluationRequest( + agent_name=normalized_name, + step=step, + stage=stage, + ) + request_payload = request.model_dump(mode="json") + + response = await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=None, + ) + response.raise_for_status() + + evaluation_response = EvaluationResponse.model_validate(response.json()) + + if is_observability_enabled(): + server_events = build_control_execution_events( + evaluation_response, + request, + _cached_server_control_lookup(normalized_name, client), + resolved_trace_id, + resolved_span_id, + normalized_name, + ) + enqueue_observability_events(server_events) + + return cast(EvaluationResult, EvaluationResult.from_dict(evaluation_response.model_dump())) + + async def check_evaluation_with_local( client: AgentControlClient, agent_name: str, - step: "Step", + step: Step, stage: Literal["pre", "post"], controls: list[dict[str, Any]], trace_id: str | None = None, span_id: str | None = None, event_agent_name: str | None = None, ) -> EvaluationResult: - """ - Check if agent interaction is safe, running local controls first. - - This function executes controls with execution="sdk" locally in the SDK, - then calls the server for execution="server" controls. If a local control - denies, it short-circuits and returns immediately without calling the server. - - Note on parse errors: If a local control fails to parse/validate, it is - skipped (logged as WARNING) and the error is included in result.errors. - This does NOT affect is_safe or confidence—callers concerned with safety - should check result.errors for any parse failures. - - Args: - client: AgentControlClient instance - agent_name: Normalized agent identifier - step: Step payload to evaluate - stage: 'pre' for pre-execution check, 'post' for post-execution check - controls: List of control dicts from initAgent response - (each has 'id', 'name', 'control' keys) - - Returns: - EvaluationResult with safety analysis (merged from local + server) - - Raises: - httpx.HTTPError: If server request fails - """ + """Evaluate controls with local-first execution and SDK-owned event emission.""" normalized_name = ensure_agent_name(agent_name) resolved_trace_id = trace_id resolved_span_id = span_id @@ -299,7 +210,6 @@ async def check_evaluation_with_local( resolved_trace_id = trace_id or current_trace_id resolved_span_id = span_id or current_span_id - # Partition controls by local flag local_controls: list[_ControlAdapter] = [] parse_errors: list[ControlMatch] = [] available_evaluators = list_evaluators() @@ -344,12 +254,6 @@ async def check_evaluation_with_local( except Exception as exc: control_id = control.get("id", -1) control_name = control.get("name", "unknown") - _logger.warning( - "Skipping invalid local control '%s' (id=%s): %s", - control_name, - control_id, - exc, - ) parse_errors.append( ControlMatch( control_id=control_id, @@ -374,16 +278,12 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if not parse_errors: return result combined_errors = (result.errors or []) + parse_errors - return EvaluationResult( - is_safe=result.is_safe, - confidence=result.confidence, - reason=result.reason, - matches=result.matches, - errors=combined_errors, - non_matches=result.non_matches, - ) + return result.model_copy(update={"errors": combined_errors}) + + should_emit_events = is_observability_enabled() local_result: EvaluationResponse | None = None + local_events = [] applicable_local_controls = _get_applicable_controls( local_controls, request, @@ -392,27 +292,24 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if applicable_local_controls: engine = ControlEngine(applicable_local_controls, context="sdk") local_result = await engine.process(request) - - _emit_local_events( - local_result, - request, - applicable_local_controls, - resolved_trace_id, - resolved_span_id, - agent_name=event_agent_name, - ) + if should_emit_events: + local_control_lookup = { + control.id: control.control for control in applicable_local_controls + } + local_events = build_control_execution_events( + local_result, + request, + local_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, + ) if not local_result.is_safe: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) - ) + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if should_emit_events: + enqueue_observability_events(local_events) + return result if _has_applicable_prefiltered_server_controls(server_control_payloads, request): request_payload = request.model_dump(mode="json", exclude_none=True) @@ -422,39 +319,47 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if resolved_span_id: headers["X-Span-Id"] = resolved_span_id - response = await client.http_client.post( - "/api/v1/evaluation", - json=request_payload, - headers=headers, - ) - response.raise_for_status() - server_result = EvaluationResponse.model_validate(response.json()) + try: + response = await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=headers, + ) + response.raise_for_status() + server_result = EvaluationResponse.model_validate(response.json()) + except Exception: + if should_emit_events and local_events: + enqueue_observability_events(local_events) + raise - if local_result is not None: - return _with_parse_errors(_merge_results(local_result, server_result)) - - return _with_parse_errors( - EvaluationResult( - is_safe=server_result.is_safe, - confidence=server_result.confidence, - reason=server_result.reason, - matches=server_result.matches, - errors=server_result.errors, - non_matches=server_result.non_matches, + server_events = [] + if should_emit_events: + server_control_lookup = _build_server_control_lookup(server_control_payloads) + server_events = build_control_execution_events( + server_result, + request, + server_control_lookup, + resolved_trace_id, + resolved_span_id, + event_agent_name, ) - ) + + if local_result is not None: + result = _with_parse_errors(_merge_results(local_result, server_result)) + if should_emit_events: + enqueue_observability_events(local_events + server_events) + return result + + result = _with_parse_errors(EvaluationResult.model_validate(server_result.model_dump())) + if should_emit_events: + enqueue_observability_events(server_events) + return result if local_result is not None: - return _with_parse_errors( - EvaluationResult( - is_safe=local_result.is_safe, - confidence=local_result.confidence, - reason=local_result.reason, - matches=local_result.matches, - errors=local_result.errors, - non_matches=local_result.non_matches, - ) - ) + result = _with_parse_errors(EvaluationResult.model_validate(local_result.model_dump())) + if should_emit_events: + enqueue_observability_events(local_events) + return result return _with_parse_errors(EvaluationResult(is_safe=True, confidence=1.0)) @@ -471,58 +376,10 @@ async def evaluate_controls( trace_id: str | None = None, span_id: str | None = None, ) -> EvaluationResult: - """ - Evaluate controls for a step. - - This convenience function evaluates controls (both local SDK-executed and - server-executed) for a given step. - - Args: - step_name: Name of the step (e.g., "chat", "search_db") - input: Input data for the step (for pre-stage evaluation) - output: Output data from the step (for post-stage evaluation) - context: Additional context metadata - step_type: Type of step - "llm" or "tool" (default: "llm") - stage: When to evaluate - "pre" or "post" (default: "pre") - agent_name: Agent name (required) - trace_id: Optional OpenTelemetry trace ID for observability - span_id: Optional OpenTelemetry span ID for observability - - Returns: - EvaluationResult with is_safe, confidence, reason, matches, errors - - Raises: - httpx.HTTPError: If server request fails - - Example: - import agent_control - - # Evaluate controls for an agent - result = await agent_control.evaluate_controls( - "chat", - input="User message here", - stage="pre", - agent_name="customer-service-bot" - ) - - # With trace/span IDs for observability - result = await agent_control.evaluate_controls( - "chat", - input="User message", - stage="pre", - agent_name="customer-service-bot", - trace_id="4bf92f3577b34da6a3ce929d0e0e4736", - span_id="00f067aa0ba902b7" - ) - """ - # Ensure server_url is set (for mypy type narrowing) + """Evaluate controls for a step.""" if state.server_url is None: - raise RuntimeError( - "Server URL not configured. Call agent_control.init() first." - ) + raise RuntimeError("Server URL not configured. Call agent_control.init() first.") - # Build Step dict (input and output are required by Step model) - # Tool steps require dict input/output, LLM steps use strings default_value = {} if step_type == "tool" else "" step_dict: dict[str, Any] = { "type": step_type, @@ -533,15 +390,11 @@ async def evaluate_controls( if context is not None: step_dict["context"] = context - # Convert to Step object if models available - step_obj = Step(**step_dict) # type: ignore - - # Get controls from server cache + step_obj = Step(**step_dict) # type: ignore[arg-type] resolved_controls = state.server_controls or [] - # Evaluate using local + server controls async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client: - result = await check_evaluation_with_local( + return await check_evaluation_with_local( client=client, agent_name=agent_name, step=step_obj, @@ -551,5 +404,3 @@ async def evaluate_controls( span_id=span_id, event_agent_name=agent_name, ) - - return result diff --git a/sdks/python/src/agent_control/evaluation_events.py b/sdks/python/src/agent_control/evaluation_events.py new file mode 100644 index 00000000..0c3f9438 --- /dev/null +++ b/sdks/python/src/agent_control/evaluation_events.py @@ -0,0 +1,214 @@ +"""Derived control-execution event reconstruction for SDK evaluation flows.""" + +from datetime import UTC, datetime +from typing import Literal + +from agent_control_models import ( + ControlDefinition, + ControlExecutionEvent, + ControlMatch, + EvaluationRequest, + EvaluationResponse, +) + +from .observability import add_event, get_logger, is_observability_enabled + +_logger = get_logger(__name__) + +# All-zero values are invalid trace/span IDs per OpenTelemetry and make it +# obvious that the event could not be correlated to an external trace. +_FALLBACK_TRACE_ID = "0" * 32 +_FALLBACK_SPAN_ID = "0" * 16 +_trace_warning_logged = False + + +def observability_metadata( + control_def: ControlDefinition, +) -> tuple[str | None, str | None, dict[str, object]]: + """Return representative event fields plus full composite context.""" + identity = control_def.observability_identity() + return ( + identity.selector_path, + identity.evaluator_name, + { + "primary_evaluator": identity.evaluator_name, + "primary_selector_path": identity.selector_path, + "leaf_count": identity.leaf_count, + "all_evaluators": identity.all_evaluators, + "all_selector_paths": identity.all_selector_paths, + }, + ) + + +def map_applies_to(step_type: str) -> Literal["llm_call", "tool_call"]: + """Map Agent Control step types to observability applies_to values.""" + return "tool_call" if step_type == "tool" else "llm_call" + + +def _resolve_event_trace_context( + trace_id: str | None, + span_id: str | None, +) -> tuple[str, str]: + """Return event IDs, applying fallback IDs and a one-time warning if needed.""" + global _trace_warning_logged # noqa: PLW0603 + + if trace_id and span_id: + return trace_id, span_id + + if not _trace_warning_logged: + _logger.warning( + "Emitting control events without trace context; events will use fallback " + "IDs and cannot be correlated with traces. Pass trace_id/span_id for " + "full observability." + ) + _trace_warning_logged = True + + return trace_id or _FALLBACK_TRACE_ID, span_id or _FALLBACK_SPAN_ID + + +def _build_events_for_matches( + matches: list[ControlMatch] | None, + *, + matched: bool, + include_error_message: bool, + request: EvaluationRequest, + control_lookup: dict[int, ControlDefinition], + trace_id: str, + span_id: str, + agent_name: str, + now: datetime, +) -> list[ControlExecutionEvent]: + if not matches: + return [] + + applies_to = map_applies_to(request.step.type) + events: list[ControlExecutionEvent] = [] + + for match in matches: + control_def = control_lookup.get(match.control_id) + event_metadata = dict(match.result.metadata or {}) + selector_path = None + evaluator_name = None + + if control_def is not None: + selector_path, evaluator_name, identity_metadata = observability_metadata(control_def) + event_metadata.update(identity_metadata) + + events.append( + ControlExecutionEvent( + control_execution_id=match.control_execution_id, + trace_id=trace_id, + span_id=span_id, + agent_name=agent_name, + control_id=match.control_id, + control_name=match.control_name, + check_stage=request.stage, + applies_to=applies_to, + action=match.action, + matched=matched, + confidence=match.result.confidence, + timestamp=now, + evaluator_name=evaluator_name, + selector_path=selector_path, + error_message=match.result.error if include_error_message else None, + metadata=event_metadata, + ) + ) + + return events + + +def build_control_execution_events( + response: EvaluationResponse, + request: EvaluationRequest, + control_lookup: dict[int, ControlDefinition], + trace_id: str | None, + span_id: str | None, + agent_name: str | None, +) -> list[ControlExecutionEvent]: + """Reconstruct control execution events from an evaluation response. + + This is the shared reconstruction step used by both supported event + creation styles: + - the default SDK observability path, where reconstructed local events are + queued into the existing SDK batcher + - the merged-event path, where local and server events are reconstructed in + the SDK and queued together through the existing SDK batcher + + Args: + response: Evaluation response containing matches, errors, and + non-matches. + request: Original evaluation request used to derive stage and + ``applies_to``. + control_lookup: Parsed controls keyed by control ID. + trace_id: Optional trace ID for correlation. + span_id: Optional span ID for correlation. + agent_name: Optional override for the agent name stamped on events. + + Returns: + A list of reconstructed ``ControlExecutionEvent`` objects. + """ + resolved_trace_id, resolved_span_id = _resolve_event_trace_context(trace_id, span_id) + resolved_agent_name = agent_name or request.agent_name + now = datetime.now(UTC) + + events: list[ControlExecutionEvent] = [] + events.extend( + _build_events_for_matches( + response.matches, + matched=True, + include_error_message=True, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + events.extend( + _build_events_for_matches( + response.errors, + matched=False, + include_error_message=True, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + events.extend( + _build_events_for_matches( + response.non_matches, + matched=False, + include_error_message=False, + request=request, + control_lookup=control_lookup, + trace_id=resolved_trace_id, + span_id=resolved_span_id, + agent_name=resolved_agent_name, + now=now, + ) + ) + return events + + +def enqueue_observability_events(events: list[ControlExecutionEvent]) -> None: + """Enqueue reconstructed events through the existing SDK observability path. + + This preserves the built-in SDK behavior of forwarding events through the + existing observability batcher. + + Args: + events: Reconstructed control execution events to enqueue. + + Returns: + None. + """ + if not is_observability_enabled(): + return + + for event in events: + add_event(event) diff --git a/sdks/python/src/agent_control/telemetry/__init__.py b/sdks/python/src/agent_control/telemetry/__init__.py index 8d2ccf90..c488d4a2 100644 --- a/sdks/python/src/agent_control/telemetry/__init__.py +++ b/sdks/python/src/agent_control/telemetry/__init__.py @@ -1,5 +1,4 @@ """Telemetry interfaces for provider-agnostic tracing.""" - from .trace_context import ( TraceContext, TraceContextProvider, diff --git a/sdks/python/tests/test_evaluation.py b/sdks/python/tests/test_evaluation.py index e9842313..4c7a647b 100644 --- a/sdks/python/tests/test_evaluation.py +++ b/sdks/python/tests/test_evaluation.py @@ -66,6 +66,7 @@ def json(self) -> dict[str, object]: }, "stage": "pre", }, + headers=None, ) diff --git a/sdks/python/tests/test_init_step_merge.py b/sdks/python/tests/test_init_step_merge.py index 669e1236..f423593f 100644 --- a/sdks/python/tests/test_init_step_merge.py +++ b/sdks/python/tests/test_init_step_merge.py @@ -20,9 +20,11 @@ class DoesNotExist: ... @pytest.fixture(autouse=True) def _clean_registry() -> Generator[None, None, None]: """Ensure each test starts with an empty step registry.""" + agent_control._reset_state() clear() yield clear() + agent_control._reset_state() def test_init_passes_merged_steps_to_register_agent( @@ -189,6 +191,34 @@ def test_init_logs_agent_updated_when_registration_already_exists( assert agent_name in caplog.text +def test_init_registers_agent_without_merge_events_arg() -> None: + register_agent_mock = AsyncMock(return_value={"created": True, "controls": []}) + health_check_mock = AsyncMock(return_value={"status": "healthy"}) + + with patch( + "agent_control.__init__.AgentControlClient.health_check", + new=health_check_mock, + ), patch( + "agent_control.__init__.agents.register_agent", + new=register_agent_mock, + ): + agent_control.init( + agent_name=f"agent-{uuid4().hex[:12]}", + policy_refresh_interval_seconds=0, + ) + + assert register_agent_mock.await_args is not None + assert "merge_events" not in register_agent_mock.await_args.kwargs + + +def test_init_omits_merge_events_from_public_signature() -> None: + import inspect + + signature = inspect.signature(agent_control.init) + + assert "merge_events" not in signature.parameters + + @pytest.mark.asyncio async def test_refresh_controls_calls_agent_controls_endpoint() -> None: # Given: an initialized SDK agent session with network-facing calls mocked. diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index cb792987..a90ea785 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -1,4 +1,4 @@ -"""Tests for observability updates: event emission, non_matches propagation, applies_to mapping.""" +"""Tests for reconstructed control-execution events in SDK evaluation flows.""" from unittest.mock import AsyncMock, MagicMock, patch @@ -6,47 +6,34 @@ from agent_control import evaluation from agent_control.evaluation import ( _ControlAdapter, - _emit_local_events, - _map_applies_to, + _build_server_control_lookup, + _has_applicable_prefiltered_server_controls, _merge_results, ) +from agent_control.evaluation_events import ( + build_control_execution_events, + enqueue_observability_events, + map_applies_to, +) from agent_control.telemetry.trace_context import ( clear_trace_context_provider, set_trace_context_provider, ) from agent_control_models import ControlDefinition -# ============================================================================= -# _map_applies_to tests -# ============================================================================= - class TestMapAppliesTo: - """Tests for _map_applies_to helper.""" - def test_maps_tool_to_tool_call(self): - assert _map_applies_to("tool") == "tool_call" + assert map_applies_to("tool") == "tool_call" def test_maps_llm_to_llm_call(self): - assert _map_applies_to("llm") == "llm_call" - - def test_maps_unknown_to_llm_call(self): - """Unknown types default to llm_call (matches server pattern).""" - assert _map_applies_to("unknown") == "llm_call" - assert _map_applies_to("") == "llm_call" - - -# ============================================================================= -# _merge_results tests -# ============================================================================= + assert map_applies_to("llm") == "llm_call" class TestMergeResults: - """Tests for _merge_results combining non_matches.""" - def _make_response(self, **kwargs): - """Create a mock EvaluationResponse.""" from agent_control_models import EvaluationResponse + defaults = { "is_safe": True, "confidence": 1.0, @@ -60,6 +47,7 @@ def _make_response(self, **kwargs): def _make_match(self, control_id, control_name="ctrl", action="observe", matched=True): from agent_control_models import ControlMatch, EvaluatorResult + return ControlMatch( control_id=control_id, control_name=control_name, @@ -67,89 +55,94 @@ def _make_match(self, control_id, control_name="ctrl", action="observe", matched result=EvaluatorResult(matched=matched, confidence=0.9), ) - def test_combines_non_matches(self): - """non_matches from both sides should be combined.""" - nm1 = self._make_match(1, "ctrl-1", matched=False) - nm2 = self._make_match(2, "ctrl-2", matched=False) - - local = self._make_response(non_matches=[nm1]) - server = self._make_response(non_matches=[nm2]) + def test_combines_matches_errors_and_non_matches(self): + local = self._make_response( + matches=[self._make_match(1)], + errors=[self._make_match(2, matched=False)], + ) + server = self._make_response(non_matches=[self._make_match(3, matched=False)]) result = _merge_results(local, server) - assert result.non_matches is not None - assert len(result.non_matches) == 2 - ids = {nm.control_id for nm in result.non_matches} - assert ids == {1, 2} - def test_non_matches_none_when_both_empty(self): - local = self._make_response() - server = self._make_response() - result = _merge_results(local, server) - assert result.non_matches is None + assert [match.control_id for match in result.matches or []] == [1] + assert [match.control_id for match in result.errors or []] == [2] + assert [match.control_id for match in result.non_matches or []] == [3] - def test_non_matches_from_one_side(self): - nm = self._make_match(1, matched=False) - local = self._make_response(non_matches=[nm]) - server = self._make_response() - result = _merge_results(local, server) - assert result.non_matches is not None - assert len(result.non_matches) == 1 - def test_still_combines_matches_and_errors(self): - m1 = self._make_match(1, "m1") - m2 = self._make_match(2, "m2") - e1 = self._make_match(3, "e1", matched=False) +class TestEvaluationHelpers: + def test_build_server_control_lookup_skips_unparseable_controls(self): + lookup = _build_server_control_lookup( + [ + { + "id": 1, + "name": "ctrl-1", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "observe"}, + "execution": "server", + }, + }, + { + "id": 2, + "name": "ctrl-2", + "control": { + "condition": {"selector": {"path": "input"}}, + "action": {"decision": "observe"}, + "execution": "server", + }, + }, + ] + ) - local = self._make_response(matches=[m1], errors=[e1]) - server = self._make_response(matches=[m2]) + assert list(lookup.keys()) == [1] - result = _merge_results(local, server) - assert len(result.matches) == 2 - assert len(result.errors) == 1 + def test_has_applicable_prefiltered_server_controls_returns_true_for_malformed_payload(self): + from agent_control_models import EvaluationRequest + request = EvaluationRequest( + agent_name="agent-000000000001", + step={"type": "llm", "name": "test-step", "input": "hello"}, + stage="pre", + ) -# ============================================================================= -# _emit_local_events tests -# ============================================================================= + assert _has_applicable_prefiltered_server_controls( + [ + { + "id": 1, + "name": "bad-server-ctrl", + "control": { + "condition": {"selector": {"path": "input"}}, + "action": {"decision": "observe"}, + "execution": "server", + }, + } + ], + request, + ) is True -class TestEmitLocalEvents: - """Tests for _emit_local_events helper.""" - def _make_control_adapter(self, id, name, evaluator_name="regex", selector_path="input"): - """Create a _ControlAdapter for testing.""" - control_def = ControlDefinition( - execution="sdk", - condition={ - "evaluator": {"name": evaluator_name, "config": {"pattern": "test"}}, - "selector": {"path": selector_path}, - }, - action={"decision": "deny"}, - ) - return _ControlAdapter(id=id, name=name, control=control_def) - def _make_response(self, matches=None, errors=None, non_matches=None): - from agent_control_models import EvaluationResponse - return EvaluationResponse( - is_safe=not bool(matches), - confidence=1.0 if not matches else 0.5, - matches=matches, - errors=errors, - non_matches=non_matches, - ) - def _make_match(self, control_id, control_name="ctrl", action="deny", matched=True): - from agent_control_models import ControlMatch, EvaluatorResult - return ControlMatch( - control_id=control_id, - control_name=control_name, - action=action, - result=EvaluatorResult(matched=matched, confidence=0.9), + +class TestBuildControlExecutionEvents: + def _make_control(self, id, name, condition): + return _ControlAdapter( + id=id, + name=name, + control=ControlDefinition( + execution="sdk", + condition=condition, + action={"decision": "allow"}, + ), ) def _make_request(self, step_type="llm"): from agent_control_models import EvaluationRequest - # Tool steps require object input, LLM steps accept string + step_input = {"query": "hello"} if step_type == "tool" else "hello" return EvaluationRequest( agent_name="agent-000000000001", @@ -157,111 +150,66 @@ def _make_request(self, step_type="llm"): stage="pre", ) - def test_emits_events_when_observability_enabled(self): - """Should call add_event for each match/error/non_match.""" - from agent_control.evaluation import _emit_local_events + def _make_match(self, control_id, control_name="ctrl", action="allow", matched=True): + from agent_control_models import ControlMatch, EvaluatorResult - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - non_match = self._make_match(2, "ctrl-2", matched=False) - response = self._make_response(matches=[match], non_matches=[non_match]) - request = self._make_request() + return ControlMatch( + control_id=control_id, + control_name=control_name, + action=action, + result=EvaluatorResult(matched=matched, confidence=0.9), + ) - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, - [ctrl, self._make_control_adapter(2, "ctrl-2")], - "trace123", "span456", "test-agent", - ) - assert mock_add.call_count == 2 - # Verify event fields for the match - event = mock_add.call_args_list[0][0][0] - assert event.trace_id == "trace123" - assert event.span_id == "span456" - assert event.agent_name == "test-agent" - assert event.matched is True - assert event.evaluator_name == "regex" - assert event.selector_path == "input" - - def test_skips_when_observability_disabled(self): - """Should not call add_event when observability is disabled.""" - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request() + def _make_response(self, matches=None, errors=None, non_matches=None): + from agent_control_models import EvaluationResponse - with patch("agent_control.evaluation.is_observability_enabled", return_value=False), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) - mock_add.assert_not_called() - - def test_maps_tool_step_to_tool_call(self): - """Should set applies_to='tool_call' for tool steps.""" - from agent_control.evaluation import _emit_local_events - - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) - request = self._make_request(step_type="tool") - - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, request, [ctrl], - "trace123", "span456", "test-agent", - ) - event = mock_add.call_args_list[0][0][0] - assert event.applies_to == "tool_call" - - def test_uses_fallback_ids_when_trace_context_missing(self): - """Should emit events with all-zero fallback IDs when trace context is absent.""" - import agent_control.evaluation as eval_mod - from agent_control.evaluation import ( - _FALLBACK_SPAN_ID, - _FALLBACK_TRACE_ID, - _emit_local_events, + return EvaluationResponse( + is_safe=not bool(matches), + confidence=1.0 if not matches else 0.5, + matches=matches, + errors=errors, + non_matches=non_matches, ) - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) + def test_builds_events_with_trace_context(self): + response = self._make_response(matches=[self._make_match(1, "ctrl-1")]) request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + ).control + } - # Reset the once-only warning flag so the warning fires in this test - eval_mod._trace_warning_logged = False + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", + ) - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add, \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events( - response, request, [ctrl], - None, None, "test-agent", - ) - assert mock_add.call_count == 1 - event = mock_add.call_args_list[0][0][0] - assert event.trace_id == _FALLBACK_TRACE_ID - assert event.span_id == _FALLBACK_SPAN_ID - assert event.trace_id == "0" * 32 - assert event.span_id == "0" * 16 - # Warning should have been logged - mock_logger.warning.assert_called_once() - assert "fallback" in mock_logger.warning.call_args[0][0].lower() - - def test_composite_control_emits_representative_leaf_metadata(self): - """Composite local controls should emit stable representative metadata.""" - # Given: a composite local control and a non-match response for that control - ctrl = _ControlAdapter( - id=1, - name="composite-ctrl", - control=ControlDefinition( - execution="sdk", - condition={ + assert len(events) == 1 + event = events[0] + assert event.trace_id == "trace123" + assert event.span_id == "span456" + assert event.agent_name == "test-agent" + assert event.evaluator_name == "regex" + assert event.selector_path == "input" + + def test_composite_control_uses_representative_observability_identity(self): + response = self._make_response(non_matches=[self._make_match(1, "ctrl-1", matched=False)]) + request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { "and": [ { "selector": {"path": "input"}, @@ -273,27 +221,20 @@ def test_composite_control_emits_representative_leaf_metadata(self): }, ] }, - action={"decision": "observe"}, - ), - ) - non_match = self._make_match(1, "composite-ctrl", action="observe", matched=False) - response = self._make_response(non_matches=[non_match]) - request = self._make_request() + ).control + } - # When: emitting local observability events - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event") as mock_add: - _emit_local_events( - response, - request, - [ctrl], - "trace123", - "span456", - "test-agent", - ) - event = mock_add.call_args_list[0][0][0] + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", + ) - # Then: the first leaf becomes the event identity and full context is preserved + assert len(events) == 1 + event = events[0] assert event.evaluator_name == "regex" assert event.selector_path == "input" assert event.metadata["primary_evaluator"] == "regex" @@ -302,52 +243,102 @@ def test_composite_control_emits_representative_leaf_metadata(self): assert event.metadata["all_evaluators"] == ["regex"] assert event.metadata["all_selector_paths"] == ["input", "output"] - def test_fallback_warning_logged_only_once(self): - """The missing-trace-context warning should fire only on the first call.""" - import agent_control.evaluation as eval_mod - from agent_control.evaluation import _emit_local_events + def test_preserves_error_message_parity_by_result_category(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult - ctrl = self._make_control_adapter(1, "ctrl-1") - match = self._make_match(1, "ctrl-1") - response = self._make_response(matches=[match]) request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + ).control + } + response = EvaluationResponse( + is_safe=False, + confidence=0.5, + matches=[ + ControlMatch( + control_id=1, + control_name="ctrl-1", + action="allow", + result=EvaluatorResult( + matched=True, + confidence=0.9, + metadata={"server_error_message": "match-error"}, + ), + ) + ], + errors=[ + ControlMatch( + control_id=1, + control_name="ctrl-1", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.2, error="eval-error"), + ) + ], + non_matches=[ + ControlMatch( + control_id=1, + control_name="ctrl-1", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.1, error="ignored-error"), + ) + ], + ) - eval_mod._trace_warning_logged = False + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", + ) - with patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ - patch("agent_control.evaluation.add_event"), \ - patch("agent_control.evaluation._logger") as mock_logger: - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") - _emit_local_events(response, request, [ctrl], None, None, "agent-test-a1") - assert mock_logger.warning.call_count == 1 + assert events[0].error_message is None + assert events[1].error_message == "eval-error" + assert events[2].error_message is None + def test_enqueue_observability_events_uses_existing_batcher(self): + from agent_control_models import ControlExecutionEvent -# ============================================================================= -# check_evaluation_with_local event emission + header forwarding -# ============================================================================= + events = [ + ControlExecutionEvent( + trace_id="a" * 32, + span_id="b" * 16, + agent_name="agent-000000000001", + control_id=1, + control_name="ctrl-1", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=1.0, + ) + ] + with patch("agent_control.evaluation_events.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation_events.add_event") as mock_add: + enqueue_observability_events(events) + + mock_add.assert_called_once_with(events[0]) -class TestCheckEvaluationWithLocal: - """Tests for check_evaluation_with_local event emission and non_matches.""" +class TestCheckEvaluationWithLocal: def teardown_method(self) -> None: clear_trace_context_provider() @pytest.mark.asyncio - async def test_emits_events_when_trace_context_provided(self): - """Should emit observability events when trace_id and span_id are passed.""" - from agent_control_models import ( - ControlMatch, - EvaluationResponse, - EvaluatorResult, - Step, - ) + async def test_delivers_local_events_in_oss_mode(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step mock_response = EvaluationResponse( is_safe=True, confidence=1.0, - matches=None, - errors=None, non_matches=[ ControlMatch( control_id=1, @@ -357,7 +348,6 @@ async def test_emits_events_when_trace_context_provided(self): ) ], ) - mock_engine = MagicMock() mock_engine.process = AsyncMock(return_value=mock_response) @@ -380,7 +370,8 @@ async def test_emits_events_when_trace_context_provided(self): with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: + patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", @@ -391,30 +382,32 @@ async def test_emits_events_when_trace_context_provided(self): span_id="def456", event_agent_name="test-agent", ) - - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][2] is not None # local_controls - assert call_args[0][3] == "abc123" # trace_id - assert call_args[0][4] == "def456" # span_id - assert call_args.kwargs["agent_name"] == "test-agent" - - # Also verify non_matches propagated + mock_enqueue.assert_called_once() + delivered_events = mock_enqueue.call_args.args[0] + assert len(delivered_events) == 1 + assert delivered_events[0].trace_id == "abc123" + assert delivered_events[0].span_id == "def456" assert result.non_matches is not None assert len(result.non_matches) == 1 @pytest.mark.asyncio - async def test_emits_events_without_trace_context(self): - """Should resolve trace context from the provider when IDs are omitted.""" - from agent_control_models import EvaluationResponse, Step + async def test_resolves_provider_trace_context_for_local_events(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step mock_response = EvaluationResponse( - is_safe=True, confidence=1.0, matches=None, errors=None, non_matches=None, + is_safe=True, + confidence=1.0, + non_matches=[ + ControlMatch( + control_id=1, + control_name="test-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.1), + ) + ], ) - mock_engine = MagicMock() mock_engine.process = AsyncMock(return_value=mock_response) - controls = [{ "id": 1, "name": "test-ctrl", @@ -431,35 +424,29 @@ async def test_emits_events_without_trace_context(self): client = MagicMock() client.http_client = AsyncMock() step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider( - lambda: { - "trace_id": "a" * 32, - "span_id": "b" * 16, - } - ) + set_trace_context_provider(lambda: {"trace_id": "a" * 32, "span_id": "b" * 16}) with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ - patch("agent_control.evaluation._emit_local_events") as mock_emit: + patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", step=step, stage="pre", controls=controls, - # No trace_id/span_id ) - mock_emit.assert_called_once() - call_args = mock_emit.call_args - assert call_args[0][3] == "a" * 32 - assert call_args[0][4] == "b" * 16 + + delivered_events = mock_enqueue.call_args.args[0] + assert delivered_events[0].trace_id == "a" * 32 + assert delivered_events[0].span_id == "b" * 16 @pytest.mark.asyncio - async def test_forwards_trace_headers_to_server(self): - """Server POST should include X-Trace-Id and X-Span-Id headers.""" + async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): + """Server POST should resolve trace headers from the provider when omitted.""" from agent_control_models import Step - # Only server controls, no local controls controls = [{ "id": 1, "name": "server-ctrl", @@ -487,6 +474,12 @@ async def test_forwards_trace_headers_to_server(self): client.http_client = AsyncMock() client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "c" * 32, + "span_id": "d" * 16, + } + ) with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): await evaluation.check_evaluation_with_local( @@ -495,68 +488,362 @@ async def test_forwards_trace_headers_to_server(self): step=step, stage="pre", controls=controls, - trace_id="aaaa1111bbbb2222cccc3333dddd4444", - span_id="eeee5555ffff6666", ) - # Verify POST was called with headers call_kwargs = client.http_client.post.call_args headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" - assert headers["X-Span-Id"] == "eeee5555ffff6666" + assert headers["X-Trace-Id"] == "c" * 32 + assert headers["X-Span-Id"] == "d" * 16 + + +class TestCheckEvaluation: @pytest.mark.asyncio - async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): - """Server POST should resolve trace headers from the provider when omitted.""" + async def test_check_evaluation_enqueues_reconstructed_server_events_when_observability_enabled(self): from agent_control_models import Step + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 0.9, + "matches": None, + "errors": None, + "non_matches": [ + { + "control_id": 1, + "control_name": "ctrl-1", + "action": "observe", + "control_execution_id": "ce-1", + "result": {"matched": False, "confidence": 0.1}, + } + ], + } + + client = MagicMock() + client.base_url = "http://localhost:8000" + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.is_observability_enabled", return_value=True), patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_enqueue.assert_called_once() + assert result.is_safe is True + assert result.confidence == 0.9 + + @pytest.mark.asyncio + async def test_skips_local_event_reconstruction_when_observability_disabled(self): + from agent_control_models import EvaluationResponse, Step + controls = [{ "id": 1, - "name": "server-ctrl", + "name": "local-ctrl", "control": { "condition": { "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "deny"}, - "execution": "server", + "action": {"decision": "allow"}, + "execution": "sdk", }, }] + mock_response = EvaluationResponse(is_safe=True, confidence=1.0) + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=mock_response) + + client = MagicMock() + client.http_client = AsyncMock() + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), patch("agent_control.evaluation.is_observability_enabled", return_value=False), patch("agent_control.evaluation.build_control_execution_events") as mock_build, patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + ) + + mock_build.assert_not_called() + mock_enqueue.assert_not_called() + assert result.is_safe is True + assert result.confidence == 1.0 + + @pytest.mark.asyncio + async def test_check_evaluation_skips_enqueue_when_observability_disabled(self): + from agent_control_models import Step + mock_http_response = MagicMock() + mock_http_response.raise_for_status = MagicMock() mock_http_response.json.return_value = { "is_safe": True, - "confidence": 1.0, + "confidence": 0.9, "matches": None, "errors": None, "non_matches": None, } + + client = MagicMock() + client.base_url = "http://localhost:8000" + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.is_observability_enabled", return_value=False), patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + ) + + call_kwargs = client.http_client.post.call_args.kwargs + assert call_kwargs["headers"] is None + mock_enqueue.assert_not_called() + assert result.is_safe is True + assert result.confidence == 0.9 + + +# ============================================================================= +# Merged Event Creation +# ============================================================================= + + +class TestMergedEventCreation: + """Tests for SDK-side merged event reconstruction and enqueueing.""" + + @pytest.mark.asyncio + async def test_merged_event_mode_enqueues_reconstructed_local_and_server_events_once(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step + + local_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + matches=[ + ControlMatch( + control_id=1, + control_name="local-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.8), + ) + ], + ) + server_response = { + "is_safe": True, + "confidence": 0.9, + "matches": [ + { + "control_id": 2, + "control_name": "server-ctrl", + "action": "allow", + "control_execution_id": "ce-server", + "result": {"matched": False, "confidence": 0.4}, + } + ], + "errors": None, + "non_matches": None, + } + + controls = [ + { + "id": 1, + "name": "local-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "sdk", + }, + }, + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + }, + ] + + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=local_response) + mock_http_response = MagicMock() mock_http_response.raise_for_status = MagicMock() + mock_http_response.json.return_value = server_response client = MagicMock() client.http_client = AsyncMock() client.http_client.post = AsyncMock(return_value=mock_http_response) step = Step(type="llm", name="test-step", input="hello") - set_trace_context_provider( - lambda: { - "trace_id": "c" * 32, - "span_id": "d" * 16, - } + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + trace_id="abc123", + span_id="def456", + event_agent_name="test-agent", + ) + mock_enqueue.assert_called_once() + merged_events = mock_enqueue.call_args.args[0] + assert len(merged_events) == 2 + assert {event.control_id for event in merged_events} == {1, 2} + assert result.matches is not None + assert len(result.matches) == 2 + + @pytest.mark.asyncio + async def test_merged_event_mode_enqueues_local_events_before_reraising_server_failure(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step + + local_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + matches=[ + ControlMatch( + control_id=1, + control_name="local-ctrl", + action="allow", + result=EvaluatorResult(matched=False, confidence=0.8), + ) + ], ) - with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): - await evaluation.check_evaluation_with_local( + controls = [ + { + "id": 1, + "name": "local-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "sdk", + }, + }, + { + "id": 2, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "server", + }, + }, + ] + + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=local_response) + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(side_effect=RuntimeError("server unavailable")) + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), patch("agent_control.evaluation.is_observability_enabled", return_value=True), patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + with pytest.raises(RuntimeError, match="server unavailable"): + await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + trace_id="abc123", + span_id="def456", + event_agent_name="test-agent", + ) + + mock_enqueue.assert_called_once() + local_events = mock_enqueue.call_args.args[0] + assert len(local_events) == 1 + assert local_events[0].control_id == 1 + assert local_events[0].trace_id == "abc123" + assert local_events[0].span_id == "def456" + + @pytest.mark.asyncio + async def test_merged_event_mode_enqueues_only_local_events_when_no_server_controls_apply(self): + from agent_control_models import ControlMatch, EvaluationResponse, EvaluatorResult, Step + + local_response = EvaluationResponse( + is_safe=True, + confidence=1.0, + matches=[ + ControlMatch( + control_id=1, + control_name="local-ctrl", + action="allow", + result=EvaluatorResult(matched=True, confidence=0.8), + ) + ], + ) + controls = [ + { + "id": 1, + "name": "local-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "allow"}, + "execution": "sdk", + }, + } + ] + + mock_engine = MagicMock() + mock_engine.process = AsyncMock(return_value=local_response) + client = MagicMock() + client.http_client = AsyncMock() + step = Step(type="llm", name="test-step", input="hello") + + with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ + patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ + patch("agent_control.evaluation.is_observability_enabled", return_value=True), \ + patch("agent_control.evaluation.enqueue_observability_events") as mock_enqueue: + result = await evaluation.check_evaluation_with_local( client=client, agent_name="agent-000000000001", step=step, stage="pre", controls=controls, + trace_id="abc123", + span_id="def456", + event_agent_name="test-agent", ) - call_kwargs = client.http_client.post.call_args - headers = call_kwargs.kwargs.get("headers", {}) - assert headers["X-Trace-Id"] == "c" * 32 - assert headers["X-Span-Id"] == "d" * 16 + client.http_client.post.assert_not_called() + mock_enqueue.assert_called_once() + merged_events = mock_enqueue.call_args.args[0] + assert len(merged_events) == 1 + assert merged_events[0].control_id == 1 + assert result.matches is not None + assert len(result.matches) == 1 # ============================================================================= @@ -572,7 +859,6 @@ async def test_non_matches_populated_in_stats(self): """non_matches should be properly converted to dicts for stats tracking.""" from agent_control.control_decorators import ControlContext - # Simulate a result dict with non_matches result = { "is_safe": True, "confidence": 1.0, diff --git a/sdks/python/tests/test_trace_context.py b/sdks/python/tests/test_trace_context.py index f08306e0..2c1d727f 100644 --- a/sdks/python/tests/test_trace_context.py +++ b/sdks/python/tests/test_trace_context.py @@ -63,3 +63,14 @@ def test_get_trace_context_from_provider_returns_none_for_empty_ids() -> None: ) assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_for_non_string_ids() -> None: + set_trace_context_provider( # type: ignore[arg-type] + lambda: { + "trace_id": 123, + "span_id": b"abc", + } + ) + + assert get_trace_context_from_provider() is None diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 508b00ed..9537567d 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -23,6 +23,7 @@ import * as errors from "../models/errors/index.js"; import { ResponseValidationError } from "../models/errors/response-validation-error.js"; import { SDKValidationError } from "../models/errors/sdk-validation-error.js"; import * as models from "../models/index.js"; +import * as operations from "../models/operations/index.js"; import { APICall, APIPromise } from "../types/async.js"; import { Result } from "../types/fp.js"; @@ -49,7 +50,7 @@ import { Result } from "../types/fp.js"; */ export function agentsInit( client: AgentControlSDKCore, - request: models.InitAgentRequest, + request: operations.InitAgentApiV1AgentsInitAgentPostRequest, options?: RequestOptions, ): APIPromise< Result< @@ -74,7 +75,7 @@ export function agentsInit( async function $do( client: AgentControlSDKCore, - request: models.InitAgentRequest, + request: operations.InitAgentApiV1AgentsInitAgentPostRequest, options?: RequestOptions, ): Promise< [ @@ -95,14 +96,18 @@ async function $do( > { const parsed = safeParse( request, - (value) => z.parse(models.InitAgentRequest$outboundSchema, value), + (value) => + z.parse( + operations.InitAgentApiV1AgentsInitAgentPostRequest$outboundSchema, + value, + ), "Input validation failed", ); if (!parsed.ok) { return [parsed, { status: "invalid" }]; } const payload = parsed.value; - const body = encodeJSON("body", payload, { explode: true }); + const body = encodeJSON("body", payload.body, { explode: true }); const path = pathToFunc("/api/v1/agents/initAgent")(); diff --git a/sdks/typescript/src/generated/models/operations/index.ts b/sdks/typescript/src/generated/models/operations/index.ts index a8706eef..80d25cc6 100644 --- a/sdks/typescript/src/generated/models/operations/index.ts +++ b/sdks/typescript/src/generated/models/operations/index.ts @@ -16,6 +16,7 @@ export * from "./get-control-api-v1-controls-control-id-get.js"; export * from "./get-control-data-api-v1-controls-control-id-data-get.js"; export * from "./get-control-stats-api-v1-observability-stats-controls-control-id-get.js"; export * from "./get-stats-api-v1-observability-stats-get.js"; +export * from "./init-agent-api-v1-agents-init-agent-post.js"; export * from "./list-agent-controls-api-v1-agents-agent-name-controls-get.js"; export * from "./list-agent-evaluators-api-v1-agents-agent-name-evaluators-get.js"; export * from "./list-agents-api-v1-agents-get.js"; diff --git a/sdks/typescript/src/generated/models/operations/init-agent-api-v1-agents-init-agent-post.ts b/sdks/typescript/src/generated/models/operations/init-agent-api-v1-agents-init-agent-post.ts new file mode 100644 index 00000000..b611fef9 --- /dev/null +++ b/sdks/typescript/src/generated/models/operations/init-agent-api-v1-agents-init-agent-post.ts @@ -0,0 +1,39 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../../lib/primitives.js"; +import * as models from "../index.js"; + +export type InitAgentApiV1AgentsInitAgentPostRequest = { + body: models.InitAgentRequest; +}; + +/** @internal */ +export type InitAgentApiV1AgentsInitAgentPostRequest$Outbound = { + body: models.InitAgentRequest$Outbound; +}; + +/** @internal */ +export const InitAgentApiV1AgentsInitAgentPostRequest$outboundSchema: + z.ZodMiniType< + InitAgentApiV1AgentsInitAgentPostRequest$Outbound, + InitAgentApiV1AgentsInitAgentPostRequest + > = z.pipe( + z.object({ + body: models.InitAgentRequest$outboundSchema, + }), + z.transform((v) => remap$(v, {})), + ); + +export function initAgentApiV1AgentsInitAgentPostRequestToJSON( + initAgentApiV1AgentsInitAgentPostRequest: + InitAgentApiV1AgentsInitAgentPostRequest, +): string { + return JSON.stringify( + InitAgentApiV1AgentsInitAgentPostRequest$outboundSchema.parse( + initAgentApiV1AgentsInitAgentPostRequest, + ), + ); +} diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index d606b384..78353323 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -75,7 +75,7 @@ export class Agents extends ClientSDK { * InitAgentResponse with created flag and active controls (policy-derived + direct) */ async init( - request: models.InitAgentRequest, + request: operations.InitAgentApiV1AgentsInitAgentPostRequest, options?: RequestOptions, ): Promise { return unwrapAsync(agentsInit( diff --git a/sdks/typescript/tests/client-api.test.ts b/sdks/typescript/tests/client-api.test.ts index fe5a98db..ff156fab 100644 --- a/sdks/typescript/tests/client-api.test.ts +++ b/sdks/typescript/tests/client-api.test.ts @@ -144,9 +144,11 @@ describe("AgentControlClient API wiring", () => { }); await client.agents.init({ - agent: { - agentId: "550e8400-e29b-41d4-a716-446655440000", - agentName: "test-agent", + body: { + agent: { + agentId: "550e8400-e29b-41d4-a716-446655440000", + agentName: "test-agent", + }, }, }); diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index d88a9f0c..53a76a24 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -35,7 +35,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth import RequireAPIKey, require_admin_key from ..db import get_async_db from ..errors import ( APIValidationError, @@ -447,7 +447,9 @@ async def list_agents( response_description="Agent registration status with active controls", ) async def init_agent( - request: InitAgentRequest, db: AsyncSession = Depends(get_async_db) + request: InitAgentRequest, + client: RequireAPIKey, + db: AsyncSession = Depends(get_async_db), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index c92ea315..18b945fc 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -1,39 +1,28 @@ """Evaluation analysis endpoints.""" -import time -from datetime import UTC, datetime -from typing import Literal - from agent_control_engine.core import ControlEngine from agent_control_models import ( ControlDefinition, - ControlExecutionEvent, ControlMatch, EvaluationRequest, EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends, Header, Request +from fastapi import APIRouter, Depends from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..config import observability_settings +from ..auth import RequireAPIKey from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..observability.ingest.base import EventIngestor from ..services.controls import list_controls_for_agent -from .observability import get_event_ingestor router = APIRouter(prefix="/evaluation", tags=["evaluation"]) _logger = get_logger(__name__) -# OTEL-standard invalid IDs - used when client doesn't provide trace context. -# These are immediately recognizable as "not traced" and can be filtered in queries. -INVALID_TRACE_ID = "0" * 32 # 128-bit, 32 hex chars -INVALID_SPAN_ID = "0" * 16 # 64-bit, 16 hex chars SAFE_EVALUATOR_ERROR = "Evaluation failed due to an internal evaluator error." SAFE_EVALUATOR_TIMEOUT_ERROR = "Evaluation timed out before completion." SAFE_INVALID_STEP_REGEX_ERROR = "Control configuration error: invalid step name regex." @@ -126,24 +115,6 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) -def _observability_metadata( - control_def: ControlDefinition, -) -> tuple[str | None, str | None, dict[str, object]]: - """Return representative event fields plus full composite context.""" - identity = control_def.observability_identity() - return ( - identity.selector_path, - identity.evaluator_name, - { - "primary_evaluator": identity.evaluator_name, - "primary_selector_path": identity.selector_path, - "leaf_count": identity.leaf_count, - "all_evaluators": identity.all_evaluators, - "all_selector_paths": identity.all_selector_paths, - }, - ) - - @router.post( "", response_model=EvaluationResponse, @@ -152,41 +123,18 @@ def _observability_metadata( ) async def evaluate( request: EvaluationRequest, - req: Request, + client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - x_trace_id: str | None = Header(default=None, alias="X-Trace-Id"), - x_span_id: str | None = Header(default=None, alias="X-Span-Id"), ) -> EvaluationResponse: """Analyze content for safety and control violations. - Runs all controls assigned to the agent via policy through the - evaluation engine. Controls are evaluated in parallel with - cancel-on-deny for efficiency. - - Custom evaluators must be deployed as Evaluator classes - with the engine. Their schemas are registered via initAgent. - - Optionally accepts X-Trace-Id and X-Span-Id headers for - OpenTelemetry-compatible distributed tracing. + This endpoint is intentionally evaluation-only. It returns the semantic + ``EvaluationResponse`` and does not build or ingest observability events + on the server; SDKs reconstruct and emit those events separately through + the observability ingestion endpoint. """ - start_time = time.perf_counter() - - # Use provided trace/span IDs or fall back to OTEL invalid IDs. - # Invalid IDs make it obvious that trace context wasn't provided by the client. - if not x_trace_id or not x_span_id: - _logger.warning( - "Missing trace context headers (X-Trace-Id, X-Span-Id). " - "Using invalid IDs - observability data will not be traceable." - ) - trace_id = x_trace_id or INVALID_TRACE_ID - span_id = x_span_id or INVALID_SPAN_ID + del client # Authentication is still required by dependency injection. - # Determine payload type for observability based on step type - applies_to: Literal["llm_call", "tool_call"] = ( - "tool_call" if request.step.type == "tool" else "llm_call" - ) - - # Fetch agent to get the name agent_result = await db.execute( select(Agent).where(Agent.name == request.agent_name) ) @@ -199,22 +147,14 @@ async def evaluate( resource_id=request.agent_name, hint="Register the agent via initAgent before evaluating.", ) - agent_name = agent.name - # Fetch controls for the agent (already validated as ControlDefinition) api_controls = await list_controls_for_agent( request.agent_name, db, allow_invalid_step_name_regex=True, ) - - # Build control lookup for observability - control_lookup = {c.id: c for c in api_controls} - - # Adapt controls for the engine engine_controls = [ControlAdapter(c.id, c.name, c.control) for c in api_controls] - # Execute Control Engine (parallel with cancel-on-deny) engine = ControlEngine(engine_controls) try: raw_response = await engine.process(request) @@ -235,155 +175,4 @@ async def evaluate( ], ) - # Calculate total execution time - total_duration_ms = (time.perf_counter() - start_time) * 1000 - - # Emit observability events if enabled - if observability_settings.enabled: - # Get ingestor from app.state (None if not initialized) - try: - ingestor = get_event_ingestor(req) - except RuntimeError: - ingestor = None - - await _emit_observability_events( - response=raw_response, - request=request, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - applies_to=applies_to, - control_lookup=control_lookup, - total_duration_ms=total_duration_ms, - ingestor=ingestor, - ) - return _sanitize_evaluation_response(raw_response) - - -async def _emit_observability_events( - response: EvaluationResponse, - request: EvaluationRequest, - trace_id: str, - span_id: str, - agent_name: str, - applies_to: Literal["llm_call", "tool_call"], - control_lookup: dict, - total_duration_ms: float, - ingestor: EventIngestor | None, -) -> None: - """Create and enqueue observability events for all evaluated controls. - - Uses control_execution_id from the engine response to ensure correlation - between SDK logs and server observability events. - """ - events: list[ControlExecutionEvent] = [] - now = datetime.now(UTC) - - # Process matches (controls that matched) - if response.matches: - for match in response.matches: - ctrl = control_lookup.get(match.control_id) - event_metadata = dict(match.result.metadata or {}) - selector_path = None - evaluator_name = None - if ctrl: - selector_path, evaluator_name, identity_metadata = _observability_metadata( - ctrl.control - ) - event_metadata.update(identity_metadata) - events.append( - ControlExecutionEvent( - control_execution_id=match.control_execution_id, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - control_id=match.control_id, - control_name=match.control_name, - check_stage=request.stage, - applies_to=applies_to, - action=match.action, - matched=True, - confidence=match.result.confidence, - timestamp=now, - evaluator_name=evaluator_name, - selector_path=selector_path, - error_message=match.result.error, - metadata=event_metadata, - ) - ) - - # Process errors (controls that failed during evaluation) - if response.errors: - for error in response.errors: - ctrl = control_lookup.get(error.control_id) - event_metadata = dict(error.result.metadata or {}) - selector_path = None - evaluator_name = None - if ctrl: - selector_path, evaluator_name, identity_metadata = _observability_metadata( - ctrl.control - ) - event_metadata.update(identity_metadata) - events.append( - ControlExecutionEvent( - control_execution_id=error.control_execution_id, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - control_id=error.control_id, - control_name=error.control_name, - check_stage=request.stage, - applies_to=applies_to, - action=error.action, - matched=False, - confidence=error.result.confidence, - timestamp=now, - evaluator_name=evaluator_name, - selector_path=selector_path, - error_message=error.result.error, - metadata=event_metadata, - ) - ) - - # Process non-matches (controls that were evaluated but did not match) - if response.non_matches: - for non_match in response.non_matches: - ctrl = control_lookup.get(non_match.control_id) - event_metadata = dict(non_match.result.metadata or {}) - selector_path = None - evaluator_name = None - if ctrl: - selector_path, evaluator_name, identity_metadata = _observability_metadata( - ctrl.control - ) - event_metadata.update(identity_metadata) - events.append( - ControlExecutionEvent( - control_execution_id=non_match.control_execution_id, - trace_id=trace_id, - span_id=span_id, - agent_name=agent_name, - control_id=non_match.control_id, - control_name=non_match.control_name, - check_stage=request.stage, - applies_to=applies_to, - action=non_match.action, - matched=False, - confidence=non_match.result.confidence, - timestamp=now, - evaluator_name=evaluator_name, - selector_path=selector_path, - error_message=None, - metadata=event_metadata, - ) - ) - - # Ingest events - if events and ingestor: - result = await ingestor.ingest(events) - if result.dropped > 0: - _logger.warning( - f"Dropped {result.dropped} observability events, " - f"processed {result.processed}" - ) diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index 942dca66..1df795da 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -1,9 +1,13 @@ """End-to-end tests for evaluator error handling.""" -import logging import uuid from unittest.mock import AsyncMock, MagicMock -from agent_control_models import ControlMatch, EvaluationRequest, EvaluatorResult, Step +from agent_control_models import ( + ControlMatch, + EvaluationRequest, + EvaluatorResult, + Step, +) from fastapi.testclient import TestClient from agent_control_server.endpoints.evaluation import ( @@ -11,7 +15,6 @@ SAFE_EVALUATOR_TIMEOUT_ERROR, _sanitize_control_match, ) -from agent_control_server.observability.ingest.base import IngestResult from .utils import create_and_assign_policy @@ -165,12 +168,11 @@ def mock_get_evaluator_instance(config): assert data["matches"] is None or len(data["matches"]) == 0 -def test_evaluation_observability_receives_raw_errors_while_api_response_is_sanitized( +def test_evaluation_response_is_sanitized_without_server_side_observability( client: TestClient, monkeypatch, ) -> None: - """Observability should ingest raw evaluator diagnostics while API clients see safe text.""" - # Given: an agent with a deny control and an evaluator that crashes at runtime + """Evaluation stays pure and returns only sanitized semantics.""" control_data = { "description": "Test control", "enabled": True, @@ -190,7 +192,6 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani mock_evaluator.get_timeout_seconds = MagicMock(return_value=30.0) import agent_control_engine.core as core_module - import agent_control_server.endpoints.evaluation as evaluation_module monkeypatch.setattr( core_module, @@ -198,11 +199,6 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani lambda _config: mock_evaluator, ) - emit_mock = AsyncMock() - monkeypatch.setattr(evaluation_module, "_emit_observability_events", emit_mock) - monkeypatch.setattr(evaluation_module.observability_settings, "enabled", True) - - # When: sending an evaluation request payload = Step(type="llm", name="test-step", input="test content", output=None) req = EvaluationRequest( agent_name=agent_name, @@ -211,7 +207,6 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani ) resp = client.post("/api/v1/evaluation", json=req.model_dump(mode="json")) - # Then: the API response remains sanitized assert resp.status_code == 200 data = resp.json() assert data["errors"] is not None @@ -219,17 +214,6 @@ def test_evaluation_observability_receives_raw_errors_while_api_response_is_sani assert data["errors"][0]["control_name"] == control_name assert data["errors"][0]["result"]["error"] == SAFE_EVALUATOR_ERROR - # And: observability receives the raw engine response with unsanitized diagnostics - emit_mock.assert_awaited_once() - raw_response = emit_mock.await_args.kwargs["response"] - assert raw_response.errors is not None - raw_error = raw_response.errors[0] - assert raw_error.control_name == control_name - assert raw_error.result.error == "RuntimeError: Simulated evaluator crash" - raw_trace = raw_error.result.metadata["condition_trace"] - assert raw_trace["error"] == "RuntimeError: Simulated evaluator crash" - assert raw_trace["message"] == "Evaluation failed: RuntimeError: Simulated evaluator crash" - def test_sanitize_control_match_redacts_nested_condition_trace_errors() -> None: # Given: a control match whose nested condition trace contains raw evaluator errors @@ -343,32 +327,24 @@ async def raise_value_error(*_args, **_kwargs): assert body["errors"][0]["message"] == "Invalid evaluation request or control configuration." -def test_evaluation_warns_when_observability_drops_events( - client: TestClient, app, caplog -) -> None: - # Given: an agent with a control that will match + +def test_evaluation_ignores_merge_headers_and_remains_pure(client: TestClient) -> None: + """/evaluation should return only semantic results regardless of merge headers.""" agent_name, _ = create_and_assign_policy(client) - class DroppingIngestor: - async def ingest(self, events): # type: ignore[no-untyped-def] - return IngestResult(received=len(events), processed=0, dropped=len(events)) - - previous_ingestor = getattr(app.state, "event_ingestor", None) - app.state.event_ingestor = DroppingIngestor() - try: - # And: a log capture for the evaluation warning - caplog.set_level(logging.WARNING, logger="agent_control_server.endpoints.evaluation") - - # When: sending an evaluation request - payload = Step(type="llm", name="test-step", input="x", output=None) - req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") - resp = client.post("/api/v1/evaluation", json=req.model_dump(mode="json")) - - # Then: the evaluation succeeds but logs a dropped-events warning - assert resp.status_code == 200 - assert any("Dropped" in record.message for record in caplog.records) - finally: - if previous_ingestor is None: - del app.state.event_ingestor - else: - app.state.event_ingestor = previous_ingestor + payload = Step(type="llm", name="test-step", input="x", output=None) + req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") + resp = client.post( + "/api/v1/evaluation", + json=req.model_dump(mode="json"), + headers={ + "X-Agent-Control-Merge-Events": "true", + "X-Trace-Id": "a" * 32, + "X-Span-Id": "b" * 16, + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert "events" not in body + assert body["is_safe"] is False diff --git a/server/unit_tests/test_endpoint_helpers.py b/server/unit_tests/test_endpoint_helpers.py index 0b0b214f..4d245206 100644 --- a/server/unit_tests/test_endpoint_helpers.py +++ b/server/unit_tests/test_endpoint_helpers.py @@ -1,22 +1,14 @@ """Unit tests for endpoint helpers that don't require the DB test fixture.""" from types import SimpleNamespace -from unittest.mock import AsyncMock -import pytest -from agent_control_models import ( - ControlDefinition, - ControlMatch, - EvaluationRequest, - EvaluationResponse, - EvaluatorResult, -) +from agent_control_models import ControlDefinition, ControlMatch, EvaluatorResult from agent_control_server.endpoints.agents import ( _find_referencing_controls_for_removed_evaluators, ) from agent_control_server.endpoints.evaluation import ( ControlAdapter, - _emit_observability_events, + _sanitize_control_match, ) @@ -55,10 +47,9 @@ def test_find_referencing_controls_dedupes_composite_matches() -> None: assert referencing_controls == [("composite-ctrl", "custom")] -@pytest.mark.asyncio -async def test_emit_observability_events_uses_representative_leaf_for_composites() -> None: - # Given: a composite control with two leaves and existing condition metadata - control = ControlAdapter( +def test_sanitize_control_match_redacts_nested_condition_trace_errors() -> None: + # Given: a composite control whose condition trace includes a raw evaluator error + _ = ControlAdapter( id=1, name="composite-ctrl", control=ControlDefinition( @@ -78,53 +69,36 @@ async def test_emit_observability_events_uses_representative_leaf_for_composites action={"decision": "observe"}, ), ) - response = EvaluationResponse( - is_safe=True, - confidence=1.0, - non_matches=[ - ControlMatch( - control_id=1, - control_name="composite-ctrl", - action="observe", - result=EvaluatorResult( - matched=False, - confidence=0.9, - metadata={"condition_trace": {"kind": "and"}}, - ), - ) - ], - ) - request = EvaluationRequest( - agent_name="agent-000000000001", - step={"type": "llm", "name": "test-step", "input": "hello"}, - stage="pre", - ) - ingestor = SimpleNamespace( - ingest=AsyncMock(return_value=SimpleNamespace(dropped=0, processed=1)) + match = ControlMatch( + control_id=1, + control_name="composite-ctrl", + action="observe", + result=EvaluatorResult( + matched=False, + confidence=0.9, + error="RuntimeError: secret evaluator failure", + metadata={ + "condition_trace": { + "type": "and", + "children": [ + { + "type": "leaf", + "error": "RuntimeError: secret evaluator failure", + "message": "Evaluation failed: RuntimeError: secret evaluator failure", + } + ], + } + }, + ), ) - # When: emitting observability events - await _emit_observability_events( - response=response, - request=request, - trace_id="trace123", - span_id="span456", - agent_name="agent-000000000001", - applies_to="llm_call", - control_lookup={1: control}, - total_duration_ms=5.0, - ingestor=ingestor, - ) + # When: sanitizing the control match for API output + sanitized = _sanitize_control_match(match) - # Then: the first leaf becomes the event identity and full context is retained - events = ingestor.ingest.await_args.args[0] - assert len(events) == 1 - event = events[0] - assert event.evaluator_name == "regex" - assert event.selector_path == "input" - assert event.metadata["condition_trace"] == {"kind": "and"} - assert event.metadata["primary_evaluator"] == "regex" - assert event.metadata["primary_selector_path"] == "input" - assert event.metadata["leaf_count"] == 2 - assert event.metadata["all_evaluators"] == ["regex", "list"] - assert event.metadata["all_selector_paths"] == ["input", "output"] + # Then: top-level and nested errors are redacted to the safe public message + assert sanitized.result.error is not None + assert "secret evaluator failure" not in sanitized.result.error + trace = sanitized.result.metadata["condition_trace"] + child = trace["children"][0] + assert child["error"] == sanitized.result.error + assert child["message"] == sanitized.result.error