Skip to content

Commit f853708

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow RegisteredMetricHandler to build payloads without a full inline metric definition
PiperOrigin-RevId: 885215504
1 parent 394253a commit f853708

2 files changed

Lines changed: 62 additions & 45 deletions

File tree

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_evaluation_agent_data(client):
353353
assert case_result.response_candidate_results is not None
354354

355355

356-
def test_metric_resource_name(client):
356+
def test_evaluation_metric_resource_name(client):
357357
"""Tests with a metric resource name in types.Metric."""
358358
client._api_client._http_options.api_version = "v1beta1"
359359
client._api_client._http_options.base_url = (
@@ -375,8 +375,22 @@ def test_metric_resource_name(client):
375375
)
376376
assert isinstance(evaluation_result, types.EvaluationResult)
377377
assert evaluation_result.eval_case_results is not None
378-
assert len(evaluation_result.eval_case_results) > 0
378+
assert len(evaluation_result.eval_case_results) == 1
379379
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
380+
assert evaluation_result.summary_metrics[0].mean_score is not None
381+
assert evaluation_result.summary_metrics[0].num_cases_valid == 1
382+
assert evaluation_result.summary_metrics[0].num_cases_error == 0
383+
384+
case_result = evaluation_result.eval_case_results[0]
385+
assert case_result.response_candidate_results is not None
386+
assert len(case_result.response_candidate_results) == 1
387+
388+
metric_result = case_result.response_candidate_results[0].metric_results[
389+
"my_custom_metric"
390+
]
391+
assert metric_result.score is not None
392+
assert metric_result.score > 0.5
393+
assert metric_result.error_message is None
380394

381395

382396
pytestmark = pytest_helper.setup(

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,74 +1281,78 @@ def aggregate(
12811281
)
12821282

12831283

1284-
class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
1284+
class RegisteredMetricHandler(MetricHandler[types.Metric]):
12851285
"""Metric handler for registered metrics."""
12861286

12871287
def __init__(
12881288
self,
12891289
module: "evals.Evals",
1290-
metric: Union[types.MetricSource, types.MetricSourceDict],
1290+
metric: types.Metric,
12911291
):
12921292
if isinstance(metric, dict):
12931293
metric = types.MetricSource(**metric)
12941294
super().__init__(module=module, metric=metric)
12951295

1296-
# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
12971296
def _build_request_payload(
12981297
self, eval_case: types.EvalCase, response_index: int
12991298
) -> dict[str, Any]:
1300-
"""Builds request payload for registered metric."""
1301-
if not self.metric.metric:
1299+
"""Builds request payload for registered metric by assembling EvaluationInstance."""
1300+
response_content = _get_response_from_eval_case(
1301+
eval_case, response_index, self.metric_name
1302+
)
1303+
1304+
if not response_content and not getattr(eval_case, "agent_data", None):
13021305
raise ValueError(
1303-
"Registered metric must have an underlying metric definition."
1306+
f"Response content missing for candidate {response_index}."
13041307
)
1305-
return PredefinedMetricHandler(
1306-
self.module, metric=self.metric.metric
1307-
)._build_request_payload(eval_case, response_index)
1308+
1309+
reference_instance_data = None
1310+
if eval_case.reference:
1311+
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
1312+
eval_case.reference.response
1313+
)
1314+
1315+
extracted_prompt = _get_prompt_from_eval_case(eval_case)
1316+
prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
1317+
extracted_prompt
1318+
)
1319+
1320+
instance_payload = types.EvaluationInstance(
1321+
prompt=prompt_instance_data,
1322+
response=PredefinedMetricHandler._content_to_instance_data(
1323+
response_content
1324+
),
1325+
reference=reference_instance_data,
1326+
rubric_groups=eval_case.rubric_groups,
1327+
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
1328+
)
1329+
1330+
request_payload = {
1331+
"instance": instance_payload,
1332+
}
1333+
return request_payload
13081334

13091335
@property
13101336
def metric_name(self) -> str:
1311-
# Resolve name from resource name or internal metric name
1312-
if isinstance(self.metric, types.MetricSource):
1313-
if self.metric.metric and self.metric.metric.name:
1314-
return self.metric.metric.name
1315-
if self.metric.metric_resource_name:
1316-
return self.metric.metric_resource_name
1317-
return "unknown"
1318-
else: # Should be Metric
1319-
metric_like = self.metric
1320-
if metric_like.name:
1321-
return metric_like.name
1322-
if metric_like.metric_resource_name:
1323-
return metric_like.metric_resource_name
1324-
return "unknown"
1337+
return self.metric.name
13251338

13261339
@override
13271340
def get_metric_result(
13281341
self, eval_case: types.EvalCase, response_index: int
13291342
) -> types.EvalCaseMetricResult:
1330-
"""Processes a single evaluation case for a registered metric."""
1343+
"""Processes a single evaluation case using a MetricSource reference."""
13311344
metric_name = self.metric_name
1345+
metric_source = types.MetricSource(
1346+
metric_resource_name=self.metric.metric_resource_name
1347+
)
1348+
13321349
try:
13331350
payload = self._build_request_payload(eval_case, response_index)
1334-
for attempt in range(_MAX_RETRIES):
1335-
try:
1336-
api_response = self.module._evaluate_instances(
1337-
metric_sources=[self.metric],
1338-
instance=payload.get("instance"),
1339-
autorater_config=payload.get("autorater_config"),
1340-
)
1341-
break
1342-
except genai_errors.ClientError as e:
1343-
if e.code == 429:
1344-
if attempt == _MAX_RETRIES - 1:
1345-
return types.EvalCaseMetricResult(
1346-
metric_name=metric_name,
1347-
error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}",
1348-
)
1349-
time.sleep(2**attempt)
1350-
else:
1351-
raise e
1351+
api_response = self.module._evaluate_instances(
1352+
metric_sources=[metric_source],
1353+
instance=payload.get("instance"),
1354+
autorater_config=payload.get("autorater_config"),
1355+
)
13521356

13531357
if api_response and api_response.metric_results:
13541358
result_data = api_response.metric_results[0]
@@ -1377,7 +1381,6 @@ def aggregate(
13771381
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
13781382
) -> types.AggregatedMetricResult:
13791383
"""Aggregates the metric results for a registered metric."""
1380-
logger.debug("Aggregating results for registered metric: %s", self.metric_name)
13811384
return _default_aggregate_scores(
13821385
self.metric_name, eval_case_metric_results, calculate_pass_rate=True
13831386
)

0 commit comments

Comments
 (0)