diff --git a/lib/json_rpc_handler.rb b/lib/json_rpc_handler.rb
index d4be6a7..b788e12 100644
--- a/lib/json_rpc_handler.rb
+++ b/lib/json_rpc_handler.rb
@@ -92,7 +92,7 @@ def process_request(request, id_validation_pattern:, &method_finder)
end
begin
- method = method_finder.call(method_name)
+ method = method_finder.call(method_name, id)
if method.nil?
return error_response(id: id, id_validation_pattern: id_validation_pattern, error: {
diff --git a/lib/mcp/progress.rb b/lib/mcp/progress.rb
index 8843a0d..6762d3c 100644
--- a/lib/mcp/progress.rb
+++ b/lib/mcp/progress.rb
@@ -2,9 +2,10 @@
module MCP
class Progress
- def initialize(notification_target:, progress_token:)
+ def initialize(notification_target:, progress_token:, related_request_id: nil)
@notification_target = notification_target
@progress_token = progress_token
+ @related_request_id = related_request_id
end
def report(progress, total: nil, message: nil)
@@ -16,6 +17,7 @@ def report(progress, total: nil, message: nil)
progress: progress,
total: total,
message: message,
+ related_request_id: @related_request_id,
)
end
end
diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb
index 484b29b..d085b8b 100644
--- a/lib/mcp/server.rb
+++ b/lib/mcp/server.rb
@@ -127,8 +127,8 @@ def initialize(
# When `nil`, progress and logging notifications from tool handlers are silently skipped.
# @return [Hash, nil] The JSON-RPC response, or `nil` for notifications.
def handle(request, session: nil)
- JsonRpcHandler.handle(request) do |method|
- handle_request(request, method, session: session)
+ JsonRpcHandler.handle(request) do |method, request_id|
+ handle_request(request, method, session: session, related_request_id: request_id)
end
end
@@ -140,8 +140,8 @@ def handle(request, session: nil)
# When `nil`, progress and logging notifications from tool handlers are silently skipped.
# @return [String, nil] The JSON-RPC response as JSON, or `nil` for notifications.
def handle_json(request, session: nil)
- JsonRpcHandler.handle_json(request) do |method|
- handle_request(request, method, session: session)
+ JsonRpcHandler.handle_json(request) do |method, request_id|
+ handle_request(request, method, session: session, related_request_id: request_id)
end
end
@@ -220,7 +220,8 @@ def create_sampling_message(
stop_sequences: nil,
metadata: nil,
tools: nil,
- tool_choice: nil
+ tool_choice: nil,
+ related_request_id: nil
)
unless @transport
raise "Cannot send sampling request without a transport."
@@ -371,7 +372,7 @@ def schema_contains_ref?(schema)
end
end
- def handle_request(request, method, session: nil)
+ def handle_request(request, method, session: nil, related_request_id: nil)
handler = @handlers[method]
unless handler
instrument_call("unsupported_method") do
@@ -399,7 +400,7 @@ def handle_request(request, method, session: nil)
when Methods::RESOURCES_TEMPLATES_LIST
{ resourceTemplates: @handlers[Methods::RESOURCES_TEMPLATES_LIST].call(params) }
when Methods::TOOLS_CALL
- call_tool(params, session: session)
+ call_tool(params, session: session, related_request_id: related_request_id)
when Methods::COMPLETION_COMPLETE
complete(params)
when Methods::LOGGING_SET_LEVEL
@@ -499,7 +500,7 @@ def list_tools(request)
@tools.values.map(&:to_h)
end
- def call_tool(request, session: nil)
+ def call_tool(request, session: nil, related_request_id: nil)
tool_name = request[:name]
tool = tools[tool_name]
@@ -531,7 +532,7 @@ def call_tool(request, session: nil)
progress_token = request.dig(:_meta, :progressToken)
- call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session)
+ call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session, related_request_id: related_request_id)
rescue RequestHandlerError
raise
rescue => e
@@ -611,12 +612,12 @@ def accepts_server_context?(method_object)
parameters.any? { |type, name| type == :keyrest || name == :server_context }
end
- def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil)
+ def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil, related_request_id: nil)
args = arguments&.transform_keys(&:to_sym) || {}
if accepts_server_context?(tool.method(:call))
- progress = Progress.new(notification_target: session, progress_token: progress_token)
- server_context = ServerContext.new(context, progress: progress, notification_target: session)
+ progress = Progress.new(notification_target: session, progress_token: progress_token, related_request_id: related_request_id)
+ server_context = ServerContext.new(context, progress: progress, notification_target: session, related_request_id: related_request_id)
tool.call(**args, server_context: server_context).to_h
else
tool.call(**args).to_h
diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb
index 31ddc89..688e38c 100644
--- a/lib/mcp/server/transports/streamable_http_transport.rb
+++ b/lib/mcp/server/transports/streamable_http_transport.rb
@@ -7,6 +7,12 @@ module MCP
class Server
module Transports
class StreamableHTTPTransport < Transport
+ SSE_HEADERS = {
+ "Content-Type" => "text/event-stream",
+ "Cache-Control" => "no-cache",
+ "Connection" => "keep-alive",
+ }.freeze
+
def initialize(server, stateless: false, session_idle_timeout: nil)
super(server)
# Maps `session_id` to `{ stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`.
@@ -56,10 +62,11 @@ def close
removed_sessions.each do |session|
close_stream_safely(session[:stream])
+ close_post_request_streams(session)
end
end
- def send_notification(method, params = nil, session_id: nil)
+ def send_notification(method, params = nil, session_id: nil, related_request_id: nil)
# Stateless mode doesn't support notifications
raise "Stateless mode does not support notifications" if @stateless
@@ -74,8 +81,10 @@ def send_notification(method, params = nil, session_id: nil)
result = @mutex.synchronize do
if session_id
# Send to specific session
- session = @sessions[session_id]
- next false unless session && session[:stream]
+ if (session = @sessions[session_id])
+ stream = active_stream(session, related_request_id: related_request_id)
+ end
+ next false unless stream
if session_expired?(session)
cleanup_and_collect_stream(session_id, streams_to_close)
@@ -83,14 +92,19 @@ def send_notification(method, params = nil, session_id: nil)
end
begin
- send_to_stream(session[:stream], notification)
+ send_to_stream(stream, notification)
true
rescue *STREAM_WRITE_ERRORS => e
MCP.configuration.exception_reporter.call(
e,
{ session_id: session_id, error: "Failed to send notification" },
)
- cleanup_and_collect_stream(session_id, streams_to_close)
+ if related_request_id && session[:post_request_streams]&.key?(related_request_id)
+ session[:post_request_streams].delete(related_request_id)
+ streams_to_close << stream
+ else
+ cleanup_and_collect_stream(session_id, streams_to_close)
+ end
false
end
else
@@ -99,7 +113,7 @@ def send_notification(method, params = nil, session_id: nil)
failed_sessions = []
@sessions.each do |sid, session|
- next unless session[:stream]
+ next unless (stream = session[:stream])
if session_expired?(session)
failed_sessions << sid
@@ -107,7 +121,7 @@ def send_notification(method, params = nil, session_id: nil)
end
begin
- send_to_stream(session[:stream], notification)
+ send_to_stream(stream, notification)
sent_count += 1
rescue *STREAM_WRITE_ERRORS => e
MCP.configuration.exception_reporter.call(
@@ -139,7 +153,7 @@ def send_notification(method, params = nil, session_id: nil)
# sends the request via SSE stream, then blocks on `queue.pop`.
# When the client POSTs a response, `handle_response` matches it by `request_id`
# and pushes the result onto the queue, unblocking this thread.
- def send_request(method, params = nil, session_id: nil)
+ def send_request(method, params = nil, session_id: nil, related_request_id: nil)
if @stateless
raise "Stateless mode does not support server-to-client requests."
end
@@ -163,12 +177,17 @@ def send_request(method, params = nil, session_id: nil)
@pending_responses[request_id] = { queue: queue, session_id: session_id }
- if (stream = session[:stream])
+ if (stream = active_stream(session, related_request_id: related_request_id))
begin
send_to_stream(stream, request)
sent = true
rescue *STREAM_WRITE_ERRORS
- cleanup_session_unsafe(session_id)
+ if related_request_id && session[:post_request_streams]&.key?(related_request_id)
+ session[:post_request_streams].delete(related_request_id)
+ close_stream_safely(stream)
+ else
+ cleanup_session_unsafe(session_id)
+ end
end
end
end
@@ -181,7 +200,7 @@ def send_request(method, params = nil, session_id: nil)
# The TypeScript and Python SDKs buffer messages and replay on reconnect.
# Until then, raise to prevent queue.pop from blocking indefinitely.
unless sent
- raise "No active SSE stream for #{method} request."
+ raise "No active stream for #{method} request."
end
response = queue.pop
@@ -229,6 +248,7 @@ def reap_expired_sessions
removed_sessions.each do |session|
close_stream_safely(session[:stream])
+ close_post_request_streams(session)
end
end
@@ -265,7 +285,7 @@ def handle_post(request)
handle_response(body, session_id: session_id)
else
- handle_regular_request(body_string, session_id)
+ handle_regular_request(body_string, session_id, related_request_id: body[:id])
end
end
rescue StandardError => e
@@ -313,7 +333,10 @@ def cleanup_session(session_id)
cleanup_session_unsafe(session_id)
end
- close_stream_safely(session[:stream]) if session
+ if session
+ close_stream_safely(session[:stream])
+ close_post_request_streams(session)
+ end
end
# Removes a session from `@sessions` and returns it. Does not close the stream.
@@ -336,6 +359,7 @@ def cleanup_and_collect_stream(session_id, streams_to_close)
return unless (removed = cleanup_session_unsafe(session_id))
streams_to_close << removed[:stream]
+ removed[:post_request_streams]&.each_value { |stream| streams_to_close << stream }
end
def close_stream_safely(stream)
@@ -344,6 +368,14 @@ def close_stream_safely(stream)
# Ignore close-related errors from already closed/broken streams.
end
+ def close_post_request_streams(session)
+ return unless (post_request_streams = session[:post_request_streams])
+
+ post_request_streams.each_value do |stream|
+ close_stream_safely(stream)
+ end
+ end
+
def extract_session_id(request)
request.env["HTTP_MCP_SESSION_ID"]
end
@@ -443,9 +475,8 @@ def handle_accepted
[202, {}, []]
end
- def handle_regular_request(body_string, session_id)
+ def handle_regular_request(body_string, session_id, related_request_id: nil)
server_session = nil
- stream = nil
unless @stateless
if session_id
@@ -455,21 +486,72 @@ def handle_regular_request(body_string, session_id)
@mutex.synchronize do
session = @sessions[session_id]
server_session = session[:server_session] if session
- stream = session[:stream] if session
end
end
end
- response = if server_session
- server_session.handle_json(body_string)
+ if session_id && !@stateless
+ handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: related_request_id)
else
- @server.handle_json(body_string)
+ response = dispatch_handle_json(body_string, server_session)
+ [200, { "Content-Type" => "application/json" }, [response]]
end
+ end
+
+ # Returns the POST response as an SSE stream so the server can send
+ # JSON-RPC requests and notifications during request processing.
+ # https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server
+ def handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: nil)
+ body = proc do |stream|
+ @mutex.synchronize do
+ session = @sessions[session_id]
+ if session && related_request_id
+ session[:post_request_streams] ||= {}
+ session[:post_request_streams][related_request_id] = stream
+ end
+ end
- if stream
- send_response_to_stream(stream, response, session_id)
+ begin
+ response = dispatch_handle_json(body_string, server_session)
+
+ send_to_stream(stream, response) if response
+ ensure
+ if related_request_id
+ @mutex.synchronize do
+ session = @sessions[session_id]
+ session[:post_request_streams]&.delete(related_request_id) if session
+ end
+ end
+
+ begin
+ stream.close
+ rescue StandardError
+ # Ignore close-related errors from already closed/broken streams.
+ end
+ end
+ end
+
+ [200, SSE_HEADERS, body]
+ end
+
+ # Returns the SSE stream available for server-to-client messages.
+ # When `related_request_id` is given, returns only the POST response
+ # stream for that request (no fallback to GET SSE). This prevents
+ # request-scoped messages from leaking to the wrong stream.
+ # When `related_request_id` is nil, returns the GET SSE stream.
+ def active_stream(session, related_request_id: nil)
+ if related_request_id
+ session.dig(:post_request_streams, related_request_id)
else
- [200, { "Content-Type" => "application/json" }, [response]]
+ session[:stream]
+ end
+ end
+
+ def dispatch_handle_json(body_string, server_session)
+ if server_session
+ server_session.handle_json(body_string)
+ else
+ @server.handle_json(body_string)
end
end
@@ -489,7 +571,13 @@ def validate_and_touch_session(session_id)
nil
end
- close_stream_safely(removed[:stream]) if removed
+ if removed
+ close_stream_safely(removed[:stream])
+
+ removed[:post_request_streams]&.each_value do |stream|
+ close_stream_safely(stream)
+ end
+ end
response
end
@@ -498,19 +586,6 @@ def get_session_stream(session_id)
@mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) }
end
- def send_response_to_stream(stream, response, session_id)
- message = JSON.parse(response)
- send_to_stream(stream, message)
- handle_accepted
- rescue *STREAM_WRITE_ERRORS => e
- MCP.configuration.exception_reporter.call(
- e,
- { session_id: session_id, error: "Stream closed during response" },
- )
- cleanup_session(session_id)
- [200, { "Content-Type" => "application/json" }, [response]]
- end
-
def session_exists?(session_id)
@mutex.synchronize { @sessions.key?(session_id) }
end
@@ -538,13 +613,7 @@ def session_already_connected_response
def setup_sse_stream(session_id)
body = create_sse_body(session_id)
- headers = {
- "Content-Type" => "text/event-stream",
- "Cache-Control" => "no-cache",
- "Connection" => "keep-alive",
- }
-
- [200, headers, body]
+ [200, SSE_HEADERS, body]
end
def create_sse_body(session_id)
diff --git a/lib/mcp/server_context.rb b/lib/mcp/server_context.rb
index b532555..aadd750 100644
--- a/lib/mcp/server_context.rb
+++ b/lib/mcp/server_context.rb
@@ -2,10 +2,11 @@
module MCP
class ServerContext
- def initialize(context, progress:, notification_target:)
+ def initialize(context, progress:, notification_target:, related_request_id: nil)
@context = context
@progress = progress
@notification_target = notification_target
+ @related_request_id = related_request_id
end
# Reports progress for the current tool operation.
@@ -26,7 +27,7 @@ def report_progress(progress, total: nil, message: nil)
def notify_log_message(data:, level:, logger: nil)
return unless @notification_target
- @notification_target.notify_log_message(data: data, level: level, logger: logger)
+ @notification_target.notify_log_message(data: data, level: level, logger: logger, related_request_id: @related_request_id)
end
# Delegates to the session so the request is scoped to the originating client.
@@ -34,9 +35,9 @@ def notify_log_message(data:, level:, logger: nil)
# does not support sampling.
def create_sampling_message(**kwargs)
if @notification_target.respond_to?(:create_sampling_message)
- @notification_target.create_sampling_message(**kwargs)
+ @notification_target.create_sampling_message(**kwargs, related_request_id: @related_request_id)
elsif @context.respond_to?(:create_sampling_message)
- @context.create_sampling_message(**kwargs)
+ @context.create_sampling_message(**kwargs, related_request_id: @related_request_id)
else
raise NoMethodError, "undefined method 'create_sampling_message' for #{self}"
end
diff --git a/lib/mcp/server_session.rb b/lib/mcp/server_session.rb
index 93e823f..2fe8f77 100644
--- a/lib/mcp/server_session.rb
+++ b/lib/mcp/server_session.rb
@@ -42,13 +42,13 @@ def client_capabilities
end
# Sends a `sampling/createMessage` request scoped to this session.
- def create_sampling_message(**kwargs)
+ def create_sampling_message(related_request_id: nil, **kwargs)
params = @server.build_sampling_params(client_capabilities, **kwargs)
- send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params)
+ send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params, related_request_id: related_request_id)
end
# Sends a progress notification to this session only.
- def notify_progress(progress_token:, progress:, total: nil, message: nil)
+ def notify_progress(progress_token:, progress:, total: nil, message: nil, related_request_id: nil)
params = {
"progressToken" => progress_token,
"progress" => progress,
@@ -56,20 +56,20 @@ def notify_progress(progress_token:, progress:, total: nil, message: nil)
"message" => message,
}.compact
- send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params)
+ send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params, related_request_id: related_request_id)
rescue => e
@server.report_exception(e, notification: "progress")
end
# Sends a log message notification to this session only.
- def notify_log_message(data:, level:, logger: nil)
+ def notify_log_message(data:, level:, logger: nil, related_request_id: nil)
effective_logging = @logging_message_notification || @server.logging_message_notification
return unless effective_logging&.should_notify?(level)
params = { "data" => data, "level" => level }
params["logger"] = logger if logger
- send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params)
+ send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params, related_request_id: related_request_id)
rescue => e
@server.report_exception(e, { notification: "log_message" })
end
@@ -82,9 +82,9 @@ def notify_log_message(data:, level:, logger: nil)
# TODO: When Ruby 2.7 support is dropped, replace with a direct call:
# `@transport.send_notification(method, params, session_id: @session_id)` and
# add `**` to `Transport#send_notification` and `StdioTransport#send_notification`.
- def send_to_transport(method, params)
+ def send_to_transport(method, params, related_request_id: nil)
if @session_id
- @transport.send_notification(method, params, session_id: @session_id)
+ @transport.send_notification(method, params, session_id: @session_id, related_request_id: related_request_id)
else
@transport.send_notification(method, params)
end
@@ -96,9 +96,9 @@ def send_to_transport(method, params)
# TODO: When Ruby 2.7 support is dropped, replace with a direct call:
# `@transport.send_request(method, params, session_id: @session_id)` and
# add `**` to `Transport#send_request` and `StdioTransport#send_request`.
- def send_to_transport_request(method, params)
+ def send_to_transport_request(method, params, related_request_id: nil)
if @session_id
- @transport.send_request(method, params, session_id: @session_id)
+ @transport.send_request(method, params, session_id: @session_id, related_request_id: related_request_id)
else
@transport.send_request(method, params)
end
diff --git a/test/json_rpc_handler_test.rb b/test/json_rpc_handler_test.rb
index 67f47db..169b1fd 100644
--- a/test/json_rpc_handler_test.rb
+++ b/test/json_rpc_handler_test.rb
@@ -621,7 +621,7 @@
@response = JsonRpcHandler.handle(
{ jsonrpc: "2.0", id: "user@example.com", method: "add", params: { a: 1, b: 2 } },
id_validation_pattern: custom_pattern,
- ) { |method_name| @registry[method_name] }
+ ) { |method_name, _request_id| @registry[method_name] }
assert_rpc_success expected_result: 3
assert_equal "user@example.com", @response[:id]
@@ -633,7 +633,7 @@
@response = JsonRpcHandler.handle(
{ jsonrpc: "2.0", id: "id", method: "add", params: { a: 1, b: 2 } },
id_validation_pattern: nil,
- ) { |method_name| @registry[method_name] }
+ ) { |method_name, _request_id| @registry[method_name] }
assert_rpc_success expected_result: 3
assert_equal "", @response[:id]
@@ -733,11 +733,11 @@ def register(method_name, &block)
end
def handle(request)
- @response = JsonRpcHandler.handle(request) { |method_name| @registry[method_name] }
+ @response = JsonRpcHandler.handle(request) { |method_name, _request_id| @registry[method_name] }
end
def handle_json(request_json)
- @response_json = JsonRpcHandler.handle_json(request_json) { |method_name| @registry[method_name] }
+ @response_json = JsonRpcHandler.handle_json(request_json) { |method_name, _request_id| @registry[method_name] }
@response = JSON.parse(@response_json, symbolize_names: true) if @response_json
end
diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb
index 507ac45..7435289 100644
--- a/test/mcp/server/transports/streamable_http_transport_test.rb
+++ b/test/mcp/server/transports/streamable_http_transport_test.rb
@@ -7,6 +7,31 @@ module MCP
class Server
module Transports
class StreamableHTTPTransportTest < ActiveSupport::TestCase
+ # A stream that buffers writes and remains readable after close.
+ class TestStream
+ def initialize
+ @buffer = "".dup
+ @closed = false
+ end
+
+ def write(data)
+ raise IOError, "closed stream" if @closed
+
+ @buffer << data
+ end
+
+ def flush
+ end
+
+ def close
+ @closed = true
+ end
+
+ def string
+ @buffer
+ end
+ end
+
setup do
@server = Server.new(
name: "test_server",
@@ -45,9 +70,11 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
response = @transport.handle_request(request)
assert_equal 200, response[0]
- assert_equal({ "Content-Type" => "application/json" }, response[1])
+ assert_equal "text/event-stream", response[1]["Content-Type"]
- body = JSON.parse(response[2][0])
+ io = StringIO.new
+ response[2].call(io)
+ body = JSON.parse(io.string.match(/^data: (.+)$/)[1])
assert_equal "2.0", body["jsonrpc"]
assert_equal "123", body["id"]
assert_equal({}, body["result"])
@@ -114,8 +141,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert response[2].is_a?(Proc) # The body should be a Proc for streaming
end
- test "handles POST request when IOError raised" do
- # Create and initialize a session
+ test "handles POST request as SSE even when GET SSE stream is closed" do
init_request = create_rack_request(
"POST",
"/",
@@ -125,7 +151,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
init_response = @transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]
- # Connect with SSE
+ # Connect with SSE then close it
io = StringIO.new
get_request = create_rack_request(
"GET",
@@ -134,13 +160,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
)
response = @transport.handle_request(get_request)
response[2].call(io) if response[2].is_a?(Proc)
-
- # Give the stream time to set up
sleep(0.1)
-
- # Close the stream
io.close
+ # POST request should still return SSE response via POST response stream
request = create_rack_request(
"POST",
"/",
@@ -151,17 +174,12 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
{ jsonrpc: "2.0", method: "ping", id: "456" }.to_json,
)
- # This should handle IOError and return the original response
response = @transport.handle_request(request)
assert_equal 200, response[0]
- assert_equal({ "Content-Type" => "application/json" }, response[1])
-
- # Verify session was cleaned up
- assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
+ assert_equal "text/event-stream", response[1]["Content-Type"]
end
- test "handles POST request when Errno::EPIPE raised" do
- # Create and initialize a session
+ test "handles POST request as SSE even when GET SSE stream has EPIPE" do
init_request = create_rack_request(
"POST",
"/",
@@ -171,10 +189,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
init_response = @transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]
- # Create a pipe to simulate EPIPE condition
+ # Connect GET SSE with a broken pipe
reader, writer = IO.pipe
-
- # Connect with SSE using the writer end of the pipe
get_request = create_rack_request(
"GET",
"/",
@@ -182,13 +198,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
)
response = @transport.handle_request(get_request)
response[2].call(writer) if response[2].is_a?(Proc)
-
- # Give the stream time to set up
sleep(0.1)
-
- # Close the reader end to break the pipe - this will cause EPIPE on write
reader.close
+ # POST request should still return SSE response via POST response stream
request = create_rack_request(
"POST",
"/",
@@ -199,23 +212,18 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
{ jsonrpc: "2.0", method: "ping", id: "789" }.to_json,
)
- # This should handle Errno::EPIPE and return the original response
response = @transport.handle_request(request)
- assert_equal 200, response[0]
- assert_equal({ "Content-Type" => "application/json" }, response[1])
-
- # Verify session was cleaned up
- assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
-
+ assert_equal(200, response[0])
+ assert_equal("text/event-stream", response[1]["Content-Type"])
+ ensure
begin
writer.close
- rescue
+ rescue StandardError
nil
end
end
- test "handles POST request when Errno::ECONNRESET raised" do
- # Create and initialize a session.
+ test "handles POST request as SSE even when GET SSE stream has ECONNRESET" do
init_request = create_rack_request(
"POST",
"/",
@@ -225,12 +233,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
init_response = @transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]
- # Use a mock stream that raises Errno::ECONNRESET on write.
+ # Connect GET SSE with a mock that raises ECONNRESET
mock_stream = Object.new
mock_stream.define_singleton_method(:write) { |_data| raise Errno::ECONNRESET }
mock_stream.define_singleton_method(:close) {}
-
- # Connect with SSE using the mock stream.
get_request = create_rack_request(
"GET",
"/",
@@ -238,10 +244,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
)
response = @transport.handle_request(get_request)
response[2].call(mock_stream) if response[2].is_a?(Proc)
-
- # Give the stream time to set up.
sleep(0.1)
+ # POST request should still return SSE response via POST response stream
request = create_rack_request(
"POST",
"/",
@@ -252,13 +257,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
{ jsonrpc: "2.0", method: "ping", id: "789" }.to_json,
)
- # This should handle Errno::ECONNRESET and return the original response.
response = @transport.handle_request(request)
assert_equal 200, response[0]
- assert_equal({ "Content-Type" => "application/json" }, response[1])
-
- # Verify session was cleaned up.
- assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
+ assert_equal "text/event-stream", response[1]["Content-Type"]
end
test "handles GET request with missing session ID" do
@@ -579,6 +580,54 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert_equal({}, @transport.instance_variable_get(:@sessions))
end
+ test "cleanup_session_unsafe closes request_streams" do
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Simulate multiple request_streams being set on the session.
+ closed = []
+ 2.times do |i|
+ mock_stream = Object.new
+ mock_stream.define_singleton_method(:close) { closed << i }
+ thread = Thread.new {}
+ thread.join
+ @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] ||= {}
+ @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams][thread] = mock_stream
+ end
+
+ delete_request = create_rack_request(
+ "DELETE",
+ "/",
+ { "HTTP_MCP_SESSION_ID" => session_id },
+ )
+ @transport.handle_request(delete_request)
+
+ assert_equal [0, 1], closed.sort
+ assert_empty @transport.instance_variable_get(:@sessions)
+ end
+
+ test "broadcast notification skips sessions without GET SSE stream" do
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ @transport.handle_request(init_request)
+
+ # No GET SSE stream connected, only request_streams.
+ # Pass **{} to prevent Ruby 2.7 from converting the Hash to keyword arguments.
+ result = @transport.send_notification("test/notify", { message: "hello" }, **{})
+
+ assert_equal 0, result
+ end
+
test "sends notification to correct session with multiple active sessions" do
# Create first session
init_request1 = create_rack_request(
@@ -653,8 +702,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
result
end
- # Handle request from session 1
- @transport.handle_request(request_as_session1)
+ # Handle request from session 1 (execute SSE proc)
+ response1 = @transport.handle_request(request_as_session1)
+ response1[2].call(StringIO.new) if response1[2].is_a?(Proc)
# Make a request as session 2
request_as_session2 = create_rack_request(
@@ -667,18 +717,17 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
{ jsonrpc: "2.0", method: "ping", id: "890" }.to_json,
)
- # Handle request from session 2
- @transport.handle_request(request_as_session2)
+ # Handle request from session 2 (execute SSE proc)
+ response2_post = @transport.handle_request(request_as_session2)
+ response2_post[2].call(StringIO.new) if response2_post[2].is_a?(Proc)
- # Check that each session received one notification
+ # Broadcast notifications are sent to GET SSE streams (no related_request_id)
io1.rewind
output1 = io1.read
- # Session 1 should have received two notifications (one from each request since we broadcast)
assert_equal 2, output1.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count
io2.rewind
output2 = io2.read
- # Session 2 should have received two notifications (one from each request since we broadcast)
assert_equal 2, output2.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count
end
@@ -888,6 +937,85 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
end
+ test "send_notification on broken request_stream removes only that stream, not the session" do
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Connect GET SSE.
+ io = StringIO.new
+ get_request = create_rack_request(
+ "GET",
+ "/",
+ { "HTTP_MCP_SESSION_ID" => session_id },
+ )
+ response = @transport.handle_request(get_request)
+ response[2].call(io) if response[2].is_a?(Proc)
+ sleep(0.1)
+
+ # Simulate a broken request_stream.
+ broken_stream = Object.new
+ broken_stream.define_singleton_method(:write) { |_data| raise Errno::EPIPE }
+ broken_stream.define_singleton_method(:close) {}
+ related_id = "req-1"
+ @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] = { related_id => broken_stream }
+
+ result = @transport.send_notification("test", { msg: "hello" }, session_id: session_id, related_request_id: related_id)
+
+ refute result
+ # Session should still exist.
+ assert @transport.instance_variable_get(:@sessions).key?(session_id)
+ # The broken request_stream should be removed.
+ refute @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams].key?(related_id)
+ # GET SSE stream should still be intact.
+ assert @transport.instance_variable_get(:@sessions)[session_id][:stream]
+ end
+
+ test "active_stream does not fall back to GET SSE when related_request_id is given but request_stream is missing" do
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Connect GET SSE.
+ io = StringIO.new
+ get_request = create_rack_request(
+ "GET",
+ "/",
+ { "HTTP_MCP_SESSION_ID" => session_id },
+ )
+ response = @transport.handle_request(get_request)
+ response[2].call(io) if response[2].is_a?(Proc)
+ sleep(0.1)
+
+ # Send notification with a related_request_id that has no matching request_stream.
+ result = @transport.send_notification(
+ "test/notify",
+ { message: "should not arrive" },
+ session_id: session_id,
+ related_request_id: "nonexistent-request-id",
+ )
+
+ # Should return false because no matching request_stream exists.
+ refute result
+
+ # Session should still exist (not cleaned up).
+ assert @transport.instance_variable_get(:@sessions).key?(session_id)
+
+ # GET SSE stream should NOT have received the notification.
+ io.rewind
+ refute_includes io.read, "should not arrive"
+ end
+
test "send_notification broadcast continues when one session raises Errno::ECONNRESET" do
# Create two sessions.
init_request1 = create_rack_request(
@@ -1334,8 +1462,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert_nil(body)
end
- test "send_response_to_stream returns 202 when message is sent to stream" do
- # Create and initialize a session
+ test "POST request returns SSE response even with GET SSE connected" do
init_request = create_rack_request(
"POST",
"/",
@@ -1345,7 +1472,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
init_response = @transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]
- # Connect with SSE
+ # Connect with GET SSE
io = StringIO.new
get_request = create_rack_request(
"GET",
@@ -1354,11 +1481,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
)
response = @transport.handle_request(get_request)
response[2].call(io) if response[2].is_a?(Proc)
-
- # Give the stream time to set up
sleep(0.1)
- # Make a regular request that will be routed through send_response_to_stream
+ # POST request should return SSE, not 202
request = create_rack_request(
"POST",
"/",
@@ -1370,9 +1495,13 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
)
response = @transport.handle_request(request)
- assert_equal 202, response[0]
- assert_empty response[1]
- assert_empty response[2]
+ assert_equal 200, response[0]
+ assert_equal "text/event-stream", response[1]["Content-Type"]
+
+ post_io = StringIO.new
+ response[2].call(post_io)
+ body = JSON.parse(post_io.string.match(/^data: (.+)$/)[1])
+ assert_equal "456", body["id"]
end
test "handle post request with a standard error" do
@@ -1436,7 +1565,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
@transport.send_request("sampling/createMessage", { "messages" => [] }, session_id: session_id)
end
- assert_equal("No active SSE stream for sampling/createMessage request.", error.message)
+ assert_equal("No active stream for sampling/createMessage request.", error.message)
end
test "send_request sends via SSE and waits for response" do
@@ -1683,6 +1812,300 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert_equal("SSE session closed while waiting for sampling/createMessage response.", error.message)
end
+ test "send_request sends via POST response stream even with GET SSE connected" do
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Connect GET SSE.
+ get_io = StringIO.new
+ get_request = create_rack_request(
+ "GET",
+ "/",
+ { "HTTP_MCP_SESSION_ID" => session_id },
+ )
+ get_response = @transport.handle_request(get_request)
+ get_response[2].call(get_io) if get_response[2].is_a?(Proc)
+ sleep(0.1)
+
+ # Set up sampling capability for the session.
+ @transport.instance_variable_get(:@sessions)[session_id][:server_session]
+ .store_client_info(client: { name: "test" }, capabilities: { sampling: {} })
+
+ # Define a tool that calls create_sampling_message.
+ sampling_tool = MCP::Tool.define(
+ name: "sampling_tool",
+ input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] },
+ ) do |prompt:, server_context:|
+ result = server_context.create_sampling_message(
+ messages: [{ role: "user", content: { type: "text", text: prompt } }],
+ max_tokens: 100,
+ )
+ MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }])
+ end
+ @server.tools[sampling_tool.name_value] = sampling_tool
+
+ # Send tools/call via POST (GET SSE is connected).
+ tool_request = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: "tool-1",
+ method: "tools/call",
+ params: { name: "sampling_tool", arguments: { prompt: "Hello" } },
+ }.to_json,
+ )
+
+ post_stream = TestStream.new
+ result_queue = Queue.new
+ Thread.new do
+ response = @transport.handle_request(tool_request)
+ response[2].call(post_stream)
+ result_queue.push(:done)
+ end
+
+ sleep(0.2)
+
+ # Sampling request should be in POST response stream, not GET SSE.
+ output = post_stream.string
+ data_lines = output.lines.select { |line| line.start_with?("data: ") }
+ sampling_request = JSON.parse(data_lines.first.sub("data: ", ""))
+ assert_equal "sampling/createMessage", sampling_request["method"]
+
+ # GET SSE should NOT have the sampling request.
+ get_io.rewind
+ refute_includes get_io.read, "sampling/createMessage"
+
+ # Simulate client sending sampling result via POST.
+ client_response = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: sampling_request["id"],
+ result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } },
+ }.to_json,
+ )
+ @transport.handle_request(client_response)
+
+ result_queue.pop
+
+ tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") }
+ tool_response = JSON.parse(tool_response_lines.last.sub("data: ", ""))
+ assert_equal "tool-1", tool_response["id"]
+ assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM"
+ end
+
+ test "send_request sends via POST response stream when no GET SSE stream" do
+ # Create session without connecting GET SSE.
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Set up sampling capability for the session.
+ @transport.instance_variable_get(:@sessions)[session_id][:server_session]
+ .store_client_info(client: { name: "test" }, capabilities: { sampling: {} })
+
+ # Define a tool that calls create_sampling_message.
+ sampling_tool = MCP::Tool.define(
+ name: "sampling_tool",
+ input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] },
+ ) do |prompt:, server_context:|
+ result = server_context.create_sampling_message(
+ messages: [{ role: "user", content: { type: "text", text: prompt } }],
+ max_tokens: 100,
+ )
+ MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }])
+ end
+ @server.tools[sampling_tool.name_value] = sampling_tool
+
+ # Send tools/call via POST (no GET SSE stream).
+ tool_request = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: "tool-1",
+ method: "tools/call",
+ params: { name: "sampling_tool", arguments: { prompt: "Hello" } },
+ }.to_json,
+ )
+
+ # Process in background since handle_request blocks until tool completes.
+ post_stream = TestStream.new
+ result_queue = Queue.new
+ Thread.new do
+ response = @transport.handle_request(tool_request)
+ response[2].call(post_stream)
+ result_queue.push(:done)
+ end
+
+ sleep(0.2) # Wait for the tool to start and send sampling request.
+
+ # Read the sampling request from the POST response stream.
+ output = post_stream.string
+ data_lines = output.lines.select { |line| line.start_with?("data: ") }
+ sampling_request = JSON.parse(data_lines.first.sub("data: ", ""))
+ assert_equal "sampling/createMessage", sampling_request["method"]
+
+ # Simulate client sending sampling result via POST.
+ client_response = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: sampling_request["id"],
+ result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } },
+ }.to_json,
+ )
+ @transport.handle_request(client_response)
+
+ result_queue.pop # Wait for tool to complete.
+
+ # Verify the tool result was written to the POST response stream.
+ tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") }
+ tool_response = JSON.parse(tool_response_lines.last.sub("data: ", ""))
+ assert_equal "tool-1", tool_response["id"]
+ assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM"
+ end
+
+ test "send_notification uses POST response stream when no GET SSE stream" do
+ # Create session without connecting GET SSE.
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Define a tool that sends a notification during execution.
+ notification_sent = Queue.new
+ slow_tool = MCP::Tool.define(
+ name: "slow_tool",
+ ) do |server_context:|
+ server_context.notify_log_message(data: "test log", level: "info")
+ notification_sent.push(true)
+ MCP::Tool::Response.new([{ type: "text", text: "done" }])
+ end
+ @server.tools[slow_tool.name_value] = slow_tool
+
+ # Configure logging so notifications are sent.
+ @transport.instance_variable_get(:@sessions)[session_id][:server_session]
+ .configure_logging(MCP::LoggingMessageNotification.new(level: "debug"))
+
+ # Send tools/call via POST (no GET SSE stream).
+ post_stream = TestStream.new
+ result_queue = Queue.new
+ Thread.new do
+ request = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: "tool-1",
+ method: "tools/call",
+ params: { name: "slow_tool", arguments: {} },
+ }.to_json,
+ )
+ response = @transport.handle_request(request)
+ response[2].call(post_stream)
+ result_queue.push(:done)
+ end
+
+ notification_sent.pop # Wait for tool to send notification.
+ result_queue.pop
+
+ # Verify notification was written to the POST response stream.
+ assert_includes post_stream.string, "notifications/message"
+ assert_includes post_stream.string, "test log"
+ end
+
+ test "progress notification uses POST response stream when no GET SSE stream" do
+ # Create session without connecting GET SSE.
+ init_request = create_rack_request(
+ "POST",
+ "/",
+ { "CONTENT_TYPE" => "application/json" },
+ { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
+ )
+ init_response = @transport.handle_request(init_request)
+ session_id = init_response[1]["Mcp-Session-Id"]
+
+ # Define a tool that reports progress during execution.
+ progress_reported = Queue.new
+ progress_tool = MCP::Tool.define(
+ name: "progress_tool",
+ ) do |server_context:|
+ server_context.report_progress(50, total: 100, message: "halfway")
+ progress_reported.push(true)
+ MCP::Tool::Response.new([{ type: "text", text: "done" }])
+ end
+ @server.tools[progress_tool.name_value] = progress_tool
+
+ # Send tools/call via POST (no GET SSE stream) with a progress token.
+ post_stream = TestStream.new
+ result_queue = Queue.new
+ Thread.new do
+ request = create_rack_request(
+ "POST",
+ "/",
+ {
+ "CONTENT_TYPE" => "application/json",
+ "HTTP_MCP_SESSION_ID" => session_id,
+ },
+ {
+ jsonrpc: "2.0",
+ id: "tool-1",
+ method: "tools/call",
+ params: { name: "progress_tool", arguments: {}, _meta: { progressToken: "token-1" } },
+ }.to_json,
+ )
+ response = @transport.handle_request(request)
+ response[2].call(post_stream)
+ result_queue.push(:done)
+ end
+
+ progress_reported.pop
+ result_queue.pop
+
+ # Verify progress notification was written to the POST response stream.
+ assert_includes post_stream.string, "notifications/progress"
+ assert_includes post_stream.string, "token-1"
+ end
+
test "POST notifications/initialized returns 202 with no body" do
# Create a session first (optional for notification, but keep consistent with flow)
init_request = create_rack_request(
@@ -2228,13 +2651,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
params: { name: "log_tool", arguments: {} },
}.to_json,
)
- transport.handle_request(tool_request)
+ tool_response = transport.handle_request(tool_request)
+ post_io = StringIO.new
+ tool_response[2].call(post_io)
- # Session 1 should receive the log notification.
- io1.rewind
- assert_includes io1.read, "secret"
+ # Session 1's POST response stream should contain the log notification.
+ assert_includes post_io.string, "secret"
- # Session 2 should NOT receive the log notification.
+ # GET SSE streams should NOT receive the log notification.
+ io1.rewind
+ refute_includes io1.read, "secret"
io2.rewind
refute_includes io2.read, "secret"
end
@@ -2306,13 +2732,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
},
}.to_json,
)
- transport.handle_request(tool_request)
+ tool_response = transport.handle_request(tool_request)
+ post_io = StringIO.new
+ tool_response[2].call(post_io)
- # Session 1 should receive the progress notification.
- io1.rewind
- assert_includes io1.read, "halfway"
+ # Session 1's POST response stream should contain the progress notification.
+ assert_includes post_io.string, "halfway"
- # Session 2 should NOT receive the progress notification.
+ # GET SSE streams should NOT receive the progress notification.
+ io1.rewind
+ refute_includes io1.read, "halfway"
io2.rewind
refute_includes io2.read, "halfway"
end
@@ -2406,7 +2835,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
params: { level: "error" },
}.to_json,
)
- transport.handle_request(set_level1)
+ response1 = transport.handle_request(set_level1)
+ response1[2].call(StringIO.new)
# Session 2 sets log level to "debug".
set_level2 = create_rack_request(
@@ -2420,7 +2850,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
params: { level: "debug" },
}.to_json,
)
- transport.handle_request(set_level2)
+ response2 = transport.handle_request(set_level2)
+ response2[2].call(StringIO.new)
# Session 1 (error level) should not notify for "info", but should for "error".
session1_logging = transport.instance_variable_get(:@sessions)[session1][:server_session].logging_message_notification
diff --git a/test/mcp/server_context_test.rb b/test/mcp/server_context_test.rb
index 605e385..81735d6 100644
--- a/test/mcp/server_context_test.rb
+++ b/test/mcp/server_context_test.rb
@@ -46,6 +46,7 @@ class ServerContextTest < ActiveSupport::TestCase
notification_target.expects(:create_sampling_message).with(
messages: [{ role: "user", content: { type: "text", text: "Hello" } }],
max_tokens: 100,
+ related_request_id: nil,
).returns({ role: "assistant", content: { type: "text", text: "Hi" } })
context = mock
@@ -67,6 +68,7 @@ class ServerContextTest < ActiveSupport::TestCase
context.expects(:create_sampling_message).with(
messages: [{ role: "user", content: { type: "text", text: "Hello" } }],
max_tokens: 100,
+ related_request_id: nil,
).returns({ role: "assistant", content: { type: "text", text: "Fallback" } })
progress = Progress.new(notification_target: notification_target, progress_token: nil)
diff --git a/test/mcp/server_sampling_test.rb b/test/mcp/server_sampling_test.rb
index 57c488d..bf250e4 100644
--- a/test/mcp/server_sampling_test.rb
+++ b/test/mcp/server_sampling_test.rb
@@ -260,7 +260,7 @@ def close; end
max_tokens: 100,
)
end
- assert_equal("No active SSE stream for sampling/createMessage request.", error_with_sampling.message)
+ assert_equal("No active stream for sampling/createMessage request.", error_with_sampling.message)
# Session without sampling capability should be rejected.
session_without_sampling = ServerSession.new(server: @server, transport: transport, session_id: "s2")
@@ -290,7 +290,7 @@ def close; end
max_tokens: 100,
)
end
- assert_equal("No active SSE stream for sampling/createMessage request.", error.message)
+ assert_equal("No active stream for sampling/createMessage request.", error.message)
end
test "session init does not overwrite server global client_capabilities" do
@@ -375,7 +375,18 @@ def close; end
max_tokens: 100,
)
end
- assert_equal("No active SSE stream for sampling/createMessage request.", error.message)
+ assert_equal("No active stream for sampling/createMessage request.", error.message)
+ end
+
+ test "Server#create_sampling_message accepts related_request_id without error" do
+ @server.create_sampling_message(
+ messages: [{ role: "user", content: { type: "text", text: "Hello" } }],
+ max_tokens: 100,
+ related_request_id: "req-1",
+ )
+
+ request = @mock_transport.requests.first
+ assert_equal "sampling/createMessage", request[:method]
end
test "create_sampling_message omits nil optional params" do