|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Any, Callable |
| 6 | + |
| 7 | + |
| 8 | +class StopRun(Exception): |
| 9 | + def __init__(self, reason: str): |
| 10 | + super().__init__(reason) |
| 11 | + self.reason = reason |
| 12 | + |
| 13 | + |
| 14 | +@dataclass(frozen=True) |
| 15 | +class Budget: |
| 16 | + max_seconds: int = 25 |
| 17 | + max_actions: int = 8 |
| 18 | + action_timeout_seconds: float = 1.2 |
| 19 | + max_recipients_per_send: int = 50000 |
| 20 | + |
| 21 | + |
| 22 | +@dataclass(frozen=True) |
| 23 | +class Decision: |
| 24 | + kind: str |
| 25 | + reason: str |
| 26 | + enforced_action: dict[str, Any] | None = None |
| 27 | + |
| 28 | + |
| 29 | +def _normalize_action(raw: Any) -> dict[str, Any]: |
| 30 | + if not isinstance(raw, dict): |
| 31 | + raise StopRun("invalid_action:not_object") |
| 32 | + |
| 33 | + action_id = raw.get("id") |
| 34 | + tool = raw.get("tool") |
| 35 | + args = raw.get("args") |
| 36 | + |
| 37 | + if not isinstance(action_id, str) or not action_id.strip(): |
| 38 | + raise StopRun("invalid_action:id") |
| 39 | + if not isinstance(tool, str) or not tool.strip(): |
| 40 | + raise StopRun("invalid_action:tool") |
| 41 | + if not isinstance(args, dict): |
| 42 | + raise StopRun("invalid_action:args") |
| 43 | + |
| 44 | + return { |
| 45 | + "id": action_id.strip(), |
| 46 | + "tool": tool.strip(), |
| 47 | + "args": dict(args), |
| 48 | + } |
| 49 | + |
| 50 | + |
| 51 | +def validate_plan(raw_actions: Any, *, max_actions: int) -> list[dict[str, Any]]: |
| 52 | + if not isinstance(raw_actions, list) or not raw_actions: |
| 53 | + raise StopRun("invalid_plan:actions") |
| 54 | + if len(raw_actions) > max_actions: |
| 55 | + raise StopRun("invalid_plan:too_many_actions") |
| 56 | + return [_normalize_action(item) for item in raw_actions] |
| 57 | + |
| 58 | + |
| 59 | +def validate_tool_observation(raw: Any, *, tool_name: str) -> dict[str, Any]: |
| 60 | + if not isinstance(raw, dict): |
| 61 | + raise StopRun(f"tool_invalid_output:{tool_name}") |
| 62 | + if raw.get("status") != "ok": |
| 63 | + raise StopRun(f"tool_status_not_ok:{tool_name}") |
| 64 | + data = raw.get("data") |
| 65 | + if not isinstance(data, dict): |
| 66 | + raise StopRun(f"tool_invalid_output:{tool_name}") |
| 67 | + return data |
| 68 | + |
| 69 | + |
| 70 | +class PolicyGateway: |
| 71 | + def __init__( |
| 72 | + self, |
| 73 | + *, |
| 74 | + allowed_tools_policy: set[str], |
| 75 | + allowed_tools_execution: set[str], |
| 76 | + budget: Budget, |
| 77 | + ): |
| 78 | + self.allowed_tools_policy = set(allowed_tools_policy) |
| 79 | + self.allowed_tools_execution = set(allowed_tools_execution) |
| 80 | + self.budget = budget |
| 81 | + self.allowed_templates = {"incident_p1_v2", "incident_p2_v1"} |
| 82 | + self._pool = ThreadPoolExecutor(max_workers=4) |
| 83 | + |
| 84 | + def close(self) -> None: |
| 85 | + self._pool.shutdown(wait=False, cancel_futures=True) |
| 86 | + |
| 87 | + def evaluate(self, *, action: dict[str, Any], state: dict[str, Any]) -> Decision: |
| 88 | + del state |
| 89 | + normalized = _normalize_action(action) |
| 90 | + tool = normalized["tool"] |
| 91 | + args = dict(normalized["args"]) |
| 92 | + |
| 93 | + if tool not in self.allowed_tools_policy: |
| 94 | + return Decision(kind="deny", reason="tool_denied_policy") |
| 95 | + if tool == "export_customer_data": |
| 96 | + return Decision(kind="deny", reason="pii_export_blocked") |
| 97 | + if tool not in self.allowed_tools_execution: |
| 98 | + return Decision(kind="deny", reason="tool_denied_execution") |
| 99 | + |
| 100 | + if tool != "send_status_update": |
| 101 | + return Decision(kind="allow", reason="policy_pass") |
| 102 | + |
| 103 | + rewrite_reasons: list[str] = [] |
| 104 | + rewritten = dict(args) |
| 105 | + |
| 106 | + if rewritten.get("template_id") not in self.allowed_templates: |
| 107 | + rewritten["template_id"] = "incident_p1_v2" |
| 108 | + rewrite_reasons.append("template_allowlist") |
| 109 | + |
| 110 | + raw_recipients = rewritten.get("max_recipients", self.budget.max_recipients_per_send) |
| 111 | + try: |
| 112 | + recipients = int(raw_recipients) |
| 113 | + except (TypeError, ValueError): |
| 114 | + recipients = self.budget.max_recipients_per_send |
| 115 | + if recipients > self.budget.max_recipients_per_send: |
| 116 | + recipients = self.budget.max_recipients_per_send |
| 117 | + rewrite_reasons.append("recipient_cap") |
| 118 | + rewritten["max_recipients"] = recipients |
| 119 | + |
| 120 | + if "free_text" in rewritten: |
| 121 | + rewritten.pop("free_text", None) |
| 122 | + rewrite_reasons.append("free_text_removed") |
| 123 | + |
| 124 | + if ( |
| 125 | + rewritten.get("channel") == "external_email" |
| 126 | + and rewritten.get("audience_segment") == "all_customers" |
| 127 | + ): |
| 128 | + rewritten["channel"] = "status_page" |
| 129 | + rewritten["audience_segment"] = "enterprise_active" |
| 130 | + enforced = { |
| 131 | + "id": normalized["id"], |
| 132 | + "tool": normalized["tool"], |
| 133 | + "args": rewritten, |
| 134 | + } |
| 135 | + return Decision( |
| 136 | + kind="escalate", |
| 137 | + reason="mass_external_broadcast", |
| 138 | + enforced_action=enforced, |
| 139 | + ) |
| 140 | + |
| 141 | + if not rewrite_reasons: |
| 142 | + return Decision(kind="allow", reason="policy_pass") |
| 143 | + |
| 144 | + enforced = { |
| 145 | + "id": normalized["id"], |
| 146 | + "tool": normalized["tool"], |
| 147 | + "args": rewritten, |
| 148 | + } |
| 149 | + return Decision( |
| 150 | + kind="rewrite", |
| 151 | + reason=f"policy_rewrite:{','.join(rewrite_reasons)}", |
| 152 | + enforced_action=enforced, |
| 153 | + ) |
| 154 | + |
| 155 | + def dispatch( |
| 156 | + self, |
| 157 | + *, |
| 158 | + tool_name: str, |
| 159 | + tool_fn: Callable[..., dict[str, Any]], |
| 160 | + args: dict[str, Any], |
| 161 | + ) -> dict[str, Any]: |
| 162 | + future = self._pool.submit(tool_fn, **args) |
| 163 | + try: |
| 164 | + raw = future.result(timeout=self.budget.action_timeout_seconds) |
| 165 | + except FuturesTimeoutError as exc: |
| 166 | + raise StopRun(f"tool_timeout:{tool_name}") from exc |
| 167 | + except Exception as exc: # noqa: BLE001 |
| 168 | + raise StopRun(f"tool_error:{tool_name}:{type(exc).__name__}") from exc |
| 169 | + |
| 170 | + return validate_tool_observation(raw, tool_name=tool_name) |
0 commit comments