Skip to content

Commit 142247f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Refactor RegisteredMetricHandler and implement llm_based_metric_spec support
PiperOrigin-RevId: 885215504
1 parent 981a551 commit 142247f

File tree

4 files changed

+155
-42
lines changed

4 files changed

+155
-42
lines changed

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16-
16+
import re
1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
1919
from google.genai import types as genai_types
@@ -353,13 +353,22 @@ def test_evaluation_agent_data(client):
353353
assert case_result.response_candidate_results is not None
354354

355355

356-
def test_metric_resource_name(client):
356+
def test_evaluation_metric_resource_name(client):
357357
"""Tests with a metric resource name in types.Metric."""
358358
client._api_client._http_options.api_version = "v1beta1"
359359
client._api_client._http_options.base_url = (
360360
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
361361
)
362-
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
362+
metric_resource_name = client.evals.create_evaluation_metric(
363+
display_name="test_metric",
364+
description="test_description",
365+
metric=types.RubricMetric.GENERAL_QUALITY,
366+
)
367+
assert isinstance(metric_resource_name, str)
368+
assert re.match(
369+
r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$",
370+
metric_resource_name,
371+
)
363372
byor_df = pd.DataFrame(
364373
{
365374
"prompt": ["Write a simple story about a dinosaur"],
@@ -375,8 +384,22 @@ def test_metric_resource_name(client):
375384
)
376385
assert isinstance(evaluation_result, types.EvaluationResult)
377386
assert evaluation_result.eval_case_results is not None
378-
assert len(evaluation_result.eval_case_results) > 0
387+
assert len(evaluation_result.eval_case_results) == 1
379388
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
389+
assert evaluation_result.summary_metrics[0].mean_score is not None
390+
assert evaluation_result.summary_metrics[0].num_cases_valid == 1
391+
assert evaluation_result.summary_metrics[0].num_cases_error == 0
392+
393+
case_result = evaluation_result.eval_case_results[0]
394+
assert case_result.response_candidate_results is not None
395+
assert len(case_result.response_candidate_results) == 1
396+
397+
metric_result = case_result.response_candidate_results[0].metric_results[
398+
"my_custom_metric"
399+
]
400+
assert metric_result.score is not None
401+
assert metric_result.score > 0.2
402+
assert metric_result.error_message is None
380403

381404

382405
pytestmark = pytest_helper.setup(

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,66 +1281,91 @@ def aggregate(
12811281
)
12821282

12831283

1284-
class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
1284+
class RegisteredMetricHandler(MetricHandler[types.Metric]):
12851285
"""Metric handler for registered metrics."""
12861286

12871287
def __init__(
12881288
self,
12891289
module: "evals.Evals",
1290-
metric: Union[types.MetricSource, types.MetricSourceDict],
1290+
metric: types.Metric,
12911291
):
12921292
if isinstance(metric, dict):
12931293
metric = types.MetricSource(**metric)
12941294
super().__init__(module=module, metric=metric)
12951295

1296-
# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
12971296
def _build_request_payload(
12981297
self, eval_case: types.EvalCase, response_index: int
12991298
) -> dict[str, Any]:
1300-
"""Builds request payload for registered metric."""
1301-
if not self.metric.metric:
1299+
"""Builds request payload for registered metric by assembling EvaluationInstance."""
1300+
response_content = _get_response_from_eval_case(
1301+
eval_case, response_index, self.metric_name
1302+
)
1303+
1304+
if not response_content and not getattr(eval_case, "agent_data", None):
13021305
raise ValueError(
1303-
"Registered metric must have an underlying metric definition."
1306+
f"Response content missing for candidate {response_index}."
1307+
)
1308+
1309+
reference_instance_data = None
1310+
if eval_case.reference:
1311+
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
1312+
eval_case.reference.response
13041313
)
1305-
return PredefinedMetricHandler(
1306-
self.module, metric=self.metric.metric
1307-
)._build_request_payload(eval_case, response_index)
1314+
1315+
extracted_prompt = _get_prompt_from_eval_case(eval_case)
1316+
prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
1317+
extracted_prompt
1318+
)
1319+
1320+
instance_payload = types.EvaluationInstance(
1321+
prompt=prompt_instance_data,
1322+
response=PredefinedMetricHandler._content_to_instance_data(
1323+
response_content
1324+
),
1325+
reference=reference_instance_data,
1326+
rubric_groups=eval_case.rubric_groups,
1327+
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
1328+
)
1329+
1330+
request_payload = {
1331+
"instance": instance_payload,
1332+
}
1333+
return request_payload
13081334

13091335
@property
13101336
def metric_name(self) -> str:
1311-
# Resolve name from resource name or internal metric name
1312-
if isinstance(self.metric, types.MetricSource):
1313-
if self.metric.metric and self.metric.metric.name:
1314-
return self.metric.metric.name
1315-
if self.metric.metric_resource_name:
1316-
return self.metric.metric_resource_name
1317-
return "unknown"
1318-
else: # Should be Metric
1319-
metric_like = self.metric
1320-
if metric_like.name:
1321-
return metric_like.name
1322-
if metric_like.metric_resource_name:
1323-
return metric_like.metric_resource_name
1324-
return "unknown"
1337+
return self.metric.name or "unknown_metric"
13251338

13261339
@override
13271340
def get_metric_result(
13281341
self, eval_case: types.EvalCase, response_index: int
13291342
) -> types.EvalCaseMetricResult:
1330-
"""Processes a single evaluation case for a registered metric."""
1343+
"""Processes a single evaluation case using a MetricSource reference."""
13311344
metric_name = self.metric_name
1345+
metric_source = types.MetricSource(
1346+
metric_resource_name=self.metric.metric_resource_name
1347+
)
1348+
13321349
try:
13331350
payload = self._build_request_payload(eval_case, response_index)
13341351
for attempt in range(_MAX_RETRIES):
13351352
try:
13361353
api_response = self.module._evaluate_instances(
1337-
metric_sources=[self.metric],
1354+
metric_sources=[metric_source],
13381355
instance=payload.get("instance"),
13391356
autorater_config=payload.get("autorater_config"),
13401357
)
13411358
break
13421359
except genai_errors.ClientError as e:
13431360
if e.code == 429:
1361+
logger.warning(
1362+
"Resource Exhausted error on attempt %d/%d: %s. Retrying in %s"
1363+
" seconds...",
1364+
attempt + 1,
1365+
_MAX_RETRIES,
1366+
e,
1367+
2**attempt,
1368+
)
13441369
if attempt == _MAX_RETRIES - 1:
13451370
return types.EvalCaseMetricResult(
13461371
metric_name=metric_name,
@@ -1377,7 +1402,6 @@ def aggregate(
13771402
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
13781403
) -> types.AggregatedMetricResult:
13791404
"""Aggregates the metric results for a registered metric."""
1380-
logger.debug("Aggregating results for registered metric: %s", self.metric_name)
13811405
return _default_aggregate_scores(
13821406
self.metric_name, eval_case_metric_results, calculate_pass_rate=True
13831407
)

vertexai/_genai/_transformers.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@
2525
_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$"
2626

2727

28-
def t_metric(
29-
metric: "types.MetricSubclass",
30-
) -> dict[str, Any]:
31-
"""Prepares the metric payload for a single metric."""
32-
return t_metrics([metric])[0]
33-
34-
3528
def t_metrics(
3629
metrics: "list[types.MetricSubclass]",
3730
set_default_aggregation_metrics: bool = False,
@@ -82,16 +75,19 @@ def t_metrics(
8275
}
8376
# Pointwise metrics
8477
elif hasattr(metric, "prompt_template") and metric.prompt_template:
85-
pointwise_spec = {"metric_prompt_template": metric.prompt_template}
78+
llm_based_spec = {"metric_prompt_template": metric.prompt_template}
8679
system_instruction = getv(metric, ["judge_model_system_instruction"])
8780
if system_instruction:
88-
pointwise_spec["system_instruction"] = system_instruction
81+
llm_based_spec["system_instruction"] = system_instruction
82+
rubric_group_name = getv(metric, ["rubric_group_name"])
83+
if rubric_group_name:
84+
llm_based_spec["rubric_group_key"] = rubric_group_name
8985
return_raw_output = getv(metric, ["return_raw_output"])
9086
if return_raw_output:
91-
pointwise_spec["custom_output_format_config"] = {
87+
llm_based_spec["custom_output_format_config"] = {
9288
"return_raw_output": return_raw_output
9389
}
94-
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
90+
metric_payload_item["llm_based_metric_spec"] = llm_based_spec
9591
elif getattr(metric, "metric_resource_name", None) is not None:
9692
# Safe pass
9793
pass
@@ -127,3 +123,55 @@ def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]:
127123
metric_payload = t_metrics([metric])[0]
128124
sources_payload.append({"metric": metric_payload})
129125
return sources_payload
126+
127+
128+
def t_metric_for_registry(
129+
metric: "types.Metric",
130+
) -> dict[str, Any]:
131+
"""Prepares the metric payload specifically for EvaluationMetric registration."""
132+
metric_payload_item: dict[str, Any] = {}
133+
metric_name = getattr(metric, "name", None)
134+
if metric_name:
135+
metric_name = metric_name.lower()
136+
137+
# Handle standard computation metrics
138+
if metric_name == "exact_match":
139+
metric_payload_item["exact_match_spec"] = {}
140+
elif metric_name == "bleu":
141+
metric_payload_item["bleu_spec"] = {}
142+
elif metric_name and metric_name.startswith("rouge"):
143+
rouge_type = metric_name.replace("_", "")
144+
metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type}
145+
# API Pre-defined metrics
146+
elif metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
147+
metric_payload_item["predefined_metric_spec"] = {
148+
"metric_spec_name": metric_name,
149+
"metric_spec_parameters": metric.metric_spec_parameters,
150+
}
151+
# Custom Code Execution Metric
152+
elif hasattr(metric, "remote_custom_function") and metric.remote_custom_function:
153+
metric_payload_item["custom_code_execution_spec"] = {
154+
"evaluation_function": metric.remote_custom_function
155+
}
156+
157+
# Map LLM-based metrics to the new llm_based_metric_spec
158+
elif (hasattr(metric, "prompt_template") and metric.prompt_template) or (
159+
hasattr(metric, "rubric_group_name") and metric.rubric_group_name
160+
):
161+
llm_based_spec = {}
162+
163+
if hasattr(metric, "prompt_template") and metric.prompt_template:
164+
llm_based_spec["metric_prompt_template"] = metric.prompt_template
165+
system_instruction = getv(metric, ["judge_model_system_instruction"])
166+
if system_instruction:
167+
llm_based_spec["system_instruction"] = system_instruction
168+
rubric_group_name = getv(metric, ["rubric_group_name"])
169+
if rubric_group_name:
170+
llm_based_spec["rubric_group_key"] = rubric_group_name
171+
172+
metric_payload_item["llm_based_metric_spec"] = llm_based_spec
173+
174+
else:
175+
raise ValueError(f"Unsupported metric type: {metric_name}")
176+
177+
return metric_payload_item

vertexai/_genai/evals.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ def _CreateEvaluationMetricParameters_to_vertex(
7979
setv(to_object, ["description"], getv(from_object, ["description"]))
8080

8181
if getv(from_object, ["metric"]) is not None:
82-
setv(to_object, ["metric"], t.t_metric(getv(from_object, ["metric"])))
82+
setv(
83+
to_object,
84+
["metric"],
85+
t.t_metric_for_registry(getv(from_object, ["metric"])),
86+
)
8387

8488
if getv(from_object, ["config"]) is not None:
8589
setv(to_object, ["config"], getv(from_object, ["config"]))
@@ -2376,6 +2380,13 @@ def create_evaluation_metric(
23762380
)
23772381
metric = resolved_metrics[0]
23782382

2383+
# Add fallback logic for display_name
2384+
if display_name is None and metric:
2385+
if isinstance(metric, dict):
2386+
display_name = metric.get("name")
2387+
else:
2388+
display_name = getattr(metric, "name", None)
2389+
23792390
result = self._create_evaluation_metric(
23802391
display_name=display_name,
23812392
description=description,
@@ -3549,6 +3560,13 @@ async def create_evaluation_metric(
35493560
)
35503561
metric = resolved_metrics[0]
35513562

3563+
# Add fallback logic for display_name
3564+
if display_name is None and metric:
3565+
if isinstance(metric, dict):
3566+
display_name = metric.get("name")
3567+
else:
3568+
display_name = getattr(metric, "name", None)
3569+
35523570
result = await self._create_evaluation_metric(
35533571
display_name=display_name,
35543572
description=description,

0 commit comments

Comments
 (0)