Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
protocol_version: str | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
Expand All @@ -133,6 +134,7 @@ def __init__(
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._initialize_result: types.InitializeResult | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
self._protocol_version = protocol_version or types.LATEST_PROTOCOL_VERSION

# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
Expand Down Expand Up @@ -168,7 +170,7 @@ async def initialize(self) -> types.InitializeResult:
result = await self.send_request(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocol_version=types.LATEST_PROTOCOL_VERSION,
protocol_version=self._protocol_version,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
Expand Down
119 changes: 119 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,125 @@ async def mock_server():
assert result.protocol_version == LATEST_PROTOCOL_VERSION


@pytest.mark.anyio
async def test_client_session_custom_protocol_version():
"""Test that custom protocol_version is sent during initialization.

This allows connecting to servers that require a specific protocol version,
such as Snowflake's managed MCP server which requires "2025-06-18".
See: https://github.com/modelcontextprotocol/python-sdk/issues/2307
"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

custom_protocol_version = "2025-06-18"
received_protocol_version = None

async def mock_server():
nonlocal received_protocol_version

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_protocol_version = request.params.protocol_version

result = InitializeResult(
protocol_version=custom_protocol_version,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
protocol_version=custom_protocol_version,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
result = await session.initialize()

# Assert that the custom protocol version was sent and received
assert received_protocol_version == custom_protocol_version
assert result.protocol_version == custom_protocol_version


@pytest.mark.anyio
async def test_client_session_default_protocol_version():
"""Test that LATEST_PROTOCOL_VERSION is used when protocol_version is not specified."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)

received_protocol_version = None

async def mock_server():
nonlocal received_protocol_version

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request, JSONRPCRequest)
request = client_request_adapter.validate_python(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request, InitializeRequest)
received_protocol_version = request.params.protocol_version

result = InitializeResult(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
server_info=Implementation(name="mock-server", version="0.1.0"),
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Assert that the default (latest) protocol version was sent
assert received_protocol_version == LATEST_PROTOCOL_VERSION


@pytest.mark.anyio
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):
Expand Down
Loading