Skip to content

Commit f7733ec

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Update the interface for custom code execution metric while maintaining remote_custom_function support for backward compatibility
PiperOrigin-RevId: 892523805
1 parent 2ad586d commit f7733ec

File tree

6 files changed

+99
-27
lines changed

6 files changed

+99
-27
lines changed

tests/unit/vertexai/genai/replays/test_custom_code_execution_metric.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,32 @@
1818
from vertexai._genai import types
1919
from google.genai import types as genai_types
2020
import pandas as pd
21+
import pytest
2122

2223

23-
def test_custom_code_execution(client):
24-
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""
25-
26-
code_snippet = """
24+
CODE_SNIPPET = """
2725
def evaluate(instance):
2826
if instance['response'] == instance['reference']:
2927
return 1.0
3028
return 0.0
3129
"""
3230

33-
custom_metric = types.Metric(
34-
name="my_custom_code_metric",
35-
remote_custom_function=code_snippet,
36-
)
31+
32+
@pytest.mark.parametrize(
33+
"custom_metric",
34+
[
35+
types.CodeExecutionMetric(
36+
name="my_custom_code_metric",
37+
custom_function=CODE_SNIPPET,
38+
),
39+
types.Metric(
40+
name="my_custom_code_metric",
41+
remote_custom_function=CODE_SNIPPET,
42+
),
43+
],
44+
)
45+
def test_custom_code_execution(client, custom_metric):
46+
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""
3747

3848
prompts_df = pd.DataFrame(
3949
{
@@ -69,21 +79,22 @@ def evaluate(instance):
6979
assert case_result.response_candidate_results is not None
7080

7181

72-
def test_custom_code_execution_batch_evaluate(client):
82+
@pytest.mark.parametrize(
83+
"custom_metric",
84+
[
85+
types.CodeExecutionMetric(
86+
name="my_custom_code_metric",
87+
custom_function=CODE_SNIPPET,
88+
),
89+
types.Metric(
90+
name="my_custom_code_metric",
91+
remote_custom_function=CODE_SNIPPET,
92+
),
93+
],
94+
)
95+
def test_custom_code_execution_batch_evaluate(client, custom_metric):
7396
"""Tests that batch_evaluate() works with custom code execution metric."""
7497

75-
code_snippet = """
76-
def evaluate(instance):
77-
if instance['response'] == instance['reference']:
78-
return 1.0
79-
return 0.0
80-
"""
81-
82-
custom_metric = types.Metric(
83-
name="my_custom_code_metric",
84-
remote_custom_function=code_snippet,
85-
)
86-
8798
eval_dataset = types.EvaluationDataset(
8899
gcs_source=genai_types.GcsSource(
89100
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def get_metric_result(
845845
score = None
846846
explanation = None
847847
try:
848-
if self.metric.custom_function:
848+
if self.metric.custom_function and callable(self.metric.custom_function):
849849
custom_function_result = self.metric.custom_function(
850850
instance_for_custom_fn
851851
)
@@ -1058,10 +1058,10 @@ def metric_name(self) -> str:
10581058
def __init__(self, module: "evals.Evals", metric: types.Metric):
10591059
super().__init__(module=module, metric=metric)
10601060

1061-
if not self.metric.remote_custom_function:
1061+
if not self.metric.remote_custom_function and not self.metric.custom_function:
10621062
raise ValueError(
10631063
f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs "
1064-
" Metric.remote_custom_function to be set."
1064+
" custom function to be set."
10651065
)
10661066

10671067
def _build_request_payload(
@@ -1310,7 +1310,17 @@ def aggregate(
13101310

13111311
_METRIC_HANDLER_MAPPING = [
13121312
(
1313-
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
1313+
lambda m: (
1314+
# Recognize the user-facing class
1315+
isinstance(m, types.CodeExecutionMetric)
1316+
and (hasattr(m, "custom_function") and m.custom_function)
1317+
)
1318+
or (hasattr(m, "remote_custom_function") and m.remote_custom_function)
1319+
# Recognize base Metric objects that have been coerced by Pydantic
1320+
or (
1321+
isinstance(m, types.Metric)
1322+
and isinstance(getattr(m, "custom_function", None), str)
1323+
),
13141324
CustomCodeExecutionMetricHandler,
13151325
),
13161326
(

vertexai/_genai/_evals_metric_loaders.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,12 @@ def GECKO_TEXT2VIDEO(self) -> LazyLoadedPrebuiltMetric:
370370

371371
PrebuiltMetric = PrebuiltMetricLoader()
372372
RubricMetric = PrebuiltMetric
373+
374+
375+
def CodeExecutionMetric(
376+
name: str, custom_function: str, **kwargs: Any
377+
) -> "types.Metric":
378+
"""Instantiates a code execution metric."""
379+
from . import types
380+
381+
return types.Metric(name=name, remote_custom_function=custom_function, **kwargs)

vertexai/_genai/_transformers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ def t_metrics(
7373
metric_payload_item["custom_code_execution_spec"] = {
7474
"evaluation_function": metric.remote_custom_function
7575
}
76+
elif (
77+
isinstance(metric, types.CodeExecutionMetric)
78+
or (
79+
isinstance(metric, types.Metric)
80+
and isinstance(getattr(metric, "custom_function", None), str)
81+
)
82+
) and getattr(metric, "custom_function", None):
83+
metric_payload_item["custom_code_execution_spec"] = {
84+
"evaluation_function": metric.custom_function
85+
}
7686
# LLM-based metrics
7787
elif hasattr(metric, "prompt_template") and metric.prompt_template:
7888
llm_based_spec: dict[str, Any] = {
@@ -196,6 +206,16 @@ def t_metric_for_registry(
196206
metric_payload_item["custom_code_execution_spec"] = {
197207
"evaluation_function": metric.remote_custom_function
198208
}
209+
elif (
210+
isinstance(metric, types.CodeExecutionMetric)
211+
or (
212+
isinstance(metric, types.Metric)
213+
and isinstance(getattr(metric, "custom_function", None), str)
214+
)
215+
) and getattr(metric, "custom_function", None):
216+
metric_payload_item["custom_code_execution_spec"] = {
217+
"evaluation_function": metric.custom_function
218+
}
199219

200220
# Map LLM-based metrics to the new llm_based_metric_spec
201221
elif (hasattr(metric, "prompt_template") and metric.prompt_template) or (

vertexai/_genai/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
from .common import Chunk
208208
from .common import ChunkDict
209209
from .common import ChunkOrDict
210+
from .common import CodeExecutionMetric
210211
from .common import CometResult
211212
from .common import CometResultDict
212213
from .common import CometResultOrDict
@@ -2153,6 +2154,7 @@
21532154
"PromptDataDict",
21542155
"PromptDataOrDict",
21552156
"LLMMetric",
2157+
"CodeExecutionMetric",
21562158
"MetricPromptBuilder",
21572159
"RubricContentProperty",
21582160
"RubricContentPropertyDict",

vertexai/_genai/types/common.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,7 +1576,7 @@ class Metric(_common.BaseModel):
15761576
"""The metric used for evaluation."""
15771577

15781578
name: Optional[str] = Field(default=None, description="""The name of the metric.""")
1579-
custom_function: Optional[Callable[..., Any]] = Field(
1579+
custom_function: Optional[Union[str, Callable[..., Any]]] = Field(
15801580
default=None,
15811581
description="""The custom function that defines the end-to-end logic for metric computation.""",
15821582
)
@@ -1682,6 +1682,26 @@ def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None:
16821682
yaml.dump(data_to_dump, f, sort_keys=False, allow_unicode=True)
16831683

16841684

1685+
class CodeExecutionMetric(Metric):
1686+
"""A metric that executes custom Python code for evaluation."""
1687+
1688+
# You can use standard Pydantic Field syntax here because this is raw Python code
1689+
custom_function: Optional[str] = Field(
1690+
default=None,
1691+
description="""The Python function code to be executed on the server side.""",
1692+
)
1693+
1694+
# You can also add hand-written validators or methods here
1695+
@field_validator("custom_function")
1696+
@classmethod
1697+
def validate_code(cls, value: Optional[str]) -> Optional[str]:
1698+
if value and "def evaluate" not in value:
1699+
raise ValueError(
1700+
"custom_function must contain a 'def evaluate(instance):' signature."
1701+
)
1702+
return value
1703+
1704+
16851705
class LLMMetric(Metric):
16861706
"""A metric that uses LLM-as-a-judge for evaluation."""
16871707

@@ -1792,7 +1812,7 @@ class MetricDict(TypedDict, total=False):
17921812
name: Optional[str]
17931813
"""The name of the metric."""
17941814

1795-
custom_function: Optional[Callable[..., Any]]
1815+
custom_function: Optional[Union[str, Callable[..., Any]]]
17961816
"""The custom function that defines the end-to-end logic for metric computation."""
17971817

17981818
prompt_template: Optional[str]

0 commit comments

Comments
 (0)