diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 062df5491c..875e0fa5ca 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -29,6 +29,7 @@ from ._tools import FunctionTool from ._types import ( + ChatOptions, Content, Message, ) @@ -146,7 +147,10 @@ def _parse_tool_result_from_mcp( Converts each content item in the MCP result to its appropriate Content form. Text items become ``Content(type="text")`` and media - items (images, audio) are preserved as rich Content. + items (images, audio) are preserved as rich Content. When no content + items are produced but ``structuredContent`` is present on the result, + the structured payload is serialised as JSON text so it is still + surfaced to the caller. Args: mcp_type: The MCP CallToolResult object to convert. @@ -192,6 +196,18 @@ def _parse_tool_result_from_mcp( case _: result.append(Content.from_text(str(item))) + if not result and mcp_type.structuredContent is not None: + try: + text = json.dumps(mcp_type.structuredContent) + except (TypeError, ValueError): + text = str(mcp_type.structuredContent) + result.append( + Content.from_text( + text, + additional_properties={"structured_content": mcp_type.structuredContent}, + ) + ) + if not result: result.append(Content.from_text("null")) return result @@ -649,6 +665,11 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: error_msg = f"Failed to connect to MCP server: {ex}" raise ToolException(error_msg, inner_exception=ex) from ex try: + sampling_capabilities = None + if self.client is not None: + sampling_capabilities = types.SamplingCapability( + tools=types.SamplingToolsCapability(), + ) session = await self._exit_stack.enter_async_context( ClientSession( read_stream=transport[0], @@ -659,6 +680,7 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: message_handler=self.message_handler, logging_callback=self.logging_callback, sampling_callback=self.sampling_callback, + sampling_capabilities=sampling_capabilities, ) ) except Exception as ex: @@ -732,14 +754,35 @@ async def sampling_callback( messages: list[Message] = [] for msg in params.messages: messages.append(_parse_message_from_mcp(msg)) + + options: ChatOptions[None] = {} + if params.systemPrompt is not None: + options["instructions"] = params.systemPrompt + if params.tools is not None: + options["tools"] = [ + FunctionTool( + name=tool.name, + description=tool.description or "", + input_model=tool.inputSchema, + ) + for tool in params.tools + ] + if params.toolChoice is not None and params.toolChoice.mode is not None: + options["tool_choice"] = params.toolChoice.mode + + if params.temperature is not None: + options["temperature"] = params.temperature + options["max_tokens"] = params.maxTokens + if params.stopSequences is not None: + options["stop"] = params.stopSequences + try: response = await self.client.get_response( messages, - temperature=params.temperature, - max_tokens=params.maxTokens, - stop=params.stopSequences, + options=options or None, ) except Exception as ex: + logger.debug("Sampling callback error: %s", ex, exc_info=True) return types.ErrorData( code=types.INTERNAL_ERROR, message=f"Failed to get chat message content: {ex}", diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index fa9e1130f0..be9ec718cc 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. # type: ignore[reportPrivateUsage] +import json import logging import os from contextlib import _AsyncGeneratorContextManager # type: ignore @@ -246,6 +247,74 @@ def test_parse_tool_result_from_mcp_blob_plain_base64(): assert "dGVzdCBkYXRh" in result[0].uri +def test_parse_tool_result_from_mcp_structured_content(): + """Test that structuredContent is ignored when content items are present.""" + structured = {"name": "Pasta Carbonara", "ingredients": ["pasta", "eggs", "cheese"]} + mcp_result = types.CallToolResult( + content=[types.TextContent(type="text", text="Here is a recipe")], + structuredContent=structured, + ) + result = _parse_tool_result_from_mcp(mcp_result) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "Here is a recipe" + + +def test_parse_tool_result_from_mcp_structured_content_only(): + """Test that structuredContent alone (no regular content) produces a text Content.""" + structured = {"temperature": 72, "unit": "F"} + mcp_result = types.CallToolResult( + content=[], + structuredContent=structured, + ) + result = _parse_tool_result_from_mcp(mcp_result) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert json.loads(result[0].text) == structured + assert result[0].additional_properties is not None + assert result[0].additional_properties["structured_content"] == structured + + +def test_parse_tool_result_from_mcp_structured_content_nested(): + """Test that structuredContent with nested complex types serialises correctly.""" + structured = { + "recipe": { + "name": "Pasta Carbonara", + "ingredients": [{"item": "pasta", "amount": 200}, {"item": "eggs", "amount": 3}], + "metadata": {"origin": "Italy", "tags": ["quick", "classic"]}, + } + } + mcp_result = types.CallToolResult( + content=[], + structuredContent=structured, + ) + result = _parse_tool_result_from_mcp(mcp_result) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert json.loads(result[0].text) == structured + + +def test_parse_tool_result_from_mcp_structured_content_non_serializable(): + """Test that structuredContent with non-JSON-serializable values falls back to str().""" + from datetime import datetime + + structured = {"timestamp": datetime(2025, 1, 1)} + mcp_result = types.CallToolResult( + content=[], + structuredContent=structured, + ) + result = _parse_tool_result_from_mcp(mcp_result) + assert len(result) == 1 + assert result[0].text == str(structured) + assert result[0].additional_properties["structured_content"] == structured + + def test_mcp_content_types_to_ai_content_text(): """Test conversion of MCP text content to AI content.""" mcp_content = types.TextContent(type="text", text="Sample text") @@ -1562,12 +1631,15 @@ async def test_mcp_tool_sampling_callback_chat_client_exception(): params.temperature = None params.maxTokens = None params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None result = await tool.sampling_callback(Mock(), params) assert isinstance(result, types.ErrorData) assert result.code == types.INTERNAL_ERROR - assert "Failed to get chat message content: Chat client error" in result.message + assert "Failed to get chat message content" in result.message async def test_mcp_tool_sampling_callback_no_valid_content(): @@ -1605,6 +1677,9 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): params.temperature = None params.maxTokens = None params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None result = await tool.sampling_callback(Mock(), params) @@ -1613,6 +1688,361 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): assert "Failed to get right content types from the response." in result.message +async def test_mcp_tool_sampling_callback_forwards_system_prompt(): + """Test sampling callback passes systemPrompt as instructions in options.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = None + params.stopSequences = None + params.systemPrompt = "You are a helpful assistant" + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options.get("instructions") == "You are a helpful assistant" + + +async def test_mcp_tool_sampling_callback_forwards_tools(): + """Test sampling callback converts MCP tools to FunctionTools and passes them in options.""" + from agent_framework import FunctionTool, Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + mcp_tool = types.Tool( + name="get_weather", + description="Get weather", + inputSchema={"type": "object", "properties": {"city": {"type": "string"}}}, + ) + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = None + params.stopSequences = None + params.systemPrompt = None + params.tools = [mcp_tool] + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + tools = options.get("tools") + assert tools is not None + assert len(tools) == 1 + assert isinstance(tools[0], FunctionTool) + assert tools[0].name == "get_weather" + assert tools[0].description == "Get weather" + + +async def test_mcp_tool_sampling_callback_forwards_tool_choice(): + """Test sampling callback passes toolChoice mode in options.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = None + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = types.ToolChoice(mode="required") + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options.get("tool_choice") == "required" + + +async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt(): + """Test sampling callback forwards empty string systemPrompt as instructions.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = None + params.stopSequences = None + params.systemPrompt = "" + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options.get("instructions") == "" + + +async def test_mcp_tool_sampling_callback_forwards_empty_tools_list(): + """Test sampling callback forwards empty tools list in options.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = None + params.stopSequences = None + params.systemPrompt = None + params.tools = [] + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options.get("tools") == [] + + +async def test_mcp_tool_sampling_callback_forwards_generation_params_in_options(): + """Test sampling callback passes temperature, max_tokens, and stop in options.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = 0.7 + params.maxTokens = 256 + params.stopSequences = ["STOP"] + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options.get("temperature") == 0.7 + assert options.get("max_tokens") == 256 + assert options.get("stop") == ["STOP"] + # These should not be passed as top-level kwargs + assert "temperature" not in call_kwargs.kwargs + assert "max_tokens" not in call_kwargs.kwargs + assert "stop" not in call_kwargs.kwargs + + +async def test_mcp_tool_sampling_callback_omits_temperature_when_none(): + """Test sampling callback does not set temperature in options when it is None.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = 100 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert "temperature" not in options + assert options.get("max_tokens") == 100 + assert "stop" not in options + + +async def test_mcp_tool_sampling_callback_always_passes_max_tokens(): + """Test sampling callback always sets max_tokens in options since maxTokens is a required int field.""" + from agent_framework import Message + + tool = MCPStdioTool(name="test_tool", command="python") + + mock_chat_client = AsyncMock() + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text("response")])] + mock_response.model_id = "test-model" + mock_chat_client.get_response.return_value = mock_response + + tool.client = mock_chat_client + + params = Mock() + mock_message = Mock() + mock_message.role = "user" + mock_message.content = Mock() + mock_message.content.text = "Test question" + params.messages = [mock_message] + params.temperature = None + params.maxTokens = 200 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + call_kwargs = mock_chat_client.get_response.call_args + options = call_kwargs.kwargs.get("options") or {} + assert options["max_tokens"] == 200 + + +async def test_connect_sampling_capabilities_with_client(): + """Test connect() passes sampling_capabilities to ClientSession when client is set.""" + tool = MCPStdioTool(name="test", command="test-command", load_tools=False, load_prompts=False) + tool.client = Mock() + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("agent_framework._mcp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session._request_id = 1 + + session_cm = AsyncMock() + session_cm.__aenter__ = AsyncMock(return_value=mock_session) + session_cm.__aexit__ = AsyncMock(return_value=None) + mock_session_class.return_value = session_cm + + await tool.connect() + + call_kwargs = mock_session_class.call_args.kwargs + sampling_caps = call_kwargs.get("sampling_capabilities") + assert sampling_caps is not None + assert isinstance(sampling_caps, types.SamplingCapability) + assert sampling_caps.tools is not None + assert isinstance(sampling_caps.tools, types.SamplingToolsCapability) + + +async def test_connect_no_sampling_capabilities_without_client(): + """Test connect() does not pass sampling_capabilities when no client is set.""" + tool = MCPStdioTool(name="test", command="test-command", load_tools=False, load_prompts=False) + # No client set + + mock_transport = (Mock(), Mock()) + mock_context_manager = Mock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + tool.get_mcp_client = Mock(return_value=mock_context_manager) + + with patch("agent_framework._mcp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session._request_id = 1 + + session_cm = AsyncMock() + session_cm.__aenter__ = AsyncMock(return_value=mock_session) + session_cm.__aexit__ = AsyncMock(return_value=None) + mock_session_class.return_value = session_cm + + await tool.connect() + + call_kwargs = mock_session_class.call_args.kwargs + assert call_kwargs.get("sampling_capabilities") is None + + # Test error handling in connect() method