2121import logging
2222import statistics
2323import time
24- from typing import Any , Callable , Optional , TypeVar , Union
24+ from typing import Any , Callable , Generic , Optional , TypeVar , Union
2525
2626from google .genai import errors as genai_errors
2727from google .genai import _common
3939_MAX_RETRIES = 3
4040
4141
42+ T = TypeVar ("T" , types .Metric , types .MetricSource , types .LLMMetric )
43+
44+
4245def _has_tool_call (intermediate_events : Optional [list [types .evals .Event ]]) -> bool :
4346 """Checks if any event in intermediate_events has a function call."""
4447 if not intermediate_events :
@@ -149,12 +152,18 @@ def _default_aggregate_scores(
149152 )
150153
151154
152- class MetricHandler (abc .ABC ):
155+ class MetricHandler (abc .ABC , Generic [ T ] ):
153156 """Abstract base class for metric handlers."""
154157
155- def __init__ (self , module : "evals.Evals" , metric : types . Metric ):
158+ def __init__ (self , module : "evals.Evals" , metric : T ):
156159 self .module = module
157- self .metric = metric
160+ self .metric : T = metric
161+
162+ @property
163+ @abc .abstractmethod
164+ def metric_name (self ) -> str :
165+ """Returns the name of the metric polymorphically."""
166+ raise NotImplementedError ()
158167
159168 @abc .abstractmethod
160169 def get_metric_result (
@@ -171,7 +180,7 @@ def aggregate(
171180 raise NotImplementedError ()
172181
173182
174- class ComputationMetricHandler (MetricHandler ):
183+ class ComputationMetricHandler (MetricHandler [ types . Metric ] ):
175184 """Metric handler for computation metrics."""
176185
177186 SUPPORTED_COMPUTATION_METRICS = frozenset (
@@ -188,6 +197,10 @@ class ComputationMetricHandler(MetricHandler):
188197 }
189198 )
190199
200+ @property
201+ def metric_name (self ) -> str :
202+ return self .metric .name or "unknown_metric"
203+
191204 def __init__ (self , module : "evals.Evals" , metric : types .Metric ):
192205 super ().__init__ (module = module , metric = metric )
193206 if self .metric .name not in self .SUPPORTED_COMPUTATION_METRICS :
@@ -299,11 +312,15 @@ def aggregate(
299312 return _default_aggregate_scores (self .metric .name , eval_case_metric_results )
300313
301314
302- class TranslationMetricHandler (MetricHandler ):
315+ class TranslationMetricHandler (MetricHandler [ types . Metric ] ):
303316 """Metric handler for translation metrics."""
304317
305318 SUPPORTED_TRANSLATION_METRICS = frozenset ({"comet" , "metricx" })
306319
320+ @property
321+ def metric_name (self ) -> str :
322+ return self .metric .name or "unknown_metric"
323+
307324 def __init__ (self , module : "evals.Evals" , metric : types .Metric ):
308325 super ().__init__ (module = module , metric = metric )
309326
@@ -469,9 +486,13 @@ def aggregate(
469486 return _default_aggregate_scores (self .metric .name , eval_case_metric_results )
470487
471488
472- class LLMMetricHandler (MetricHandler ):
489+ class LLMMetricHandler (MetricHandler [ types . LLMMetric ] ):
473490 """Metric handler for LLM metrics."""
474491
492+ @property
493+ def metric_name (self ) -> str :
494+ return self .metric .name or "unknown_metric"
495+
475496 def __init__ (self , module : "evals.Evals" , metric : types .LLMMetric ):
476497 super ().__init__ (module = module , metric = metric )
477498
@@ -750,9 +771,13 @@ def aggregate(
750771 return _default_aggregate_scores (self .metric .name , eval_case_metric_results )
751772
752773
753- class CustomMetricHandler (MetricHandler ):
774+ class CustomMetricHandler (MetricHandler [ types . Metric ] ):
754775 """Metric handler for custom metrics."""
755776
777+ @property
778+ def metric_name (self ) -> str :
779+ return self .metric .name or "unknown_metric"
780+
756781 def __init__ (self , module : "evals.Evals" , metric : types .Metric ):
757782 super ().__init__ (module = module , metric = metric )
758783
@@ -853,9 +878,13 @@ def aggregate(
853878 return _default_aggregate_scores (self .metric .name , eval_case_metric_results )
854879
855880
856- class PredefinedMetricHandler (MetricHandler ):
881+ class PredefinedMetricHandler (MetricHandler [ types . Metric ] ):
857882 """Metric handler for predefined metrics."""
858883
884+ @property
885+ def metric_name (self ) -> str :
886+ return self .metric .name or "unknown_metric"
887+
859888 def __init__ (self , module : "evals.Evals" , metric : types .Metric ):
860889 super ().__init__ (module = module , metric = metric )
861890 if self .metric .name not in _evals_constant .SUPPORTED_PREDEFINED_METRICS :
@@ -1106,9 +1135,13 @@ def aggregate(
11061135 )
11071136
11081137
1109- class CustomCodeExecutionMetricHandler (MetricHandler ):
1138+ class CustomCodeExecutionMetricHandler (MetricHandler [ types . Metric ] ):
11101139 """Metric handler for custom code execution metrics."""
11111140
1141+ @property
1142+ def metric_name (self ) -> str :
1143+ return self .metric .name or "unknown_metric"
1144+
11121145 def __init__ (self , module : "evals.Evals" , metric : types .Metric ):
11131146 super ().__init__ (module = module , metric = metric )
11141147
@@ -1242,6 +1275,108 @@ def aggregate(
12421275 )
12431276
12441277
1278+ class RegisteredMetricHandler (MetricHandler [types .MetricSource ]):
1279+ """Metric handler for registered metrics."""
1280+
1281+ def __init__ (
1282+ self ,
1283+ module : "evals.Evals" ,
1284+ metric : Union [types .MetricSource , types .MetricSourceDict ],
1285+ ):
1286+ if isinstance (metric , dict ):
1287+ metric = types .MetricSource (** metric )
1288+ super ().__init__ (module = module , metric = metric )
1289+
1290+ # TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
1291+ def _build_request_payload (
1292+ self , eval_case : types .EvalCase , response_index : int
1293+ ) -> dict [str , Any ]:
1294+ """Builds request payload for registered metric."""
1295+ if not self .metric .metric :
1296+ raise ValueError (
1297+ "Registered metric must have an underlying metric definition."
1298+ )
1299+ return PredefinedMetricHandler (
1300+ self .module , metric = self .metric .metric
1301+ )._build_request_payload (eval_case , response_index )
1302+
1303+ @property
1304+ def metric_name (self ) -> str :
1305+ # Resolve name from resource name or internal metric name
1306+ if isinstance (self .metric , types .MetricSource ):
1307+ if self .metric .metric and self .metric .metric .name :
1308+ return self .metric .metric .name
1309+ if self .metric .metric_resource_name :
1310+ return self .metric .metric_resource_name
1311+ return "unknown"
1312+ else : # Should be Metric
1313+ metric_like = self .metric
1314+ if metric_like .name :
1315+ return metric_like .name
1316+ if metric_like .metric_resource_name :
1317+ return metric_like .metric_resource_name
1318+ return "unknown"
1319+
1320+ @override
1321+ def get_metric_result (
1322+ self , eval_case : types .EvalCase , response_index : int
1323+ ) -> types .EvalCaseMetricResult :
1324+ """Processes a single evaluation case for a registered metric."""
1325+ metric_name = self .metric_name
1326+ try :
1327+ payload = self ._build_request_payload (eval_case , response_index )
1328+ for attempt in range (_MAX_RETRIES ):
1329+ try :
1330+ api_response = self .module ._evaluate_instances (
1331+ metric_sources = [self .metric ],
1332+ instance = payload .get ("instance" ),
1333+ autorater_config = payload .get ("autorater_config" ),
1334+ )
1335+ break
1336+ except genai_errors .ClientError as e :
1337+ if e .code == 429 :
1338+ if attempt == _MAX_RETRIES - 1 :
1339+ return types .EvalCaseMetricResult (
1340+ metric_name = metric_name ,
1341+ error_message = f"Judge model resource exhausted after { _MAX_RETRIES } retries: { e } " ,
1342+ )
1343+ time .sleep (2 ** attempt )
1344+ else :
1345+ raise e
1346+
1347+ if api_response and api_response .metric_results :
1348+ result_data = api_response .metric_results [0 ]
1349+ error_message = None
1350+ if result_data .error and getattr (result_data .error , "code" ):
1351+ error_message = f"Error in metric result: { result_data .error } "
1352+ return types .EvalCaseMetricResult (
1353+ metric_name = metric_name ,
1354+ score = result_data .score ,
1355+ explanation = result_data .explanation ,
1356+ rubric_verdicts = result_data .rubric_verdicts ,
1357+ error_message = error_message ,
1358+ )
1359+ else :
1360+ return types .EvalCaseMetricResult (
1361+ metric_name = metric_name ,
1362+ error_message = "Metric results missing in API response." ,
1363+ )
1364+ except Exception as e :
1365+ return types .EvalCaseMetricResult (
1366+ metric_name = metric_name , error_message = str (e )
1367+ )
1368+
1369+ @override
1370+ def aggregate (
1371+ self , eval_case_metric_results : list [types .EvalCaseMetricResult ]
1372+ ) -> types .AggregatedMetricResult :
1373+ """Aggregates the metric results for a registered metric."""
1374+ logger .debug ("Aggregating results for registered metric: %s" , self .metric_name )
1375+ return _default_aggregate_scores (
1376+ self .metric_name , eval_case_metric_results , calculate_pass_rate = True
1377+ )
1378+
1379+
12451380_METRIC_HANDLER_MAPPING = [
12461381 (
12471382 lambda m : hasattr (m , "remote_custom_function" ) and m .remote_custom_function ,
@@ -1251,6 +1386,10 @@ def aggregate(
12511386 lambda m : m .custom_function and isinstance (m .custom_function , Callable ),
12521387 CustomMetricHandler ,
12531388 ),
1389+ (
1390+ lambda m : getattr (m , "metric_resource_name" , None ) is not None ,
1391+ RegisteredMetricHandler ,
1392+ ),
12541393 (
12551394 lambda m : m .name in ComputationMetricHandler .SUPPORTED_COMPUTATION_METRICS ,
12561395 ComputationMetricHandler ,
@@ -1337,14 +1476,14 @@ def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]:
13371476
13381477
13391478def _aggregate_metric_results (
1340- metric_handlers : list [MetricHandler ],
1479+ metric_handlers : list [MetricHandler [ Any ] ],
13411480 eval_case_results : list [types .EvalCaseResult ],
13421481) -> list [types .AggregatedMetricResult ]:
13431482 """Aggregates results by calling the aggregate method of each handler."""
13441483 aggregated_metric_results = []
13451484 logger .info ("Aggregating results per metric..." )
13461485 for handler in metric_handlers :
1347- metric_name = handler .metric . name
1486+ metric_name = handler .metric_name
13481487 results_for_this_metric : list [types .EvalCaseMetricResult ] = []
13491488 for case_result in eval_case_results :
13501489 if case_result .response_candidate_results :
@@ -1473,12 +1612,12 @@ def compute_metrics_and_aggregate(
14731612 "response %d for metric %s." ,
14741613 eval_case_index ,
14751614 response_index ,
1476- metric_handler_instance .metric . name ,
1615+ metric_handler_instance .metric_name ,
14771616 )
14781617 all_futures .append (
14791618 (
14801619 future ,
1481- metric_handler_instance .metric . name ,
1620+ metric_handler_instance .metric_name ,
14821621 eval_case_index ,
14831622 response_index ,
14841623 )
@@ -1489,25 +1628,25 @@ def compute_metrics_and_aggregate(
14891628 "response %d for metric %s: %s" ,
14901629 eval_case_index ,
14911630 response_index ,
1492- metric_handler_instance .metric . name ,
1631+ metric_handler_instance .metric_name ,
14931632 e ,
14941633 exc_info = True ,
14951634 )
14961635 submission_errors .append (
14971636 (
1498- metric_handler_instance .metric . name ,
1637+ metric_handler_instance .metric_name ,
14991638 eval_case_index ,
15001639 response_index ,
15011640 f"Error: { e } " ,
15021641 )
15031642 )
15041643 error_result = types .EvalCaseMetricResult (
1505- metric_name = metric_handler_instance .metric . name ,
1644+ metric_name = metric_handler_instance .metric_name ,
15061645 error_message = f"Submission Error: { e } " ,
15071646 )
15081647 results_by_case_response_metric [eval_case_index ][
15091648 response_index
1510- ][metric_handler_instance .metric . name ] = error_result
1649+ ][metric_handler_instance .metric_name ] = error_result
15111650 case_indices_with_errors .add (eval_case_index )
15121651 pbar .update (1 )
15131652
0 commit comments