Skip to content

Commit 82b9c8e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow RegisteredMetricHandler to build payloads without a full inline metric definition
PiperOrigin-RevId: 885215504
1 parent 394253a commit 82b9c8e

2 files changed

Lines changed: 29 additions & 16 deletions

File tree

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def test_metric_resource_name(client):
377377
assert evaluation_result.eval_case_results is not None
378378
assert len(evaluation_result.eval_case_results) > 0
379379
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
380+
assert evaluation_result.summary_metrics[0].mean_score is not None
380381

381382

382383
pytestmark = pytest_helper.setup(

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -965,10 +965,13 @@ def _eval_case_to_agent_data(
965965
events=events,
966966
)
967967

968-
def _build_request_payload(
969-
self, eval_case: types.EvalCase, response_index: int
968+
@staticmethod
969+
def build_evaluation_request_payload(
970+
eval_case: types.EvalCase,
971+
metric: types.Metric,
972+
response_index: int,
970973
) -> dict[str, Any]:
971-
"""Builds the request parameters for evaluate instances request."""
974+
"""Shared logic to build the EvaluationInstance and AutoraterConfig."""
972975
response_content = _get_response_from_eval_case(
973976
eval_case, response_index, self.metric.name
974977
)
@@ -978,7 +981,7 @@ def _build_request_payload(
978981
f"Response content missing for candidate {response_index}."
979982
)
980983

981-
if self.metric.name == "tool_use_quality_v1":
984+
if metric.name == "tool_use_quality_v1":
982985
if not _has_tool_call(eval_case.intermediate_events):
983986
logger.warning(
984987
"Metric 'tool_use_quality_v1' requires tool usage in "
@@ -994,7 +997,7 @@ def _build_request_payload(
994997

995998
extracted_prompt = _get_prompt_from_eval_case(eval_case)
996999
prompt_instance_data = None
997-
if self.metric.name is not None and self.metric.name.startswith("multi_turn"):
1000+
if metric.name is not None and metric.name.startswith("multi_turn"):
9981001
prompt_contents = []
9991002
if eval_case.conversation_history:
10001003
for message in eval_case.conversation_history:
@@ -1050,6 +1053,14 @@ def _build_request_payload(
10501053
)
10511054
return request_payload
10521055

1056+
def _build_request_payload(
1057+
self, eval_case: types.EvalCase, response_index: int
1058+
) -> dict[str, Any]:
1059+
"""Delegates to the shared static method."""
1060+
return self.build_evaluation_request_payload(
1061+
eval_case, self.metric, response_index
1062+
)
1063+
10531064
@override
10541065
def get_metric_result(
10551066
self, eval_case: types.EvalCase, response_index: int
@@ -1281,30 +1292,31 @@ def aggregate(
12811292
)
12821293

12831294

1284-
class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
1295+
class RegisteredMetricHandler(MetricHandler[Union[types.MetricSource, types.Metric]]):
12851296
"""Metric handler for registered metrics."""
12861297

12871298
def __init__(
12881299
self,
12891300
module: "evals.Evals",
1290-
metric: Union[types.MetricSource, types.MetricSourceDict],
1301+
metric: Union[types.MetricSource, types.Metric, types.MetricSourceDict],
12911302
):
12921303
if isinstance(metric, dict):
12931304
metric = types.MetricSource(**metric)
12941305
super().__init__(module=module, metric=metric)
12951306

1296-
# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
12971307
def _build_request_payload(
12981308
self, eval_case: types.EvalCase, response_index: int
12991309
) -> dict[str, Any]:
1300-
"""Builds request payload for registered metric."""
1301-
if not self.metric.metric:
1302-
raise ValueError(
1303-
"Registered metric must have an underlying metric definition."
1304-
)
1305-
return PredefinedMetricHandler(
1306-
self.module, metric=self.metric.metric
1307-
)._build_request_payload(eval_case, response_index)
1310+
"""Builds request payload using shared predefined logic."""
1311+
inline_metric = getattr(self.metric, "metric", None) or (
1312+
self.metric if isinstance(self.metric, types.Metric) else None
1313+
)
1314+
1315+
metric_for_building = inline_metric or types.Metric(name="registered_metric")
1316+
1317+
return PredefinedMetricHandler.build_evaluation_request_payload(
1318+
eval_case, metric_for_building, response_index
1319+
)
13081320

13091321
@property
13101322
def metric_name(self) -> str:

0 commit comments

Comments
 (0)