Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,32 @@
from vertexai._genai import types
from google.genai import types as genai_types
import pandas as pd
import pytest


def test_custom_code_execution(client):
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""

code_snippet = """
CODE_SNIPPET = """
def evaluate(instance):
if instance['response'] == instance['reference']:
return 1.0
return 0.0
"""

custom_metric = types.Metric(
name="my_custom_code_metric",
remote_custom_function=code_snippet,
)

@pytest.mark.parametrize(
"custom_metric",
[
types.CodeExecutionMetric(
name="my_custom_code_metric",
custom_function=CODE_SNIPPET,
),
types.Metric(
name="my_custom_code_metric",
remote_custom_function=CODE_SNIPPET,
),
],
)
def test_custom_code_execution(client, custom_metric):
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""

prompts_df = pd.DataFrame(
{
Expand Down Expand Up @@ -69,21 +79,22 @@ def evaluate(instance):
assert case_result.response_candidate_results is not None


def test_custom_code_execution_batch_evaluate(client):
@pytest.mark.parametrize(
"custom_metric",
[
types.CodeExecutionMetric(
name="my_custom_code_metric",
custom_function=CODE_SNIPPET,
),
types.Metric(
name="my_custom_code_metric",
remote_custom_function=CODE_SNIPPET,
),
],
)
def test_custom_code_execution_batch_evaluate(client, custom_metric):
"""Tests that batch_evaluate() works with custom code execution metric."""

code_snippet = """
def evaluate(instance):
if instance['response'] == instance['reference']:
return 1.0
return 0.0
"""

custom_metric = types.Metric(
name="my_custom_code_metric",
remote_custom_function=code_snippet,
)

eval_dataset = types.EvaluationDataset(
gcs_source=genai_types.GcsSource(
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
Expand Down
18 changes: 14 additions & 4 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def get_metric_result(
score = None
explanation = None
try:
if self.metric.custom_function:
if self.metric.custom_function and callable(self.metric.custom_function):
custom_function_result = self.metric.custom_function(
instance_for_custom_fn
)
Expand Down Expand Up @@ -1058,10 +1058,10 @@ def metric_name(self) -> str:
def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)

if not self.metric.remote_custom_function:
if not self.metric.remote_custom_function and not self.metric.custom_function:
raise ValueError(
f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs "
" Metric.remote_custom_function to be set."
" custom function to be set."
)

def _build_request_payload(
Expand Down Expand Up @@ -1310,7 +1310,17 @@ def aggregate(

_METRIC_HANDLER_MAPPING = [
(
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
lambda m: (
# Recognize the user-facing class
isinstance(m, types.CodeExecutionMetric)
and (hasattr(m, "custom_function") and m.custom_function)
)
or (hasattr(m, "remote_custom_function") and m.remote_custom_function)
# Recognize base Metric objects that have been coerced by Pydantic
or (
isinstance(m, types.Metric)
and isinstance(getattr(m, "custom_function", None), str)
),
CustomCodeExecutionMetricHandler,
),
(
Expand Down
9 changes: 9 additions & 0 deletions vertexai/_genai/_evals_metric_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,12 @@ def GECKO_TEXT2VIDEO(self) -> LazyLoadedPrebuiltMetric:

PrebuiltMetric = PrebuiltMetricLoader()
RubricMetric = PrebuiltMetric


def CodeExecutionMetric(
name: str, custom_function: str, **kwargs: Any
) -> "types.Metric":
"""Instantiates a code execution metric."""
from . import types

return types.Metric(name=name, remote_custom_function=custom_function, **kwargs)
20 changes: 20 additions & 0 deletions vertexai/_genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def t_metrics(
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.remote_custom_function
}
elif (
isinstance(metric, types.CodeExecutionMetric)
or (
isinstance(metric, types.Metric)
and isinstance(getattr(metric, "custom_function", None), str)
)
) and getattr(metric, "custom_function", None):
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.custom_function
}
# LLM-based metrics
elif hasattr(metric, "prompt_template") and metric.prompt_template:
llm_based_spec: dict[str, Any] = {
Expand Down Expand Up @@ -196,6 +206,16 @@ def t_metric_for_registry(
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.remote_custom_function
}
elif (
isinstance(metric, types.CodeExecutionMetric)
or (
isinstance(metric, types.Metric)
and isinstance(getattr(metric, "custom_function", None), str)
)
) and getattr(metric, "custom_function", None):
metric_payload_item["custom_code_execution_spec"] = {
"evaluation_function": metric.custom_function
}

# Map LLM-based metrics to the new llm_based_metric_spec
elif (hasattr(metric, "prompt_template") and metric.prompt_template) or (
Expand Down
2 changes: 2 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@
from .common import Chunk
from .common import ChunkDict
from .common import ChunkOrDict
from .common import CodeExecutionMetric
from .common import CometResult
from .common import CometResultDict
from .common import CometResultOrDict
Expand Down Expand Up @@ -2153,6 +2154,7 @@
"PromptDataDict",
"PromptDataOrDict",
"LLMMetric",
"CodeExecutionMetric",
"MetricPromptBuilder",
"RubricContentProperty",
"RubricContentPropertyDict",
Expand Down
24 changes: 22 additions & 2 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ class Metric(_common.BaseModel):
"""The metric used for evaluation."""

name: Optional[str] = Field(default=None, description="""The name of the metric.""")
custom_function: Optional[Callable[..., Any]] = Field(
custom_function: Optional[Union[str, Callable[..., Any]]] = Field(
default=None,
description="""The custom function that defines the end-to-end logic for metric computation.""",
)
Expand Down Expand Up @@ -1682,6 +1682,26 @@ def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None:
yaml.dump(data_to_dump, f, sort_keys=False, allow_unicode=True)


class CodeExecutionMetric(Metric):
"""A metric that executes custom Python code for evaluation."""

# You can use standard Pydantic Field syntax here because this is raw Python code
custom_function: Optional[str] = Field(
default=None,
description="""The Python function code to be executed on the server side.""",
)

# You can also add hand-written validators or methods here
@field_validator("custom_function")
@classmethod
def validate_code(cls, value: Optional[str]) -> Optional[str]:
if value and "def evaluate" not in value:
raise ValueError(
"custom_function must contain a 'def evaluate(instance):' signature."
)
return value


class LLMMetric(Metric):
"""A metric that uses LLM-as-a-judge for evaluation."""

Expand Down Expand Up @@ -1792,7 +1812,7 @@ class MetricDict(TypedDict, total=False):
name: Optional[str]
"""The name of the metric."""

custom_function: Optional[Callable[..., Any]]
custom_function: Optional[Union[str, Callable[..., Any]]]
"""The custom function that defines the end-to-end logic for metric computation."""

prompt_template: Optional[str]
Expand Down
Loading