Skip to content

Commit 2ae435b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow using registered metric resource names in evaluation
PiperOrigin-RevId: 880868820
1 parent 3d05ffa commit 2ae435b

File tree

9 files changed

+334
-19
lines changed

9 files changed

+334
-19
lines changed

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,28 @@ def test_create_eval_run_with_inference_configs(client):
238238
assert evaluation_run.error is None
239239

240240

241+
def test_create_eval_run_with_metric_resource_name(client):
242+
"""Tests create_evaluation_run with metric_resource_name."""
243+
client._api_client._http_options.api_version = "v1beta1"
244+
client._api_client._http_options.base_url = (
245+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
246+
)
247+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
248+
metric = types.EvaluationRunMetric(
249+
metric="my_custom_metric",
250+
metric_resource_name=metric_resource_name,
251+
)
252+
evaluation_run = client.evals.create_evaluation_run(
253+
dataset=types.EvaluationDataset(
254+
eval_dataset_df=INPUT_DF_WITH_CONTEXT_AND_HISTORY
255+
),
256+
metrics=[metric],
257+
dest=GCS_DEST,
258+
)
259+
assert isinstance(evaluation_run, types.EvaluationRun)
260+
assert evaluation_run.evaluation_config.metrics[0].metric == "my_custom_metric"
261+
262+
241263
# Dataframe tests fail in replay mode because of UUID generation mismatch.
242264
# def test_create_eval_run_data_source_evaluation_dataset(client):
243265
# """Tests that create_evaluation_run() creates a correctly structured

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,32 @@ def test_evaluation_agent_data(client):
353353
assert case_result.response_candidate_results is not None
354354

355355

356+
def test_metric_resource_name(client):
357+
"""Tests with a metric resource name in types.Metric."""
358+
client._api_client._http_options.api_version = "v1beta1"
359+
client._api_client._http_options.base_url = (
360+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
361+
)
362+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
363+
byor_df = pd.DataFrame(
364+
{
365+
"prompt": ["Write a simple story about a dinosaur"],
366+
"response": ["Once upon a time, there was a T-Rex named Rexy."],
367+
}
368+
)
369+
metric = types.Metric(
370+
name="my_custom_metric", metric_resource_name=metric_resource_name
371+
)
372+
evaluation_result = client.evals.evaluate(
373+
dataset=byor_df,
374+
metrics=[metric],
375+
)
376+
assert isinstance(evaluation_result, types.EvaluationResult)
377+
assert evaluation_result.eval_case_results is not None
378+
assert len(evaluation_result.eval_case_results) > 0
379+
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"
380+
381+
356382
pytestmark = pytest_helper.setup(
357383
file=__file__,
358384
globals_for_file=globals(),

tests/unit/vertexai/genai/replays/test_public_generate_rubrics.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,21 @@
143143
User prompt:
144144
{prompt}"""
145145

146+
_PROMPTS_DF = pd.DataFrame(
147+
{
148+
"prompt": [
149+
"Explain the theory of relativity in one sentence.",
150+
"Write a short poem about a cat.",
151+
]
152+
}
153+
)
154+
146155

147156
def test_public_method_generate_rubrics(client):
148157
"""Tests the public generate_rubrics method."""
149-
prompts_df = pd.DataFrame(
150-
{
151-
"prompt": [
152-
"Explain the theory of relativity in one sentence.",
153-
"Write a short poem about a cat.",
154-
]
155-
}
156-
)
158+
157159
eval_dataset = client.evals.generate_rubrics(
158-
src=prompts_df,
160+
src=_PROMPTS_DF,
159161
prompt_template=_TEST_RUBRIC_GENERATION_PROMPT,
160162
rubric_group_name="text_quality_rubrics",
161163
)
@@ -176,6 +178,36 @@ def test_public_method_generate_rubrics(client):
176178
assert isinstance(first_rubric_group["text_quality_rubrics"][0], types.evals.Rubric)
177179

178180

181+
def test_public_method_generate_rubrics_with_metric(client):
182+
"""Tests the public generate_rubrics method with a metric."""
183+
client._api_client._http_options.api_version = "v1beta1"
184+
client._api_client._http_options.base_url = (
185+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
186+
)
187+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
188+
metric = types.Metric(
189+
name="my_custom_metric", metric_resource_name=metric_resource_name
190+
)
191+
eval_dataset = client.evals.generate_rubrics(
192+
src=_PROMPTS_DF, rubric_group_name="my_registered_rubrics", metric=metric
193+
)
194+
eval_dataset_df = eval_dataset.eval_dataset_df
195+
196+
assert isinstance(eval_dataset, types.EvaluationDataset)
197+
assert isinstance(eval_dataset_df, pd.DataFrame)
198+
assert "rubric_groups" in eval_dataset_df.columns
199+
assert len(eval_dataset_df) == 2
200+
201+
first_rubric_group = eval_dataset_df["rubric_groups"][0]
202+
assert isinstance(first_rubric_group, dict)
203+
assert "my_registered_rubrics" in first_rubric_group
204+
assert isinstance(first_rubric_group["my_registered_rubrics"], list)
205+
assert first_rubric_group["my_registered_rubrics"]
206+
assert isinstance(
207+
first_rubric_group["my_registered_rubrics"][0], types.evals.Rubric
208+
)
209+
210+
179211
pytestmark = pytest_helper.setup(
180212
file=__file__,
181213
globals_for_file=globals(),

vertexai/_genai/_evals_common.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import _gcs_utils
4646
from . import evals
4747
from . import types
48+
from . import _transformers as t
4849

4950
logger = logging.getLogger(__name__)
5051

@@ -1328,7 +1329,7 @@ def _resolve_dataset_inputs(
13281329

13291330

13301331
def _resolve_evaluation_run_metrics(
1331-
metrics: list[types.EvaluationRunMetric], api_client: Any
1332+
metrics: Union[list[types.EvaluationRunMetric], list[types.Metric]], api_client: Any
13321333
) -> list[types.EvaluationRunMetric]:
13331334
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
13341335
if not metrics:
@@ -1361,6 +1362,16 @@ def _resolve_evaluation_run_metrics(
13611362
e,
13621363
)
13631364
raise
1365+
elif isinstance(metric_instance, types.Metric):
1366+
config_dict = t.t_metrics([metric_instance])[0]
1367+
res_name = getattr(metric_instance, "metric_resource_name", None)
1368+
resolved_metrics_list.append(
1369+
types.EvaluationRunMetric(
1370+
metric=metric_instance.name,
1371+
metric_config=config_dict if config_dict else None,
1372+
metric_resource_name=res_name,
1373+
)
1374+
)
13641375
else:
13651376
try:
13661377
metric_name_str = str(metric_instance)

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,62 @@ def aggregate(
12421242
)
12431243

12441244

1245+
class RegisteredMetricHandler(PredefinedMetricHandler):
1246+
"""Metric handler for registered metrics."""
1247+
1248+
def __init__(self, module: "evals.Evals", metric: types.Metric):
1249+
MetricHandler.__init__(self, module=module, metric=metric)
1250+
1251+
@override
1252+
def get_metric_result(
1253+
self, eval_case: types.EvalCase, response_index: int
1254+
) -> types.EvalCaseMetricResult:
1255+
"""Processes a single evaluation case for a registered metric."""
1256+
metric_name = self.metric.name
1257+
try:
1258+
payload = self._build_request_payload(eval_case, response_index)
1259+
for attempt in range(_MAX_RETRIES):
1260+
try:
1261+
api_response = self.module._evaluate_instances(
1262+
metric_sources=[self.metric],
1263+
instance=payload.get("instance"),
1264+
autorater_config=payload.get("autorater_config"),
1265+
)
1266+
break
1267+
except genai_errors.ClientError as e:
1268+
if e.code == 429:
1269+
if attempt == _MAX_RETRIES - 1:
1270+
return types.EvalCaseMetricResult(
1271+
metric_name=metric_name,
1272+
error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}",
1273+
)
1274+
time.sleep(2**attempt)
1275+
else:
1276+
raise e
1277+
1278+
if api_response and api_response.metric_results:
1279+
result_data = api_response.metric_results[0]
1280+
error_message = None
1281+
if result_data.error and getattr(result_data.error, "code"):
1282+
error_message = f"Error in metric result: {result_data.error}"
1283+
return types.EvalCaseMetricResult(
1284+
metric_name=metric_name,
1285+
score=result_data.score,
1286+
explanation=result_data.explanation,
1287+
rubric_verdicts=result_data.rubric_verdicts,
1288+
error_message=error_message,
1289+
)
1290+
else:
1291+
return types.EvalCaseMetricResult(
1292+
metric_name=metric_name,
1293+
error_message="Metric results missing in API response.",
1294+
)
1295+
except Exception as e:
1296+
return types.EvalCaseMetricResult(
1297+
metric_name=metric_name, error_message=str(e)
1298+
)
1299+
1300+
12451301
_METRIC_HANDLER_MAPPING = [
12461302
(
12471303
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
@@ -1251,6 +1307,10 @@ def aggregate(
12511307
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
12521308
CustomMetricHandler,
12531309
),
1310+
(
1311+
lambda m: getattr(m, "metric_resource_name", None) is not None,
1312+
RegisteredMetricHandler,
1313+
),
12541314
(
12551315
lambda m: m.name in ComputationMetricHandler.SUPPORTED_COMPUTATION_METRICS,
12561316
ComputationMetricHandler,

vertexai/_genai/_transformers.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
#
1515

1616
"""Transformers module for Vertex addons."""
17+
import re
1718
from typing import Any
1819

1920
from google.genai._common import get_value_by_path as getv
2021

2122
from . import _evals_constant
2223
from . import types
2324

25+
_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$"
26+
2427

2528
def t_metrics(
2629
metrics: list["types.MetricSubclass"],
@@ -39,7 +42,8 @@ def t_metrics(
3942
for metric in metrics:
4043
metric_payload_item: dict[str, Any] = {}
4144

42-
metric_name = getv(metric, ["name"]).lower()
45+
metric_id = getv(metric, ["metric"]) or getv(metric, ["name"])
46+
metric_name = metric_id.lower() if metric_id else None
4347

4448
if set_default_aggregation_metrics:
4549
metric_payload_item["aggregation_metrics"] = [
@@ -51,11 +55,13 @@ def t_metrics(
5155
metric_payload_item["exact_match_spec"] = {}
5256
elif metric_name == "bleu":
5357
metric_payload_item["bleu_spec"] = {}
54-
elif metric_name.startswith("rouge"):
58+
elif metric_name and metric_name.startswith("rouge"):
5559
rouge_type = metric_name.replace("_", "")
5660
metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type}
5761
# API Pre-defined metrics
58-
elif metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
62+
elif (
63+
metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS
64+
):
5965
metric_payload_item["predefined_metric_spec"] = {
6066
"metric_spec_name": metric_name,
6167
"metric_spec_parameters": metric.metric_spec_parameters,
@@ -79,9 +85,38 @@ def t_metrics(
7985
"return_raw_output": return_raw_output
8086
}
8187
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
88+
elif getattr(metric, "metric_resource_name", None) is not None:
89+
# Safe pass
90+
pass
8291
else:
8392
raise ValueError(
8493
f"Unsupported metric type or invalid metric name: {metric_name}"
8594
)
8695
metrics_payload.append(metric_payload_item)
8796
return metrics_payload
97+
98+
99+
def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]:
100+
"""Prepares the MetricSource payload."""
101+
sources_payload = []
102+
for metric in metrics:
103+
resource_name = getattr(metric, "metric_resource_name", None)
104+
if (
105+
not resource_name
106+
and isinstance(metric, str)
107+
and re.match(_METRIC_RES_NAME_RE, metric)
108+
):
109+
resource_name = metric
110+
111+
if resource_name:
112+
sources_payload.append({"metric_resource_name": resource_name})
113+
else:
114+
if hasattr(metric, "metric") and not isinstance(metric, str):
115+
metric = metric.metric
116+
117+
if not hasattr(metric, "name"):
118+
metric = types.Metric(name=str(metric))
119+
120+
metric_payload = t_metrics([metric])[0]
121+
sources_payload.append({"metric": metric_payload})
122+
return sources_payload

0 commit comments

Comments
 (0)