From 3119e71a4de71df46c1ab4817732b011a9fcddfa Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Wed, 8 Apr 2026 14:09:20 -0700 Subject: [PATCH] chore: GenAI Client(evals) - Improve retry budget, add jitter, and expand retryable errors PiperOrigin-RevId: 896691317 --- tests/unit/vertexai/genai/test_evals.py | 192 +++++++++++++++++- vertexai/_genai/_evals_common.py | 7 +- vertexai/_genai/_evals_metric_handlers.py | 234 +++++++++++++--------- vertexai/_genai/_evals_utils.py | 53 ++++- vertexai/_genai/evals.py | 4 + vertexai/_genai/types/common.py | 11 + 6 files changed, 393 insertions(+), 108 deletions(-) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 4f09093fa7..cf2a662d47 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -6194,6 +6194,8 @@ def test_predefined_metric_retry_fail_on_resource_exhausted( genai_errors.ClientError(code=429, response_json=error_response_json), genai_errors.ClientError(code=429, response_json=error_response_json), genai_errors.ClientError(code=429, response_json=error_response_json), + genai_errors.ClientError(code=429, response_json=error_response_json), + genai_errors.ClientError(code=429, response_json=error_response_json), ] result = _evals_common._execute_evaluation( @@ -6202,18 +6204,13 @@ def test_predefined_metric_retry_fail_on_resource_exhausted( metrics=[metric], ) - assert mock_private_evaluate_instances.call_count == 3 - assert mock_sleep.call_count == 2 + assert mock_private_evaluate_instances.call_count == 5 + assert mock_sleep.call_count == 4 assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "summarization_quality" assert summary_metric.mean_score is None assert summary_metric.num_cases_error == 1 - assert ( - "Judge model resource exhausted after 3 retries" - ) in result.eval_case_results[0].response_candidate_results[0].metric_results[ - "summarization_quality" - ].error_message class TestEvaluationDataset: @@ -6734,3 +6731,184 @@ def test_create_evaluation_set_with_agent_data( candidate_response = candidate_responses[0] assert candidate_response["candidate"] == "test-candidate" assert candidate_response["agent_data"] == agent_data + + +class TestRateLimiter: + """Tests for the RateLimiter class in _evals_utils.""" + + def test_rate_limiter_init(self): + """Tests that RateLimiter initializes correctly.""" + limiter = _evals_utils.RateLimiter(rate=10.0) + assert limiter.seconds_per_event == pytest.approx(0.1) + + def test_rate_limiter_invalid_rate(self): + """Tests that RateLimiter raises ValueError for non-positive rate.""" + with pytest.raises(ValueError, match="Rate must be a positive number"): + _evals_utils.RateLimiter(rate=0) + with pytest.raises(ValueError, match="Rate must be a positive number"): + _evals_utils.RateLimiter(rate=-1) + + @mock.patch("time.sleep", return_value=None) + @mock.patch("time.monotonic") + def test_rate_limiter_sleep_and_advance(self, mock_monotonic, mock_sleep): + """Tests that sleep_and_advance properly throttles calls.""" + # With rate=10 (0.1s interval): + # - __init__ at t=0: _next_allowed = 0.0 + # - first call at t=0: no delay, _next_allowed = 0.1 + # - second call at t=0.01: delay = 0.1 - 0.01 = 0.09 + mock_monotonic.side_effect = [ + 0.0, # __init__: time.monotonic() + 0.0, # first sleep_and_advance: now + 0.01, # second sleep_and_advance: now + ] + limiter = _evals_utils.RateLimiter(rate=10.0) + limiter.sleep_and_advance() # First call - should not sleep + limiter.sleep_and_advance() # Second call - should sleep + assert mock_sleep.call_count == 1 + # Verify sleep was called with approximately the right delay + sleep_delay = mock_sleep.call_args[0][0] + assert 0.08 < sleep_delay <= 0.1 + + def test_rate_limiter_no_sleep_when_enough_time_passed(self): + """Tests that no sleep occurs when enough time has passed.""" + import time as real_time + + limiter = _evals_utils.RateLimiter(rate=1000.0) # Very high rate + # With rate=1000, interval is 0.001s - should not sleep + start = real_time.time() + for _ in range(5): + limiter.sleep_and_advance() + elapsed = real_time.time() - start + # 5 calls at 1000 QPS should take ~0.005s, certainly under 1s + assert elapsed < 1.0 + + +class TestCallWithRetry: + """Tests for the shared _call_with_retry helper.""" + + @mock.patch("time.sleep", return_value=None) + def test_call_with_retry_success_on_first_try(self, mock_sleep): + """Tests that _call_with_retry returns immediately on success.""" + fn = mock.Mock(return_value="success") + result = _evals_metric_handlers._call_with_retry(fn, "test_metric") + assert result == "success" + assert fn.call_count == 1 + assert mock_sleep.call_count == 0 + + @mock.patch("time.sleep", return_value=None) + def test_call_with_retry_success_after_retries(self, mock_sleep): + """Tests that _call_with_retry succeeds after transient failures.""" + error_json = {"error": {"code": 429, "message": "exhausted"}} + fn = mock.Mock( + side_effect=[ + genai_errors.ClientError(code=429, response_json=error_json), + genai_errors.ClientError(code=429, response_json=error_json), + "success", + ] + ) + result = _evals_metric_handlers._call_with_retry(fn, "test_metric") + assert result == "success" + assert fn.call_count == 3 + assert mock_sleep.call_count == 2 + + @mock.patch("time.sleep", return_value=None) + def test_call_with_retry_raises_after_max_retries(self, mock_sleep): + """Tests that _call_with_retry raises after exhausting retries.""" + error_json = {"error": {"code": 429, "message": "exhausted"}} + fn = mock.Mock( + side_effect=genai_errors.ClientError(code=429, response_json=error_json) + ) + with pytest.raises(genai_errors.ClientError): + _evals_metric_handlers._call_with_retry(fn, "test_metric") + assert fn.call_count == 5 # _MAX_RETRIES + assert mock_sleep.call_count == 4 + + @mock.patch("time.sleep", return_value=None) + def test_call_with_retry_retries_on_server_error(self, mock_sleep): + """Tests retry on 503 ServiceUnavailable (ServerError).""" + error_json = {"error": {"code": 503, "message": "unavailable"}} + fn = mock.Mock( + side_effect=[ + genai_errors.ServerError(code=503, response_json=error_json), + "success", + ] + ) + result = _evals_metric_handlers._call_with_retry(fn, "test_metric") + assert result == "success" + assert fn.call_count == 2 + + @mock.patch("time.sleep", return_value=None) + def test_call_with_retry_no_retry_on_non_retryable(self, mock_sleep): + """Tests that non-retryable errors are raised immediately.""" + error_json = {"error": {"code": 400, "message": "bad request"}} + fn = mock.Mock( + side_effect=genai_errors.ClientError(code=400, response_json=error_json) + ) + with pytest.raises(genai_errors.ClientError): + _evals_metric_handlers._call_with_retry(fn, "test_metric") + assert fn.call_count == 1 + assert mock_sleep.call_count == 0 + + +class TestComputationMetricRetry: + """Tests for retry behavior in ComputationMetricHandler.""" + + @mock.patch.object( + _evals_metric_handlers.ComputationMetricHandler, + "SUPPORTED_COMPUTATION_METRICS", + frozenset(["bleu"]), + ) + @mock.patch("time.sleep", return_value=None) + # fmt: off + @mock.patch( + "vertexai._genai.evals.Evals.evaluate_instances" + ) + # fmt: on + def test_computation_metric_retry_on_resource_exhausted( + self, + mock_evaluate_instances, + mock_sleep, + mock_api_client_fixture, + ): + """Tests that ComputationMetricHandler retries on 429.""" + dataset_df = pd.DataFrame( + [ + { + "prompt": "Test prompt", + "response": "Test response", + "reference": "Test reference", + } + ] + ) + input_dataset = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=dataset_df + ) + metric = vertexai_genai_types.Metric(name="bleu") + error_response_json = { + "error": { + "code": 429, + "message": "Resource exhausted.", + "status": "RESOURCE_EXHAUSTED", + } + } + mock_bleu_result = mock.MagicMock() + mock_bleu_result.model_dump.return_value = { + "bleu_results": {"bleu_metric_values": [{"score": 0.85}]} + } + mock_evaluate_instances.side_effect = [ + genai_errors.ClientError(code=429, response_json=error_response_json), + genai_errors.ClientError(code=429, response_json=error_response_json), + mock_bleu_result, + ] + + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset, + metrics=[metric], + ) + + assert mock_evaluate_instances.call_count == 3 + assert mock_sleep.call_count == 2 + summary_metric = result.summary_metrics[0] + assert summary_metric.metric_name == "bleu" + assert summary_metric.mean_score == 0.85 diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index b51017a889..3a585d287f 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -1532,6 +1532,7 @@ def _execute_evaluation( # type: ignore[no-untyped-def] dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None, dest: Optional[str] = None, location: Optional[str] = None, + evaluation_service_qps: Optional[float] = None, **kwargs, ) -> types.EvaluationResult: """Evaluates a dataset using the provided metrics. @@ -1544,6 +1545,9 @@ def _execute_evaluation( # type: ignore[no-untyped-def] dest: The destination to save the evaluation results. location: The location to use for the evaluation. If not specified, the location configured in the client will be used. + evaluation_service_qps: The rate limit (queries per second) for calls + to the evaluation service. Defaults to 10. Increase this value if + your project has a higher EvaluateInstances API quota. **kwargs: Extra arguments to pass to evaluation, such as `agent_info`. Returns: @@ -1619,7 +1623,8 @@ def _execute_evaluation( # type: ignore[no-untyped-def] logger.info("Running Metric Computation...") t1 = time.perf_counter() evaluation_result = _evals_metric_handlers.compute_metrics_and_aggregate( - evaluation_run_config + evaluation_run_config, + evaluation_service_qps=evaluation_service_qps, ) t2 = time.perf_counter() logger.info("Evaluation took: %f seconds", t2 - t1) diff --git a/vertexai/_genai/_evals_metric_handlers.py b/vertexai/_genai/_evals_metric_handlers.py index 9d72bafc86..4571802dbc 100644 --- a/vertexai/_genai/_evals_metric_handlers.py +++ b/vertexai/_genai/_evals_metric_handlers.py @@ -19,6 +19,7 @@ from concurrent import futures import json import logging +import random import statistics import time from typing import Any, Callable, Generic, Optional, TypeVar, Union @@ -31,17 +32,80 @@ from . import _evals_common from . import _evals_constant +from . import _evals_utils from . import evals from . import types logger = logging.getLogger(__name__) -_MAX_RETRIES = 3 - +_MAX_RETRIES = 5 +# HTTP status codes that are safe to retry with backoff. +_RETRYABLE_STATUS_CODES = frozenset( + { + 408, # RequestTimeout (DEADLINE_EXCEEDED) + 409, # Conflict / Aborted (ABORTED) + 429, # TooManyRequests / ResourceExhausted (RESOURCE_EXHAUSTED) + 499, # Client Closed Request (CANCELLED) + 500, # InternalServerError (INTERNAL) + 502, # BadGateway + 503, # ServiceUnavailable (UNAVAILABLE) + 504, # GatewayTimeout (DEADLINE_EXCEEDED) + } +) +R = TypeVar("R") T = TypeVar("T", types.Metric, types.MetricSource, types.LLMMetric) +def _call_with_retry( + fn: Callable[[], R], + metric_name: str, +) -> R: + """Calls ``fn()`` with exponential backoff + jitter on retryable errors. + + Retries up to ``_MAX_RETRIES`` times on errors whose HTTP status code is + in ``_RETRYABLE_STATUS_CODES`` (Aborted, DeadlineExceeded, + ResourceExhausted, ServiceUnavailable, Cancelled). Non-retryable errors + are re-raised immediately. If all retries are exhausted the last + exception is re-raised so the caller can decide how to handle it. + + Args: + fn: A zero-argument callable that performs the API call. + metric_name: Name of the metric, used for log messages. + + Returns: + The return value of ``fn()``. + + Raises: + genai_errors.APIError: If all retries are exhausted or the error is + not retryable. + """ + for attempt in range(_MAX_RETRIES): + try: + return fn() + except genai_errors.APIError as e: + if e.code in _RETRYABLE_STATUS_CODES: + backoff = 2**attempt + random.uniform(0, 1) + logger.warning( + "Retryable error (code=%s) on attempt %d/%d for metric" + " '%s': %s. Retrying in %.1f seconds...", + e.code, + attempt + 1, + _MAX_RETRIES, + metric_name, + e, + backoff, + ) + if attempt == _MAX_RETRIES - 1: + raise + time.sleep(backoff) + else: + raise + raise genai_errors.APIError( + code=504, response_json={"message": "Retries exhausted"} + ) + + def _has_tool_call(events: Optional[list[Any]]) -> bool: """Checks if any event in events has a function call.""" if not events: @@ -344,9 +408,12 @@ def get_metric_result( metric_name, eval_case.model_dump(exclude_none=True), ) - response = self.module.evaluate_instances( - metric_config=self._build_request_payload(eval_case, response_index) - ).model_dump(exclude_none=True) + response = _call_with_retry( + lambda: self.module.evaluate_instances( + metric_config=self._build_request_payload(eval_case, response_index) + ).model_dump(exclude_none=True), + metric_name, + ) logger.debug("response: %s", response) score = None for _, result_value in response.items(): @@ -459,8 +526,11 @@ def get_metric_result( metric_name, eval_case, ) - api_response = self.module.evaluate_instances( - metric_config=self._build_request_payload(eval_case, response_index) + api_response = _call_with_retry( + lambda: self.module.evaluate_instances( + metric_config=self._build_request_payload(eval_case, response_index) + ), + metric_name, ) logger.debug("API Response: %s", api_response) @@ -678,18 +748,13 @@ def get_metric_result( instance = _build_evaluation_instance( eval_case, response_content, prompt_template=self.metric.prompt_template ) - for attempt in range(_MAX_RETRIES): - try: - api_response = self.module._evaluate_instances( - metrics=[self.metric], - instance=instance, - ) - break - except genai_errors.ClientError as e: - if e.code == 429 and attempt < _MAX_RETRIES - 1: - time.sleep(2**attempt) - continue - raise e + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=instance, + ), + self.metric_name, + ) if api_response and api_response.metric_results: result = api_response.metric_results[0] @@ -976,32 +1041,14 @@ def get_metric_result( metric_name = self.metric.name try: payload = self._build_request_payload(eval_case, response_index) - for attempt in range(_MAX_RETRIES): - try: - api_response = self.module._evaluate_instances( - metrics=[self.metric], - instance=payload.get("instance"), - autorater_config=payload.get("autorater_config"), - ) - break - except genai_errors.ClientError as e: - if e.code == 429: - logger.warning( - "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" - " seconds...", - attempt + 1, - _MAX_RETRIES, - e, - 2**attempt, - ) - if attempt == _MAX_RETRIES - 1: - return types.EvalCaseMetricResult( - metric_name=metric_name, - error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}", - ) - time.sleep(2**attempt) - else: - raise e + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=payload.get("instance"), + autorater_config=payload.get("autorater_config"), + ), + metric_name, + ) if ( api_response @@ -1115,31 +1162,13 @@ def get_metric_result( metric_name = self.metric.name try: payload = self._build_request_payload(eval_case, response_index) - for attempt in range(_MAX_RETRIES): - try: - api_response = self.module._evaluate_instances( - metrics=[self.metric], - instance=payload.get("instance"), - ) - break - except genai_errors.ClientError as e: - if e.code == 429: - logger.warning( - "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" - " seconds...", - attempt + 1, - _MAX_RETRIES, - e, - 2**attempt, - ) - if attempt == _MAX_RETRIES - 1: - return types.EvalCaseMetricResult( - metric_name=metric_name, - error_message=f"Resource exhausted after {_MAX_RETRIES} retries: {e}", - ) - time.sleep(2**attempt) - else: - raise e + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=payload.get("instance"), + ), + metric_name, + ) if ( api_response @@ -1259,32 +1288,14 @@ def get_metric_result( try: payload = self._build_request_payload(eval_case, response_index) - for attempt in range(_MAX_RETRIES): - try: - api_response = self.module._evaluate_instances( - metric_sources=[metric_source], - instance=payload.get("instance"), - autorater_config=payload.get("autorater_config"), - ) - break - except genai_errors.ClientError as e: - if e.code == 429: - logger.warning( - "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" - " seconds...", - attempt + 1, - _MAX_RETRIES, - e, - 2**attempt, - ) - if attempt == _MAX_RETRIES - 1: - return types.EvalCaseMetricResult( - metric_name=metric_name, - error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}", - ) - time.sleep(2**attempt) - else: - raise e + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metric_sources=[metric_source], + instance=payload.get("instance"), + autorater_config=payload.get("autorater_config"), + ), + metric_name, + ) if api_response and api_response.metric_results: result_data = api_response.metric_results[0] @@ -1498,10 +1509,29 @@ class EvaluationRunConfig(_common.BaseModel): """The number of response candidates for the evaluation run.""" +def _rate_limited_get_metric_result( + rate_limiter: _evals_utils.RateLimiter, + handler: MetricHandler[Any], + eval_case: types.EvalCase, + response_index: int, +) -> types.EvalCaseMetricResult: + """Wraps a handler's get_metric_result with rate limiting.""" + rate_limiter.sleep_and_advance() + return handler.get_metric_result(eval_case, response_index) + + def compute_metrics_and_aggregate( evaluation_run_config: EvaluationRunConfig, + evaluation_service_qps: Optional[float] = None, ) -> types.EvaluationResult: - """Computes metrics and aggregates them for a given evaluation run config.""" + """Computes metrics and aggregates them for a given evaluation run config. + + Args: + evaluation_run_config: The configuration for the evaluation run. + evaluation_service_qps: Optional QPS limit for the evaluation service. + Defaults to _DEFAULT_EVAL_SERVICE_QPS (10). Users with higher + quotas can increase this value. + """ metric_handlers = [] all_futures = [] results_by_case_response_metric: collections.defaultdict[ @@ -1511,6 +1541,12 @@ def compute_metrics_and_aggregate( execution_errors = [] case_indices_with_errors = set() + if evaluation_service_qps is not None and evaluation_service_qps <= 0: + raise ValueError("evaluation_service_qps must be a positive number.") + qps = evaluation_service_qps or _evals_utils._DEFAULT_EVAL_SERVICE_QPS + rate_limiter = _evals_utils.RateLimiter(rate=qps) + logger.info("Rate limiting evaluation service requests to %.1f QPS.", qps) + for eval_metric in evaluation_run_config.metrics: metric_handlers.append( get_handler_for_metric(evaluation_run_config.evals_module, eval_metric) @@ -1553,7 +1589,9 @@ def compute_metrics_and_aggregate( for response_index in range(actual_num_candidates_for_case): try: future = executor.submit( - metric_handler_instance.get_metric_result, + _rate_limited_get_metric_result, + rate_limiter, + metric_handler_instance, eval_case, response_index, ) diff --git a/vertexai/_genai/_evals_utils.py b/vertexai/_genai/_evals_utils.py index 9d4dd4fc71..8754e280e0 100644 --- a/vertexai/_genai/_evals_utils.py +++ b/vertexai/_genai/_evals_utils.py @@ -15,9 +15,11 @@ """Utility functions for evals.""" import abc +import json import logging import os -import json +import threading +import time from typing import Any, Optional, Union from google.genai._api_client import BaseApiClient @@ -36,12 +38,59 @@ GCS_PREFIX = "gs://" BQ_PREFIX = "bq://" +_DEFAULT_EVAL_SERVICE_QPS = 10 + + +class RateLimiter: + """Helper class for rate-limiting requests to Vertex AI to improve QoS. + + Implements a token bucket algorithm to limit the rate at which API calls + can occur. Designed for cases where the batch size is always 1 for traffic + shaping and rate limiting. + + Attributes: + seconds_per_event: The time interval (in seconds) between events to + maintain the desired rate. + last: The timestamp of the last event. + _lock: A lock to ensure thread safety. + """ + + def __init__(self, rate: float) -> None: + """Initializes the rate limiter. + + Args: + rate: The number of queries allowed per second. + + Raises: + ValueError: If the rate is not positive. + """ + if not rate or rate <= 0: + raise ValueError("Rate must be a positive number") + self.seconds_per_event = 1.0 / rate + self._next_allowed = time.monotonic() + self._lock = threading.Lock() + + def sleep_and_advance(self) -> None: + """Blocks the current thread until the next event can be admitted. + + The lock is held only long enough to reserve a time slot. The + actual sleep happens outside the lock so that multiple threads + can be sleeping concurrently with staggered wake-up times. + """ + with self._lock: + now = time.monotonic() + wait_until = max(now, self._next_allowed) + delay = wait_until - now + self._next_allowed = wait_until + self.seconds_per_event + + if delay > 0: + time.sleep(delay) class EvalDatasetLoader: """A loader for datasets from various sources, using a shared client.""" - def __init__(self, api_client: BaseApiClient): + def __init__(self, api_client: BaseApiClient) -> None: self.api_client = api_client self.gcs_utils = _gcs_utils.GcsUtils(self.api_client) self.bigquery_utils = _bigquery_utils.BigQueryUtils(self.api_client) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index ce516c8cf6..a5a9ce0581 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -2024,6 +2024,9 @@ def evaluate( - dataset_schema: Schema to use for the dataset. If not specified, the dataset schema will be inferred from the dataset automatically. - dest: Destination path for storing evaluation results. + - evaluation_service_qps: The rate limit (queries per second) for + calls to the evaluation service. Defaults to 10. Increase this + value if your project has a higher EvaluateInstances API quota. **kwargs: Extra arguments to pass to evaluation, such as `agent_info`. Returns: @@ -2065,6 +2068,7 @@ def evaluate( dataset_schema=config.dataset_schema, dest=config.dest, location=location, + evaluation_service_qps=getattr(config, "evaluation_service_qps", None), **kwargs, ) diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index d017fda1d1..c72fafa485 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -15442,6 +15442,12 @@ class EvaluateMethodConfig(_common.BaseModel): dest: Optional[str] = Field( default=None, description="""The destination path for the evaluation results.""" ) + evaluation_service_qps: Optional[float] = Field( + default=None, + description="""The rate limit (queries per second) for calls to the + evaluation service. Defaults to 10. Increase this value if your + project has a higher EvaluateInstances API quota.""", + ) class EvaluateMethodConfigDict(TypedDict, total=False): @@ -15458,6 +15464,11 @@ class EvaluateMethodConfigDict(TypedDict, total=False): dest: Optional[str] """The destination path for the evaluation results.""" + evaluation_service_qps: Optional[float] + """The rate limit (queries per second) for calls to the + evaluation service. Defaults to 10. Increase this value if your + project has a higher EvaluateInstances API quota.""" + EvaluateMethodConfigOrDict = Union[EvaluateMethodConfig, EvaluateMethodConfigDict]