From 5cec6062751c37bc65c89eeef78fd769bc45abd5 Mon Sep 17 00:00:00 2001 From: Hashwanth S Date: Thu, 2 Apr 2026 00:23:51 -0700 Subject: [PATCH] [azure-ai-inference] Support dict response_format in _get_internal_response_format Fix _get_internal_response_format to handle OpenAI-style dict response_format (e.g. {"type": "json_schema", "json_schema": {...}}). This enables LangChain's AzureAIChatCompletionsModel and other frameworks that pass response_format as a dict instead of a string or JsonSchemaFormat object. Supports dict types: "text", "json_object", and "json_schema". Raises ValueError for unsupported dict types. Fixes #44201 --- .../azure/ai/inference/_patch.py | 308 ++++++++++++++---- .../tests/test_response_format_conversion.py | 115 +++++++ 2 files changed, 359 insertions(+), 64 deletions(-) create mode 100644 sdk/ai/azure-ai-inference/tests/test_response_format_conversion.py diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py index da95cf93daf9..fe473cf6bd3f 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py @@ -19,7 +19,7 @@ 9. Simplify how chat completions "response_format" is set. Define "response_format" as a flat Union of strings and JsonSchemaFormat object, instead of using auto-generated base/derived classes named ChatCompletionsResponseFormatXxxInternal. -10. Allow UserMessage("my message") in addition to UserMessage(content="my message"). Same applies to +10. Allow UserMessage("my message") in addition to UserMessage(content="my message"). Same applies to AssistantMessage, SystemMessage, DeveloperMessage and ToolMessage. """ @@ -28,7 +28,19 @@ import sys from io import IOBase -from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable +from typing import ( + Any, + Dict, + Union, + IO, + List, + Literal, + Optional, + overload, + Type, + TYPE_CHECKING, + Iterable, +) from azure.core.pipeline import PipelineResponse from azure.core.credentials import AzureKeyCredential @@ -73,7 +85,9 @@ def _get_internal_response_format( - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ], ) -> Optional[_models._models.ChatCompletionsResponseFormat]: """ Internal helper method to convert between the public response format type that's supported in the `complete` method, @@ -103,6 +117,32 @@ def _get_internal_response_format( json_schema=response_format ) ) + elif isinstance(response_format, dict): + rf_type = response_format.get("type") + if rf_type == "text": + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatText() # pylint: disable=protected-access + ) + elif rf_type == "json_object": + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatJsonObject() # pylint: disable=protected-access + ) + elif rf_type == "json_schema": + json_schema_data = response_format.get("json_schema", {}) + internal_response_format = ( + _models._models.ChatCompletionsResponseFormatJsonSchema( # pylint: disable=protected-access + json_schema=_models.JsonSchemaFormat( + name=json_schema_data.get("name", ""), + schema=json_schema_data.get("schema", {}), + description=json_schema_data.get("description"), + strict=json_schema_data.get("strict"), + ) + ) + ) + else: + raise ValueError( + f"Unsupported `response_format` type in dict: {rf_type}" + ) else: raise ValueError(f"Unsupported `response_format` {response_format}") @@ -112,7 +152,9 @@ def _get_internal_response_format( def load_client( - endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any + endpoint: str, + credential: Union[AzureKeyCredential, "TokenCredential"], + **kwargs: Any, ) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: """ Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route @@ -197,10 +239,14 @@ def load_client( ) return image_embedding_client - raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") + raise ValueError( + f"No client available to support AI model type `{model_info.model_type}`" + ) -class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes +class ChatCompletionsClient( + ChatCompletionsClientGenerated +): # pylint: disable=too-many-instance-attributes """ChatCompletionsClient. :param endpoint: Service endpoint URL for AI model inference. Required. @@ -292,11 +338,17 @@ def __init__( temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ] = None, stop: Optional[List[str]] = None, tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + Union[ + str, + _models.ChatCompletionsToolChoicePreset, + _models.ChatCompletionsNamedToolChoice, + ] ] = None, seed: Optional[int] = None, model: Optional[str] = None, @@ -347,11 +399,17 @@ def complete( temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ] = None, stop: Optional[List[str]] = None, tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + Union[ + str, + _models.ChatCompletionsToolChoicePreset, + _models.ChatCompletionsNamedToolChoice, + ] ] = None, seed: Optional[int] = None, model: Optional[str] = None, @@ -370,11 +428,17 @@ def complete( temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ] = None, stop: Optional[List[str]] = None, tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + Union[ + str, + _models.ChatCompletionsToolChoicePreset, + _models.ChatCompletionsNamedToolChoice, + ] ] = None, seed: Optional[int] = None, model: Optional[str] = None, @@ -393,17 +457,25 @@ def complete( temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ] = None, stop: Optional[List[str]] = None, tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + Union[ + str, + _models.ChatCompletionsToolChoicePreset, + _models.ChatCompletionsNamedToolChoice, + ] ] = None, seed: Optional[int] = None, model: Optional[str] = None, model_extras: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + ) -> Union[ + Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions + ]: # pylint: disable=line-too-long """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or @@ -503,7 +575,9 @@ def complete( *, content_type: str = "application/json", **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + ) -> Union[ + Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions + ]: # pylint: disable=line-too-long """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or @@ -527,7 +601,9 @@ def complete( *, content_type: str = "application/json", **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + ) -> Union[ + Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions + ]: # pylint: disable=line-too-long # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. @@ -549,24 +625,34 @@ def complete( self, body: Union[JSON, IO[bytes]] = _Unset, *, - messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]] = _Unset, + messages: Union[ + List[_models.ChatRequestMessage], List[Dict[str, Any]] + ] = _Unset, stream: Optional[bool] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - response_format: Optional[Union[Literal["text", "json_object"], _models.JsonSchemaFormat]] = None, + response_format: Optional[ + Union[Literal["text", "json_object"], _models.JsonSchemaFormat] + ] = None, stop: Optional[List[str]] = None, tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] + Union[ + str, + _models.ChatCompletionsToolChoicePreset, + _models.ChatCompletionsNamedToolChoice, + ] ] = None, seed: Optional[int] = None, model: Optional[str] = None, model_extras: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: + ) -> Union[ + Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions + ]: # pylint: disable=line-too-long # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. @@ -671,7 +757,9 @@ def complete( _params = kwargs.pop("params", {}) or {} _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) internal_response_format = _get_internal_response_format(response_format) @@ -681,28 +769,52 @@ def complete( body = { "messages": messages, "stream": stream, - "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, - "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, + "frequency_penalty": ( + frequency_penalty + if frequency_penalty is not None + else self._frequency_penalty + ), + "max_tokens": ( + max_tokens if max_tokens is not None else self._max_tokens + ), "model": model if model is not None else self._model, - "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, + "presence_penalty": ( + presence_penalty + if presence_penalty is not None + else self._presence_penalty + ), "response_format": ( - internal_response_format if internal_response_format is not None else self._internal_response_format + internal_response_format + if internal_response_format is not None + else self._internal_response_format ), "seed": seed if seed is not None else self._seed, "stop": stop if stop is not None else self._stop, - "temperature": temperature if temperature is not None else self._temperature, - "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, + "temperature": ( + temperature if temperature is not None else self._temperature + ), + "tool_choice": ( + tool_choice if tool_choice is not None else self._tool_choice + ), "tools": tools if tools is not None else self._tools, "top_p": top_p if top_p is not None else self._top_p, } if model_extras is not None and bool(model_extras): body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access elif self._model_extras is not None and bool(self._model_extras): body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access body = {k: v for k, v in body.items() if v is not None} - elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): + elif ( + isinstance(body, dict) + and "stream" in body + and isinstance(body["stream"], bool) + ): stream = body["stream"] content_type = content_type or "application/json" _content = None @@ -720,13 +832,17 @@ def complete( params=_params, ) path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + "endpoint": self._serialize.url( + "self._config.endpoint", self._config.endpoint, "str", skip_quote=True + ), } _request.url = self._client.format_url(_request.url, **path_format_arguments) _stream = stream or False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) ) response = pipeline_response.http_response @@ -734,13 +850,17 @@ def complete( if response.status_code not in [200]: if _stream: response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) raise HttpResponseError(response=response) if _stream: return _models.StreamingChatCompletions(response) - return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access + return _deserialize( + _models._patch.ChatCompletions, response.json() + ) # pylint: disable=protected-access @distributed_trace def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: @@ -756,7 +876,9 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """ if not self._model_info: try: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + self._model_info = self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init except ResourceNotFoundError as error: error.message = "Model information is not available on this endpoint (`/info` route not supported)." raise error @@ -765,7 +887,11 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: def __str__(self) -> str: # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + return ( + super().__str__() + f"\n{self._model_info}" + if self._model_info + else super().__str__() + ) class EmbeddingsClient(EmbeddingsClientGenerated): @@ -982,24 +1108,38 @@ def embed( _params = kwargs.pop("params", {}) or {} _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) if body is _Unset: if input is _Unset: raise TypeError("missing required argument: input") body = { "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, + "dimensions": ( + dimensions if dimensions is not None else self._dimensions + ), + "encoding_format": ( + encoding_format + if encoding_format is not None + else self._encoding_format + ), + "input_type": ( + input_type if input_type is not None else self._input_type + ), "model": model if model is not None else self._model, } if model_extras is not None and bool(model_extras): body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access elif self._model_extras is not None and bool(self._model_extras): body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access body = {k: v for k, v in body.items() if v is not None} content_type = content_type or "application/json" _content = None @@ -1017,13 +1157,17 @@ def embed( params=_params, ) path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + "endpoint": self._serialize.url( + "self._config.endpoint", self._config.endpoint, "str", skip_quote=True + ), } _request.url = self._client.format_url(_request.url, **path_format_arguments) _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) ) response = pipeline_response.http_response @@ -1031,14 +1175,17 @@ def embed( if response.status_code not in [200]: if _stream: response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) raise HttpResponseError(response=response) if _stream: deserialized = response.iter_bytes() else: deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + _models._patch.EmbeddingsResult, + response.json(), # pylint: disable=protected-access ) return deserialized # type: ignore @@ -1057,7 +1204,9 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """ if not self._model_info: try: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + self._model_info = self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init except ResourceNotFoundError as error: error.message = "Model information is not available on this endpoint (`/info` route not supported)." raise error @@ -1066,7 +1215,11 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: def __str__(self) -> str: # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + return ( + super().__str__() + f"\n{self._model_info}" + if self._model_info + else super().__str__() + ) class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): @@ -1283,24 +1436,38 @@ def embed( _params = kwargs.pop("params", {}) or {} _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) if body is _Unset: if input is _Unset: raise TypeError("missing required argument: input") body = { "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, + "dimensions": ( + dimensions if dimensions is not None else self._dimensions + ), + "encoding_format": ( + encoding_format + if encoding_format is not None + else self._encoding_format + ), + "input_type": ( + input_type if input_type is not None else self._input_type + ), "model": model if model is not None else self._model, } if model_extras is not None and bool(model_extras): body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access elif self._model_extras is not None and bool(self._model_extras): body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access + _extra_parameters = ( + _models._enums.ExtraParameters.PASS_THROUGH + ) # pylint: disable=protected-access body = {k: v for k, v in body.items() if v is not None} content_type = content_type or "application/json" _content = None @@ -1318,13 +1485,17 @@ def embed( params=_params, ) path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), + "endpoint": self._serialize.url( + "self._config.endpoint", self._config.endpoint, "str", skip_quote=True + ), } _request.url = self._client.format_url(_request.url, **path_format_arguments) _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) ) response = pipeline_response.http_response @@ -1332,14 +1503,17 @@ def embed( if response.status_code not in [200]: if _stream: response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) raise HttpResponseError(response=response) if _stream: deserialized = response.iter_bytes() else: deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access + _models._patch.EmbeddingsResult, + response.json(), # pylint: disable=protected-access ) return deserialized # type: ignore @@ -1358,7 +1532,9 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """ if not self._model_info: try: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init + self._model_info = self._get_model_info( + **kwargs + ) # pylint: disable=attribute-defined-outside-init except ResourceNotFoundError as error: error.message = "Model information is not available on this endpoint (`/info` route not supported)." raise error @@ -1367,7 +1543,11 @@ def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: def __str__(self) -> str: # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() + return ( + super().__str__() + f"\n{self._model_info}" + if self._model_info + else super().__str__() + ) __all__: List[str] = [ diff --git a/sdk/ai/azure-ai-inference/tests/test_response_format_conversion.py b/sdk/ai/azure-ai-inference/tests/test_response_format_conversion.py new file mode 100644 index 000000000000..927bff48118e --- /dev/null +++ b/sdk/ai/azure-ai-inference/tests/test_response_format_conversion.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Unit tests for _get_internal_response_format dict handling (issue #44201).""" + +import pytest +from azure.ai.inference._patch import _get_internal_response_format +from azure.ai.inference import models as _models +from azure.ai.inference.models import _models as _internal_models + + +class TestGetInternalResponseFormatDict: + """Tests for OpenAI-style dict response_format support.""" + + def test_dict_text_format(self): + """Dict with type='text' should produce ChatCompletionsResponseFormatText.""" + result = _get_internal_response_format({"type": "text"}) + assert isinstance(result, _internal_models.ChatCompletionsResponseFormatText) + + def test_dict_json_object_format(self): + """Dict with type='json_object' should produce ChatCompletionsResponseFormatJsonObject.""" + result = _get_internal_response_format({"type": "json_object"}) + assert isinstance( + result, _internal_models.ChatCompletionsResponseFormatJsonObject + ) + + def test_dict_json_schema_format(self): + """Dict with type='json_schema' should produce ChatCompletionsResponseFormatJsonSchema.""" + response_format = { + "type": "json_schema", + "json_schema": { + "name": "ContactInfo", + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"}, + }, + "required": ["name", "email"], + "additionalProperties": False, + }, + "strict": True, + }, + } + result = _get_internal_response_format(response_format) + assert isinstance( + result, _internal_models.ChatCompletionsResponseFormatJsonSchema + ) + assert result.json_schema.name == "ContactInfo" + assert result.json_schema.schema["type"] == "object" + assert result.json_schema.strict is True + + def test_dict_json_schema_with_description(self): + """Dict with json_schema including description should preserve it.""" + response_format = { + "type": "json_schema", + "json_schema": { + "name": "PersonInfo", + "description": "Information about a person", + "schema": {"type": "object", "properties": {}}, + }, + } + result = _get_internal_response_format(response_format) + assert isinstance( + result, _internal_models.ChatCompletionsResponseFormatJsonSchema + ) + assert result.json_schema.name == "PersonInfo" + assert result.json_schema.description == "Information about a person" + + def test_dict_unsupported_type_raises(self): + """Dict with unknown type should raise ValueError.""" + with pytest.raises( + ValueError, match="Unsupported `response_format` type in dict" + ): + _get_internal_response_format({"type": "xml"}) + + def test_dict_missing_type_raises(self): + """Dict without type key should raise ValueError.""" + with pytest.raises( + ValueError, match="Unsupported `response_format` type in dict" + ): + _get_internal_response_format({"foo": "bar"}) + + +class TestGetInternalResponseFormatExisting: + """Verify existing behavior is preserved.""" + + def test_string_text(self): + result = _get_internal_response_format("text") + assert isinstance(result, _internal_models.ChatCompletionsResponseFormatText) + + def test_string_json_object(self): + result = _get_internal_response_format("json_object") + assert isinstance( + result, _internal_models.ChatCompletionsResponseFormatJsonObject + ) + + def test_json_schema_format_object(self): + schema = _models.JsonSchemaFormat( + name="TestSchema", + schema={"type": "object", "properties": {}}, + ) + result = _get_internal_response_format(schema) + assert isinstance( + result, _internal_models.ChatCompletionsResponseFormatJsonSchema + ) + assert result.json_schema.name == "TestSchema" + + def test_none_returns_none(self): + result = _get_internal_response_format(None) + assert result is None + + def test_unsupported_type_raises(self): + with pytest.raises(ValueError, match="Unsupported `response_format`"): + _get_internal_response_format(12345)