diff --git a/packages/optimization/src/ldai_optimization/__init__.py b/packages/optimization/src/ldai_optimization/__init__.py index a0fc60a..5ee175f 100644 --- a/packages/optimization/src/ldai_optimization/__init__.py +++ b/packages/optimization/src/ldai_optimization/__init__.py @@ -7,19 +7,23 @@ from ldai_optimization.dataclasses import ( AIJudgeCallConfig, OptimizationContext, + OptimizationFromConfigOptions, OptimizationJudge, OptimizationJudgeContext, OptimizationOptions, ToolDefinition, ) +from ldai_optimization.ld_api_client import LDApiError __version__ = "0.0.0" __all__ = [ '__version__', 'AIJudgeCallConfig', + 'LDApiError', 'OptimizationClient', 'OptimizationContext', + 'OptimizationFromConfigOptions', 'OptimizationJudge', 'OptimizationJudgeContext', 'OptimizationOptions', diff --git a/packages/optimization/src/ldai_optimization/client.py b/packages/optimization/src/ldai_optimization/client.py index f78c2a7..ff3dde9 100644 --- a/packages/optimization/src/ldai_optimization/client.py +++ b/packages/optimization/src/ldai_optimization/client.py @@ -5,6 +5,7 @@ import logging import os import random +import uuid from typing import Any, Dict, List, Literal, Optional from ldai import AIAgentConfig, AIJudgeConfig, AIJudgeConfigDefault, LDAIClient @@ -16,11 +17,17 @@ AutoCommitConfig, JudgeResult, OptimizationContext, + OptimizationFromConfigOptions, OptimizationJudge, OptimizationJudgeContext, OptimizationOptions, ToolDefinition, ) +from ldai_optimization.ld_api_client import ( + AgentOptimizationConfig, + LDApiClient, + OptimizationResultPayload, +) from ldai_optimization.prompts import ( build_message_history_text, build_new_variation_prompt, @@ -39,6 +46,34 @@ logger = logging.getLogger(__name__) +def _strip_provider_prefix(model: str) -> str: + """Strip the provider prefix from a model identifier returned by the LD API. + + API model keys are formatted as "Provider.model-name" (e.g. "OpenAI.gpt-5", + "Anthropic.claude-opus-4.6"). Only the part after the first period is needed + by the underlying LLM clients. If no period is present the string is returned + unchanged. + + :param model: Raw model string from the API. + :return: Model name with provider prefix removed. + """ + return model.split(".", 1)[-1] + + +# Maps SDK status strings to the API status/activity values expected by +# agent_optimization_result records. Defined at module level to avoid +# allocating the dict on every on_status_update invocation. +_OPTIMIZATION_STATUS_MAP: Dict[str, Dict[str, str]] = { + "init": {"status": "RUNNING", "activity": "PENDING"}, + "generating": {"status": "RUNNING", "activity": "GENERATING"}, + "evaluating": {"status": "RUNNING", "activity": "EVALUATING"}, + "generating variation": {"status": "RUNNING", "activity": "GENERATING_VARIATION"}, + "turn completed": {"status": "RUNNING", "activity": "COMPLETED"}, + "success": {"status": "PASSED", "activity": "COMPLETED"}, + "failure": {"status": "FAILED", "activity": "COMPLETED"}, +} + + class OptimizationClient: _options: OptimizationOptions _ldClient: LDAIClient @@ -883,21 +918,151 @@ async def _generate_new_variation( ) async def optimize_from_config( - self, agent_key: str, optimization_config_key: str + self, optimization_config_key: str, options: OptimizationFromConfigOptions ) -> Any: - """Optimize an agent from a configuration. + """Optimize an agent using a configuration fetched from the LaunchDarkly API. - :param agent_key: Identifier of the agent to optimize. - :param optimization_config_key: Identifier of the optimization configuration to use. - :return: Optimization result. + The agent key, judge configuration, model choices, and other optimization + parameters are all sourced from the remote agent optimization config. The + caller only needs to provide the execution callbacks and evaluation contexts. + + Iteration results are automatically persisted to the LaunchDarkly API so + the UI can display live run progress. + + :param optimization_config_key: Key of the agent optimization config to fetch. + :param options: User-provided callbacks and evaluation contexts. + :return: Optimization result (OptimizationContext from the final iteration). """ if not self._has_api_key: raise ValueError( "LAUNCHDARKLY_API_KEY is not set, so optimize_from_config is not available" ) - self._agent_key = agent_key - raise NotImplementedError + assert self._api_key is not None + api_client = LDApiClient( + self._api_key, + **({"base_url": options.base_url} if options.base_url else {}), + ) + config = api_client.get_agent_optimization(options.project_key, optimization_config_key) + + self._agent_key = config["aiConfigKey"] + optimization_id: str = config["id"] + run_id = str(uuid.uuid4()) + + context = random.choice(options.context_choices) + # _get_agent_config calls _initialize_class_members_from_config internally; + # _run_optimization calls it again to reset history before the loop starts. + agent_config = await self._get_agent_config(self._agent_key, context) + + optimization_options = self._build_options_from_config( + config, options, api_client, optimization_id, run_id + ) + return await self._run_optimization(agent_config, optimization_options) + + def _build_options_from_config( + self, + config: AgentOptimizationConfig, + options: OptimizationFromConfigOptions, + api_client: LDApiClient, + optimization_id: str, + run_id: str, + ) -> OptimizationOptions: + """Map a fetched AgentOptimization config + user options into OptimizationOptions. + + Acceptance statements and judge configs from the API are merged into a single + judges dict. An on_status_update closure is injected to persist each iteration + result to the LaunchDarkly API; any user-supplied on_status_update is chained + after the persistence call. + + :param config: Validated AgentOptimizationConfig from the API. + :param options: User-provided options from optimize_from_config. + :param api_client: Initialised LDApiClient for result persistence. + :param optimization_id: UUID id of the parent agent_optimization record. + :param run_id: UUID that groups all result records for this run. + :return: A fully populated OptimizationOptions ready for _run_optimization. + """ + judges: Dict[str, OptimizationJudge] = {} + + for i, stmt in enumerate(config["acceptanceStatements"]): + key = f"acceptance-statement-{i}" + judges[key] = OptimizationJudge( + threshold=float(stmt.get("threshold", 0.95)), + acceptance_statement=stmt["statement"], + ) + + for judge in config["judges"]: + judges[judge["key"]] = OptimizationJudge( + threshold=float(judge.get("threshold", 0.95)), + judge_key=judge["key"], + ) + + has_ground_truth = bool(config.get("groundTruthResponses")) + if not judges and not has_ground_truth and options.on_turn is None: + raise ValueError( + "The optimization config has no acceptance statements, judges, or ground truth " + "responses, and no on_turn callback was provided. At least one is required to " + "evaluate optimization results." + ) + + variable_choices: List[Dict[str, Any]] = config["variableChoices"] or [{}] + user_input_options: Optional[List[str]] = config["userInputOptions"] or None + + project_key = options.project_key + config_version: int = config["version"] + + def _persist_and_forward( + status: Literal[ + "init", + "generating", + "evaluating", + "generating variation", + "turn completed", + "success", + "failure", + ], + ctx: OptimizationContext, + ) -> None: + # _safe_status_update (the caller) already wraps this entire function in + # a try/except, so errors here are caught and logged without aborting the run. + mapped = _OPTIMIZATION_STATUS_MAP.get( + status, {"status": "RUNNING", "activity": "PENDING"} + ) + snapshot = ctx.copy_without_history() + payload: OptimizationResultPayload = { + "run_id": run_id, + "config_optimization_version": config_version, + "status": mapped["status"], + "activity": mapped["activity"], + "iteration": snapshot.iteration, + "instructions": snapshot.current_instructions, + "parameters": snapshot.current_parameters, + "completion_response": snapshot.completion_response, + "scores": {k: v.to_json() for k, v in snapshot.scores.items()}, + "user_input": snapshot.user_input, + } + api_client.post_agent_optimization_result(project_key, optimization_id, payload) + + if options.on_status_update: + try: + options.on_status_update(status, ctx) + except Exception: + logger.exception("User on_status_update callback failed for status=%s", status) + + return OptimizationOptions( + context_choices=options.context_choices, + max_attempts=config["maxAttempts"], + model_choices=[_strip_provider_prefix(m) for m in config["modelChoices"]], + judge_model=_strip_provider_prefix(config["judgeModel"]), + variable_choices=variable_choices, + handle_agent_call=options.handle_agent_call, + handle_judge_call=options.handle_judge_call, + judges=judges or None, + user_input_options=user_input_options, + on_turn=options.on_turn, + on_passing_result=options.on_passing_result, + on_failing_result=options.on_failing_result, + on_status_update=_persist_and_forward, + ) async def _execute_agent_turn( self, diff --git a/packages/optimization/src/ldai_optimization/dataclasses.py b/packages/optimization/src/ldai_optimization/dataclasses.py index 944f7ec..2635053 100644 --- a/packages/optimization/src/ldai_optimization/dataclasses.py +++ b/packages/optimization/src/ldai_optimization/dataclasses.py @@ -204,6 +204,20 @@ class OptimizationJudgeContext: variables: Dict[str, Any] = field(default_factory=dict) # variable set used during agent generation +# Shared callback type aliases used by both OptimizationOptions and +# OptimizationFromConfigOptions to avoid duplicating the full signatures. +# Placed here so all referenced types (OptimizationContext, AIJudgeCallConfig, +# OptimizationJudgeContext) are already defined above. +HandleAgentCall = Union[ + Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], str], + Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], Awaitable[str]], +] +HandleJudgeCall = Union[ + Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], str], + Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], Awaitable[str]], +] + + @dataclass class OptimizationOptions: """Options for agent optimization.""" @@ -218,14 +232,8 @@ class OptimizationOptions: Dict[str, Any] ] # choices of interpolated variables to be chosen at random per turn, 1 min required # Actual agent/completion (judge) calls - Required - handle_agent_call: Union[ - Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], str], - Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], Awaitable[str]], - ] - handle_judge_call: Union[ - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], str], - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], Awaitable[str]], - ] + handle_agent_call: HandleAgentCall + handle_judge_call: HandleJudgeCall # Criteria for pass/fail - Optional user_input_options: Optional[List[str]] = ( None # optional list of user input messages to randomly select from @@ -270,3 +278,56 @@ def __post_init__(self): raise ValueError("Either judges or on_turn must be provided") if self.judge_model is None: raise ValueError("judge_model must be provided") + + +@dataclass +class OptimizationFromConfigOptions: + """User-provided options for optimize_from_config. + + Fields that come from the LaunchDarkly API (max_attempts, model_choices, + judge_model, variable_choices, user_input_options, judges) are omitted here + and sourced from the fetched agent optimization config instead. + + :param project_key: LaunchDarkly project key used to build API paths. + :param context_choices: One or more LD evaluation contexts to use. + :param handle_agent_call: Callback that invokes the agent and returns its response. + :param handle_judge_call: Callback that invokes a judge and returns its response. + :param on_turn: Optional manual pass/fail callback; when provided, judge scoring is skipped. + :param on_passing_result: Called with the winning OptimizationContext on success. + :param on_failing_result: Called with the final OptimizationContext on failure. + :param on_status_update: Called on each status transition; chained after the + automatic result-persistence POST so it always runs after the record is saved. + :param base_url: Base URL of the LaunchDarkly instance. Defaults to + https://app.launchdarkly.com. Override to target a staging instance. + """ + + project_key: str + context_choices: List[Context] + handle_agent_call: HandleAgentCall + handle_judge_call: HandleJudgeCall + on_turn: Optional[Callable[["OptimizationContext"], bool]] = None + on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None + on_failing_result: Optional[Callable[["OptimizationContext"], None]] = None + on_status_update: Optional[ + Callable[ + [ + Literal[ + "init", + "generating", + "evaluating", + "generating variation", + "turn completed", + "success", + "failure", + ], + "OptimizationContext", + ], + None, + ] + ] = None + base_url: Optional[str] = None + + def __post_init__(self): + """Validate required options.""" + if len(self.context_choices) < 1: + raise ValueError("context_choices must have at least 1 context") diff --git a/packages/optimization/src/ldai_optimization/ld_api_client.py b/packages/optimization/src/ldai_optimization/ld_api_client.py new file mode 100644 index 0000000..8a457cc --- /dev/null +++ b/packages/optimization/src/ldai_optimization/ld_api_client.py @@ -0,0 +1,252 @@ +"""Internal LaunchDarkly REST API client for the optimization package.""" + +import json +import logging +import urllib.error +import urllib.request +from typing import Any, Dict, List, Optional, TypedDict + +logger = logging.getLogger(__name__) + +_BASE_URL = "https://app.launchdarkly.com" + + +class LDApiError(Exception): + """Raised when the LaunchDarkly REST API returns an error or is unreachable. + + Attributes: + status_code: HTTP status code, or None for network-level failures. + path: The API path that was requested. + """ + + def __init__(self, message: str, status_code: Optional[int] = None, path: str = "") -> None: + super().__init__(message) + self.status_code = status_code + self.path = path + + +_HTTP_ERROR_HINTS: Dict[int, str] = { + 401: "Authentication failed — check that LAUNCHDARKLY_API_KEY is set correctly.", + 403: "Authorization failed — check that your API key has the required permissions.", + 404: "Resource not found — check that the project key and optimization config key are correct.", + 429: "Rate limit exceeded — too many requests to the LaunchDarkly API.", +} + +_REQUIRED_STRING_FIELDS = ("id", "key", "aiConfigKey", "judgeModel") +_REQUIRED_INT_FIELDS = ("maxAttempts", "version", "createdAt") +_REQUIRED_LIST_FIELDS = ( + "modelChoices", + "variableChoices", + "acceptanceStatements", + "judges", + "userInputOptions", +) + + +# --------------------------------------------------------------------------- +# API response shapes +# --------------------------------------------------------------------------- + +class _AcceptanceStatement(TypedDict): + statement: str + threshold: float + + +class _AgentOptimizationJudge(TypedDict): + key: str + threshold: float + + +class _AgentOptimizationConfigRequired(TypedDict): + id: str + key: str + aiConfigKey: str + maxAttempts: int + modelChoices: List[str] + judgeModel: str + variableChoices: List[Dict[str, Any]] + acceptanceStatements: List[_AcceptanceStatement] + judges: List[_AgentOptimizationJudge] + userInputOptions: List[str] + version: int + createdAt: int + + +class AgentOptimizationConfig(_AgentOptimizationConfigRequired, total=False): + """Typed representation of the AgentOptimization API response.""" + + groundTruthResponses: List[str] + metricKey: str + + +# --------------------------------------------------------------------------- +# POST payload shape +# --------------------------------------------------------------------------- + +class _OptimizationResultPayloadRequired(TypedDict): + run_id: str + config_optimization_version: int + status: str + activity: str + iteration: int + instructions: str + parameters: Dict[str, Any] + completion_response: str + scores: Dict[str, Any] + + +class OptimizationResultPayload(_OptimizationResultPayloadRequired, total=False): + """Typed payload for a single agent_optimization_result POST request. + + Required fields are always sent. Optional fields are omitted when not + available. Fields that require separate tracking instrumentation + (variation, generation_tokens, evaluation_tokens, generation_latency, + evaluation_latencies) are deferred. + + created_variation_key is only present on the final result record of a + successful run, populated once a winning variation is committed to LD. + """ + + user_input: Optional[str] + created_variation_key: str + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + +def _parse_agent_optimization(data: Any) -> AgentOptimizationConfig: + """Validate and cast a raw API response dict to AgentOptimizationConfig. + + :param data: Parsed JSON response from the GET endpoint. + :return: The same dict narrowed to AgentOptimizationConfig. + :raises ValueError: If required fields are missing or have wrong types. + """ + if not isinstance(data, dict): + raise ValueError( + f"Expected a JSON object from AgentOptimization API, got {type(data).__name__}" + ) + + errors: List[str] = [] + + for field in _REQUIRED_STRING_FIELDS: + if field not in data: + errors.append(f"missing required field '{field}'") + elif not isinstance(data[field], str): + errors.append( + f"field '{field}' must be a string, got {type(data[field]).__name__}" + ) + + for field in _REQUIRED_INT_FIELDS: + if field not in data: + errors.append(f"missing required field '{field}'") + elif not isinstance(data[field], int): + errors.append( + f"field '{field}' must be an integer, got {type(data[field]).__name__}" + ) + + for field in _REQUIRED_LIST_FIELDS: + if field not in data: + errors.append(f"missing required field '{field}'") + elif not isinstance(data[field], list): + errors.append( + f"field '{field}' must be a list, got {type(data[field]).__name__}" + ) + + if not errors and "modelChoices" in data and isinstance(data["modelChoices"], list): + if len(data["modelChoices"]) < 1: + errors.append("field 'modelChoices' must have at least 1 entry") + + if errors: + raise ValueError( + f"Invalid AgentOptimization response: {'; '.join(errors)}" + ) + + return data # type: ignore[return-value] + + +# --------------------------------------------------------------------------- +# Client +# --------------------------------------------------------------------------- + +class LDApiClient: + """Thin wrapper around the LaunchDarkly REST API for agent-optimization endpoints.""" + + def __init__(self, api_key: str, base_url: str = _BASE_URL) -> None: + self._api_key = api_key + self._base_url = base_url.rstrip("/") + + def _auth_headers(self) -> Dict[str, str]: + return {"Authorization": self._api_key} + + def _request(self, method: str, path: str, body: Any = None) -> Any: + url = f"{self._base_url}{path}" + data = json.dumps(body).encode() if body is not None else None + headers = self._auth_headers() + if data is not None: + headers["Content-Type"] = "application/json" + req = urllib.request.Request(url, data=data, headers=headers, method=method) + try: + with urllib.request.urlopen(req) as resp: + raw = resp.read() + return json.loads(raw) if raw else None + except urllib.error.HTTPError as exc: + body_excerpt = exc.read(500).decode(errors="replace") + hint = _HTTP_ERROR_HINTS.get(exc.code, "") + detail = f"{hint} (API response: {body_excerpt})" if hint else f"API response: {body_excerpt}" + raise LDApiError( + f"LaunchDarkly API error {exc.code} {exc.msg} for {method} {path}. {detail}", + status_code=exc.code, + path=path, + ) from exc + except urllib.error.URLError as exc: + raise LDApiError( + f"Could not reach LaunchDarkly API at {url}: {exc.reason}. " + "Check your network connection and the base_url setting.", + path=path, + ) from exc + + def get_agent_optimization( + self, project_key: str, optimization_key: str + ) -> AgentOptimizationConfig: + """Fetch and validate a single agent optimization config by key. + + :param project_key: LaunchDarkly project key. + :param optimization_key: Key of the agent optimization config. + :return: Validated AgentOptimizationConfig. + :raises LDApiError: On non-200 HTTP responses or network errors. + :raises ValueError: If the response is missing required fields. + """ + path = f"/api/v2/projects/{project_key}/agent-optimizations/{optimization_key}" + raw = self._request("GET", path) + return _parse_agent_optimization(raw) + + def post_agent_optimization_result( + self, project_key: str, optimization_id: str, payload: OptimizationResultPayload + ) -> None: + """Persist an iteration result record for the given optimization run. + + Errors are caught and logged rather than raised so that persistence + failures never abort an in-progress optimization run. + + :param project_key: LaunchDarkly project key. + :param optimization_id: UUID id of the parent agent_optimization record. + :param payload: Typed result payload for this iteration. + """ + path = f"/api/v2/projects/{project_key}/agent-optimizations/{optimization_id}/results" + try: + self._request("POST", path, body=payload) + except LDApiError as exc: + logger.debug( + "Failed to persist optimization result (optimization_id=%s, iteration=%s): %s", + optimization_id, + payload.get("iteration"), + exc, + ) + except Exception as exc: + logger.debug( + "Unexpected error persisting optimization result (optimization_id=%s, iteration=%s): %s", + optimization_id, + payload.get("iteration"), + exc, + ) diff --git a/packages/optimization/tests/test_client.py b/packages/optimization/tests/test_client.py index 7ccb406..5ece830 100644 --- a/packages/optimization/tests/test_client.py +++ b/packages/optimization/tests/test_client.py @@ -14,6 +14,7 @@ AIJudgeCallConfig, JudgeResult, OptimizationContext, + OptimizationFromConfigOptions, OptimizationJudge, OptimizationJudgeContext, OptimizationOptions, @@ -1006,3 +1007,389 @@ def test_section_appears_in_full_prompt(self): ) assert "Facts only." in prompt assert "ACCEPTANCE CRITERIA" in prompt + + +# --------------------------------------------------------------------------- +# _build_options_from_config helpers +# --------------------------------------------------------------------------- + +_API_CONFIG: Dict[str, Any] = { + "id": "opt-uuid-123", + "key": "my-optimization", + "aiConfigKey": "my-agent", + "maxAttempts": 3, + "modelChoices": ["gpt-4o", "gpt-4o-mini"], + "judgeModel": "gpt-4o", + "variableChoices": [{"language": "English"}], + "acceptanceStatements": [{"statement": "Be accurate.", "threshold": 0.9}], + "judges": [], + "userInputOptions": ["What is 2+2?"], + "version": 2, + "createdAt": 1700000000, +} + + +def _make_from_config_options(**overrides: Any) -> OptimizationFromConfigOptions: + defaults: Dict[str, Any] = dict( + project_key="my-project", + context_choices=[LD_CONTEXT], + handle_agent_call=AsyncMock(return_value="The answer is 4."), + handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + ) + defaults.update(overrides) + return OptimizationFromConfigOptions(**defaults) + + +def _make_mock_api_client() -> MagicMock: + mock = MagicMock() + mock.post_agent_optimization_result = MagicMock() + return mock + + +# --------------------------------------------------------------------------- +# _build_options_from_config +# --------------------------------------------------------------------------- + + +class TestBuildOptionsFromConfig: + def setup_method(self): + self.client = _make_client() + self.client._agent_key = "my-agent" + self.client._initialize_class_members_from_config(_make_agent_config()) + self.client._options = _make_options() + self.api_client = _make_mock_api_client() + + def _build(self, config=None, options=None) -> OptimizationOptions: + return self.client._build_options_from_config( + config or dict(_API_CONFIG), + options or _make_from_config_options(), + self.api_client, + optimization_id="opt-uuid-123", + run_id="run-uuid-456", + ) + + def test_acceptance_statements_mapped_to_judges(self): + result = self._build() + assert "acceptance-statement-0" in result.judges + judge = result.judges["acceptance-statement-0"] + assert judge.acceptance_statement == "Be accurate." + assert judge.threshold == 0.9 + + def test_multiple_acceptance_statements_get_indexed_keys(self): + config = dict(_API_CONFIG, acceptanceStatements=[ + {"statement": "First.", "threshold": 0.8}, + {"statement": "Second.", "threshold": 0.7}, + ]) + result = self._build(config=config) + assert "acceptance-statement-0" in result.judges + assert "acceptance-statement-1" in result.judges + assert result.judges["acceptance-statement-0"].acceptance_statement == "First." + assert result.judges["acceptance-statement-1"].acceptance_statement == "Second." + + def test_judges_mapped_by_key(self): + config = dict(_API_CONFIG, acceptanceStatements=[], judges=[ + {"key": "accuracy", "threshold": 0.85}, + ]) + result = self._build(config=config) + assert "accuracy" in result.judges + judge = result.judges["accuracy"] + assert judge.judge_key == "accuracy" + assert judge.threshold == 0.85 + + def test_acceptance_statements_and_judges_merged(self): + config = dict(_API_CONFIG, + acceptanceStatements=[{"statement": "Be brief.", "threshold": 0.8}], + judges=[{"key": "accuracy", "threshold": 0.9}], + ) + result = self._build(config=config) + assert "acceptance-statement-0" in result.judges + assert "accuracy" in result.judges + + def test_raises_when_no_judges_no_ground_truth_no_on_turn(self): + config = dict(_API_CONFIG, acceptanceStatements=[], judges=[]) + with pytest.raises(ValueError, match="no acceptance statements, judges, or ground truth"): + self._build(config=config) + + def test_ground_truth_responses_alone_does_not_pass_no_criteria_check(self): + # groundTruthResponses is not yet implemented as standalone criteria; + # OptimizationOptions still requires judges or on_turn. + config = dict(_API_CONFIG, acceptanceStatements=[], judges=[], groundTruthResponses=["4"]) + with pytest.raises((ValueError, Exception)): + self._build(config=config) + + def test_on_turn_satisfies_no_judges_requirement(self): + config = dict(_API_CONFIG, acceptanceStatements=[], judges=[]) + options = _make_from_config_options(on_turn=lambda ctx: True) + result = self._build(config=config, options=options) + assert result.on_turn is not None + + def test_empty_variable_choices_defaults_to_single_empty_dict(self): + config = dict(_API_CONFIG, variableChoices=[]) + result = self._build(config=config) + assert result.variable_choices == [{}] + + def test_non_empty_variable_choices_passed_through(self): + result = self._build() + assert result.variable_choices == [{"language": "English"}] + + def test_empty_user_input_options_becomes_none(self): + config = dict(_API_CONFIG, userInputOptions=[]) + result = self._build(config=config) + assert result.user_input_options is None + + def test_non_empty_user_input_options_passed_through(self): + result = self._build() + assert result.user_input_options == ["What is 2+2?"] + + def test_max_attempts_from_config(self): + result = self._build() + assert result.max_attempts == 3 + + def test_model_choices_provider_prefix_stripped(self): + config = dict(_API_CONFIG, modelChoices=["OpenAI.gpt-4o", "Anthropic.claude-opus-4-5"]) + result = self._build(config=config) + assert result.model_choices == ["gpt-4o", "claude-opus-4-5"] + + def test_judge_model_provider_prefix_stripped(self): + config = dict(_API_CONFIG, judgeModel="OpenAI.gpt-4o") + result = self._build(config=config) + assert result.judge_model == "gpt-4o" + + def test_model_choices_without_prefix_unchanged(self): + result = self._build() + assert result.model_choices == ["gpt-4o", "gpt-4o-mini"] + + def test_judge_model_without_prefix_unchanged(self): + result = self._build() + assert result.judge_model == "gpt-4o" + + def test_model_with_multiple_dots_only_prefix_stripped(self): + config = dict(_API_CONFIG, judgeModel="Anthropic.claude-opus-4.6") + result = self._build(config=config) + assert result.judge_model == "claude-opus-4.6" + + def test_callbacks_forwarded_from_options(self): + handle_agent = AsyncMock(return_value="ok") + handle_judge = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + options = _make_from_config_options( + handle_agent_call=handle_agent, + handle_judge_call=handle_judge, + on_passing_result=MagicMock(), + on_failing_result=MagicMock(), + ) + result = self._build(options=options) + assert result.handle_agent_call is handle_agent + assert result.handle_judge_call is handle_judge + assert result.on_passing_result is options.on_passing_result + assert result.on_failing_result is options.on_failing_result + + def test_persist_and_forward_posts_result_on_status_update(self): + result = self._build() + ctx = OptimizationContext( + scores={}, + completion_response="The answer is 4.", + current_instructions="Be helpful.", + current_parameters={"temperature": 0.7}, + current_variables={"language": "English"}, + current_model="gpt-4o", + user_input="What is 2+2?", + iteration=1, + ) + result.on_status_update("generating", ctx) + self.api_client.post_agent_optimization_result.assert_called_once() + call_args = self.api_client.post_agent_optimization_result.call_args + assert call_args[0][0] == "my-project" + assert call_args[0][1] == "opt-uuid-123" + + def test_persist_and_forward_payload_has_correct_field_names(self): + result = self._build() + ctx = OptimizationContext( + scores={"j": JudgeResult(score=0.9, rationale="Good.")}, + completion_response="Paris.", + current_instructions="Be helpful.", + current_parameters={"temperature": 0.5}, + current_variables={}, + current_model="gpt-4o", + user_input="Capital of France?", + iteration=2, + ) + result.on_status_update("evaluating", ctx) + payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert payload["instructions"] == "Be helpful." + assert payload["parameters"] == {"temperature": 0.5} + assert payload["completion_response"] == "Paris." + assert payload["user_input"] == "Capital of France?" + assert payload["iteration"] == 2 + assert "j" in payload["scores"] + + def test_persist_and_forward_includes_run_id_and_version(self): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) + payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert payload["run_id"] == "run-uuid-456" + assert payload["config_optimization_version"] == 2 + + @pytest.mark.parametrize("sdk_status,expected_status,expected_activity", [ + ("init", "RUNNING", "PENDING"), + ("generating", "RUNNING", "GENERATING"), + ("evaluating", "RUNNING", "EVALUATING"), + ("generating variation", "RUNNING", "GENERATING_VARIATION"), + ("turn completed", "RUNNING", "COMPLETED"), + ("success", "PASSED", "COMPLETED"), + ("failure", "FAILED", "COMPLETED"), + ]) + def test_status_mapping(self, sdk_status, expected_status, expected_activity): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update(sdk_status, ctx) + payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert payload["status"] == expected_status + assert payload["activity"] == expected_activity + + def test_user_on_status_update_chained_after_post(self): + call_order = [] + self.api_client.post_agent_optimization_result.side_effect = ( + lambda *a, **kw: call_order.append("post") + ) + user_cb = MagicMock(side_effect=lambda s, c: call_order.append("user")) + options = _make_from_config_options(on_status_update=user_cb) + result = self._build(options=options) + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) + assert call_order == ["post", "user"] + + def test_user_on_status_update_exception_does_not_propagate(self): + options = _make_from_config_options( + on_status_update=MagicMock(side_effect=RuntimeError("cb boom")) + ) + result = self._build(options=options) + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) # must not raise + + def test_payload_history_not_included(self): + result = self._build() + ctx = OptimizationContext( + scores={}, completion_response="", current_instructions="", + current_parameters={}, current_variables={}, iteration=1, + ) + result.on_status_update("generating", ctx) + payload = self.api_client.post_agent_optimization_result.call_args[0][2] + assert "history" not in payload + + +# --------------------------------------------------------------------------- +# optimize_from_config +# --------------------------------------------------------------------------- + + +class TestOptimizeFromConfig: + def setup_method(self): + self.mock_ldai = _make_ldai_client() + + def _make_client_with_key(self) -> OptimizationClient: + with patch.dict("os.environ", {"LAUNCHDARKLY_API_KEY": "test-api-key"}): + return _make_client(self.mock_ldai) + + def _make_client_without_key(self) -> OptimizationClient: + with patch.dict("os.environ", {}, clear=True): + import os + os.environ.pop("LAUNCHDARKLY_API_KEY", None) + client = OptimizationClient(self.mock_ldai) + client._has_api_key = False + client._api_key = None + return client + + async def test_raises_without_api_key(self): + client = self._make_client_without_key() + options = _make_from_config_options() + with pytest.raises(ValueError, match="LAUNCHDARKLY_API_KEY is not set"): + await client.optimize_from_config("my-opt", options) + + async def test_fetches_config_and_uses_ai_config_key(self): + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + options = _make_from_config_options() + await client.optimize_from_config("my-opt", options) + + mock_api.get_agent_optimization.assert_called_once_with("my-project", "my-opt") + assert client._agent_key == "my-agent" + + async def test_posts_result_on_each_status_event(self): + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + options = _make_from_config_options() + await client.optimize_from_config("my-opt", options) + + assert mock_api.post_agent_optimization_result.call_count >= 1 + + async def test_user_on_status_update_called_during_run(self): + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + statuses = [] + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + options = _make_from_config_options( + on_status_update=lambda status, ctx: statuses.append(status) + ) + await client.optimize_from_config("my-opt", options) + + assert "generating" in statuses + assert "success" in statuses + + async def test_custom_base_url_passed_to_api_client(self): + client = self._make_client_with_key() + + with patch("ldai_optimization.client.LDApiClient") as MockLDApiClient: + instance = _make_mock_api_client() + instance.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + MockLDApiClient.return_value = instance + options = _make_from_config_options(base_url="https://staging.launchdarkly.com") + await client.optimize_from_config("my-opt", options) + + MockLDApiClient.assert_called_once_with( + "test-api-key", base_url="https://staging.launchdarkly.com" + ) + + async def test_no_base_url_does_not_pass_kwarg(self): + client = self._make_client_with_key() + + with patch("ldai_optimization.client.LDApiClient") as MockLDApiClient: + instance = _make_mock_api_client() + instance.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + MockLDApiClient.return_value = instance + options = _make_from_config_options() + await client.optimize_from_config("my-opt", options) + + MockLDApiClient.assert_called_once_with("test-api-key") + + async def test_returns_optimization_context_on_success(self): + client = self._make_client_with_key() + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + options = _make_from_config_options() + result = await client.optimize_from_config("my-opt", options) + + assert isinstance(result, OptimizationContext) + assert result.completion_response == "The answer is 4." diff --git a/packages/optimization/tests/test_ld_api_client.py b/packages/optimization/tests/test_ld_api_client.py new file mode 100644 index 0000000..da79025 --- /dev/null +++ b/packages/optimization/tests/test_ld_api_client.py @@ -0,0 +1,304 @@ +"""Tests for ldai_optimization.ld_api_client.""" + +import json +import urllib.error +import urllib.request +from io import BytesIO +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import pytest + +from ldai_optimization.ld_api_client import ( + AgentOptimizationConfig, + LDApiClient, + LDApiError, + OptimizationResultPayload, + _parse_agent_optimization, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +_BASE_CONFIG: Dict[str, Any] = { + "id": "opt-uuid-123", + "key": "my-optimization", + "aiConfigKey": "my-agent", + "maxAttempts": 3, + "modelChoices": ["gpt-4o", "gpt-4o-mini"], + "judgeModel": "gpt-4o", + "variableChoices": [{"language": "English"}], + "acceptanceStatements": [{"statement": "Be accurate.", "threshold": 0.9}], + "judges": [], + "userInputOptions": ["What is 2+2?"], + "version": 1, + "createdAt": 1700000000, +} + + +def _make_config(**overrides: Any) -> Dict[str, Any]: + return {**_BASE_CONFIG, **overrides} + + +def _mock_urlopen(response_data: Any, status: int = 200) -> MagicMock: + """Return a context-manager mock whose .read() returns JSON-encoded response_data.""" + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(response_data).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + return mock_resp + + +# --------------------------------------------------------------------------- +# _parse_agent_optimization +# --------------------------------------------------------------------------- + + +class TestParseAgentOptimization: + def test_valid_config_is_returned_unchanged(self): + config = _make_config() + result = _parse_agent_optimization(config) + assert result["id"] == "opt-uuid-123" + assert result["aiConfigKey"] == "my-agent" + + def test_optional_fields_not_required(self): + config = _make_config() + # groundTruthResponses and metricKey are optional — should not raise + assert "groundTruthResponses" not in config + assert "metricKey" not in config + _parse_agent_optimization(config) # must not raise + + def test_raises_on_non_dict_input(self): + with pytest.raises(ValueError, match="Expected a JSON object"): + _parse_agent_optimization(["not", "a", "dict"]) + + def test_raises_on_none_input(self): + with pytest.raises(ValueError, match="Expected a JSON object"): + _parse_agent_optimization(None) + + @pytest.mark.parametrize("field", ["id", "key", "aiConfigKey", "judgeModel"]) + def test_raises_on_missing_required_string_field(self, field: str): + config = _make_config() + del config[field] + with pytest.raises(ValueError, match=f"missing required field '{field}'"): + _parse_agent_optimization(config) + + @pytest.mark.parametrize("field", ["maxAttempts", "version", "createdAt"]) + def test_raises_on_missing_required_int_field(self, field: str): + config = _make_config() + del config[field] + with pytest.raises(ValueError, match=f"missing required field '{field}'"): + _parse_agent_optimization(config) + + @pytest.mark.parametrize( + "field", + ["modelChoices", "variableChoices", "acceptanceStatements", "judges", "userInputOptions"], + ) + def test_raises_on_missing_required_list_field(self, field: str): + config = _make_config() + del config[field] + with pytest.raises(ValueError, match=f"missing required field '{field}'"): + _parse_agent_optimization(config) + + def test_raises_on_wrong_type_for_string_field(self): + config = _make_config(aiConfigKey=123) + with pytest.raises(ValueError, match="field 'aiConfigKey' must be a string"): + _parse_agent_optimization(config) + + def test_raises_on_wrong_type_for_int_field(self): + config = _make_config(maxAttempts="three") + with pytest.raises(ValueError, match="field 'maxAttempts' must be an integer"): + _parse_agent_optimization(config) + + def test_raises_on_wrong_type_for_list_field(self): + config = _make_config(modelChoices="gpt-4o") + with pytest.raises(ValueError, match="field 'modelChoices' must be a list"): + _parse_agent_optimization(config) + + def test_raises_when_model_choices_is_empty(self): + config = _make_config(modelChoices=[]) + with pytest.raises(ValueError, match="at least 1 entry"): + _parse_agent_optimization(config) + + def test_collects_multiple_errors_in_one_raise(self): + config = _make_config() + del config["id"] + del config["maxAttempts"] + config["modelChoices"] = "bad" + with pytest.raises(ValueError) as exc_info: + _parse_agent_optimization(config) + msg = str(exc_info.value) + assert "id" in msg + assert "maxAttempts" in msg + assert "modelChoices" in msg + + +# --------------------------------------------------------------------------- +# LDApiClient._request +# --------------------------------------------------------------------------- + + +class TestLDApiClientRequest: + def test_get_does_not_send_content_type(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client._request("GET", "/some/path") + req: urllib.request.Request = mock_open.call_args[0][0] + assert "Content-Type" not in req.headers + + def test_post_sends_content_type(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client._request("POST", "/some/path", body={"key": "value"}) + req: urllib.request.Request = mock_open.call_args[0][0] + assert req.get_header("Content-type") == "application/json" + + def test_authorization_header_always_sent(self): + client = LDApiClient("my-api-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client._request("GET", "/path") + req: urllib.request.Request = mock_open.call_args[0][0] + assert req.get_header("Authorization") == "my-api-key" + + def test_raises_ld_api_error_on_http_error(self): + client = LDApiClient("test-key") + http_error = urllib.error.HTTPError( + url="http://x", code=404, msg="Not Found", hdrs=MagicMock(), fp=BytesIO(b"not found body") + ) + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(LDApiError) as exc_info: + client._request("GET", "/missing") + assert exc_info.value.status_code == 404 + assert "404" in str(exc_info.value) + + def test_raises_ld_api_error_on_url_error(self): + client = LDApiClient("test-key") + url_error = urllib.error.URLError(reason="Connection refused") + with patch("urllib.request.urlopen", side_effect=url_error): + with pytest.raises(LDApiError) as exc_info: + client._request("GET", "/path") + assert exc_info.value.status_code is None + assert "Connection refused" in str(exc_info.value) + + def test_401_error_includes_api_key_hint(self): + client = LDApiClient("test-key") + http_error = urllib.error.HTTPError( + url="http://x", code=401, msg="Unauthorized", hdrs=MagicMock(), fp=BytesIO(b"") + ) + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(LDApiError, match="LAUNCHDARKLY_API_KEY"): + client._request("GET", "/path") + + def test_404_error_includes_key_hint(self): + client = LDApiClient("test-key") + http_error = urllib.error.HTTPError( + url="http://x", code=404, msg="Not Found", hdrs=MagicMock(), fp=BytesIO(b"") + ) + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(LDApiError, match="project key"): + client._request("GET", "/path") + + def test_custom_base_url_used_in_request(self): + client = LDApiClient("test-key", base_url="https://staging.launchdarkly.com") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client._request("GET", "/api/v2/test") + req: urllib.request.Request = mock_open.call_args[0][0] + assert req.full_url.startswith("https://staging.launchdarkly.com") + + def test_trailing_slash_stripped_from_base_url(self): + client = LDApiClient("test-key", base_url="https://app.launchdarkly.com/") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client._request("GET", "/api/v2/test") + req: urllib.request.Request = mock_open.call_args[0][0] + assert "//" not in req.full_url.replace("https://", "") + + +# --------------------------------------------------------------------------- +# LDApiClient.get_agent_optimization +# --------------------------------------------------------------------------- + + +class TestGetAgentOptimization: + def test_requests_correct_path(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen(_make_config())) as mock_open: + client.get_agent_optimization("my-project", "my-opt-key") + req: urllib.request.Request = mock_open.call_args[0][0] + assert "/api/v2/projects/my-project/agent-optimizations/my-opt-key" in req.full_url + + def test_returns_validated_config(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen(_make_config())): + result = client.get_agent_optimization("proj", "opt") + assert result["aiConfigKey"] == "my-agent" + assert result["maxAttempts"] == 3 + + def test_raises_on_invalid_response(self): + client = LDApiClient("test-key") + bad_response = {"id": "x"} # missing many required fields + with patch("urllib.request.urlopen", return_value=_mock_urlopen(bad_response)): + with pytest.raises(ValueError, match="Invalid AgentOptimization response"): + client.get_agent_optimization("proj", "opt") + + def test_raises_ld_api_error_on_http_404(self): + client = LDApiClient("test-key") + http_error = urllib.error.HTTPError( + url="http://x", code=404, msg="Not Found", hdrs=MagicMock(), fp=BytesIO(b"not found") + ) + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(LDApiError) as exc_info: + client.get_agent_optimization("proj", "missing-key") + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# LDApiClient.post_agent_optimization_result +# --------------------------------------------------------------------------- + + +class TestPostAgentOptimizationResult: + def _make_payload(self) -> OptimizationResultPayload: + return { + "run_id": "run-abc", + "config_optimization_version": 1, + "status": "RUNNING", + "activity": "GENERATING", + "iteration": 1, + "instructions": "You are a helpful assistant.", + "parameters": {"temperature": 0.7}, + "completion_response": "The answer is 4.", + "scores": {}, + } + + def test_requests_correct_path(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client.post_agent_optimization_result("my-project", "opt-uuid", self._make_payload()) + req: urllib.request.Request = mock_open.call_args[0][0] + assert "/api/v2/projects/my-project/agent-optimizations/opt-uuid/results" in req.full_url + + def test_sends_payload_as_json_body(self): + client = LDApiClient("test-key") + payload = self._make_payload() + with patch("urllib.request.urlopen", return_value=_mock_urlopen({})) as mock_open: + client.post_agent_optimization_result("proj", "opt-id", payload) + req: urllib.request.Request = mock_open.call_args[0][0] + sent = json.loads(req.data.decode()) + assert sent["run_id"] == "run-abc" + assert sent["status"] == "RUNNING" + assert sent["instructions"] == "You are a helpful assistant." + + def test_swallows_http_errors_without_raising(self): + client = LDApiClient("test-key") + http_error = urllib.error.HTTPError( + url="http://x", code=500, msg="Server Error", hdrs=MagicMock(), fp=BytesIO(b"err") + ) + with patch("urllib.request.urlopen", side_effect=http_error): + # must not raise + client.post_agent_optimization_result("proj", "opt-id", self._make_payload()) + + def test_swallows_url_errors_without_raising(self): + client = LDApiClient("test-key") + with patch("urllib.request.urlopen", side_effect=urllib.error.URLError("timeout")): + client.post_agent_optimization_result("proj", "opt-id", self._make_payload())