From 44c8c59434d9f9f38251e636ec091cb7f91abaff Mon Sep 17 00:00:00 2001 From: Andrew Klatzke Date: Fri, 3 Apr 2026 14:06:20 -0800 Subject: [PATCH] feat: ground truth optimization path --- .../src/ldai_optimization/__init__.py | 4 + .../src/ldai_optimization/client.py | 369 +++++++++++++++- .../src/ldai_optimization/dataclasses.py | 120 +++-- packages/optimization/tests/test_client.py | 413 ++++++++++++++++++ 4 files changed, 860 insertions(+), 46 deletions(-) diff --git a/packages/optimization/src/ldai_optimization/__init__.py b/packages/optimization/src/ldai_optimization/__init__.py index 5ee175f..87401b3 100644 --- a/packages/optimization/src/ldai_optimization/__init__.py +++ b/packages/optimization/src/ldai_optimization/__init__.py @@ -6,6 +6,8 @@ from ldai_optimization.client import OptimizationClient from ldai_optimization.dataclasses import ( AIJudgeCallConfig, + GroundTruthOptimizationOptions, + GroundTruthSample, OptimizationContext, OptimizationFromConfigOptions, OptimizationJudge, @@ -20,6 +22,8 @@ __all__ = [ '__version__', 'AIJudgeCallConfig', + 'GroundTruthOptimizationOptions', + 'GroundTruthSample', 'LDApiError', 'OptimizationClient', 'OptimizationContext', diff --git a/packages/optimization/src/ldai_optimization/client.py b/packages/optimization/src/ldai_optimization/client.py index ff3dde9..8f0b287 100644 --- a/packages/optimization/src/ldai_optimization/client.py +++ b/packages/optimization/src/ldai_optimization/client.py @@ -6,7 +6,7 @@ import os import random import uuid -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from ldai import AIAgentConfig, AIJudgeConfig, AIJudgeConfigDefault, LDAIClient from ldai.models import LDMessage, ModelConfig @@ -15,6 +15,8 @@ from ldai_optimization.dataclasses import ( AIJudgeCallConfig, AutoCommitConfig, + GroundTruthOptimizationOptions, + GroundTruthSample, JudgeResult, OptimizationContext, OptimizationFromConfigOptions, @@ -99,7 +101,12 @@ def __init__(self, ldClient: LDAIClient) -> None: def _initialize_class_members_from_config( self, agent_config: AIAgentConfig ) -> None: - self._current_instructions = agent_config.instructions or "" + if not agent_config.instructions: + raise ValueError( + f"Agent '{agent_config.key}' has no instructions configured. " + "Ensure the agent flag has instructions set before running an optimization." + ) + self._current_instructions = agent_config.instructions self._current_parameters: Dict[str, Any] = ( agent_config.model._parameters if agent_config.model else None ) or {} @@ -193,7 +200,7 @@ def _safe_status_update( if self._options.on_status_update: try: self._options.on_status_update(status, context.copy_without_history()) - except Exception as e: + except Exception: logger.exception( "[Iteration %d] -> on_status_update callback failed", iteration ) @@ -333,6 +340,7 @@ async def _call_judges( user_input: str, variables: Optional[Dict[str, Any]] = None, agent_tools: Optional[List[ToolDefinition]] = None, + expected_response: Optional[str] = None, ) -> Dict[str, JudgeResult]: """ Call all judges in parallel (auto-path). @@ -347,6 +355,8 @@ async def _call_judges( self._history when judges run) :param variables: The variable set that was used during the agent generation :param agent_tools: Normalised list of tool dicts that were available to the agent + :param expected_response: Optional ground truth expected response. When provided, + judges are instructed to factor it into their scoring alongside acceptance criteria. :return: Dictionary of judge results (score and rationale) """ if not self._options.judges: @@ -385,6 +395,7 @@ async def _call_judges( user_input=user_input, variables=resolved_variables, agent_tools=resolved_agent_tools, + expected_response=expected_response, ) judge_results[judge_key] = result else: @@ -397,6 +408,7 @@ async def _call_judges( user_input=user_input, variables=resolved_variables, agent_tools=resolved_agent_tools, + expected_response=expected_response, ) judge_results[judge_key] = result @@ -415,7 +427,7 @@ async def _call_judges( "PASSED" if passed else "FAILED", f" | {result.rationale}" if result.rationale else "", ) - except Exception as e: + except Exception: logger.exception( "[Iteration %d] -> Judge %s evaluation failed", iteration, judge_key ) @@ -439,6 +451,7 @@ async def _evaluate_config_judge( user_input: str, variables: Optional[Dict[str, Any]] = None, agent_tools: Optional[List[ToolDefinition]] = None, + expected_response: Optional[str] = None, ) -> JudgeResult: """ Evaluate using a config-type judge (with judge_key). @@ -451,6 +464,8 @@ async def _evaluate_config_judge( :param user_input: The user's question for this turn :param variables: The variable set that was used during agent generation :param agent_tools: Normalised list of tool dicts that were available to the agent + :param expected_response: Optional ground truth expected response. When provided, + injected into template variables and judge messages. :return: The judge result with score and rationale """ # Config-type judge: fetch judge config on-demand from LaunchDarkly SDK @@ -467,6 +482,8 @@ async def _evaluate_config_judge( "message_history": message_history_text, "response_to_evaluate": completion_response, } + if expected_response is not None: + template_variables["expected_response"] = expected_response assert optimization_judge.judge_key is not None judge_config = self._judge_config( @@ -514,6 +531,13 @@ async def _evaluate_config_judge( else f"Here is the response to evaluate: {completion_response}" ) + if expected_response is not None: + judge_user_input += ( + f"\n\nHere is the expected response: {expected_response}" + "\n\nEvaluate the actual response against both the acceptance criteria AND " + "how closely it matches the expected response. Factor both into your score." + ) + # Rebuild the message list with the updated system content so completions users # receive the same scoring instructions that are baked into `instructions`. updated_messages: List[LDMessage] = [ @@ -600,6 +624,7 @@ async def _evaluate_acceptance_judge( user_input: str, variables: Optional[Dict[str, Any]] = None, agent_tools: Optional[List[ToolDefinition]] = None, + expected_response: Optional[str] = None, ) -> JudgeResult: """ Evaluate using an acceptance statement judge. @@ -612,6 +637,8 @@ async def _evaluate_acceptance_judge( :param user_input: The user's question for this turn :param variables: The variable set that was used during agent generation :param agent_tools: Normalised list of tool dicts that were available to the agent + :param expected_response: Optional ground truth expected response. When provided, + injected into instructions and judge message so the judge can score actual vs. expected. :return: The judge result with score and rationale """ if not optimization_judge.acceptance_statement: @@ -668,6 +695,12 @@ async def _evaluate_acceptance_judge( ] judge_user_input = f"Here is the response to evaluate: {completion_response}" + if expected_response is not None: + judge_user_input += ( + f"\n\nHere is the expected response: {expected_response}" + "\n\nEvaluate the actual response against both the acceptance statement AND " + "how closely it matches the expected response. Factor both into your score." + ) judge_call_config = AIJudgeCallConfig( key=judge_key, @@ -728,6 +761,11 @@ async def _get_agent_config( raw_instructions = raw_variation.get( "instructions", agent_config.instructions ) + if not raw_instructions: + raise ValueError( + f"Agent '{agent_key}' has no instructions configured. " + "Ensure the agent flag has instructions set before running an optimization." + ) self._initial_instructions = raw_instructions agent_config = dataclasses.replace( @@ -753,6 +791,233 @@ async def optimize_from_options( agent_config = await self._get_agent_config(agent_key, context) return await self._run_optimization(agent_config, options) + async def optimize_from_ground_truth_options( + self, agent_key: str, options: GroundTruthOptimizationOptions + ) -> List[OptimizationContext]: + """Execute a ground truth optimization on the given agent. + + Unlike optimize_from_options (which tests random choices until one passes), + this path evaluates all N ground truth samples in each attempt and only + succeeds when every sample passes its judges. A new variation is generated + whenever any sample fails, and all N samples are re-evaluated from scratch + with the updated configuration, up to max_attempts. + + :param agent_key: Identifier of the agent to optimize. + :param options: Ground truth optimization options including the ordered sample list. + :return: List of OptimizationContexts from the final attempt (one per sample). + """ + self._agent_key = agent_key + context = random.choice(options.context_choices) + agent_config = await self._get_agent_config(agent_key, context) + return await self._run_ground_truth_optimization(agent_config, options) + + async def _run_ground_truth_optimization( + self, + agent_config: AIAgentConfig, + gt_options: GroundTruthOptimizationOptions, + ) -> List[OptimizationContext]: + """Run the ground truth optimization loop. + + Uses the "bridge" pattern to reuse existing internal methods (judge evaluation, + variation generation, status callbacks) for the ground truth optimization. + + :param agent_config: Agent configuration from LaunchDarkly. + :param gt_options: Ground truth options supplied by the caller. + :return: List of OptimizationContexts from the final attempt (one per sample). + """ + bridge = OptimizationOptions( + context_choices=gt_options.context_choices, + max_attempts=gt_options.max_attempts, + model_choices=gt_options.model_choices, + judge_model=gt_options.judge_model, + variable_choices=[s.variables for s in gt_options.ground_truth_responses], + handle_agent_call=gt_options.handle_agent_call, + handle_judge_call=gt_options.handle_judge_call, + judges=gt_options.judges, + on_turn=gt_options.on_turn, + on_passing_result=gt_options.on_passing_result, + on_failing_result=gt_options.on_failing_result, + on_status_update=gt_options.on_status_update, + ) + self._options = bridge + self._agent_config = agent_config + self._initialize_class_members_from_config(agent_config) + + # Seed from the first model choice on the first iteration + # so agent calls never receive an empty model string. + if not self._current_model and bridge.model_choices: + self._current_model = bridge.model_choices[0] + logger.debug( + "[GT] -> No model in agent config; defaulting to first model choice: %s", + self._current_model, + ) + + samples = gt_options.ground_truth_responses + n = len(samples) + + initial_context = self._create_optimization_context( + iteration=0, + variables=samples[0].variables, + ) + self._safe_status_update("init", initial_context, 0) + + # Attempt tracks the current "batch" loop that runs + # through all N samples. Iteration in this context refers to the + # total number of batch runs so far. + attempt = 0 + while True: + attempt += 1 + logger.info( + "[GT Attempt %d/%d] -> Starting ground truth run (%d samples, model=%s)", + attempt, + gt_options.max_attempts, + n, + self._current_model, + ) + + attempt_results: List[OptimizationContext] = [] + all_passed = True + failed_count = 0 + + # Now iterate through each individual sample in the batch, + # creating a new context for each sample + running judges etc. + for i, sample in enumerate(samples): + linear_iter = (attempt - 1) * n + i + 1 + truncated = len(sample.user_input) > 100 + logger.info( + "[GT Attempt %d] -> Sample %d/%d (user_input=%.100s%s)", + attempt, + i + 1, + n, + sample.user_input, + "..." if truncated else "", + ) + + optimize_context = self._create_optimization_context( + iteration=linear_iter, + user_input=sample.user_input, + variables=sample.variables, + ) + + self._safe_status_update("generating", optimize_context, linear_iter) + optimize_context = await self._execute_agent_turn( + optimize_context, + linear_iter, + expected_response=sample.expected_response, + ) + + # Per-sample pass/fail check + if self._options.on_turn is not None: + try: + sample_passed = self._options.on_turn(optimize_context) + except Exception: + logger.exception( + "[GT Attempt %d] -> Sample %d on_turn evaluation failed", + attempt, + i + 1, + ) + sample_passed = False + else: + sample_passed = self._evaluate_response(optimize_context) + + if not sample_passed: + logger.info( + "[GT Attempt %d] -> Sample %d/%d FAILED", + attempt, + i + 1, + n, + ) + all_passed = False + failed_count += 1 + else: + logger.debug( + "[GT Attempt %d] -> Sample %d/%d passed", + attempt, + i + 1, + n, + ) + + attempt_results.append(optimize_context) + + if gt_options.on_sample_result is not None: + try: + gt_options.on_sample_result(optimize_context) + except Exception: + logger.exception( + "[GT Attempt %d] -> on_sample_result callback failed for sample %d", + attempt, + i + 1, + ) + + last_ctx = attempt_results[-1] + + if all_passed: + logger.info( + "[GT Attempt %d] -> All %d samples passed — optimization succeeded", + attempt, + n, + ) + self._safe_status_update("success", last_ctx, last_ctx.iteration) + if self._options.on_passing_result: + try: + self._options.on_passing_result(last_ctx) + except Exception: + logger.exception( + "[GT Attempt %d] -> on_passing_result callback failed", attempt + ) + return attempt_results + + # We've hit max attempts for the batches, bail at this point + if attempt >= gt_options.max_attempts: + logger.warning( + "[GT Optimization] -> Failed after %d attempt(s) — not all samples passed", + attempt, + ) + self._safe_status_update("failure", last_ctx, last_ctx.iteration) + if self._options.on_failing_result: + try: + self._options.on_failing_result(last_ctx) + except Exception: + logger.exception( + "[GT Attempt %d] -> on_failing_result callback failed", attempt + ) + return attempt_results + + # Append all N results to history so the variation generator has full context + # from all of the previous samples + self._history.extend(attempt_results) + + logger.info( + "[GT Attempt %d] -> %d/%d samples failed — generating new variation", + attempt, + failed_count, + n, + ) + try: + await self._generate_new_variation(last_ctx.iteration, last_ctx.current_variables) + except Exception: + logger.exception( + "[GT Attempt %d] -> Variation generation failed", attempt + ) + self._safe_status_update("failure", last_ctx, last_ctx.iteration) + if self._options.on_failing_result: + try: + self._options.on_failing_result(last_ctx) + except Exception: + logger.exception( + "[GT Attempt %d] -> on_failing_result callback failed", attempt + ) + return attempt_results + + self._safe_status_update("turn completed", last_ctx, last_ctx.iteration) + + # Every branch inside the while True loop returns explicitly (success, max-attempts + # exhaustion, or variation-generation failure). This line is structurally unreachable, + # but without it type checkers infer the return type as List[OptimizationContext] | None + # because they don't always treat `while True` as exhaustive. The RuntimeError makes + # the intent unambiguous and causes a loud failure if that invariant is ever broken. + raise RuntimeError("unreachable: ground truth loop exited without returning") + def _apply_new_variation_response( self, response_data: Dict[str, Any], @@ -820,12 +1085,22 @@ def _apply_new_variation_response( else: old_model = self._current_model self._current_model = model_value - logger.info( - "[Iteration %d] -> Model updated from '%s' to '%s'", - iteration, - old_model, - self._current_model, - ) + + # Log regardless of whether we change the model so that logs + # are consistently structured + if old_model != self._current_model: + logger.info( + "[Iteration %d] -> Model updated from '%s' to '%s'", + iteration, + old_model, + self._current_model, + ) + else: + logger.debug( + "[Iteration %d] -> Keeping model '%s'", + iteration, + self._current_model, + ) logger.debug( "[Iteration %d] -> New variation generated: instructions='%s', model=%s, parameters=%s", @@ -957,6 +1232,8 @@ async def optimize_from_config( optimization_options = self._build_options_from_config( config, options, api_client, optimization_id, run_id ) + if isinstance(optimization_options, GroundTruthOptimizationOptions): + return await self._run_ground_truth_optimization(agent_config, optimization_options) return await self._run_optimization(agent_config, optimization_options) def _build_options_from_config( @@ -966,8 +1243,13 @@ def _build_options_from_config( api_client: LDApiClient, optimization_id: str, run_id: str, - ) -> OptimizationOptions: - """Map a fetched AgentOptimization config + user options into OptimizationOptions. + ) -> "Union[OptimizationOptions, GroundTruthOptimizationOptions]": + """Map a fetched AgentOptimization config + user options into the appropriate options type. + + When the config contains groundTruthResponses, the three lists (groundTruthResponses, + userInputOptions, variableChoices) are zipped by index into GroundTruthSample objects + and a GroundTruthOptimizationOptions is returned. Otherwise a standard OptimizationOptions + is returned. Acceptance statements and judge configs from the API are merged into a single judges dict. An on_status_update closure is injected to persist each iteration @@ -979,7 +1261,7 @@ def _build_options_from_config( :param api_client: Initialised LDApiClient for result persistence. :param optimization_id: UUID id of the parent agent_optimization record. :param run_id: UUID that groups all result records for this run. - :return: A fully populated OptimizationOptions ready for _run_optimization. + :return: OptimizationOptions or GroundTruthOptimizationOptions. """ judges: Dict[str, OptimizationJudge] = {} @@ -996,7 +1278,8 @@ def _build_options_from_config( judge_key=judge["key"], ) - has_ground_truth = bool(config.get("groundTruthResponses")) + raw_ground_truth: List[str] = config.get("groundTruthResponses") or [] + has_ground_truth = bool(raw_ground_truth) if not judges and not has_ground_truth and options.on_turn is None: raise ValueError( "The optimization config has no acceptance statements, judges, or ground truth " @@ -1004,9 +1287,6 @@ def _build_options_from_config( "evaluate optimization results." ) - variable_choices: List[Dict[str, Any]] = config["variableChoices"] or [{}] - user_input_options: Optional[List[str]] = config["userInputOptions"] or None - project_key = options.project_key config_version: int = config["version"] @@ -1048,6 +1328,48 @@ def _persist_and_forward( except Exception: logger.exception("User on_status_update callback failed for status=%s", status) + # If we have ground truth responses, we provide a different + # configuration options type that contains the bundled GroundTruthSamples + # so that the ultimate output is correctly formatted. + if has_ground_truth: + user_inputs: List[str] = config["userInputOptions"] or [] + variable_choices_raw: List[Dict[str, Any]] = config["variableChoices"] or [] + + if len(raw_ground_truth) != len(user_inputs) or len(raw_ground_truth) != len(variable_choices_raw): + raise ValueError( + f"groundTruthResponses ({len(raw_ground_truth)}), userInputOptions " + f"({len(user_inputs)}), and variableChoices ({len(variable_choices_raw)}) " + "must all have the same length when groundTruthResponses is provided." + ) + + gt_samples = [ + GroundTruthSample( + user_input=user_inputs[idx], + expected_response=raw_ground_truth[idx], + variables=variable_choices_raw[idx], + ) + for idx in range(len(raw_ground_truth)) + ] + + return GroundTruthOptimizationOptions( + context_choices=options.context_choices, + ground_truth_responses=gt_samples, + max_attempts=config["maxAttempts"], + model_choices=[_strip_provider_prefix(m) for m in config["modelChoices"]], + judge_model=_strip_provider_prefix(config["judgeModel"]), + handle_agent_call=options.handle_agent_call, + handle_judge_call=options.handle_judge_call, + judges=judges or None, + on_turn=options.on_turn, + on_sample_result=options.on_sample_result, + on_passing_result=options.on_passing_result, + on_failing_result=options.on_failing_result, + on_status_update=_persist_and_forward, + ) + + variable_choices: List[Dict[str, Any]] = config["variableChoices"] or [{}] + user_input_options: Optional[List[str]] = config["userInputOptions"] or None + return OptimizationOptions( context_choices=options.context_choices, max_attempts=config["maxAttempts"], @@ -1068,6 +1390,7 @@ async def _execute_agent_turn( self, optimize_context: OptimizationContext, iteration: int, + expected_response: Optional[str] = None, ) -> OptimizationContext: """ Run the agent call and judge scoring for one optimization turn. @@ -1079,6 +1402,8 @@ async def _execute_agent_turn( :param optimize_context: The context for this turn (instructions, model, history, etc.) :param iteration: Current iteration number for logging and status callbacks + :param expected_response: Optional ground truth expected response. When provided, + injected into judge context so judges can score actual vs. expected. :return: Updated context with completion_response and scores filled in """ logger.info( @@ -1116,6 +1441,7 @@ async def _execute_agent_turn( user_input=optimize_context.user_input or "", variables=optimize_context.current_variables, agent_tools=agent_tools, + expected_response=expected_response, ) return dataclasses.replace( @@ -1214,6 +1540,15 @@ async def _run_optimization( self._agent_config = agent_config self._initialize_class_members_from_config(agent_config) + # If the LD flag doesn't carry a model name, seed from the first model choice + # so agent calls never receive an empty model string. + if not self._current_model and options.model_choices: + self._current_model = options.model_choices[0] + logger.debug( + "[Optimization] -> No model in agent config; defaulting to first model choice: %s", + self._current_model, + ) + initial_context = self._create_optimization_context( iteration=0, variables=random.choice(options.variable_choices), diff --git a/packages/optimization/src/ldai_optimization/dataclasses.py b/packages/optimization/src/ldai_optimization/dataclasses.py index 2635053..fdca939 100644 --- a/packages/optimization/src/ldai_optimization/dataclasses.py +++ b/packages/optimization/src/ldai_optimization/dataclasses.py @@ -217,6 +217,16 @@ class OptimizationJudgeContext: Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], Awaitable[str]], ] +_StatusLiteral = Literal[ + "init", + "generating", + "evaluating", + "generating variation", + "turn completed", + "success", + "failure", +] + @dataclass class OptimizationOptions: @@ -250,23 +260,88 @@ class OptimizationOptions: ) on_passing_result: Optional[Callable[[OptimizationContext], None]] = None on_failing_result: Optional[Callable[[OptimizationContext], None]] = None + # called to provide status updates during the optimization flow + on_status_update: Optional[Callable[[_StatusLiteral, OptimizationContext], None]] = None + + def __post_init__(self): + """Validate required options.""" + if len(self.context_choices) < 1: + raise ValueError("context_choices must have at least 1 context") + if len(self.model_choices) < 1: + raise ValueError("model_choices must have at least 1 model") + if self.judges is None and self.on_turn is None: + raise ValueError("Either judges or on_turn must be provided") + if self.judge_model is None: + raise ValueError("judge_model must be provided") + + +@dataclass +class GroundTruthSample: + """A single ground truth evaluation sample for use with optimize_from_ground_truth_options. + + Each sample ties together the user input, expected response, and variable set for one + evaluation. Samples are evaluated in order; the optimization only passes if all samples + pass their judges in the same attempt. + + :param user_input: The user message to send to the agent for this evaluation. + :param expected_response: The ideal response the agent should produce. Injected into + judge context so judges can score actual vs. expected. + :param variables: Variable set interpolated into the agent instructions for this sample. + Defaults to an empty dict if no placeholders are used. + """ + + user_input: str + expected_response: str + variables: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class GroundTruthOptimizationOptions: + """Options for optimize_from_ground_truth_options. + + Mirrors OptimizationOptions but replaces variable_choices / user_input_options with + ground_truth_responses. Each GroundTruthSample bundles the user input, expected + response, and variable set for one evaluation. All N samples must pass their judges + in the same attempt for the optimization to succeed. + + :param context_choices: One or more LD evaluation contexts to use. + :param ground_truth_responses: Ordered list of ground truth samples to evaluate. + At least 1 required. All samples share the same instructions and model being optimized. + :param max_attempts: Maximum number of variation attempts before the run is marked failed. + :param model_choices: Model IDs the variation generator may select from. At least 1 required. + :param judge_model: Model used for judge evaluation. Should remain consistent across attempts. + :param handle_agent_call: Callback that invokes the agent and returns its response. + :param handle_judge_call: Callback that invokes a judge LLM and returns its response. + :param judges: Auto-judges (config judges and/or acceptance statements) to score each response. + :param on_turn: Optional manual pass/fail callback applied per sample; skips judge scoring when provided. + :param on_sample_result: Called with each sample's OptimizationContext as results arrive, + before the overall pass/fail decision is made for the attempt. + :param on_passing_result: Called once with the last context when all N samples pass. + :param on_failing_result: Called once with the last context when max attempts are exhausted. + :param on_status_update: Called on each status transition during the run. + """ + + context_choices: List[Context] + ground_truth_responses: List[GroundTruthSample] + max_attempts: int + model_choices: List[str] + judge_model: str + handle_agent_call: HandleAgentCall + handle_judge_call: HandleJudgeCall + judges: Optional[Dict[str, OptimizationJudge]] = None + on_turn: Optional[Callable[[OptimizationContext], bool]] = None + on_sample_result: Optional[Callable[[OptimizationContext], None]] = None + on_passing_result: Optional[Callable[[OptimizationContext], None]] = None + on_failing_result: Optional[Callable[[OptimizationContext], None]] = None on_status_update: Optional[ Callable[ [ - Literal[ - "init", - "generating", - "evaluating", - "generating variation", - "turn completed", - "success", - "failure", - ], + _StatusLiteral, OptimizationContext, ], None, ] - ] = None # called to provide status updates during the optimization flow + ] = None def __post_init__(self): """Validate required options.""" @@ -274,10 +349,10 @@ def __post_init__(self): raise ValueError("context_choices must have at least 1 context") if len(self.model_choices) < 1: raise ValueError("model_choices must have at least 1 model") + if len(self.ground_truth_responses) < 1: + raise ValueError("ground_truth_responses must have at least 1 sample") if self.judges is None and self.on_turn is None: raise ValueError("Either judges or on_turn must be provided") - if self.judge_model is None: - raise ValueError("judge_model must be provided") @dataclass @@ -293,6 +368,8 @@ class OptimizationFromConfigOptions: :param handle_agent_call: Callback that invokes the agent and returns its response. :param handle_judge_call: Callback that invokes a judge and returns its response. :param on_turn: Optional manual pass/fail callback; when provided, judge scoring is skipped. + :param on_sample_result: Ground truth path only. Called with each sample's + OptimizationContext as results arrive during a ground truth run. :param on_passing_result: Called with the winning OptimizationContext on success. :param on_failing_result: Called with the final OptimizationContext on failure. :param on_status_update: Called on each status transition; chained after the @@ -306,25 +383,10 @@ class OptimizationFromConfigOptions: handle_agent_call: HandleAgentCall handle_judge_call: HandleJudgeCall on_turn: Optional[Callable[["OptimizationContext"], bool]] = None + on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None on_failing_result: Optional[Callable[["OptimizationContext"], None]] = None - on_status_update: Optional[ - Callable[ - [ - Literal[ - "init", - "generating", - "evaluating", - "generating variation", - "turn completed", - "success", - "failure", - ], - "OptimizationContext", - ], - None, - ] - ] = None + on_status_update: Optional[Callable[[_StatusLiteral, "OptimizationContext"], None]] = None base_url: Optional[str] = None def __post_init__(self): diff --git a/packages/optimization/tests/test_client.py b/packages/optimization/tests/test_client.py index 5ece830..8423e7f 100644 --- a/packages/optimization/tests/test_client.py +++ b/packages/optimization/tests/test_client.py @@ -12,6 +12,8 @@ from ldai_optimization.client import OptimizationClient from ldai_optimization.dataclasses import ( AIJudgeCallConfig, + GroundTruthOptimizationOptions, + GroundTruthSample, JudgeResult, OptimizationContext, OptimizationFromConfigOptions, @@ -1393,3 +1395,414 @@ async def test_returns_optimization_context_on_success(self): assert isinstance(result, OptimizationContext) assert result.completion_response == "The answer is 4." + + +# --------------------------------------------------------------------------- +# GroundTruthSample / GroundTruthOptimizationOptions dataclass validation +# --------------------------------------------------------------------------- + + +class TestGroundTruthSampleDataclass: + def test_required_fields(self): + s = GroundTruthSample(user_input="hi", expected_response="hello") + assert s.user_input == "hi" + assert s.expected_response == "hello" + assert s.variables == {} + + def test_variables_populated(self): + s = GroundTruthSample(user_input="hi", expected_response="hello", variables={"lang": "en"}) + assert s.variables == {"lang": "en"} + + +class TestGroundTruthOptimizationOptionsValidation: + def _make(self, **overrides) -> GroundTruthOptimizationOptions: + defaults = dict( + context_choices=[LD_CONTEXT], + ground_truth_responses=[ + GroundTruthSample(user_input="q1", expected_response="a1"), + ], + max_attempts=3, + model_choices=["gpt-4o"], + judge_model="gpt-4o", + handle_agent_call=AsyncMock(return_value="ans"), + handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + judges={ + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + }, + ) + defaults.update(overrides) + return GroundTruthOptimizationOptions(**defaults) + + def test_valid_options_created(self): + opts = self._make() + assert len(opts.ground_truth_responses) == 1 + + def test_raises_empty_context_choices(self): + with pytest.raises(ValueError, match="context_choices"): + self._make(context_choices=[]) + + def test_raises_empty_model_choices(self): + with pytest.raises(ValueError, match="model_choices"): + self._make(model_choices=[]) + + def test_raises_empty_ground_truth_responses(self): + with pytest.raises(ValueError, match="ground_truth_responses"): + self._make(ground_truth_responses=[]) + + def test_raises_no_judges_and_no_on_turn(self): + with pytest.raises(ValueError, match="judges or on_turn"): + self._make(judges=None, on_turn=None) + + def test_on_turn_satisfies_criteria_requirement(self): + opts = self._make(judges=None, on_turn=lambda ctx: True) + assert opts.on_turn is not None + + +# --------------------------------------------------------------------------- +# _run_ground_truth_optimization / optimize_from_ground_truth_options +# --------------------------------------------------------------------------- + + +def _make_gt_options(**overrides) -> GroundTruthOptimizationOptions: + defaults = dict( + context_choices=[LD_CONTEXT], + ground_truth_responses=[ + GroundTruthSample(user_input="What is 2+2?", expected_response="4", variables={"lang": "English"}), + GroundTruthSample(user_input="What is 3+3?", expected_response="6", variables={"lang": "English"}), + ], + max_attempts=3, + model_choices=["gpt-4o", "gpt-4o-mini"], + judge_model="gpt-4o", + handle_agent_call=AsyncMock(return_value="The answer is correct."), + handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + judges={ + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + }, + ) + defaults.update(overrides) + return GroundTruthOptimizationOptions(**defaults) + + +class TestRunGroundTruthOptimization: + def setup_method(self): + self.mock_ldai = _make_ldai_client() + + def _make_client(self) -> OptimizationClient: + return _make_client(self.mock_ldai) + + async def test_returns_list_of_contexts_on_success(self): + client = self._make_client() + opts = _make_gt_options() + results = await client.optimize_from_ground_truth_options("test-agent", opts) + assert isinstance(results, list) + assert len(results) == 2 + for ctx in results: + assert isinstance(ctx, OptimizationContext) + + async def test_each_context_has_correct_user_input(self): + client = self._make_client() + opts = _make_gt_options() + results = await client.optimize_from_ground_truth_options("test-agent", opts) + assert results[0].user_input == "What is 2+2?" + assert results[1].user_input == "What is 3+3?" + + async def test_completion_response_set_on_each_context(self): + client = self._make_client() + opts = _make_gt_options(handle_agent_call=AsyncMock(return_value="42")) + results = await client.optimize_from_ground_truth_options("test-agent", opts) + for ctx in results: + assert ctx.completion_response == "42" + + async def test_on_sample_result_called_per_sample(self): + client = self._make_client() + sample_results = [] + opts = _make_gt_options(on_sample_result=lambda ctx: sample_results.append(ctx)) + await client.optimize_from_ground_truth_options("test-agent", opts) + assert len(sample_results) == 2 + + async def test_on_passing_result_called_once_on_success(self): + client = self._make_client() + passing_calls = [] + opts = _make_gt_options(on_passing_result=lambda ctx: passing_calls.append(ctx)) + await client.optimize_from_ground_truth_options("test-agent", opts) + assert len(passing_calls) == 1 + + async def test_on_failing_result_called_when_max_attempts_exceeded(self): + client = self._make_client() + failing_calls = [] + opts = _make_gt_options( + handle_judge_call=AsyncMock(return_value=JUDGE_FAIL_RESPONSE), + max_attempts=2, + on_failing_result=lambda ctx: failing_calls.append(ctx), + ) + results = await client.optimize_from_ground_truth_options("test-agent", opts) + assert isinstance(results, list) + assert len(failing_calls) == 1 + + async def test_generates_variation_when_any_sample_fails(self): + client = self._make_client() + judge_responses = [ + JUDGE_PASS_RESPONSE, # sample 1 attempt 1 — pass + JUDGE_FAIL_RESPONSE, # sample 2 attempt 1 — fail → trigger variation + JUDGE_PASS_RESPONSE, # sample 1 attempt 2 — pass + JUDGE_PASS_RESPONSE, # sample 2 attempt 2 — pass + ] + call_count = 0 + async def side_effect(*args, **kwargs): + nonlocal call_count + resp = judge_responses[call_count] + call_count += 1 + return resp + + opts = _make_gt_options( + handle_judge_call=side_effect, + handle_agent_call=AsyncMock(side_effect=[ + "ans1", "ans2", # attempt 1 samples + VARIATION_RESPONSE, # variation generation + "ans3", "ans4", # attempt 2 samples + ]), + max_attempts=3, + ) + results = await client.optimize_from_ground_truth_options("test-agent", opts) + assert isinstance(results, list) + assert len(results) == 2 + + async def test_iteration_numbers_are_linear_and_unique(self): + client = self._make_client() + opts = _make_gt_options() + results = await client.optimize_from_ground_truth_options("test-agent", opts) + iterations = [ctx.iteration for ctx in results] + assert len(set(iterations)) == len(iterations) + + async def test_on_sample_result_exception_does_not_abort(self): + client = self._make_client() + + def bad_callback(ctx): + raise RuntimeError("boom") + + opts = _make_gt_options(on_sample_result=bad_callback) + results = await client.optimize_from_ground_truth_options("test-agent", opts) + assert len(results) == 2 + + async def test_variables_from_samples_used_per_evaluation(self): + client = self._make_client() + received_contexts = [] + async def capture_agent_call(key, config, ctx, tools): + received_contexts.append(ctx) + return "response" + + opts = _make_gt_options( + ground_truth_responses=[ + GroundTruthSample(user_input="q1", expected_response="a1", variables={"lang": "English"}), + GroundTruthSample(user_input="q2", expected_response="a2", variables={"lang": "French"}), + ], + handle_agent_call=capture_agent_call, + ) + await client.optimize_from_ground_truth_options("test-agent", opts) + assert received_contexts[0].current_variables == {"lang": "English"} + assert received_contexts[1].current_variables == {"lang": "French"} + + async def test_model_falls_back_to_first_model_choice_when_agent_config_has_no_model(self): + """When the LD agent config has no model name the first model_choices entry is used.""" + config_without_model = _make_agent_config(model_name="") + mock_ldai = _make_ldai_client(agent_config=config_without_model) + client = _make_client(mock_ldai) + + observed_models = [] + async def capture(key, config, ctx, tools): + observed_models.append(config.model.name if config.model else None) + return "answer" + + opts = _make_gt_options( + handle_agent_call=capture, + model_choices=["gpt-4o", "gpt-4o-mini"], + ) + await client.optimize_from_ground_truth_options("test-agent", opts) + assert all(m == "gpt-4o" for m in observed_models), ( + f"Expected all agent calls to use 'gpt-4o' (fallback), got: {observed_models}" + ) + + async def test_missing_instructions_raises_value_error(self): + """An agent config with no instructions raises ValueError before the loop starts.""" + config_no_instructions = _make_agent_config(instructions="") + mock_ldai = _make_ldai_client(agent_config=config_no_instructions) + # variation() also needs to return no instructions so the fallback doesn't hide the gap. + mock_ldai._client.variation.return_value = {"instructions": ""} + client = _make_client(mock_ldai) + + opts = _make_gt_options() + with pytest.raises(ValueError, match="has no instructions configured"): + await client.optimize_from_ground_truth_options("test-agent", opts) + + +# --------------------------------------------------------------------------- +# expected_response in judge evaluation +# --------------------------------------------------------------------------- + + +class TestExpectedResponseInJudges: + def setup_method(self): + self.client = _make_client() + self.client._agent_key = "test-agent" + self.client._options = _make_options() + self.client._agent_config = _make_agent_config() + self.client._initialize_class_members_from_config(_make_agent_config()) + + async def test_expected_response_included_in_acceptance_judge_user_message(self): + captured_configs = [] + + async def capture_judge_call(key, config, ctx, tools): + captured_configs.append(config) + return JUDGE_PASS_RESPONSE + + self.client._options = _make_options( + judges={ + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + }, + handle_judge_call=capture_judge_call, + ) + await self.client._execute_agent_turn( + self.client._create_optimization_context(iteration=1, variables={}), + 1, + expected_response="The expected answer is 42.", + ) + assert len(captured_configs) == 1 + user_msg = captured_configs[0].messages[-1].content + assert "The expected answer is 42." in user_msg + + async def test_expected_response_in_acceptance_judge_user_message(self): + captured_configs = [] + + async def capture_judge_call(key, config, ctx, tools): + captured_configs.append(config) + return JUDGE_PASS_RESPONSE + + self.client._options = _make_options( + judges={ + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + }, + handle_judge_call=capture_judge_call, + ) + await self.client._execute_agent_turn( + self.client._create_optimization_context(iteration=1, variables={}), + 1, + expected_response="gold standard", + ) + user_msg = captured_configs[0].messages[1].content + assert "gold standard" in user_msg + assert "expected response" in user_msg.lower() + # Scoring instructions should now live in the user message, not the system prompt + system_msg = captured_configs[0].messages[0].content + assert "gold standard" not in system_msg + + async def test_no_expected_response_leaves_judge_messages_unchanged(self): + captured_configs = [] + + async def capture_judge_call(key, config, ctx, tools): + captured_configs.append(config) + return JUDGE_PASS_RESPONSE + + self.client._options = _make_options( + judges={ + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + }, + handle_judge_call=capture_judge_call, + ) + await self.client._execute_agent_turn( + self.client._create_optimization_context(iteration=1, variables={}), + 1, + ) + user_msg = captured_configs[0].messages[-1].content + assert "expected response" not in user_msg.lower() + + +# --------------------------------------------------------------------------- +# _build_options_from_config — ground truth path +# --------------------------------------------------------------------------- + + +_API_CONFIG_WITH_GT: Dict[str, Any] = { + "id": "opt-gt-uuid", + "key": "my-gt-optimization", + "aiConfigKey": "my-agent", + "maxAttempts": 3, + "modelChoices": ["gpt-4o"], + "judgeModel": "gpt-4o", + "variableChoices": [{"lang": "English"}, {"lang": "French"}], + "acceptanceStatements": [{"statement": "Be accurate.", "threshold": 0.9}], + "judges": [], + "userInputOptions": ["What is 2+2?", "What is 3+3?"], + "groundTruthResponses": ["4", "6"], + "version": 1, + "createdAt": 1700000000, +} + + +class TestBuildOptionsFromConfigGroundTruth: + def setup_method(self): + self.client = _make_client() + self.client._agent_key = "my-agent" + self.client._initialize_class_members_from_config(_make_agent_config()) + self.client._options = _make_options() + self.api_client = _make_mock_api_client() + + def _build(self, config=None, options=None): + return self.client._build_options_from_config( + config or dict(_API_CONFIG_WITH_GT), + options or _make_from_config_options(), + self.api_client, + optimization_id="opt-gt-uuid", + run_id="run-uuid-789", + ) + + def test_returns_ground_truth_options_when_gt_present(self): + result = self._build() + assert isinstance(result, GroundTruthOptimizationOptions) + + def test_samples_zipped_by_index(self): + result = self._build() + assert isinstance(result, GroundTruthOptimizationOptions) + assert len(result.ground_truth_responses) == 2 + s0 = result.ground_truth_responses[0] + assert s0.user_input == "What is 2+2?" + assert s0.expected_response == "4" + assert s0.variables == {"lang": "English"} + s1 = result.ground_truth_responses[1] + assert s1.user_input == "What is 3+3?" + assert s1.expected_response == "6" + assert s1.variables == {"lang": "French"} + + def test_model_choices_have_prefix_stripped(self): + config = dict(_API_CONFIG_WITH_GT) + config["modelChoices"] = ["OpenAI.gpt-4o"] + result = self._build(config=config) + assert isinstance(result, GroundTruthOptimizationOptions) + assert result.model_choices == ["gpt-4o"] + + def test_raises_on_mismatched_lengths(self): + config = dict(_API_CONFIG_WITH_GT) + config["userInputOptions"] = ["only one input"] + with pytest.raises(ValueError, match="same length"): + self._build(config=config) + + def test_returns_standard_options_when_no_gt(self): + config = dict(_API_CONFIG) # no groundTruthResponses + result = self._build(config=config) + assert isinstance(result, OptimizationOptions) + + async def test_optimize_from_config_dispatches_to_gt_run(self): + mock_ldai = _make_ldai_client() + with patch.dict("os.environ", {"LAUNCHDARKLY_API_KEY": "test-key"}): + client = _make_client(mock_ldai) + mock_api = _make_mock_api_client() + mock_api.get_agent_optimization = MagicMock(return_value=dict(_API_CONFIG_WITH_GT)) + + with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): + options = _make_from_config_options( + handle_agent_call=AsyncMock(return_value="correct answer"), + handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + ) + result = await client.optimize_from_config("my-gt-opt", options) + + assert isinstance(result, list) + assert len(result) == 2