Skip to content

Commit 506af0f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow using registered metric resource names in evaluation
PiperOrigin-RevId: 880868820
1 parent 50d804f commit 506af0f

File tree

7 files changed

+276
-28
lines changed

7 files changed

+276
-28
lines changed

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,32 @@ def test_evaluation_agent_data(client):
353353
assert case_result.response_candidate_results is not None
354354

355355

356+
def test_metric_resource_name(client):
357+
"""Tests with a metric resource name in types.Metric."""
358+
client._api_client._http_options.api_version = "v1beta1"
359+
client._api_client._http_options.base_url = (
360+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
361+
)
362+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
363+
byor_df = pd.DataFrame(
364+
{
365+
"prompt": ["Write a simple story about a dinosaur"],
366+
"response": ["Once upon a time, there was a T-Rex named Rexy."],
367+
}
368+
)
369+
metric = types.Metric(
370+
name="my_custom_metric", metric_resource_name=metric_resource_name
371+
)
372+
evaluation_result = client.evals.evaluate(
373+
dataset=byor_df,
374+
metrics=[metric],
375+
)
376+
assert isinstance(evaluation_result, types.EvaluationResult)
377+
assert evaluation_result.eval_case_results is not None
378+
assert len(evaluation_result.eval_case_results) > 0
379+
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
380+
381+
356382
pytestmark = pytest_helper.setup(
357383
file=__file__,
358384
globals_for_file=globals(),

vertexai/_genai/_evals_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ def _resolve_evaluation_run_metrics(
13641364
raise
13651365
elif isinstance(metric_instance, types.Metric):
13661366
config_dict = t.t_metrics([metric_instance])[0]
1367-
res_name = config_dict.pop("metric_resource_name", None)
1367+
res_name = getattr(metric_instance, "metric_resource_name", None)
13681368
resolved_metrics_list.append(
13691369
types.EvaluationRunMetric(
13701370
metric=metric_instance.name,

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 157 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import logging
2222
import statistics
2323
import time
24-
from typing import Any, Callable, Optional, TypeVar, Union
24+
from typing import Any, Callable, Generic, Optional, TypeVar, Union
2525

2626
from google.genai import errors as genai_errors
2727
from google.genai import _common
@@ -39,6 +39,9 @@
3939
_MAX_RETRIES = 3
4040

4141

42+
T = TypeVar("T", types.Metric, types.MetricSource, types.LLMMetric)
43+
44+
4245
def _has_tool_call(intermediate_events: Optional[list[types.evals.Event]]) -> bool:
4346
"""Checks if any event in intermediate_events has a function call."""
4447
if not intermediate_events:
@@ -149,12 +152,18 @@ def _default_aggregate_scores(
149152
)
150153

151154

152-
class MetricHandler(abc.ABC):
155+
class MetricHandler(abc.ABC, Generic[T]):
153156
"""Abstract base class for metric handlers."""
154157

155-
def __init__(self, module: "evals.Evals", metric: types.Metric):
158+
def __init__(self, module: "evals.Evals", metric: T):
156159
self.module = module
157-
self.metric = metric
160+
self.metric: T = metric
161+
162+
@property
163+
@abc.abstractmethod
164+
def metric_name(self) -> str:
165+
"""Returns the name of the metric polymorphically."""
166+
raise NotImplementedError()
158167

159168
@abc.abstractmethod
160169
def get_metric_result(
@@ -171,7 +180,7 @@ def aggregate(
171180
raise NotImplementedError()
172181

173182

174-
class ComputationMetricHandler(MetricHandler):
183+
class ComputationMetricHandler(MetricHandler[types.Metric]):
175184
"""Metric handler for computation metrics."""
176185

177186
SUPPORTED_COMPUTATION_METRICS = frozenset(
@@ -188,6 +197,10 @@ class ComputationMetricHandler(MetricHandler):
188197
}
189198
)
190199

200+
@property
201+
def metric_name(self) -> str:
202+
return self.metric.name or "unknown_metric"
203+
191204
def __init__(self, module: "evals.Evals", metric: types.Metric):
192205
super().__init__(module=module, metric=metric)
193206
if self.metric.name not in self.SUPPORTED_COMPUTATION_METRICS:
@@ -299,11 +312,15 @@ def aggregate(
299312
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)
300313

301314

302-
class TranslationMetricHandler(MetricHandler):
315+
class TranslationMetricHandler(MetricHandler[types.Metric]):
303316
"""Metric handler for translation metrics."""
304317

305318
SUPPORTED_TRANSLATION_METRICS = frozenset({"comet", "metricx"})
306319

320+
@property
321+
def metric_name(self) -> str:
322+
return self.metric.name or "unknown_metric"
323+
307324
def __init__(self, module: "evals.Evals", metric: types.Metric):
308325
super().__init__(module=module, metric=metric)
309326

@@ -469,9 +486,13 @@ def aggregate(
469486
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)
470487

471488

472-
class LLMMetricHandler(MetricHandler):
489+
class LLMMetricHandler(MetricHandler[types.LLMMetric]):
473490
"""Metric handler for LLM metrics."""
474491

492+
@property
493+
def metric_name(self) -> str:
494+
return self.metric.name or "unknown_metric"
495+
475496
def __init__(self, module: "evals.Evals", metric: types.LLMMetric):
476497
super().__init__(module=module, metric=metric)
477498

@@ -750,9 +771,13 @@ def aggregate(
750771
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)
751772

752773

753-
class CustomMetricHandler(MetricHandler):
774+
class CustomMetricHandler(MetricHandler[types.Metric]):
754775
"""Metric handler for custom metrics."""
755776

777+
@property
778+
def metric_name(self) -> str:
779+
return self.metric.name or "unknown_metric"
780+
756781
def __init__(self, module: "evals.Evals", metric: types.Metric):
757782
super().__init__(module=module, metric=metric)
758783

@@ -853,9 +878,13 @@ def aggregate(
853878
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)
854879

855880

856-
class PredefinedMetricHandler(MetricHandler):
881+
class PredefinedMetricHandler(MetricHandler[types.Metric]):
857882
"""Metric handler for predefined metrics."""
858883

884+
@property
885+
def metric_name(self) -> str:
886+
return self.metric.name or "unknown_metric"
887+
859888
def __init__(self, module: "evals.Evals", metric: types.Metric):
860889
super().__init__(module=module, metric=metric)
861890
if self.metric.name not in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
@@ -1106,9 +1135,13 @@ def aggregate(
11061135
)
11071136

11081137

1109-
class CustomCodeExecutionMetricHandler(MetricHandler):
1138+
class CustomCodeExecutionMetricHandler(MetricHandler[types.Metric]):
11101139
"""Metric handler for custom code execution metrics."""
11111140

1141+
@property
1142+
def metric_name(self) -> str:
1143+
return self.metric.name or "unknown_metric"
1144+
11121145
def __init__(self, module: "evals.Evals", metric: types.Metric):
11131146
super().__init__(module=module, metric=metric)
11141147

@@ -1242,6 +1275,108 @@ def aggregate(
12421275
)
12431276

12441277

1278+
class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
1279+
"""Metric handler for registered metrics."""
1280+
1281+
def __init__(
1282+
self,
1283+
module: "evals.Evals",
1284+
metric: Union[types.MetricSource, types.MetricSourceDict],
1285+
):
1286+
if isinstance(metric, dict):
1287+
metric = types.MetricSource(**metric)
1288+
super().__init__(module=module, metric=metric)
1289+
1290+
# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
1291+
def _build_request_payload(
1292+
self, eval_case: types.EvalCase, response_index: int
1293+
) -> dict[str, Any]:
1294+
"""Builds request payload for registered metric."""
1295+
if not self.metric.metric:
1296+
raise ValueError(
1297+
"Registered metric must have an underlying metric definition."
1298+
)
1299+
return PredefinedMetricHandler(
1300+
self.module, metric=self.metric.metric
1301+
)._build_request_payload(eval_case, response_index)
1302+
1303+
@property
1304+
def metric_name(self) -> str:
1305+
# Resolve name from resource name or internal metric name
1306+
if isinstance(self.metric, types.MetricSource):
1307+
if self.metric.metric and self.metric.metric.name:
1308+
return self.metric.metric.name
1309+
if self.metric.metric_resource_name:
1310+
return self.metric.metric_resource_name
1311+
return "unknown"
1312+
else: # Should be Metric
1313+
metric_like = self.metric
1314+
if metric_like.name:
1315+
return metric_like.name
1316+
if metric_like.metric_resource_name:
1317+
return metric_like.metric_resource_name
1318+
return "unknown"
1319+
1320+
@override
1321+
def get_metric_result(
1322+
self, eval_case: types.EvalCase, response_index: int
1323+
) -> types.EvalCaseMetricResult:
1324+
"""Processes a single evaluation case for a registered metric."""
1325+
metric_name = self.metric_name
1326+
try:
1327+
payload = self._build_request_payload(eval_case, response_index)
1328+
for attempt in range(_MAX_RETRIES):
1329+
try:
1330+
api_response = self.module._evaluate_instances(
1331+
metric_sources=[self.metric],
1332+
instance=payload.get("instance"),
1333+
autorater_config=payload.get("autorater_config"),
1334+
)
1335+
break
1336+
except genai_errors.ClientError as e:
1337+
if e.code == 429:
1338+
if attempt == _MAX_RETRIES - 1:
1339+
return types.EvalCaseMetricResult(
1340+
metric_name=metric_name,
1341+
error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}",
1342+
)
1343+
time.sleep(2**attempt)
1344+
else:
1345+
raise e
1346+
1347+
if api_response and api_response.metric_results:
1348+
result_data = api_response.metric_results[0]
1349+
error_message = None
1350+
if result_data.error and getattr(result_data.error, "code"):
1351+
error_message = f"Error in metric result: {result_data.error}"
1352+
return types.EvalCaseMetricResult(
1353+
metric_name=metric_name,
1354+
score=result_data.score,
1355+
explanation=result_data.explanation,
1356+
rubric_verdicts=result_data.rubric_verdicts,
1357+
error_message=error_message,
1358+
)
1359+
else:
1360+
return types.EvalCaseMetricResult(
1361+
metric_name=metric_name,
1362+
error_message="Metric results missing in API response.",
1363+
)
1364+
except Exception as e:
1365+
return types.EvalCaseMetricResult(
1366+
metric_name=metric_name, error_message=str(e)
1367+
)
1368+
1369+
@override
1370+
def aggregate(
1371+
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
1372+
) -> types.AggregatedMetricResult:
1373+
"""Aggregates the metric results for a registered metric."""
1374+
logger.debug("Aggregating results for registered metric: %s", self.metric_name)
1375+
return _default_aggregate_scores(
1376+
self.metric_name, eval_case_metric_results, calculate_pass_rate=True
1377+
)
1378+
1379+
12451380
_METRIC_HANDLER_MAPPING = [
12461381
(
12471382
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
@@ -1251,6 +1386,10 @@ def aggregate(
12511386
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
12521387
CustomMetricHandler,
12531388
),
1389+
(
1390+
lambda m: getattr(m, "metric_resource_name", None) is not None,
1391+
RegisteredMetricHandler,
1392+
),
12541393
(
12551394
lambda m: m.name in ComputationMetricHandler.SUPPORTED_COMPUTATION_METRICS,
12561395
ComputationMetricHandler,
@@ -1337,14 +1476,14 @@ def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]:
13371476

13381477

13391478
def _aggregate_metric_results(
1340-
metric_handlers: list[MetricHandler],
1479+
metric_handlers: list[MetricHandler[Any]],
13411480
eval_case_results: list[types.EvalCaseResult],
13421481
) -> list[types.AggregatedMetricResult]:
13431482
"""Aggregates results by calling the aggregate method of each handler."""
13441483
aggregated_metric_results = []
13451484
logger.info("Aggregating results per metric...")
13461485
for handler in metric_handlers:
1347-
metric_name = handler.metric.name
1486+
metric_name = handler.metric_name
13481487
results_for_this_metric: list[types.EvalCaseMetricResult] = []
13491488
for case_result in eval_case_results:
13501489
if case_result.response_candidate_results:
@@ -1473,12 +1612,12 @@ def compute_metrics_and_aggregate(
14731612
"response %d for metric %s.",
14741613
eval_case_index,
14751614
response_index,
1476-
metric_handler_instance.metric.name,
1615+
metric_handler_instance.metric_name,
14771616
)
14781617
all_futures.append(
14791618
(
14801619
future,
1481-
metric_handler_instance.metric.name,
1620+
metric_handler_instance.metric_name,
14821621
eval_case_index,
14831622
response_index,
14841623
)
@@ -1489,25 +1628,25 @@ def compute_metrics_and_aggregate(
14891628
"response %d for metric %s: %s",
14901629
eval_case_index,
14911630
response_index,
1492-
metric_handler_instance.metric.name,
1631+
metric_handler_instance.metric_name,
14931632
e,
14941633
exc_info=True,
14951634
)
14961635
submission_errors.append(
14971636
(
1498-
metric_handler_instance.metric.name,
1637+
metric_handler_instance.metric_name,
14991638
eval_case_index,
15001639
response_index,
15011640
f"Error: {e}",
15021641
)
15031642
)
15041643
error_result = types.EvalCaseMetricResult(
1505-
metric_name=metric_handler_instance.metric.name,
1644+
metric_name=metric_handler_instance.metric_name,
15061645
error_message=f"Submission Error: {e}",
15071646
)
15081647
results_by_case_response_metric[eval_case_index][
15091648
response_index
1510-
][metric_handler_instance.metric.name] = error_result
1649+
][metric_handler_instance.metric_name] = error_result
15111650
case_indices_with_errors.add(eval_case_index)
15121651
pbar.update(1)
15131652

0 commit comments

Comments
 (0)