Skip to content

Commit 47783dc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Allow using registered metric resource names in evaluation
PiperOrigin-RevId: 880868820
1 parent 1ecaa9b commit 47783dc

7 files changed

Lines changed: 221 additions & 31 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,28 @@ def test_create_eval_run_with_inference_configs(client):
238238
assert evaluation_run.error is None
239239

240240

241+
def test_create_eval_run_with_metric_resource_name(client):
242+
"""Tests create_evaluation_run with metric_resource_name."""
243+
client._api_client._http_options.api_version = "v1beta1"
244+
client._api_client._http_options.base_url = (
245+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
246+
)
247+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
248+
metric = types.EvaluationRunMetric(
249+
metric="my_custom_metric",
250+
metric_resource_name=metric_resource_name,
251+
)
252+
evaluation_run = client.evals.create_evaluation_run(
253+
dataset=types.EvaluationDataset(
254+
eval_dataset_df=INPUT_DF_WITH_CONTEXT_AND_HISTORY
255+
),
256+
metrics=[metric],
257+
dest=GCS_DEST,
258+
)
259+
assert isinstance(evaluation_run, types.EvaluationRun)
260+
assert evaluation_run.evaluation_config.metrics[0].metric == "my_custom_metric"
261+
262+
241263
# Dataframe tests fail in replay mode because of UUID generation mismatch.
242264
# def test_create_eval_run_data_source_evaluation_dataset(client):
243265
# """Tests that create_evaluation_run() creates a correctly structured

tests/unit/vertexai/genai/replays/test_public_generate_rubrics.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,21 @@
143143
User prompt:
144144
{prompt}"""
145145

146+
_PROMPTS_DF = pd.DataFrame(
147+
{
148+
"prompt": [
149+
"Explain the theory of relativity in one sentence.",
150+
"Write a short poem about a cat.",
151+
]
152+
}
153+
)
154+
146155

147156
def test_public_method_generate_rubrics(client):
148157
"""Tests the public generate_rubrics method."""
149-
prompts_df = pd.DataFrame(
150-
{
151-
"prompt": [
152-
"Explain the theory of relativity in one sentence.",
153-
"Write a short poem about a cat.",
154-
]
155-
}
156-
)
158+
157159
eval_dataset = client.evals.generate_rubrics(
158-
src=prompts_df,
160+
src=_PROMPTS_DF,
159161
prompt_template=_TEST_RUBRIC_GENERATION_PROMPT,
160162
rubric_group_name="text_quality_rubrics",
161163
)
@@ -176,6 +178,36 @@ def test_public_method_generate_rubrics(client):
176178
assert isinstance(first_rubric_group["text_quality_rubrics"][0], types.evals.Rubric)
177179

178180

181+
def test_public_method_generate_rubrics_with_metric(client):
182+
"""Tests the public generate_rubrics method with a metric."""
183+
client._api_client._http_options.api_version = "v1beta1"
184+
client._api_client._http_options.base_url = (
185+
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
186+
)
187+
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
188+
metric = types.Metric(
189+
name="my_custom_metric", metric_resource_name=metric_resource_name
190+
)
191+
eval_dataset = client.evals.generate_rubrics(
192+
src=_PROMPTS_DF, rubric_group_name="my_registered_rubrics", metric=metric
193+
)
194+
eval_dataset_df = eval_dataset.eval_dataset_df
195+
196+
assert isinstance(eval_dataset, types.EvaluationDataset)
197+
assert isinstance(eval_dataset_df, pd.DataFrame)
198+
assert "rubric_groups" in eval_dataset_df.columns
199+
assert len(eval_dataset_df) == 2
200+
201+
first_rubric_group = eval_dataset_df["rubric_groups"][0]
202+
assert isinstance(first_rubric_group, dict)
203+
assert "my_registered_rubrics" in first_rubric_group
204+
assert isinstance(first_rubric_group["my_registered_rubrics"], list)
205+
assert first_rubric_group["my_registered_rubrics"]
206+
assert isinstance(
207+
first_rubric_group["my_registered_rubrics"][0], types.evals.Rubric
208+
)
209+
210+
179211
pytestmark = pytest_helper.setup(
180212
file=__file__,
181213
globals_for_file=globals(),

vertexai/_genai/_evals_common.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import _gcs_utils
4646
from . import evals
4747
from . import types
48+
from . import _transformers as t
4849

4950
logger = logging.getLogger(__name__)
5051

@@ -1328,7 +1329,7 @@ def _resolve_dataset_inputs(
13281329

13291330

13301331
def _resolve_evaluation_run_metrics(
1331-
metrics: list[types.EvaluationRunMetric], api_client: Any
1332+
metrics: Union[list[types.EvaluationRunMetric], list[types.Metric]], api_client: Any
13321333
) -> list[types.EvaluationRunMetric]:
13331334
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
13341335
if not metrics:
@@ -1361,6 +1362,16 @@ def _resolve_evaluation_run_metrics(
13611362
e,
13621363
)
13631364
raise
1365+
elif isinstance(metric_instance, types.Metric):
1366+
config_dict = t.t_metrics([metric_instance])[0]
1367+
res_name = config_dict.pop("metric_resource_name", None)
1368+
resolved_metrics_list.append(
1369+
types.EvaluationRunMetric(
1370+
metric=metric_instance.name,
1371+
metric_config=config_dict if config_dict else None,
1372+
metric_resource_name=res_name,
1373+
)
1374+
)
13641375
else:
13651376
try:
13661377
metric_name_str = str(metric_instance)

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ def get_metric_result(
10271027
for attempt in range(_MAX_RETRIES):
10281028
try:
10291029
api_response = self.module._evaluate_instances(
1030-
metrics=[self.metric],
1030+
metrics_sources=[self.metric],
10311031
instance=payload.get("instance"),
10321032
autorater_config=payload.get("autorater_config"),
10331033
)
@@ -1164,7 +1164,7 @@ def get_metric_result(
11641164
for attempt in range(_MAX_RETRIES):
11651165
try:
11661166
api_response = self.module._evaluate_instances(
1167-
metrics=[self.metric],
1167+
metrics_sources=[self.metric],
11681168
instance=payload.get("instance"),
11691169
)
11701170
break
@@ -1242,6 +1242,14 @@ def aggregate(
12421242
)
12431243

12441244

1245+
class RegisteredMetricHandler(PredefinedMetricHandler):
1246+
"""Metric handler for registered metrics."""
1247+
1248+
def __init__(self, module: "evals.Evals", metric: types.Metric):
1249+
# Skip the parent check for SUPPORTED_PREDEFINED_METRICS
1250+
MetricHandler.__init__(self, module=module, metric=metric)
1251+
1252+
12451253
_METRIC_HANDLER_MAPPING = [
12461254
(
12471255
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
@@ -1251,6 +1259,10 @@ def aggregate(
12511259
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
12521260
CustomMetricHandler,
12531261
),
1262+
(
1263+
lambda m: getattr(m, "metric_resource_name", None) is not None,
1264+
RegisteredMetricHandler,
1265+
),
12541266
(
12551267
lambda m: m.name in ComputationMetricHandler.SUPPORTED_COMPUTATION_METRICS,
12561268
ComputationMetricHandler,

vertexai/_genai/_transformers.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515

1616
"""Transformers module for Vertex addons."""
17+
import re
1718
from typing import Any
1819

1920
from google.genai._common import get_value_by_path as getv
@@ -31,15 +32,26 @@ def t_metrics(
3132
Args:
3233
metrics: A list of metrics used for evaluation.
3334
set_default_aggregation_metrics: Whether to set default aggregation metrics.
35+
3436
Returns:
3537
A list of resolved metric payloads for the evaluation request.
3638
"""
3739
metrics_payload = []
3840

3941
for metric in metrics:
42+
# Case 1: Registered Metric Resource Name
43+
if isinstance(metric, str) and re.match(
44+
r"^projects/[^/]+/location/[^/]+/evaluationMetric/[^/]+$", metric
45+
):
46+
metrics_payload.append({"metric_resource_name": metric})
47+
continue
48+
49+
# Case 2: Inline Metric Configuration
4050
metric_payload_item: dict[str, Any] = {}
51+
if hasattr(metric, "metric_resource_name") and metric.metric_resource_name:
52+
metric_payload_item["metric_resource_name"] = metric.metric_resource_name
4153

42-
metric_name = getv(metric, ["name"]).lower()
54+
metric_name = getattr(metric, "name", "").lower()
4355

4456
if set_default_aggregation_metrics:
4557
metric_payload_item["aggregation_metrics"] = [
@@ -79,9 +91,36 @@ def t_metrics(
7991
"return_raw_output": return_raw_output
8092
}
8193
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
94+
elif "metric_resource_name" in metric_payload_item:
95+
# Valid case: Metric is identified by resource name; no inline spec required.
96+
pass
8297
else:
8398
raise ValueError(
8499
f"Unsupported metric type or invalid metric name: {metric_name}"
85100
)
86-
metrics_payload.append(metric_payload_item)
101+
metrics_payload.append({"metric": metric_payload_item})
87102
return metrics_payload
103+
104+
105+
def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]:
106+
"""Prepares the MetricSource payload for the evaluation request."""
107+
sources_payload = []
108+
for metric in metrics:
109+
# Check if the 'metric' is a resource name string or contains one
110+
resource_name = getattr(metric, "metric_resource_name", None)
111+
if (
112+
not resource_name
113+
and isinstance(metric, str)
114+
and re.match(
115+
r"^projects/[^/]+/location/[^/]+/evaluationMetric/[^/]+$", metric
116+
)
117+
):
118+
resource_name = metric
119+
120+
if resource_name:
121+
sources_payload.append({"metric_resource_name": resource_name})
122+
else:
123+
# Fallback to existing Metric spec transformation
124+
metric_payload = t_metrics([metric])[0]
125+
sources_payload.append({"metric": metric_payload})
126+
return sources_payload

0 commit comments

Comments
 (0)