diff --git a/lib/json_rpc_handler.rb b/lib/json_rpc_handler.rb index d4be6a7..b788e12 100644 --- a/lib/json_rpc_handler.rb +++ b/lib/json_rpc_handler.rb @@ -92,7 +92,7 @@ def process_request(request, id_validation_pattern:, &method_finder) end begin - method = method_finder.call(method_name) + method = method_finder.call(method_name, id) if method.nil? return error_response(id: id, id_validation_pattern: id_validation_pattern, error: { diff --git a/lib/mcp/progress.rb b/lib/mcp/progress.rb index 8843a0d..6762d3c 100644 --- a/lib/mcp/progress.rb +++ b/lib/mcp/progress.rb @@ -2,9 +2,10 @@ module MCP class Progress - def initialize(notification_target:, progress_token:) + def initialize(notification_target:, progress_token:, related_request_id: nil) @notification_target = notification_target @progress_token = progress_token + @related_request_id = related_request_id end def report(progress, total: nil, message: nil) @@ -16,6 +17,7 @@ def report(progress, total: nil, message: nil) progress: progress, total: total, message: message, + related_request_id: @related_request_id, ) end end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 484b29b..d085b8b 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -127,8 +127,8 @@ def initialize( # When `nil`, progress and logging notifications from tool handlers are silently skipped. # @return [Hash, nil] The JSON-RPC response, or `nil` for notifications. def handle(request, session: nil) - JsonRpcHandler.handle(request) do |method| - handle_request(request, method, session: session) + JsonRpcHandler.handle(request) do |method, request_id| + handle_request(request, method, session: session, related_request_id: request_id) end end @@ -140,8 +140,8 @@ def handle(request, session: nil) # When `nil`, progress and logging notifications from tool handlers are silently skipped. # @return [String, nil] The JSON-RPC response as JSON, or `nil` for notifications. def handle_json(request, session: nil) - JsonRpcHandler.handle_json(request) do |method| - handle_request(request, method, session: session) + JsonRpcHandler.handle_json(request) do |method, request_id| + handle_request(request, method, session: session, related_request_id: request_id) end end @@ -220,7 +220,8 @@ def create_sampling_message( stop_sequences: nil, metadata: nil, tools: nil, - tool_choice: nil + tool_choice: nil, + related_request_id: nil ) unless @transport raise "Cannot send sampling request without a transport." @@ -371,7 +372,7 @@ def schema_contains_ref?(schema) end end - def handle_request(request, method, session: nil) + def handle_request(request, method, session: nil, related_request_id: nil) handler = @handlers[method] unless handler instrument_call("unsupported_method") do @@ -399,7 +400,7 @@ def handle_request(request, method, session: nil) when Methods::RESOURCES_TEMPLATES_LIST { resourceTemplates: @handlers[Methods::RESOURCES_TEMPLATES_LIST].call(params) } when Methods::TOOLS_CALL - call_tool(params, session: session) + call_tool(params, session: session, related_request_id: related_request_id) when Methods::COMPLETION_COMPLETE complete(params) when Methods::LOGGING_SET_LEVEL @@ -499,7 +500,7 @@ def list_tools(request) @tools.values.map(&:to_h) end - def call_tool(request, session: nil) + def call_tool(request, session: nil, related_request_id: nil) tool_name = request[:name] tool = tools[tool_name] @@ -531,7 +532,7 @@ def call_tool(request, session: nil) progress_token = request.dig(:_meta, :progressToken) - call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session) + call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session, related_request_id: related_request_id) rescue RequestHandlerError raise rescue => e @@ -611,12 +612,12 @@ def accepts_server_context?(method_object) parameters.any? { |type, name| type == :keyrest || name == :server_context } end - def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil) + def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil, related_request_id: nil) args = arguments&.transform_keys(&:to_sym) || {} if accepts_server_context?(tool.method(:call)) - progress = Progress.new(notification_target: session, progress_token: progress_token) - server_context = ServerContext.new(context, progress: progress, notification_target: session) + progress = Progress.new(notification_target: session, progress_token: progress_token, related_request_id: related_request_id) + server_context = ServerContext.new(context, progress: progress, notification_target: session, related_request_id: related_request_id) tool.call(**args, server_context: server_context).to_h else tool.call(**args).to_h diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb index 31ddc89..688e38c 100644 --- a/lib/mcp/server/transports/streamable_http_transport.rb +++ b/lib/mcp/server/transports/streamable_http_transport.rb @@ -7,6 +7,12 @@ module MCP class Server module Transports class StreamableHTTPTransport < Transport + SSE_HEADERS = { + "Content-Type" => "text/event-stream", + "Cache-Control" => "no-cache", + "Connection" => "keep-alive", + }.freeze + def initialize(server, stateless: false, session_idle_timeout: nil) super(server) # Maps `session_id` to `{ stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`. @@ -56,10 +62,11 @@ def close removed_sessions.each do |session| close_stream_safely(session[:stream]) + close_post_request_streams(session) end end - def send_notification(method, params = nil, session_id: nil) + def send_notification(method, params = nil, session_id: nil, related_request_id: nil) # Stateless mode doesn't support notifications raise "Stateless mode does not support notifications" if @stateless @@ -74,8 +81,10 @@ def send_notification(method, params = nil, session_id: nil) result = @mutex.synchronize do if session_id # Send to specific session - session = @sessions[session_id] - next false unless session && session[:stream] + if (session = @sessions[session_id]) + stream = active_stream(session, related_request_id: related_request_id) + end + next false unless stream if session_expired?(session) cleanup_and_collect_stream(session_id, streams_to_close) @@ -83,14 +92,19 @@ def send_notification(method, params = nil, session_id: nil) end begin - send_to_stream(session[:stream], notification) + send_to_stream(stream, notification) true rescue *STREAM_WRITE_ERRORS => e MCP.configuration.exception_reporter.call( e, { session_id: session_id, error: "Failed to send notification" }, ) - cleanup_and_collect_stream(session_id, streams_to_close) + if related_request_id && session[:post_request_streams]&.key?(related_request_id) + session[:post_request_streams].delete(related_request_id) + streams_to_close << stream + else + cleanup_and_collect_stream(session_id, streams_to_close) + end false end else @@ -99,7 +113,7 @@ def send_notification(method, params = nil, session_id: nil) failed_sessions = [] @sessions.each do |sid, session| - next unless session[:stream] + next unless (stream = session[:stream]) if session_expired?(session) failed_sessions << sid @@ -107,7 +121,7 @@ def send_notification(method, params = nil, session_id: nil) end begin - send_to_stream(session[:stream], notification) + send_to_stream(stream, notification) sent_count += 1 rescue *STREAM_WRITE_ERRORS => e MCP.configuration.exception_reporter.call( @@ -139,7 +153,7 @@ def send_notification(method, params = nil, session_id: nil) # sends the request via SSE stream, then blocks on `queue.pop`. # When the client POSTs a response, `handle_response` matches it by `request_id` # and pushes the result onto the queue, unblocking this thread. - def send_request(method, params = nil, session_id: nil) + def send_request(method, params = nil, session_id: nil, related_request_id: nil) if @stateless raise "Stateless mode does not support server-to-client requests." end @@ -163,12 +177,17 @@ def send_request(method, params = nil, session_id: nil) @pending_responses[request_id] = { queue: queue, session_id: session_id } - if (stream = session[:stream]) + if (stream = active_stream(session, related_request_id: related_request_id)) begin send_to_stream(stream, request) sent = true rescue *STREAM_WRITE_ERRORS - cleanup_session_unsafe(session_id) + if related_request_id && session[:post_request_streams]&.key?(related_request_id) + session[:post_request_streams].delete(related_request_id) + close_stream_safely(stream) + else + cleanup_session_unsafe(session_id) + end end end end @@ -181,7 +200,7 @@ def send_request(method, params = nil, session_id: nil) # The TypeScript and Python SDKs buffer messages and replay on reconnect. # Until then, raise to prevent queue.pop from blocking indefinitely. unless sent - raise "No active SSE stream for #{method} request." + raise "No active stream for #{method} request." end response = queue.pop @@ -229,6 +248,7 @@ def reap_expired_sessions removed_sessions.each do |session| close_stream_safely(session[:stream]) + close_post_request_streams(session) end end @@ -265,7 +285,7 @@ def handle_post(request) handle_response(body, session_id: session_id) else - handle_regular_request(body_string, session_id) + handle_regular_request(body_string, session_id, related_request_id: body[:id]) end end rescue StandardError => e @@ -313,7 +333,10 @@ def cleanup_session(session_id) cleanup_session_unsafe(session_id) end - close_stream_safely(session[:stream]) if session + if session + close_stream_safely(session[:stream]) + close_post_request_streams(session) + end end # Removes a session from `@sessions` and returns it. Does not close the stream. @@ -336,6 +359,7 @@ def cleanup_and_collect_stream(session_id, streams_to_close) return unless (removed = cleanup_session_unsafe(session_id)) streams_to_close << removed[:stream] + removed[:post_request_streams]&.each_value { |stream| streams_to_close << stream } end def close_stream_safely(stream) @@ -344,6 +368,14 @@ def close_stream_safely(stream) # Ignore close-related errors from already closed/broken streams. end + def close_post_request_streams(session) + return unless (post_request_streams = session[:post_request_streams]) + + post_request_streams.each_value do |stream| + close_stream_safely(stream) + end + end + def extract_session_id(request) request.env["HTTP_MCP_SESSION_ID"] end @@ -443,9 +475,8 @@ def handle_accepted [202, {}, []] end - def handle_regular_request(body_string, session_id) + def handle_regular_request(body_string, session_id, related_request_id: nil) server_session = nil - stream = nil unless @stateless if session_id @@ -455,21 +486,72 @@ def handle_regular_request(body_string, session_id) @mutex.synchronize do session = @sessions[session_id] server_session = session[:server_session] if session - stream = session[:stream] if session end end end - response = if server_session - server_session.handle_json(body_string) + if session_id && !@stateless + handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: related_request_id) else - @server.handle_json(body_string) + response = dispatch_handle_json(body_string, server_session) + [200, { "Content-Type" => "application/json" }, [response]] end + end + + # Returns the POST response as an SSE stream so the server can send + # JSON-RPC requests and notifications during request processing. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + def handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: nil) + body = proc do |stream| + @mutex.synchronize do + session = @sessions[session_id] + if session && related_request_id + session[:post_request_streams] ||= {} + session[:post_request_streams][related_request_id] = stream + end + end - if stream - send_response_to_stream(stream, response, session_id) + begin + response = dispatch_handle_json(body_string, server_session) + + send_to_stream(stream, response) if response + ensure + if related_request_id + @mutex.synchronize do + session = @sessions[session_id] + session[:post_request_streams]&.delete(related_request_id) if session + end + end + + begin + stream.close + rescue StandardError + # Ignore close-related errors from already closed/broken streams. + end + end + end + + [200, SSE_HEADERS, body] + end + + # Returns the SSE stream available for server-to-client messages. + # When `related_request_id` is given, returns only the POST response + # stream for that request (no fallback to GET SSE). This prevents + # request-scoped messages from leaking to the wrong stream. + # When `related_request_id` is nil, returns the GET SSE stream. + def active_stream(session, related_request_id: nil) + if related_request_id + session.dig(:post_request_streams, related_request_id) else - [200, { "Content-Type" => "application/json" }, [response]] + session[:stream] + end + end + + def dispatch_handle_json(body_string, server_session) + if server_session + server_session.handle_json(body_string) + else + @server.handle_json(body_string) end end @@ -489,7 +571,13 @@ def validate_and_touch_session(session_id) nil end - close_stream_safely(removed[:stream]) if removed + if removed + close_stream_safely(removed[:stream]) + + removed[:post_request_streams]&.each_value do |stream| + close_stream_safely(stream) + end + end response end @@ -498,19 +586,6 @@ def get_session_stream(session_id) @mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) } end - def send_response_to_stream(stream, response, session_id) - message = JSON.parse(response) - send_to_stream(stream, message) - handle_accepted - rescue *STREAM_WRITE_ERRORS => e - MCP.configuration.exception_reporter.call( - e, - { session_id: session_id, error: "Stream closed during response" }, - ) - cleanup_session(session_id) - [200, { "Content-Type" => "application/json" }, [response]] - end - def session_exists?(session_id) @mutex.synchronize { @sessions.key?(session_id) } end @@ -538,13 +613,7 @@ def session_already_connected_response def setup_sse_stream(session_id) body = create_sse_body(session_id) - headers = { - "Content-Type" => "text/event-stream", - "Cache-Control" => "no-cache", - "Connection" => "keep-alive", - } - - [200, headers, body] + [200, SSE_HEADERS, body] end def create_sse_body(session_id) diff --git a/lib/mcp/server_context.rb b/lib/mcp/server_context.rb index b532555..aadd750 100644 --- a/lib/mcp/server_context.rb +++ b/lib/mcp/server_context.rb @@ -2,10 +2,11 @@ module MCP class ServerContext - def initialize(context, progress:, notification_target:) + def initialize(context, progress:, notification_target:, related_request_id: nil) @context = context @progress = progress @notification_target = notification_target + @related_request_id = related_request_id end # Reports progress for the current tool operation. @@ -26,7 +27,7 @@ def report_progress(progress, total: nil, message: nil) def notify_log_message(data:, level:, logger: nil) return unless @notification_target - @notification_target.notify_log_message(data: data, level: level, logger: logger) + @notification_target.notify_log_message(data: data, level: level, logger: logger, related_request_id: @related_request_id) end # Delegates to the session so the request is scoped to the originating client. @@ -34,9 +35,9 @@ def notify_log_message(data:, level:, logger: nil) # does not support sampling. def create_sampling_message(**kwargs) if @notification_target.respond_to?(:create_sampling_message) - @notification_target.create_sampling_message(**kwargs) + @notification_target.create_sampling_message(**kwargs, related_request_id: @related_request_id) elsif @context.respond_to?(:create_sampling_message) - @context.create_sampling_message(**kwargs) + @context.create_sampling_message(**kwargs, related_request_id: @related_request_id) else raise NoMethodError, "undefined method 'create_sampling_message' for #{self}" end diff --git a/lib/mcp/server_session.rb b/lib/mcp/server_session.rb index 93e823f..2fe8f77 100644 --- a/lib/mcp/server_session.rb +++ b/lib/mcp/server_session.rb @@ -42,13 +42,13 @@ def client_capabilities end # Sends a `sampling/createMessage` request scoped to this session. - def create_sampling_message(**kwargs) + def create_sampling_message(related_request_id: nil, **kwargs) params = @server.build_sampling_params(client_capabilities, **kwargs) - send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params) + send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params, related_request_id: related_request_id) end # Sends a progress notification to this session only. - def notify_progress(progress_token:, progress:, total: nil, message: nil) + def notify_progress(progress_token:, progress:, total: nil, message: nil, related_request_id: nil) params = { "progressToken" => progress_token, "progress" => progress, @@ -56,20 +56,20 @@ def notify_progress(progress_token:, progress:, total: nil, message: nil) "message" => message, }.compact - send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params) + send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params, related_request_id: related_request_id) rescue => e @server.report_exception(e, notification: "progress") end # Sends a log message notification to this session only. - def notify_log_message(data:, level:, logger: nil) + def notify_log_message(data:, level:, logger: nil, related_request_id: nil) effective_logging = @logging_message_notification || @server.logging_message_notification return unless effective_logging&.should_notify?(level) params = { "data" => data, "level" => level } params["logger"] = logger if logger - send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params) + send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params, related_request_id: related_request_id) rescue => e @server.report_exception(e, { notification: "log_message" }) end @@ -82,9 +82,9 @@ def notify_log_message(data:, level:, logger: nil) # TODO: When Ruby 2.7 support is dropped, replace with a direct call: # `@transport.send_notification(method, params, session_id: @session_id)` and # add `**` to `Transport#send_notification` and `StdioTransport#send_notification`. - def send_to_transport(method, params) + def send_to_transport(method, params, related_request_id: nil) if @session_id - @transport.send_notification(method, params, session_id: @session_id) + @transport.send_notification(method, params, session_id: @session_id, related_request_id: related_request_id) else @transport.send_notification(method, params) end @@ -96,9 +96,9 @@ def send_to_transport(method, params) # TODO: When Ruby 2.7 support is dropped, replace with a direct call: # `@transport.send_request(method, params, session_id: @session_id)` and # add `**` to `Transport#send_request` and `StdioTransport#send_request`. - def send_to_transport_request(method, params) + def send_to_transport_request(method, params, related_request_id: nil) if @session_id - @transport.send_request(method, params, session_id: @session_id) + @transport.send_request(method, params, session_id: @session_id, related_request_id: related_request_id) else @transport.send_request(method, params) end diff --git a/test/json_rpc_handler_test.rb b/test/json_rpc_handler_test.rb index 67f47db..169b1fd 100644 --- a/test/json_rpc_handler_test.rb +++ b/test/json_rpc_handler_test.rb @@ -621,7 +621,7 @@ @response = JsonRpcHandler.handle( { jsonrpc: "2.0", id: "user@example.com", method: "add", params: { a: 1, b: 2 } }, id_validation_pattern: custom_pattern, - ) { |method_name| @registry[method_name] } + ) { |method_name, _request_id| @registry[method_name] } assert_rpc_success expected_result: 3 assert_equal "user@example.com", @response[:id] @@ -633,7 +633,7 @@ @response = JsonRpcHandler.handle( { jsonrpc: "2.0", id: "id", method: "add", params: { a: 1, b: 2 } }, id_validation_pattern: nil, - ) { |method_name| @registry[method_name] } + ) { |method_name, _request_id| @registry[method_name] } assert_rpc_success expected_result: 3 assert_equal "", @response[:id] @@ -733,11 +733,11 @@ def register(method_name, &block) end def handle(request) - @response = JsonRpcHandler.handle(request) { |method_name| @registry[method_name] } + @response = JsonRpcHandler.handle(request) { |method_name, _request_id| @registry[method_name] } end def handle_json(request_json) - @response_json = JsonRpcHandler.handle_json(request_json) { |method_name| @registry[method_name] } + @response_json = JsonRpcHandler.handle_json(request_json) { |method_name, _request_id| @registry[method_name] } @response = JSON.parse(@response_json, symbolize_names: true) if @response_json end diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb index 507ac45..7435289 100644 --- a/test/mcp/server/transports/streamable_http_transport_test.rb +++ b/test/mcp/server/transports/streamable_http_transport_test.rb @@ -7,6 +7,31 @@ module MCP class Server module Transports class StreamableHTTPTransportTest < ActiveSupport::TestCase + # A stream that buffers writes and remains readable after close. + class TestStream + def initialize + @buffer = "".dup + @closed = false + end + + def write(data) + raise IOError, "closed stream" if @closed + + @buffer << data + end + + def flush + end + + def close + @closed = true + end + + def string + @buffer + end + end + setup do @server = Server.new( name: "test_server", @@ -45,9 +70,11 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) + assert_equal "text/event-stream", response[1]["Content-Type"] - body = JSON.parse(response[2][0]) + io = StringIO.new + response[2].call(io) + body = JSON.parse(io.string.match(/^data: (.+)$/)[1]) assert_equal "2.0", body["jsonrpc"] assert_equal "123", body["id"] assert_equal({}, body["result"]) @@ -114,8 +141,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert response[2].is_a?(Proc) # The body should be a Proc for streaming end - test "handles POST request when IOError raised" do - # Create and initialize a session + test "handles POST request as SSE even when GET SSE stream is closed" do init_request = create_rack_request( "POST", "/", @@ -125,7 +151,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Connect with SSE + # Connect with SSE then close it io = StringIO.new get_request = create_rack_request( "GET", @@ -134,13 +160,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(io) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - - # Close the stream io.close + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -151,17 +174,12 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, ) - # This should handle IOError and return the original response response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) + assert_equal "text/event-stream", response[1]["Content-Type"] end - test "handles POST request when Errno::EPIPE raised" do - # Create and initialize a session + test "handles POST request as SSE even when GET SSE stream has EPIPE" do init_request = create_rack_request( "POST", "/", @@ -171,10 +189,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Create a pipe to simulate EPIPE condition + # Connect GET SSE with a broken pipe reader, writer = IO.pipe - - # Connect with SSE using the writer end of the pipe get_request = create_rack_request( "GET", "/", @@ -182,13 +198,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(writer) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - - # Close the reader end to break the pipe - this will cause EPIPE on write reader.close + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -199,23 +212,18 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "789" }.to_json, ) - # This should handle Errno::EPIPE and return the original response response = @transport.handle_request(request) - assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) - + assert_equal(200, response[0]) + assert_equal("text/event-stream", response[1]["Content-Type"]) + ensure begin writer.close - rescue + rescue StandardError nil end end - test "handles POST request when Errno::ECONNRESET raised" do - # Create and initialize a session. + test "handles POST request as SSE even when GET SSE stream has ECONNRESET" do init_request = create_rack_request( "POST", "/", @@ -225,12 +233,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Use a mock stream that raises Errno::ECONNRESET on write. + # Connect GET SSE with a mock that raises ECONNRESET mock_stream = Object.new mock_stream.define_singleton_method(:write) { |_data| raise Errno::ECONNRESET } mock_stream.define_singleton_method(:close) {} - - # Connect with SSE using the mock stream. get_request = create_rack_request( "GET", "/", @@ -238,10 +244,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(mock_stream) if response[2].is_a?(Proc) - - # Give the stream time to set up. sleep(0.1) + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -252,13 +257,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "789" }.to_json, ) - # This should handle Errno::ECONNRESET and return the original response. response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up. - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) + assert_equal "text/event-stream", response[1]["Content-Type"] end test "handles GET request with missing session ID" do @@ -579,6 +580,54 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal({}, @transport.instance_variable_get(:@sessions)) end + test "cleanup_session_unsafe closes request_streams" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Simulate multiple request_streams being set on the session. + closed = [] + 2.times do |i| + mock_stream = Object.new + mock_stream.define_singleton_method(:close) { closed << i } + thread = Thread.new {} + thread.join + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] ||= {} + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams][thread] = mock_stream + end + + delete_request = create_rack_request( + "DELETE", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + @transport.handle_request(delete_request) + + assert_equal [0, 1], closed.sort + assert_empty @transport.instance_variable_get(:@sessions) + end + + test "broadcast notification skips sessions without GET SSE stream" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + @transport.handle_request(init_request) + + # No GET SSE stream connected, only request_streams. + # Pass **{} to prevent Ruby 2.7 from converting the Hash to keyword arguments. + result = @transport.send_notification("test/notify", { message: "hello" }, **{}) + + assert_equal 0, result + end + test "sends notification to correct session with multiple active sessions" do # Create first session init_request1 = create_rack_request( @@ -653,8 +702,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase result end - # Handle request from session 1 - @transport.handle_request(request_as_session1) + # Handle request from session 1 (execute SSE proc) + response1 = @transport.handle_request(request_as_session1) + response1[2].call(StringIO.new) if response1[2].is_a?(Proc) # Make a request as session 2 request_as_session2 = create_rack_request( @@ -667,18 +717,17 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "890" }.to_json, ) - # Handle request from session 2 - @transport.handle_request(request_as_session2) + # Handle request from session 2 (execute SSE proc) + response2_post = @transport.handle_request(request_as_session2) + response2_post[2].call(StringIO.new) if response2_post[2].is_a?(Proc) - # Check that each session received one notification + # Broadcast notifications are sent to GET SSE streams (no related_request_id) io1.rewind output1 = io1.read - # Session 1 should have received two notifications (one from each request since we broadcast) assert_equal 2, output1.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count io2.rewind output2 = io2.read - # Session 2 should have received two notifications (one from each request since we broadcast) assert_equal 2, output2.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count end @@ -888,6 +937,85 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_not @transport.instance_variable_get(:@sessions).key?(session_id) end + test "send_notification on broken request_stream removes only that stream, not the session" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + sleep(0.1) + + # Simulate a broken request_stream. + broken_stream = Object.new + broken_stream.define_singleton_method(:write) { |_data| raise Errno::EPIPE } + broken_stream.define_singleton_method(:close) {} + related_id = "req-1" + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] = { related_id => broken_stream } + + result = @transport.send_notification("test", { msg: "hello" }, session_id: session_id, related_request_id: related_id) + + refute result + # Session should still exist. + assert @transport.instance_variable_get(:@sessions).key?(session_id) + # The broken request_stream should be removed. + refute @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams].key?(related_id) + # GET SSE stream should still be intact. + assert @transport.instance_variable_get(:@sessions)[session_id][:stream] + end + + test "active_stream does not fall back to GET SSE when related_request_id is given but request_stream is missing" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + sleep(0.1) + + # Send notification with a related_request_id that has no matching request_stream. + result = @transport.send_notification( + "test/notify", + { message: "should not arrive" }, + session_id: session_id, + related_request_id: "nonexistent-request-id", + ) + + # Should return false because no matching request_stream exists. + refute result + + # Session should still exist (not cleaned up). + assert @transport.instance_variable_get(:@sessions).key?(session_id) + + # GET SSE stream should NOT have received the notification. + io.rewind + refute_includes io.read, "should not arrive" + end + test "send_notification broadcast continues when one session raises Errno::ECONNRESET" do # Create two sessions. init_request1 = create_rack_request( @@ -1334,8 +1462,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_nil(body) end - test "send_response_to_stream returns 202 when message is sent to stream" do - # Create and initialize a session + test "POST request returns SSE response even with GET SSE connected" do init_request = create_rack_request( "POST", "/", @@ -1345,7 +1472,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Connect with SSE + # Connect with GET SSE io = StringIO.new get_request = create_rack_request( "GET", @@ -1354,11 +1481,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(io) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - # Make a regular request that will be routed through send_response_to_stream + # POST request should return SSE, not 202 request = create_rack_request( "POST", "/", @@ -1370,9 +1495,13 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(request) - assert_equal 202, response[0] - assert_empty response[1] - assert_empty response[2] + assert_equal 200, response[0] + assert_equal "text/event-stream", response[1]["Content-Type"] + + post_io = StringIO.new + response[2].call(post_io) + body = JSON.parse(post_io.string.match(/^data: (.+)$/)[1]) + assert_equal "456", body["id"] end test "handle post request with a standard error" do @@ -1436,7 +1565,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase @transport.send_request("sampling/createMessage", { "messages" => [] }, session_id: session_id) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) end test "send_request sends via SSE and waits for response" do @@ -1683,6 +1812,300 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal("SSE session closed while waiting for sampling/createMessage response.", error.message) end + test "send_request sends via POST response stream even with GET SSE connected" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + get_io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + get_response = @transport.handle_request(get_request) + get_response[2].call(get_io) if get_response[2].is_a?(Proc) + sleep(0.1) + + # Set up sampling capability for the session. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .store_client_info(client: { name: "test" }, capabilities: { sampling: {} }) + + # Define a tool that calls create_sampling_message. + sampling_tool = MCP::Tool.define( + name: "sampling_tool", + input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] }, + ) do |prompt:, server_context:| + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: prompt } }], + max_tokens: 100, + ) + MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }]) + end + @server.tools[sampling_tool.name_value] = sampling_tool + + # Send tools/call via POST (GET SSE is connected). + tool_request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "sampling_tool", arguments: { prompt: "Hello" } }, + }.to_json, + ) + + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + response = @transport.handle_request(tool_request) + response[2].call(post_stream) + result_queue.push(:done) + end + + sleep(0.2) + + # Sampling request should be in POST response stream, not GET SSE. + output = post_stream.string + data_lines = output.lines.select { |line| line.start_with?("data: ") } + sampling_request = JSON.parse(data_lines.first.sub("data: ", "")) + assert_equal "sampling/createMessage", sampling_request["method"] + + # GET SSE should NOT have the sampling request. + get_io.rewind + refute_includes get_io.read, "sampling/createMessage" + + # Simulate client sending sampling result via POST. + client_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: sampling_request["id"], + result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } }, + }.to_json, + ) + @transport.handle_request(client_response) + + result_queue.pop + + tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") } + tool_response = JSON.parse(tool_response_lines.last.sub("data: ", "")) + assert_equal "tool-1", tool_response["id"] + assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM" + end + + test "send_request sends via POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Set up sampling capability for the session. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .store_client_info(client: { name: "test" }, capabilities: { sampling: {} }) + + # Define a tool that calls create_sampling_message. + sampling_tool = MCP::Tool.define( + name: "sampling_tool", + input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] }, + ) do |prompt:, server_context:| + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: prompt } }], + max_tokens: 100, + ) + MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }]) + end + @server.tools[sampling_tool.name_value] = sampling_tool + + # Send tools/call via POST (no GET SSE stream). + tool_request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "sampling_tool", arguments: { prompt: "Hello" } }, + }.to_json, + ) + + # Process in background since handle_request blocks until tool completes. + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + response = @transport.handle_request(tool_request) + response[2].call(post_stream) + result_queue.push(:done) + end + + sleep(0.2) # Wait for the tool to start and send sampling request. + + # Read the sampling request from the POST response stream. + output = post_stream.string + data_lines = output.lines.select { |line| line.start_with?("data: ") } + sampling_request = JSON.parse(data_lines.first.sub("data: ", "")) + assert_equal "sampling/createMessage", sampling_request["method"] + + # Simulate client sending sampling result via POST. + client_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: sampling_request["id"], + result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } }, + }.to_json, + ) + @transport.handle_request(client_response) + + result_queue.pop # Wait for tool to complete. + + # Verify the tool result was written to the POST response stream. + tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") } + tool_response = JSON.parse(tool_response_lines.last.sub("data: ", "")) + assert_equal "tool-1", tool_response["id"] + assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM" + end + + test "send_notification uses POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Define a tool that sends a notification during execution. + notification_sent = Queue.new + slow_tool = MCP::Tool.define( + name: "slow_tool", + ) do |server_context:| + server_context.notify_log_message(data: "test log", level: "info") + notification_sent.push(true) + MCP::Tool::Response.new([{ type: "text", text: "done" }]) + end + @server.tools[slow_tool.name_value] = slow_tool + + # Configure logging so notifications are sent. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .configure_logging(MCP::LoggingMessageNotification.new(level: "debug")) + + # Send tools/call via POST (no GET SSE stream). + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "slow_tool", arguments: {} }, + }.to_json, + ) + response = @transport.handle_request(request) + response[2].call(post_stream) + result_queue.push(:done) + end + + notification_sent.pop # Wait for tool to send notification. + result_queue.pop + + # Verify notification was written to the POST response stream. + assert_includes post_stream.string, "notifications/message" + assert_includes post_stream.string, "test log" + end + + test "progress notification uses POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Define a tool that reports progress during execution. + progress_reported = Queue.new + progress_tool = MCP::Tool.define( + name: "progress_tool", + ) do |server_context:| + server_context.report_progress(50, total: 100, message: "halfway") + progress_reported.push(true) + MCP::Tool::Response.new([{ type: "text", text: "done" }]) + end + @server.tools[progress_tool.name_value] = progress_tool + + # Send tools/call via POST (no GET SSE stream) with a progress token. + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "progress_tool", arguments: {}, _meta: { progressToken: "token-1" } }, + }.to_json, + ) + response = @transport.handle_request(request) + response[2].call(post_stream) + result_queue.push(:done) + end + + progress_reported.pop + result_queue.pop + + # Verify progress notification was written to the POST response stream. + assert_includes post_stream.string, "notifications/progress" + assert_includes post_stream.string, "token-1" + end + test "POST notifications/initialized returns 202 with no body" do # Create a session first (optional for notification, but keep consistent with flow) init_request = create_rack_request( @@ -2228,13 +2651,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { name: "log_tool", arguments: {} }, }.to_json, ) - transport.handle_request(tool_request) + tool_response = transport.handle_request(tool_request) + post_io = StringIO.new + tool_response[2].call(post_io) - # Session 1 should receive the log notification. - io1.rewind - assert_includes io1.read, "secret" + # Session 1's POST response stream should contain the log notification. + assert_includes post_io.string, "secret" - # Session 2 should NOT receive the log notification. + # GET SSE streams should NOT receive the log notification. + io1.rewind + refute_includes io1.read, "secret" io2.rewind refute_includes io2.read, "secret" end @@ -2306,13 +2732,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase }, }.to_json, ) - transport.handle_request(tool_request) + tool_response = transport.handle_request(tool_request) + post_io = StringIO.new + tool_response[2].call(post_io) - # Session 1 should receive the progress notification. - io1.rewind - assert_includes io1.read, "halfway" + # Session 1's POST response stream should contain the progress notification. + assert_includes post_io.string, "halfway" - # Session 2 should NOT receive the progress notification. + # GET SSE streams should NOT receive the progress notification. + io1.rewind + refute_includes io1.read, "halfway" io2.rewind refute_includes io2.read, "halfway" end @@ -2406,7 +2835,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { level: "error" }, }.to_json, ) - transport.handle_request(set_level1) + response1 = transport.handle_request(set_level1) + response1[2].call(StringIO.new) # Session 2 sets log level to "debug". set_level2 = create_rack_request( @@ -2420,7 +2850,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { level: "debug" }, }.to_json, ) - transport.handle_request(set_level2) + response2 = transport.handle_request(set_level2) + response2[2].call(StringIO.new) # Session 1 (error level) should not notify for "info", but should for "error". session1_logging = transport.instance_variable_get(:@sessions)[session1][:server_session].logging_message_notification diff --git a/test/mcp/server_context_test.rb b/test/mcp/server_context_test.rb index 605e385..81735d6 100644 --- a/test/mcp/server_context_test.rb +++ b/test/mcp/server_context_test.rb @@ -46,6 +46,7 @@ class ServerContextTest < ActiveSupport::TestCase notification_target.expects(:create_sampling_message).with( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, + related_request_id: nil, ).returns({ role: "assistant", content: { type: "text", text: "Hi" } }) context = mock @@ -67,6 +68,7 @@ class ServerContextTest < ActiveSupport::TestCase context.expects(:create_sampling_message).with( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, + related_request_id: nil, ).returns({ role: "assistant", content: { type: "text", text: "Fallback" } }) progress = Progress.new(notification_target: notification_target, progress_token: nil) diff --git a/test/mcp/server_sampling_test.rb b/test/mcp/server_sampling_test.rb index 57c488d..bf250e4 100644 --- a/test/mcp/server_sampling_test.rb +++ b/test/mcp/server_sampling_test.rb @@ -260,7 +260,7 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error_with_sampling.message) + assert_equal("No active stream for sampling/createMessage request.", error_with_sampling.message) # Session without sampling capability should be rejected. session_without_sampling = ServerSession.new(server: @server, transport: transport, session_id: "s2") @@ -290,7 +290,7 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) end test "session init does not overwrite server global client_capabilities" do @@ -375,7 +375,18 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) + end + + test "Server#create_sampling_message accepts related_request_id without error" do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + related_request_id: "req-1", + ) + + request = @mock_transport.requests.first + assert_equal "sampling/createMessage", request[:method] end test "create_sampling_message omits nil optional params" do