diff --git a/google/genai/models.py b/google/genai/models.py index 4a62a917e..dc4365854 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -37,6 +37,44 @@ logger = logging.getLogger('google_genai.models') +def _filter_thought_parts( + return_value: 'types.GenerateContentResponse', + parameter_model: 'types._GenerateContentParameters', +) -> 'types.GenerateContentResponse': + """Filters thought parts from response when include_thoughts is False. + + When the Vertex AI API returns thought parts despite include_thoughts=False + being set in ThinkingConfig, this function performs client-side filtering + to suppress them. The part.thought flag is reliably set by the API, so + filtering on it is safe. + + Args: + return_value: The GenerateContentResponse to filter. + parameter_model: The request parameters, used to read ThinkingConfig. + + Returns: + The response with thought parts removed if include_thoughts=False, + otherwise the response unchanged. + """ + config = parameter_model.config + if config is None: + return return_value + thinking_config = getattr(config, 'thinking_config', None) + if thinking_config is None: + return return_value + include_thoughts = getattr(thinking_config, 'include_thoughts', None) + if include_thoughts is not False: + return return_value + if not return_value.candidates: + return return_value + for candidate in return_value.candidates: + if candidate.content and candidate.content.parts: + candidate.content.parts = [ + part for part in candidate.content.parts if not part.thought + ] + return return_value + + def _PersonGeneration_to_mldev_enum_validate(enum_value: Any) -> None: if enum_value in set(['ALLOW_ALL']): raise ValueError(f'{enum_value} enum value is not supported in Gemini API.') @@ -4725,7 +4763,7 @@ def _generate_content( headers=response.headers ) self._api_client._verify_response(return_value) - return return_value + return _filter_thought_parts(return_value, parameter_model) def _generate_content_stream( self, @@ -4826,7 +4864,7 @@ def _generate_content_stream( headers=response.headers ) self._api_client._verify_response(return_value) - yield return_value + yield _filter_thought_parts(return_value, parameter_model) def _embed_content( self, @@ -6891,7 +6929,7 @@ async def _generate_content( headers=response.headers ) self._api_client._verify_response(return_value) - return return_value + return _filter_thought_parts(return_value, parameter_model) async def _generate_content_stream( self, @@ -6995,7 +7033,7 @@ async def async_generator(): # type: ignore[no-untyped-def] headers=response.headers ) self._api_client._verify_response(return_value) - yield return_value + yield _filter_thought_parts(return_value, parameter_model) return async_generator() # type: ignore[no-untyped-call, no-any-return] diff --git a/google/genai/tests/models/test_filter_thought_parts.py b/google/genai/tests/models/test_filter_thought_parts.py new file mode 100644 index 000000000..7e07509d3 --- /dev/null +++ b/google/genai/tests/models/test_filter_thought_parts.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for _filter_thought_parts. + +Verifies client-side filtering of thought parts when include_thoughts=False +is set in ThinkingConfig. Regression test for: +https://github.com/googleapis/python-genai/issues/2239 +""" + +import pytest + +from ... import _transformers as t +from ... import types +from ...models import _filter_thought_parts + + +def _make_response_with_thoughts() -> types.GenerateContentResponse: + """Build a synthetic response resembling Vertex AI image gen with thoughts.""" + parts = [ + types.Part(text='thinking step 1', thought=True), + types.Part(text='thinking step 2', thought=True), + types.Part(inline_data=types.Blob(mime_type='image/png', data=b'draft'), thought=True), + types.Part(text='self-critique', thought=True), + types.Part(inline_data=types.Blob(mime_type='image/png', data=b'final')), + ] + content = types.Content(role='model', parts=parts) + candidate = types.Candidate(content=content) + return types.GenerateContentResponse(candidates=[candidate]) + + +def _make_parameter_model(include_thoughts) -> types._GenerateContentParameters: + return types._GenerateContentParameters( + model='gemini-3.1-flash-image-preview', + contents=t.t_contents('Draw a red car'), + config=types.GenerateContentConfig( + thinking_config=types.ThinkingConfig(include_thoughts=include_thoughts) + ), + ) + + +class TestFilterThoughtParts: + + def test_include_thoughts_false_removes_thought_parts(self): + """When include_thoughts=False, all parts with thought=True are removed.""" + response = _make_response_with_thoughts() + parameter_model = _make_parameter_model(include_thoughts=False) + + result = _filter_thought_parts(response, parameter_model) + + parts = result.candidates[0].content.parts + assert all(not part.thought for part in parts), ( + 'Expected no thought parts but found some' + ) + assert len(parts) == 1, f'Expected 1 non-thought part, got {len(parts)}' + assert parts[0].inline_data is not None + assert parts[0].inline_data.data == b'final' + + def test_include_thoughts_true_preserves_all_parts(self): + """When include_thoughts=True, no parts are filtered.""" + response = _make_response_with_thoughts() + parameter_model = _make_parameter_model(include_thoughts=True) + + result = _filter_thought_parts(response, parameter_model) + + parts = result.candidates[0].content.parts + assert len(parts) == 5, f'Expected 5 parts, got {len(parts)}' + + def test_include_thoughts_none_preserves_all_parts(self): + """When include_thoughts is unset, no parts are filtered.""" + response = _make_response_with_thoughts() + parameter_model = _make_parameter_model(include_thoughts=None) + + result = _filter_thought_parts(response, parameter_model) + + parts = result.candidates[0].content.parts + assert len(parts) == 5 + + def test_no_thinking_config_preserves_all_parts(self): + """When ThinkingConfig is absent entirely, no parts are filtered.""" + response = _make_response_with_thoughts() + parameter_model = types._GenerateContentParameters( + model='gemini-3.1-flash-image-preview', + contents=t.t_contents('Draw a red car'), + config=types.GenerateContentConfig(), + ) + + result = _filter_thought_parts(response, parameter_model) + + parts = result.candidates[0].content.parts + assert len(parts) == 5 + + def test_no_config_preserves_all_parts(self): + """When config is None entirely, no parts are filtered.""" + response = _make_response_with_thoughts() + parameter_model = types._GenerateContentParameters( + model='gemini-3.1-flash-image-preview', + contents=t.t_contents('Draw a red car'), + ) + + result = _filter_thought_parts(response, parameter_model) + + parts = result.candidates[0].content.parts + assert len(parts) == 5 + + def test_empty_candidates_is_safe(self): + """Response with no candidates does not raise.""" + response = types.GenerateContentResponse(candidates=[]) + parameter_model = _make_parameter_model(include_thoughts=False) + + result = _filter_thought_parts(response, parameter_model) + + assert result.candidates == [] + + def test_no_thought_parts_in_response(self): + """If API returns no thought parts, filtering is a no-op.""" + parts = [ + types.Part(inline_data=types.Blob(mime_type='image/png', data=b'final')), + ] + content = types.Content(role='model', parts=parts) + candidate = types.Candidate(content=content) + response = types.GenerateContentResponse(candidates=[candidate]) + parameter_model = _make_parameter_model(include_thoughts=False) + + result = _filter_thought_parts(response, parameter_model) + + assert len(result.candidates[0].content.parts) == 1