Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,9 @@ def _create_call_tool_coroutine(
async def _call_as_task() -> MCPCallToolResult:
# When task-augmented execution is used, use the read_timeout_seconds parameter
# (which is a timedelta) for the polling timeout.
return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout=read_timeout_seconds)
return await self._call_tool_as_task_and_poll_async(
name, arguments, poll_timeout=read_timeout_seconds, meta=meta
)

return _call_as_task()
else:
Expand Down Expand Up @@ -1100,6 +1102,7 @@ async def _call_tool_as_task_and_poll_async(
arguments: dict[str, Any] | None = None,
ttl: timedelta | None = None,
poll_timeout: timedelta | None = None,
meta: dict[str, Any] | None = None,
) -> MCPCallToolResult:
"""Call a tool using task-augmented execution and poll until completion.

Expand All @@ -1113,6 +1116,7 @@ async def _call_tool_as_task_and_poll_async(
arguments: Optional arguments to pass to the tool.
ttl: Task time-to-live. Uses configured value if not specified.
poll_timeout: Timeout for polling. Uses configured value if not specified.
meta: Optional metadata to pass to the tool call per MCP spec (_meta).

Returns:
MCPCallToolResult: The final tool result after task completion.
Expand All @@ -1133,6 +1137,7 @@ async def _call_tool_as_task_and_poll_async(
name=name,
arguments=arguments,
ttl=ttl_ms,
meta=meta,
)
task_id = create_result.task.taskId
self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id)
Expand Down
74 changes: 74 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,77 @@ async def poll(task_id):
result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={})
assert result["status"] == "success"
assert "Done" in result["content"][0].get("text", "")


class TestTaskMetaForwarding:
"""Tests for meta parameter forwarding in task-augmented execution."""

def _setup_task_tool_with_meta(self, mock_session, tool_name: str) -> MagicMock:
"""Helper to set up a mock task-enabled tool and return the experimental mock."""
mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True))
mock_tool = MCPTool(
name=tool_name,
description="A test tool",
inputSchema={"type": "object"},
execution=ToolExecution(taskSupport="optional"),
)
mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None))
mock_create_result = MagicMock()
mock_create_result.task.taskId = "test-task-id"
mock_session.experimental = MagicMock()
mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result)

async def successful_poll(task_id):
yield MagicMock(status="completed", statusMessage=None)

mock_session.experimental.poll_task = successful_poll
mock_session.experimental.get_task_result = AsyncMock(
return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False)
)

return mock_session.experimental

def test_call_tool_sync_forwards_meta_to_task(self, mock_transport, mock_session):
"""Test that call_tool_sync forwards meta to call_tool_as_task."""
experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool")
meta = {"com.example/request_id": "abc-123"}

with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client:
client.list_tools_sync()
client.call_tool_sync(
tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta
)

experimental.call_tool_as_task.assert_called_once()
call_kwargs = experimental.call_tool_as_task.call_args
assert call_kwargs.kwargs.get("meta") == meta

@pytest.mark.asyncio
async def test_call_tool_async_forwards_meta_to_task(self, mock_transport, mock_session):
"""Test that call_tool_async forwards meta to call_tool_as_task."""
experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool")
meta = {"com.example/trace_id": "xyz-456"}

with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client:
client.list_tools_sync()
await client.call_tool_async(
tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta
)

experimental.call_tool_as_task.assert_called_once()
call_kwargs = experimental.call_tool_as_task.call_args
assert call_kwargs.kwargs.get("meta") == meta

def test_call_tool_sync_forwards_none_meta_to_task(self, mock_transport, mock_session):
"""Test that call_tool_sync forwards None meta to call_tool_as_task when not provided."""
experimental = self._setup_task_tool_with_meta(mock_session, "no_meta_tool")

with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client:
client.list_tools_sync()
client.call_tool_sync(
tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"}
)

experimental.call_tool_as_task.assert_called_once()
call_kwargs = experimental.call_tool_as_task.call_args
assert call_kwargs.kwargs.get("meta") is None
Loading