Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,170 @@ def test_loss_analysis_result_show(self, capsys):
assert "c1" in captured.out


def _make_eval_result(
metrics=None,
candidate_names=None,
):
"""Helper to create an EvaluationResult with the given metrics and candidates."""
metrics = metrics or ["task_success_v1"]
candidate_names = candidate_names or ["agent-1"]

metric_results = {}
for m in metrics:
metric_results[m] = common_types.EvalCaseMetricResult(metric_name=m)

eval_case_results = [
common_types.EvalCaseResult(
eval_case_index=0,
response_candidate_results=[
common_types.ResponseCandidateResult(
response_index=0,
metric_results=metric_results,
)
],
)
]
metadata = common_types.EvaluationRunMetadata(
candidate_names=candidate_names,
)
return common_types.EvaluationResult(
eval_case_results=eval_case_results,
metadata=metadata,
)


class TestResolveMetricName:
"""Unit tests for _resolve_metric_name."""

def test_none_returns_none(self):
assert _evals_utils._resolve_metric_name(None) is None

def test_string_passes_through(self):
assert _evals_utils._resolve_metric_name("task_success_v1") == "task_success_v1"

def test_metric_object_extracts_name(self):
metric = common_types.Metric(name="multi_turn_task_success_v1")
assert _evals_utils._resolve_metric_name(metric) == "multi_turn_task_success_v1"

def test_object_with_name_attr(self):
"""Tests that any object with a .name attribute works (e.g., LazyLoadedPrebuiltMetric)."""

class FakeMetric:
name = "tool_use_quality_v1"

assert _evals_utils._resolve_metric_name(FakeMetric()) == "tool_use_quality_v1"

def test_lazy_loaded_prebuilt_metric_resolves_versioned_name(self):
"""Tests that LazyLoadedPrebuiltMetric resolves to the versioned API spec name."""

class FakeLazyMetric:
name = "MULTI_TURN_TASK_SUCCESS"

def _get_api_metric_spec_name(self):
return "multi_turn_task_success_v1"

assert (
_evals_utils._resolve_metric_name(FakeLazyMetric())
== "multi_turn_task_success_v1"
)

def test_lazy_loaded_prebuilt_metric_falls_back_to_name(self):
"""Tests fallback to .name when _get_api_metric_spec_name returns None."""

class FakeLazyMetricNoSpec:
name = "CUSTOM_METRIC"

def _get_api_metric_spec_name(self):
return None

assert (
_evals_utils._resolve_metric_name(FakeLazyMetricNoSpec()) == "CUSTOM_METRIC"
)


class TestResolveLossAnalysisConfig:
"""Unit tests for _resolve_loss_analysis_config."""

def test_auto_infer_single_metric_and_candidate(self):
eval_result = _make_eval_result(
metrics=["task_success_v1"], candidate_names=["agent-1"]
)
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
assert resolved.metric == "task_success_v1"
assert resolved.candidate == "agent-1"

def test_explicit_metric_and_candidate(self):
eval_result = _make_eval_result(
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
)
resolved = _evals_utils._resolve_loss_analysis_config(
eval_result=eval_result, metric="m1", candidate="c2"
)
assert resolved.metric == "m1"
assert resolved.candidate == "c2"

def test_config_provides_metric_and_candidate(self):
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
config = common_types.LossAnalysisConfig(
metric="m1", candidate="c1", predefined_taxonomy="my_taxonomy"
)
resolved = _evals_utils._resolve_loss_analysis_config(
eval_result=eval_result, config=config
)
assert resolved.metric == "m1"
assert resolved.candidate == "c1"
assert resolved.predefined_taxonomy == "my_taxonomy"

def test_explicit_args_override_config(self):
eval_result = _make_eval_result(
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
)
config = common_types.LossAnalysisConfig(metric="m1", candidate="c1")
resolved = _evals_utils._resolve_loss_analysis_config(
eval_result=eval_result, config=config, metric="m2", candidate="c2"
)
assert resolved.metric == "m2"
assert resolved.candidate == "c2"

def test_error_multiple_metrics_no_explicit(self):
eval_result = _make_eval_result(metrics=["m1", "m2"], candidate_names=["c1"])
with pytest.raises(ValueError, match="multiple metrics"):
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)

def test_error_multiple_candidates_no_explicit(self):
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1", "c2"])
with pytest.raises(ValueError, match="multiple candidates"):
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)

def test_error_invalid_metric(self):
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
with pytest.raises(ValueError, match="not found in eval_result"):
_evals_utils._resolve_loss_analysis_config(
eval_result=eval_result, metric="nonexistent"
)

def test_error_invalid_candidate(self):
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
with pytest.raises(ValueError, match="not found in eval_result"):
_evals_utils._resolve_loss_analysis_config(
eval_result=eval_result, candidate="nonexistent"
)

def test_no_candidates_defaults_to_candidate_1(self):
eval_result = _make_eval_result(metrics=["m1"], candidate_names=[])
eval_result = eval_result.model_copy(
update={"metadata": common_types.EvaluationRunMetadata()}
)
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
assert resolved.metric == "m1"
assert resolved.candidate == "candidate_1"

def test_no_eval_case_results_raises(self):
eval_result = common_types.EvaluationResult()
with pytest.raises(ValueError, match="no metric results"):
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)


class TestEvals:
"""Unit tests for the GenAI client."""

Expand Down
142 changes: 142 additions & 0 deletions vertexai/_genai/_evals_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,148 @@ def _display_loss_analysis_result(
print(df.to_string()) # pylint: disable=print-function


def _resolve_metric_name(
metric: Optional[Any],
) -> Optional[str]:
"""Extracts a metric name string from a metric argument.

Accepts a string, a Metric object, or a LazyLoadedPrebuiltMetric
(RubricMetric) and returns the metric name as a string.

For LazyLoadedPrebuiltMetric (e.g., RubricMetric.MULTI_TURN_TASK_SUCCESS),
this resolves to the API metric spec name (e.g.,
"multi_turn_task_success_v1") so it matches the keys in eval results.

Args:
metric: A metric name string, Metric object, RubricMetric enum value, or
None.

Returns:
The metric name as a string, or None if metric is None.
"""
if metric is None:
return None
if isinstance(metric, str):
return metric
# LazyLoadedPrebuiltMetric: resolve to versioned API spec name.
if hasattr(metric, "_get_api_metric_spec_name"):
spec_name: Optional[str] = metric._get_api_metric_spec_name()
if spec_name:
return spec_name
# Metric objects and other types with a .name attribute.
if hasattr(metric, "name"):
return str(metric.name)
return str(metric)


def _resolve_loss_analysis_config(
eval_result: types.EvaluationResult,
config: Optional[types.LossAnalysisConfig] = None,
metric: Optional[str] = None,
candidate: Optional[str] = None,
) -> types.LossAnalysisConfig:
"""Resolves and validates the LossAnalysisConfig for generate_loss_clusters.

Auto-infers `metric` and `candidate` from the EvaluationResult when not
explicitly provided. Validates that provided values exist in the eval result.

Args:
eval_result: The EvaluationResult from client.evals.evaluate().
config: Optional explicit LossAnalysisConfig. If provided, metric and
candidate from config take precedence over the separate arguments.
metric: Optional metric name override.
candidate: Optional candidate name override.

Returns:
A resolved LossAnalysisConfig with metric and candidate populated.

Raises:
ValueError: If metric/candidate cannot be inferred or are invalid.
"""
# Start from config if provided, otherwise create a new one.
if config is not None:
resolved_metric = metric or config.metric
resolved_candidate = candidate or config.candidate
resolved_config = config.model_copy(
update={"metric": resolved_metric, "candidate": resolved_candidate}
)
else:
resolved_config = types.LossAnalysisConfig(metric=metric, candidate=candidate)

# Collect available metric names from the eval result.
available_metrics: set[str] = set()
if eval_result.eval_case_results:
for case_result in eval_result.eval_case_results:
for resp_cand in case_result.response_candidate_results or []:
for m_name in (resp_cand.metric_results or {}).keys():
available_metrics.add(m_name)

# Collect available candidate names from metadata.
available_candidates: list[str] = []
if eval_result.metadata and eval_result.metadata.candidate_names:
available_candidates = list(eval_result.metadata.candidate_names)

# Auto-infer metric if not provided.
if not resolved_config.metric:
if len(available_metrics) == 1:
resolved_config = resolved_config.model_copy(
update={"metric": next(iter(available_metrics))}
)
elif len(available_metrics) == 0:
raise ValueError(
"Cannot infer metric: no metric results found in eval_result."
" Please provide metric explicitly via"
" config=types.LossAnalysisConfig(metric='...')."
)
else:
raise ValueError(
"Cannot infer metric: multiple metrics found in eval_result:"
f" {sorted(available_metrics)}. Please provide metric"
" explicitly via config=types.LossAnalysisConfig(metric='...')."
)

# Validate metric if provided explicitly.
if available_metrics and resolved_config.metric not in available_metrics:
raise ValueError(
f"Metric '{resolved_config.metric}' not found in eval_result."
f" Available metrics: {sorted(available_metrics)}."
)

# Auto-infer candidate if not provided.
if not resolved_config.candidate:
if len(available_candidates) == 1:
resolved_config = resolved_config.model_copy(
update={"candidate": available_candidates[0]}
)
elif len(available_candidates) == 0:
# Fallback: use default candidate naming convention from SDK.
resolved_config = resolved_config.model_copy(
update={"candidate": "candidate_1"}
)
logger.warning(
"No candidate names found in eval_result.metadata."
" Defaulting to 'candidate_1'. If this is incorrect, provide"
" candidate explicitly via"
" config=types.LossAnalysisConfig(candidate='...')."
)
else:
raise ValueError(
"Cannot infer candidate: multiple candidates found in"
f" eval_result: {available_candidates}. Please provide"
" candidate explicitly via"
" config=types.LossAnalysisConfig(candidate='...')."
)

# Validate candidate if provided explicitly and candidates are known.
if available_candidates and resolved_config.candidate not in available_candidates:
raise ValueError(
f"Candidate '{resolved_config.candidate}' not found in"
f" eval_result. Available candidates: {available_candidates}."
)

return resolved_config


def _poll_operation(
api_client: BaseApiClient,
operation: types.GenerateLossClustersOperation,
Expand Down
Loading
Loading