Skip to content

Commit c47e32b

Browse files
VoidChecksumVoidChecksum
authored andcommitted
fix: propagate transport errors in ClientSession default message handler
The default message_handler silently swallowed transport-level errors (like SSE read timeouts), causing tool invocation tasks to hang indefinitely. Raise exceptions in the default handler so errors propagate to waiting callers. Fixes #1401 Github-Issue: #1401
1 parent e1fd62e commit c47e32b

File tree

2 files changed

+175
-7
lines changed

2 files changed

+175
-7
lines changed

src/mcp/client/session.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,16 @@ async def __call__(self, params: types.LoggingMessageNotificationParams) -> None
5050
class MessageHandlerFnT(Protocol):
5151
async def __call__(
5252
self,
53-
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
53+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
5454
) -> None: ... # pragma: no branch
5555

5656

5757
async def _default_message_handler(
58-
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
58+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
5959
) -> None:
60+
if isinstance(message, Exception):
61+
logger.exception("Transport error received", exc_info=message)
62+
raise message
6063
await anyio.lowlevel.checkpoint()
6164

6265

@@ -152,7 +155,10 @@ async def initialize(self) -> types.InitializeResult:
152155
else None
153156
)
154157
elicitation = (
155-
types.ElicitationCapability(form=types.FormElicitationCapability(), url=types.UrlElicitationCapability())
158+
types.ElicitationCapability(
159+
form=types.FormElicitationCapability(),
160+
url=types.UrlElicitationCapability(),
161+
)
156162
if self._elicitation_callback is not _default_elicitation_callback
157163
else None
158164
)
@@ -459,7 +465,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
459465

460466
async def _handle_incoming(
461467
self,
462-
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
468+
req: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
463469
) -> None:
464470
"""Handle incoming messages by forwarding to the message handler."""
465471
await self._message_handler(req)

tests/client/test_session.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from __future__ import annotations
22

3+
import unittest.mock
4+
35
import anyio
46
import pytest
57

68
from mcp import types
7-
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
9+
from mcp.client.session import (
10+
DEFAULT_CLIENT_INFO,
11+
ClientSession,
12+
_default_message_handler,
13+
)
814
from mcp.shared._context import RequestContext
15+
from mcp.shared.exceptions import MCPError
916
from mcp.shared.message import SessionMessage
1017
from mcp.shared.session import RequestResponder
1118
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -78,7 +85,7 @@ async def mock_server():
7885

7986
# Create a message handler to catch exceptions
8087
async def message_handler( # pragma: no cover
81-
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
88+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
8289
) -> None:
8390
if isinstance(message, Exception):
8491
raise message
@@ -658,7 +665,10 @@ async def mock_server():
658665
assert "_meta" in jsonrpc_request.params
659666
assert jsonrpc_request.params["_meta"] == meta
660667

661-
result = CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False)
668+
result = CallToolResult(
669+
content=[TextContent(type="text", text="Called successfully")],
670+
is_error=False,
671+
)
662672

663673
# Send the tools/call result
664674
await server_to_client_send.send(
@@ -706,3 +716,155 @@ async def mock_server():
706716
await session.initialize()
707717

708718
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
719+
720+
721+
@pytest.mark.anyio
722+
async def test_default_message_handler_raises_on_exception():
723+
"""The default message handler must re-raise exceptions so they propagate."""
724+
error = RuntimeError("transport error")
725+
with pytest.raises(RuntimeError, match="transport error"):
726+
await _default_message_handler(error)
727+
728+
729+
@pytest.mark.anyio
730+
async def test_default_message_handler_checkpoints_on_non_exception():
731+
"""The default message handler yields control for non-exception messages."""
732+
checkpointed = False
733+
original_checkpoint = anyio.lowlevel.checkpoint
734+
735+
async def mock_checkpoint():
736+
nonlocal checkpointed
737+
checkpointed = True
738+
await original_checkpoint()
739+
740+
notification = types.ProgressNotification(
741+
params=types.ProgressNotificationParams(
742+
progress_token="tok",
743+
progress=0.5,
744+
)
745+
)
746+
with unittest.mock.patch("anyio.lowlevel.checkpoint", mock_checkpoint):
747+
await _default_message_handler(notification)
748+
749+
assert checkpointed
750+
751+
752+
@pytest.mark.anyio
753+
async def test_transport_error_propagates_to_waiting_send_request():
754+
"""A transport-level exception sent on the read stream must unblock pending
755+
send_request callers with a CONNECTION_CLOSED error rather than hanging."""
756+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
757+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
758+
759+
async def mock_server():
760+
# Consume the initialize request
761+
await client_to_server_receive.receive()
762+
763+
init_result = InitializeResult(
764+
protocol_version=LATEST_PROTOCOL_VERSION,
765+
capabilities=ServerCapabilities(),
766+
server_info=Implementation(name="mock-server", version="0.1.0"),
767+
)
768+
await server_to_client_send.send(
769+
SessionMessage(
770+
JSONRPCResponse(
771+
jsonrpc="2.0",
772+
id=0,
773+
result=init_result.model_dump(by_alias=True, mode="json", exclude_none=True),
774+
)
775+
)
776+
)
777+
# Consume initialized notification
778+
await client_to_server_receive.receive()
779+
780+
# Consume the tools/call request — but don't reply; inject a transport
781+
# error instead so the waiting send_request is unblocked.
782+
await client_to_server_receive.receive()
783+
await server_to_client_send.send(ConnectionError("SSE read timeout"))
784+
785+
async with (
786+
ClientSession(server_to_client_receive, client_to_server_send) as session,
787+
anyio.create_task_group() as tg,
788+
client_to_server_send,
789+
client_to_server_receive,
790+
server_to_client_send,
791+
server_to_client_receive,
792+
):
793+
tg.start_soon(mock_server)
794+
await session.initialize()
795+
796+
with pytest.raises(MCPError):
797+
with anyio.fail_after(5):
798+
await session.call_tool("any_tool")
799+
800+
801+
@pytest.mark.anyio
802+
async def test_custom_message_handler_not_affected_by_default_behavior():
803+
"""A user-supplied message_handler that silently ignores exceptions must not
804+
be overridden by the new default behavior."""
805+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
806+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
807+
808+
exceptions_received: list[Exception] = []
809+
810+
async def silent_handler(
811+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
812+
) -> None:
813+
"""Custom handler that records exceptions but does NOT raise them."""
814+
if isinstance(message, Exception):
815+
exceptions_received.append(message)
816+
817+
handler_saw_error = anyio.Event()
818+
819+
async def recording_handler(
820+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
821+
) -> None:
822+
if isinstance(message, Exception):
823+
exceptions_received.append(message)
824+
handler_saw_error.set()
825+
826+
async def mock_server():
827+
# Consume initialize request
828+
await client_to_server_receive.receive()
829+
830+
init_result = InitializeResult(
831+
protocol_version=LATEST_PROTOCOL_VERSION,
832+
capabilities=ServerCapabilities(),
833+
server_info=Implementation(name="mock-server", version="0.1.0"),
834+
)
835+
await server_to_client_send.send(
836+
SessionMessage(
837+
JSONRPCResponse(
838+
jsonrpc="2.0",
839+
id=0,
840+
result=init_result.model_dump(by_alias=True, mode="json", exclude_none=True),
841+
)
842+
)
843+
)
844+
# Consume initialized notification
845+
await client_to_server_receive.receive()
846+
847+
# Inject a transport error without replying to any pending request
848+
await server_to_client_send.send(ValueError("custom handler test"))
849+
850+
async with (
851+
ClientSession(
852+
server_to_client_receive,
853+
client_to_server_send,
854+
message_handler=recording_handler,
855+
) as session,
856+
anyio.create_task_group() as tg,
857+
client_to_server_send,
858+
client_to_server_receive,
859+
server_to_client_send,
860+
server_to_client_receive,
861+
):
862+
tg.start_soon(mock_server)
863+
await session.initialize()
864+
865+
with anyio.fail_after(5):
866+
await handler_saw_error.wait()
867+
868+
assert len(exceptions_received) == 1
869+
assert isinstance(exceptions_received[0], ValueError)
870+
assert str(exceptions_received[0]) == "custom handler test"

0 commit comments

Comments
 (0)