Skip to content

Commit 92c693b

Browse files
authored
fix: cancel in-flight handlers when transport closes in server.run() (#2306)
1 parent 883d893 commit 92c693b

File tree

3 files changed

+199
-16
lines changed

3 files changed

+199
-16
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,23 @@ async def run(
387387
await stack.enter_async_context(task_support.run())
388388

389389
async with anyio.create_task_group() as tg:
390-
async for message in session.incoming_messages:
391-
logger.debug("Received message: %s", message)
392-
393-
tg.start_soon(
394-
self._handle_message,
395-
message,
396-
session,
397-
lifespan_context,
398-
raise_exceptions,
399-
)
390+
try:
391+
async for message in session.incoming_messages:
392+
logger.debug("Received message: %s", message)
393+
394+
tg.start_soon(
395+
self._handle_message,
396+
message,
397+
session,
398+
lifespan_context,
399+
raise_exceptions,
400+
)
401+
finally:
402+
# Transport closed: cancel in-flight handlers. Without this the
403+
# TG join waits for them, and when they eventually try to
404+
# respond they hit a closed write stream (the session's
405+
# _receive_loop closed it when the read stream ended).
406+
tg.cancel_scope.cancel()
400407

401408
async def _handle_message(
402409
self,
@@ -470,16 +477,32 @@ async def _handle_request(
470477
except MCPError as err:
471478
response = err.error
472479
except anyio.get_cancelled_exc_class():
473-
logger.info("Request %s cancelled - duplicate response suppressed", message.request_id)
474-
return
480+
if message.cancelled:
481+
# Client sent CancelledNotification; responder.cancel() already
482+
# sent an error response, so skip the duplicate.
483+
logger.info("Request %s cancelled - duplicate response suppressed", message.request_id)
484+
return
485+
# Transport-close cancellation from the TG in run(); re-raise so the
486+
# TG swallows its own cancellation.
487+
raise
475488
except Exception as err:
476489
if raise_exceptions: # pragma: no cover
477490
raise err
478491
response = types.ErrorData(code=0, message=str(err))
492+
else: # pragma: no cover
493+
response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found")
479494

495+
try:
480496
await message.respond(response)
481-
else: # pragma: no cover
482-
await message.respond(types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found"))
497+
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
498+
# Transport closed between handler unblocking and respond. Happens
499+
# when _receive_loop's finally wakes a handler blocked on
500+
# send_request: the handler runs to respond() before run()'s TG
501+
# cancel fires, but after the write stream closed. Closed if our
502+
# end closed (_receive_loop's async-with exit); Broken if the peer
503+
# end closed first (streamable_http terminate()).
504+
logger.debug("Response for %s dropped - transport closed", message.request_id)
505+
return
483506

484507
logger.debug("Response sent")
485508

src/mcp/shared/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __exit__(
105105
) -> None:
106106
"""Exit the context manager, performing cleanup and notifying completion."""
107107
try:
108-
if self._completed: # pragma: no branch
108+
if self._completed:
109109
self._on_complete(self)
110110
finally:
111111
self._entered = False
@@ -418,7 +418,9 @@ async def _receive_loop(self) -> None:
418418
finally:
419419
# after the read stream is closed, we need to send errors
420420
# to any pending requests
421-
for id, stream in self._response_streams.items():
421+
# Snapshot: stream.send() wakes the waiter, whose finally pops
422+
# from _response_streams before the next __next__() call.
423+
for id, stream in list(self._response_streams.items()):
422424
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
423425
try:
424426
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))

tests/server/test_cancel_handling.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@
66
from mcp import Client
77
from mcp.server import Server, ServerRequestContext
88
from mcp.shared.exceptions import MCPError
9+
from mcp.shared.message import SessionMessage
910
from mcp.types import (
11+
LATEST_PROTOCOL_VERSION,
1012
CallToolRequest,
1113
CallToolRequestParams,
1214
CallToolResult,
1315
CancelledNotification,
1416
CancelledNotificationParams,
17+
ClientCapabilities,
18+
Implementation,
19+
InitializeRequestParams,
20+
JSONRPCNotification,
21+
JSONRPCRequest,
1522
ListToolsResult,
1623
PaginatedRequestParams,
1724
TextContent,
@@ -90,3 +97,154 @@ async def first_request():
9097
assert isinstance(content, TextContent)
9198
assert content.text == "Call number: 2"
9299
assert call_count == 2
100+
101+
102+
@pytest.mark.anyio
103+
async def test_server_cancels_in_flight_handlers_on_transport_close():
104+
"""When the transport closes mid-request, server.run() must cancel in-flight
105+
handlers rather than join on them.
106+
107+
Without the cancel, the task group waits for the handler, which then tries
108+
to respond through a write stream that _receive_loop already closed,
109+
raising ClosedResourceError and crashing server.run() with exit code 1.
110+
111+
This drives server.run() with raw memory streams because InMemoryTransport
112+
wraps it in its own finally-cancel (_memory.py) which masks the bug.
113+
"""
114+
handler_started = anyio.Event()
115+
handler_cancelled = anyio.Event()
116+
server_run_returned = anyio.Event()
117+
118+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
119+
handler_started.set()
120+
try:
121+
await anyio.sleep_forever()
122+
finally:
123+
handler_cancelled.set()
124+
# unreachable: sleep_forever only exits via cancellation
125+
raise AssertionError # pragma: no cover
126+
127+
server = Server("test", on_call_tool=handle_call_tool)
128+
129+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
130+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
131+
132+
async def run_server():
133+
await server.run(server_read, server_write, server.create_initialization_options())
134+
server_run_returned.set()
135+
136+
init_req = JSONRPCRequest(
137+
jsonrpc="2.0",
138+
id=1,
139+
method="initialize",
140+
params=InitializeRequestParams(
141+
protocol_version=LATEST_PROTOCOL_VERSION,
142+
capabilities=ClientCapabilities(),
143+
client_info=Implementation(name="test", version="1.0"),
144+
).model_dump(by_alias=True, mode="json", exclude_none=True),
145+
)
146+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
147+
call_req = JSONRPCRequest(
148+
jsonrpc="2.0",
149+
id=2,
150+
method="tools/call",
151+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
152+
)
153+
154+
with anyio.fail_after(5):
155+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
156+
tg.start_soon(run_server)
157+
158+
await to_server.send(SessionMessage(init_req))
159+
await from_server.receive() # init response
160+
await to_server.send(SessionMessage(initialized))
161+
await to_server.send(SessionMessage(call_req))
162+
163+
await handler_started.wait()
164+
165+
# Close the server's input stream — this is what stdin EOF does.
166+
# server.run()'s incoming_messages loop ends, finally-cancel fires,
167+
# handler gets CancelledError, server.run() returns.
168+
await to_server.aclose()
169+
170+
await server_run_returned.wait()
171+
172+
assert handler_cancelled.is_set()
173+
174+
175+
@pytest.mark.anyio
176+
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
177+
"""When the transport closes while handlers are blocked on server→client
178+
requests (sampling, roots, elicitation), server.run() must still exit cleanly.
179+
180+
Two bugs covered:
181+
1. _receive_loop's finally iterates _response_streams with await checkpoints
182+
inside; the woken handler's send_request finally pops from that dict
183+
before the next __next__() — RuntimeError: dictionary changed size.
184+
2. The woken handler's MCPError is caught in _handle_request, which falls
185+
through to respond() against a write stream _receive_loop already closed.
186+
"""
187+
handlers_started = 0
188+
both_started = anyio.Event()
189+
server_run_returned = anyio.Event()
190+
191+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
192+
nonlocal handlers_started
193+
handlers_started += 1
194+
if handlers_started == 2:
195+
both_started.set()
196+
# Blocks on send_request waiting for a client response that never comes.
197+
# _receive_loop's finally will wake this with CONNECTION_CLOSED.
198+
await ctx.session.list_roots()
199+
raise AssertionError # pragma: no cover
200+
201+
server = Server("test", on_call_tool=handle_call_tool)
202+
203+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
204+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
205+
206+
async def run_server():
207+
await server.run(server_read, server_write, server.create_initialization_options())
208+
server_run_returned.set()
209+
210+
init_req = JSONRPCRequest(
211+
jsonrpc="2.0",
212+
id=1,
213+
method="initialize",
214+
params=InitializeRequestParams(
215+
protocol_version=LATEST_PROTOCOL_VERSION,
216+
capabilities=ClientCapabilities(),
217+
client_info=Implementation(name="test", version="1.0"),
218+
).model_dump(by_alias=True, mode="json", exclude_none=True),
219+
)
220+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
221+
222+
with anyio.fail_after(5):
223+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
224+
tg.start_soon(run_server)
225+
226+
await to_server.send(SessionMessage(init_req))
227+
await from_server.receive() # init response
228+
await to_server.send(SessionMessage(initialized))
229+
230+
# Two tool calls → two handlers → two _response_streams entries.
231+
for rid in (2, 3):
232+
call_req = JSONRPCRequest(
233+
jsonrpc="2.0",
234+
id=rid,
235+
method="tools/call",
236+
params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"),
237+
)
238+
await to_server.send(SessionMessage(call_req))
239+
240+
await both_started.wait()
241+
# Drain the two roots/list requests so send_request's _write_stream.send()
242+
# completes and both handlers are parked at response_stream_reader.receive().
243+
await from_server.receive()
244+
await from_server.receive()
245+
246+
await to_server.aclose()
247+
248+
# Without the fixes: RuntimeError (dict mutation) or ClosedResourceError
249+
# (respond after write-stream close) escapes run_server and this hangs.
250+
await server_run_returned.wait()

0 commit comments

Comments
 (0)