From 8278ca658669c6b57b938faca13a70d2304234d9 Mon Sep 17 00:00:00 2001 From: Koichi ITO Date: Wed, 1 Apr 2026 10:10:59 +0900 Subject: [PATCH] Support POST response SSE streams for server-to-client messages ## Motivation and Context The MCP Streamable HTTP specification defines that servers can return POST responses as SSE streams and send server-to-client JSON-RPC requests and notifications through them: > If the input is a JSON-RPC request, the server MUST either return > `Content-Type: text/event-stream`, to initiate an SSE stream, or > `Content-Type: application/json`, to return one JSON object. If the server initiates an SSE stream: > The server MAY send JSON-RPC requests and notifications before > sending the JSON-RPC response. These messages SHOULD relate to the > originating client request. See: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server Previously, when a GET SSE stream was connected, the Ruby SDK sent POST response bodies through the GET stream and returned a 202 HTTP status for the POST itself. When no GET SSE stream was connected, the SDK correctly returned `application/json` responses. The 202 behavior with GET SSE was non-compliant with the specification, which requires POST requests to return `text/event-stream` or `application/json`. Additionally, since GET SSE is optional per the specification, clients that did not establish a GET SSE connection could not receive server-to-client messages (e.g., `sampling/createMessage`, log notifications) during request processing. With this change, `handle_regular_request` always returns the POST response as an SSE stream for stateful sessions. Each POST response stream is stored in `session[:post_request_streams]` keyed by `related_request_id` (the JSON-RPC request ID), enabling correct routing when multiple POST requests are processed concurrently on the same session. Server-to-client messages with a `related_request_id` are routed to the originating POST response stream only; when `related_request_id` is nil, the GET SSE stream is used. The TypeScript and Python SDKs already support this pattern. ### Internal Changes `JsonRpcHandler.handle` and `JsonRpcHandler.handle_json` now pass both `method_name` and `request_id` to the method finder block. This allows `Server#handle_request` to receive `related_request_id` directly from the protocol layer. Without this, `related_request_id` would need to be relayed as a keyword argument through `Server#handle_json`, `ServerSession#handle_json`, and `dispatch_handle_json`, unnecessarily exposing it on public method signatures. This follows the same design as the TypeScript and Python SDKs, where the protocol layer extracts the request ID and propagates it to the handler context. ## How Has This Been Tested? Added tests for POST response stream: - `send_request` via POST response stream (sampling with and without GET SSE) - `send_notification` via POST response stream (logging without GET SSE) - `progress` notification via POST response stream (without GET SSE) - POST request returns SSE response even with GET SSE connected - Session-scoped notifications (log, progress) are sent to POST response stream, not GET SSE stream - `active_stream` does not fall back to GET SSE when `related_request_id` is given but request stream is missing Updated existing tests to handle SSE response format where applicable. ## Breaking Changes (spec compliance fix) POST responses for JSON-RPC requests in stateful sessions now return `Content-Type: text/event-stream` instead of being sent through the GET SSE stream with a 202 HTTP status. Clients that relied on receiving responses via the GET SSE stream will need to read the POST response body instead. This is a spec compliance fix: the MCP specification requires POST requests to return `text/event-stream` or `application/json`, not 202 with the response on a separate GET stream. --- lib/json_rpc_handler.rb | 2 +- lib/mcp/progress.rb | 4 +- lib/mcp/server.rb | 25 +- .../transports/streamable_http_transport.rb | 155 +++-- lib/mcp/server_context.rb | 9 +- lib/mcp/server_session.rb | 20 +- test/json_rpc_handler_test.rb | 16 +- .../streamable_http_transport_test.rb | 575 +++++++++++++++--- test/mcp/server_context_test.rb | 2 + test/mcp/server_sampling_test.rb | 17 +- 10 files changed, 671 insertions(+), 154 deletions(-) diff --git a/lib/json_rpc_handler.rb b/lib/json_rpc_handler.rb index d4be6a76..b788e12c 100644 --- a/lib/json_rpc_handler.rb +++ b/lib/json_rpc_handler.rb @@ -92,7 +92,7 @@ def process_request(request, id_validation_pattern:, &method_finder) end begin - method = method_finder.call(method_name) + method = method_finder.call(method_name, id) if method.nil? return error_response(id: id, id_validation_pattern: id_validation_pattern, error: { diff --git a/lib/mcp/progress.rb b/lib/mcp/progress.rb index 8843a0d2..6762d3c7 100644 --- a/lib/mcp/progress.rb +++ b/lib/mcp/progress.rb @@ -2,9 +2,10 @@ module MCP class Progress - def initialize(notification_target:, progress_token:) + def initialize(notification_target:, progress_token:, related_request_id: nil) @notification_target = notification_target @progress_token = progress_token + @related_request_id = related_request_id end def report(progress, total: nil, message: nil) @@ -16,6 +17,7 @@ def report(progress, total: nil, message: nil) progress: progress, total: total, message: message, + related_request_id: @related_request_id, ) end end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index 484b29b5..d085b8ba 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -127,8 +127,8 @@ def initialize( # When `nil`, progress and logging notifications from tool handlers are silently skipped. # @return [Hash, nil] The JSON-RPC response, or `nil` for notifications. def handle(request, session: nil) - JsonRpcHandler.handle(request) do |method| - handle_request(request, method, session: session) + JsonRpcHandler.handle(request) do |method, request_id| + handle_request(request, method, session: session, related_request_id: request_id) end end @@ -140,8 +140,8 @@ def handle(request, session: nil) # When `nil`, progress and logging notifications from tool handlers are silently skipped. # @return [String, nil] The JSON-RPC response as JSON, or `nil` for notifications. def handle_json(request, session: nil) - JsonRpcHandler.handle_json(request) do |method| - handle_request(request, method, session: session) + JsonRpcHandler.handle_json(request) do |method, request_id| + handle_request(request, method, session: session, related_request_id: request_id) end end @@ -220,7 +220,8 @@ def create_sampling_message( stop_sequences: nil, metadata: nil, tools: nil, - tool_choice: nil + tool_choice: nil, + related_request_id: nil ) unless @transport raise "Cannot send sampling request without a transport." @@ -371,7 +372,7 @@ def schema_contains_ref?(schema) end end - def handle_request(request, method, session: nil) + def handle_request(request, method, session: nil, related_request_id: nil) handler = @handlers[method] unless handler instrument_call("unsupported_method") do @@ -399,7 +400,7 @@ def handle_request(request, method, session: nil) when Methods::RESOURCES_TEMPLATES_LIST { resourceTemplates: @handlers[Methods::RESOURCES_TEMPLATES_LIST].call(params) } when Methods::TOOLS_CALL - call_tool(params, session: session) + call_tool(params, session: session, related_request_id: related_request_id) when Methods::COMPLETION_COMPLETE complete(params) when Methods::LOGGING_SET_LEVEL @@ -499,7 +500,7 @@ def list_tools(request) @tools.values.map(&:to_h) end - def call_tool(request, session: nil) + def call_tool(request, session: nil, related_request_id: nil) tool_name = request[:name] tool = tools[tool_name] @@ -531,7 +532,7 @@ def call_tool(request, session: nil) progress_token = request.dig(:_meta, :progressToken) - call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session) + call_tool_with_args(tool, arguments, server_context_with_meta(request), progress_token: progress_token, session: session, related_request_id: related_request_id) rescue RequestHandlerError raise rescue => e @@ -611,12 +612,12 @@ def accepts_server_context?(method_object) parameters.any? { |type, name| type == :keyrest || name == :server_context } end - def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil) + def call_tool_with_args(tool, arguments, context, progress_token: nil, session: nil, related_request_id: nil) args = arguments&.transform_keys(&:to_sym) || {} if accepts_server_context?(tool.method(:call)) - progress = Progress.new(notification_target: session, progress_token: progress_token) - server_context = ServerContext.new(context, progress: progress, notification_target: session) + progress = Progress.new(notification_target: session, progress_token: progress_token, related_request_id: related_request_id) + server_context = ServerContext.new(context, progress: progress, notification_target: session, related_request_id: related_request_id) tool.call(**args, server_context: server_context).to_h else tool.call(**args).to_h diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb index 31ddc896..688e38c5 100644 --- a/lib/mcp/server/transports/streamable_http_transport.rb +++ b/lib/mcp/server/transports/streamable_http_transport.rb @@ -7,6 +7,12 @@ module MCP class Server module Transports class StreamableHTTPTransport < Transport + SSE_HEADERS = { + "Content-Type" => "text/event-stream", + "Cache-Control" => "no-cache", + "Connection" => "keep-alive", + }.freeze + def initialize(server, stateless: false, session_idle_timeout: nil) super(server) # Maps `session_id` to `{ stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`. @@ -56,10 +62,11 @@ def close removed_sessions.each do |session| close_stream_safely(session[:stream]) + close_post_request_streams(session) end end - def send_notification(method, params = nil, session_id: nil) + def send_notification(method, params = nil, session_id: nil, related_request_id: nil) # Stateless mode doesn't support notifications raise "Stateless mode does not support notifications" if @stateless @@ -74,8 +81,10 @@ def send_notification(method, params = nil, session_id: nil) result = @mutex.synchronize do if session_id # Send to specific session - session = @sessions[session_id] - next false unless session && session[:stream] + if (session = @sessions[session_id]) + stream = active_stream(session, related_request_id: related_request_id) + end + next false unless stream if session_expired?(session) cleanup_and_collect_stream(session_id, streams_to_close) @@ -83,14 +92,19 @@ def send_notification(method, params = nil, session_id: nil) end begin - send_to_stream(session[:stream], notification) + send_to_stream(stream, notification) true rescue *STREAM_WRITE_ERRORS => e MCP.configuration.exception_reporter.call( e, { session_id: session_id, error: "Failed to send notification" }, ) - cleanup_and_collect_stream(session_id, streams_to_close) + if related_request_id && session[:post_request_streams]&.key?(related_request_id) + session[:post_request_streams].delete(related_request_id) + streams_to_close << stream + else + cleanup_and_collect_stream(session_id, streams_to_close) + end false end else @@ -99,7 +113,7 @@ def send_notification(method, params = nil, session_id: nil) failed_sessions = [] @sessions.each do |sid, session| - next unless session[:stream] + next unless (stream = session[:stream]) if session_expired?(session) failed_sessions << sid @@ -107,7 +121,7 @@ def send_notification(method, params = nil, session_id: nil) end begin - send_to_stream(session[:stream], notification) + send_to_stream(stream, notification) sent_count += 1 rescue *STREAM_WRITE_ERRORS => e MCP.configuration.exception_reporter.call( @@ -139,7 +153,7 @@ def send_notification(method, params = nil, session_id: nil) # sends the request via SSE stream, then blocks on `queue.pop`. # When the client POSTs a response, `handle_response` matches it by `request_id` # and pushes the result onto the queue, unblocking this thread. - def send_request(method, params = nil, session_id: nil) + def send_request(method, params = nil, session_id: nil, related_request_id: nil) if @stateless raise "Stateless mode does not support server-to-client requests." end @@ -163,12 +177,17 @@ def send_request(method, params = nil, session_id: nil) @pending_responses[request_id] = { queue: queue, session_id: session_id } - if (stream = session[:stream]) + if (stream = active_stream(session, related_request_id: related_request_id)) begin send_to_stream(stream, request) sent = true rescue *STREAM_WRITE_ERRORS - cleanup_session_unsafe(session_id) + if related_request_id && session[:post_request_streams]&.key?(related_request_id) + session[:post_request_streams].delete(related_request_id) + close_stream_safely(stream) + else + cleanup_session_unsafe(session_id) + end end end end @@ -181,7 +200,7 @@ def send_request(method, params = nil, session_id: nil) # The TypeScript and Python SDKs buffer messages and replay on reconnect. # Until then, raise to prevent queue.pop from blocking indefinitely. unless sent - raise "No active SSE stream for #{method} request." + raise "No active stream for #{method} request." end response = queue.pop @@ -229,6 +248,7 @@ def reap_expired_sessions removed_sessions.each do |session| close_stream_safely(session[:stream]) + close_post_request_streams(session) end end @@ -265,7 +285,7 @@ def handle_post(request) handle_response(body, session_id: session_id) else - handle_regular_request(body_string, session_id) + handle_regular_request(body_string, session_id, related_request_id: body[:id]) end end rescue StandardError => e @@ -313,7 +333,10 @@ def cleanup_session(session_id) cleanup_session_unsafe(session_id) end - close_stream_safely(session[:stream]) if session + if session + close_stream_safely(session[:stream]) + close_post_request_streams(session) + end end # Removes a session from `@sessions` and returns it. Does not close the stream. @@ -336,6 +359,7 @@ def cleanup_and_collect_stream(session_id, streams_to_close) return unless (removed = cleanup_session_unsafe(session_id)) streams_to_close << removed[:stream] + removed[:post_request_streams]&.each_value { |stream| streams_to_close << stream } end def close_stream_safely(stream) @@ -344,6 +368,14 @@ def close_stream_safely(stream) # Ignore close-related errors from already closed/broken streams. end + def close_post_request_streams(session) + return unless (post_request_streams = session[:post_request_streams]) + + post_request_streams.each_value do |stream| + close_stream_safely(stream) + end + end + def extract_session_id(request) request.env["HTTP_MCP_SESSION_ID"] end @@ -443,9 +475,8 @@ def handle_accepted [202, {}, []] end - def handle_regular_request(body_string, session_id) + def handle_regular_request(body_string, session_id, related_request_id: nil) server_session = nil - stream = nil unless @stateless if session_id @@ -455,21 +486,72 @@ def handle_regular_request(body_string, session_id) @mutex.synchronize do session = @sessions[session_id] server_session = session[:server_session] if session - stream = session[:stream] if session end end end - response = if server_session - server_session.handle_json(body_string) + if session_id && !@stateless + handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: related_request_id) else - @server.handle_json(body_string) + response = dispatch_handle_json(body_string, server_session) + [200, { "Content-Type" => "application/json" }, [response]] end + end + + # Returns the POST response as an SSE stream so the server can send + # JSON-RPC requests and notifications during request processing. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + def handle_request_with_sse_response(body_string, session_id, server_session, related_request_id: nil) + body = proc do |stream| + @mutex.synchronize do + session = @sessions[session_id] + if session && related_request_id + session[:post_request_streams] ||= {} + session[:post_request_streams][related_request_id] = stream + end + end - if stream - send_response_to_stream(stream, response, session_id) + begin + response = dispatch_handle_json(body_string, server_session) + + send_to_stream(stream, response) if response + ensure + if related_request_id + @mutex.synchronize do + session = @sessions[session_id] + session[:post_request_streams]&.delete(related_request_id) if session + end + end + + begin + stream.close + rescue StandardError + # Ignore close-related errors from already closed/broken streams. + end + end + end + + [200, SSE_HEADERS, body] + end + + # Returns the SSE stream available for server-to-client messages. + # When `related_request_id` is given, returns only the POST response + # stream for that request (no fallback to GET SSE). This prevents + # request-scoped messages from leaking to the wrong stream. + # When `related_request_id` is nil, returns the GET SSE stream. + def active_stream(session, related_request_id: nil) + if related_request_id + session.dig(:post_request_streams, related_request_id) else - [200, { "Content-Type" => "application/json" }, [response]] + session[:stream] + end + end + + def dispatch_handle_json(body_string, server_session) + if server_session + server_session.handle_json(body_string) + else + @server.handle_json(body_string) end end @@ -489,7 +571,13 @@ def validate_and_touch_session(session_id) nil end - close_stream_safely(removed[:stream]) if removed + if removed + close_stream_safely(removed[:stream]) + + removed[:post_request_streams]&.each_value do |stream| + close_stream_safely(stream) + end + end response end @@ -498,19 +586,6 @@ def get_session_stream(session_id) @mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) } end - def send_response_to_stream(stream, response, session_id) - message = JSON.parse(response) - send_to_stream(stream, message) - handle_accepted - rescue *STREAM_WRITE_ERRORS => e - MCP.configuration.exception_reporter.call( - e, - { session_id: session_id, error: "Stream closed during response" }, - ) - cleanup_session(session_id) - [200, { "Content-Type" => "application/json" }, [response]] - end - def session_exists?(session_id) @mutex.synchronize { @sessions.key?(session_id) } end @@ -538,13 +613,7 @@ def session_already_connected_response def setup_sse_stream(session_id) body = create_sse_body(session_id) - headers = { - "Content-Type" => "text/event-stream", - "Cache-Control" => "no-cache", - "Connection" => "keep-alive", - } - - [200, headers, body] + [200, SSE_HEADERS, body] end def create_sse_body(session_id) diff --git a/lib/mcp/server_context.rb b/lib/mcp/server_context.rb index b532555b..aadd7505 100644 --- a/lib/mcp/server_context.rb +++ b/lib/mcp/server_context.rb @@ -2,10 +2,11 @@ module MCP class ServerContext - def initialize(context, progress:, notification_target:) + def initialize(context, progress:, notification_target:, related_request_id: nil) @context = context @progress = progress @notification_target = notification_target + @related_request_id = related_request_id end # Reports progress for the current tool operation. @@ -26,7 +27,7 @@ def report_progress(progress, total: nil, message: nil) def notify_log_message(data:, level:, logger: nil) return unless @notification_target - @notification_target.notify_log_message(data: data, level: level, logger: logger) + @notification_target.notify_log_message(data: data, level: level, logger: logger, related_request_id: @related_request_id) end # Delegates to the session so the request is scoped to the originating client. @@ -34,9 +35,9 @@ def notify_log_message(data:, level:, logger: nil) # does not support sampling. def create_sampling_message(**kwargs) if @notification_target.respond_to?(:create_sampling_message) - @notification_target.create_sampling_message(**kwargs) + @notification_target.create_sampling_message(**kwargs, related_request_id: @related_request_id) elsif @context.respond_to?(:create_sampling_message) - @context.create_sampling_message(**kwargs) + @context.create_sampling_message(**kwargs, related_request_id: @related_request_id) else raise NoMethodError, "undefined method 'create_sampling_message' for #{self}" end diff --git a/lib/mcp/server_session.rb b/lib/mcp/server_session.rb index 93e823fb..2fe8f77a 100644 --- a/lib/mcp/server_session.rb +++ b/lib/mcp/server_session.rb @@ -42,13 +42,13 @@ def client_capabilities end # Sends a `sampling/createMessage` request scoped to this session. - def create_sampling_message(**kwargs) + def create_sampling_message(related_request_id: nil, **kwargs) params = @server.build_sampling_params(client_capabilities, **kwargs) - send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params) + send_to_transport_request(Methods::SAMPLING_CREATE_MESSAGE, params, related_request_id: related_request_id) end # Sends a progress notification to this session only. - def notify_progress(progress_token:, progress:, total: nil, message: nil) + def notify_progress(progress_token:, progress:, total: nil, message: nil, related_request_id: nil) params = { "progressToken" => progress_token, "progress" => progress, @@ -56,20 +56,20 @@ def notify_progress(progress_token:, progress:, total: nil, message: nil) "message" => message, }.compact - send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params) + send_to_transport(Methods::NOTIFICATIONS_PROGRESS, params, related_request_id: related_request_id) rescue => e @server.report_exception(e, notification: "progress") end # Sends a log message notification to this session only. - def notify_log_message(data:, level:, logger: nil) + def notify_log_message(data:, level:, logger: nil, related_request_id: nil) effective_logging = @logging_message_notification || @server.logging_message_notification return unless effective_logging&.should_notify?(level) params = { "data" => data, "level" => level } params["logger"] = logger if logger - send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params) + send_to_transport(Methods::NOTIFICATIONS_MESSAGE, params, related_request_id: related_request_id) rescue => e @server.report_exception(e, { notification: "log_message" }) end @@ -82,9 +82,9 @@ def notify_log_message(data:, level:, logger: nil) # TODO: When Ruby 2.7 support is dropped, replace with a direct call: # `@transport.send_notification(method, params, session_id: @session_id)` and # add `**` to `Transport#send_notification` and `StdioTransport#send_notification`. - def send_to_transport(method, params) + def send_to_transport(method, params, related_request_id: nil) if @session_id - @transport.send_notification(method, params, session_id: @session_id) + @transport.send_notification(method, params, session_id: @session_id, related_request_id: related_request_id) else @transport.send_notification(method, params) end @@ -96,9 +96,9 @@ def send_to_transport(method, params) # TODO: When Ruby 2.7 support is dropped, replace with a direct call: # `@transport.send_request(method, params, session_id: @session_id)` and # add `**` to `Transport#send_request` and `StdioTransport#send_request`. - def send_to_transport_request(method, params) + def send_to_transport_request(method, params, related_request_id: nil) if @session_id - @transport.send_request(method, params, session_id: @session_id) + @transport.send_request(method, params, session_id: @session_id, related_request_id: related_request_id) else @transport.send_request(method, params) end diff --git a/test/json_rpc_handler_test.rb b/test/json_rpc_handler_test.rb index 67f47dba..169b1fd8 100644 --- a/test/json_rpc_handler_test.rb +++ b/test/json_rpc_handler_test.rb @@ -621,7 +621,7 @@ @response = JsonRpcHandler.handle( { jsonrpc: "2.0", id: "user@example.com", method: "add", params: { a: 1, b: 2 } }, id_validation_pattern: custom_pattern, - ) { |method_name| @registry[method_name] } + ) { |method_name, _request_id| @registry[method_name] } assert_rpc_success expected_result: 3 assert_equal "user@example.com", @response[:id] @@ -633,7 +633,7 @@ @response = JsonRpcHandler.handle( { jsonrpc: "2.0", id: "id", method: "add", params: { a: 1, b: 2 } }, id_validation_pattern: nil, - ) { |method_name| @registry[method_name] } + ) { |method_name, _request_id| @registry[method_name] } assert_rpc_success expected_result: 3 assert_equal "", @response[:id] @@ -733,11 +733,11 @@ def register(method_name, &block) end def handle(request) - @response = JsonRpcHandler.handle(request) { |method_name| @registry[method_name] } + @response = JsonRpcHandler.handle(request) { |method_name, _request_id| @registry[method_name] } end def handle_json(request_json) - @response_json = JsonRpcHandler.handle_json(request_json) { |method_name| @registry[method_name] } + @response_json = JsonRpcHandler.handle_json(request_json) { |method_name, _request_id| @registry[method_name] } @response = JSON.parse(@response_json, symbolize_names: true) if @response_json end diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb index 507ac45b..7435289e 100644 --- a/test/mcp/server/transports/streamable_http_transport_test.rb +++ b/test/mcp/server/transports/streamable_http_transport_test.rb @@ -7,6 +7,31 @@ module MCP class Server module Transports class StreamableHTTPTransportTest < ActiveSupport::TestCase + # A stream that buffers writes and remains readable after close. + class TestStream + def initialize + @buffer = "".dup + @closed = false + end + + def write(data) + raise IOError, "closed stream" if @closed + + @buffer << data + end + + def flush + end + + def close + @closed = true + end + + def string + @buffer + end + end + setup do @server = Server.new( name: "test_server", @@ -45,9 +70,11 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) + assert_equal "text/event-stream", response[1]["Content-Type"] - body = JSON.parse(response[2][0]) + io = StringIO.new + response[2].call(io) + body = JSON.parse(io.string.match(/^data: (.+)$/)[1]) assert_equal "2.0", body["jsonrpc"] assert_equal "123", body["id"] assert_equal({}, body["result"]) @@ -114,8 +141,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert response[2].is_a?(Proc) # The body should be a Proc for streaming end - test "handles POST request when IOError raised" do - # Create and initialize a session + test "handles POST request as SSE even when GET SSE stream is closed" do init_request = create_rack_request( "POST", "/", @@ -125,7 +151,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Connect with SSE + # Connect with SSE then close it io = StringIO.new get_request = create_rack_request( "GET", @@ -134,13 +160,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(io) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - - # Close the stream io.close + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -151,17 +174,12 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, ) - # This should handle IOError and return the original response response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) + assert_equal "text/event-stream", response[1]["Content-Type"] end - test "handles POST request when Errno::EPIPE raised" do - # Create and initialize a session + test "handles POST request as SSE even when GET SSE stream has EPIPE" do init_request = create_rack_request( "POST", "/", @@ -171,10 +189,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Create a pipe to simulate EPIPE condition + # Connect GET SSE with a broken pipe reader, writer = IO.pipe - - # Connect with SSE using the writer end of the pipe get_request = create_rack_request( "GET", "/", @@ -182,13 +198,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(writer) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - - # Close the reader end to break the pipe - this will cause EPIPE on write reader.close + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -199,23 +212,18 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "789" }.to_json, ) - # This should handle Errno::EPIPE and return the original response response = @transport.handle_request(request) - assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) - + assert_equal(200, response[0]) + assert_equal("text/event-stream", response[1]["Content-Type"]) + ensure begin writer.close - rescue + rescue StandardError nil end end - test "handles POST request when Errno::ECONNRESET raised" do - # Create and initialize a session. + test "handles POST request as SSE even when GET SSE stream has ECONNRESET" do init_request = create_rack_request( "POST", "/", @@ -225,12 +233,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Use a mock stream that raises Errno::ECONNRESET on write. + # Connect GET SSE with a mock that raises ECONNRESET mock_stream = Object.new mock_stream.define_singleton_method(:write) { |_data| raise Errno::ECONNRESET } mock_stream.define_singleton_method(:close) {} - - # Connect with SSE using the mock stream. get_request = create_rack_request( "GET", "/", @@ -238,10 +244,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(mock_stream) if response[2].is_a?(Proc) - - # Give the stream time to set up. sleep(0.1) + # POST request should still return SSE response via POST response stream request = create_rack_request( "POST", "/", @@ -252,13 +257,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "789" }.to_json, ) - # This should handle Errno::ECONNRESET and return the original response. response = @transport.handle_request(request) assert_equal 200, response[0] - assert_equal({ "Content-Type" => "application/json" }, response[1]) - - # Verify session was cleaned up. - assert_not @transport.instance_variable_get(:@sessions).key?(session_id) + assert_equal "text/event-stream", response[1]["Content-Type"] end test "handles GET request with missing session ID" do @@ -579,6 +580,54 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal({}, @transport.instance_variable_get(:@sessions)) end + test "cleanup_session_unsafe closes request_streams" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Simulate multiple request_streams being set on the session. + closed = [] + 2.times do |i| + mock_stream = Object.new + mock_stream.define_singleton_method(:close) { closed << i } + thread = Thread.new {} + thread.join + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] ||= {} + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams][thread] = mock_stream + end + + delete_request = create_rack_request( + "DELETE", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + @transport.handle_request(delete_request) + + assert_equal [0, 1], closed.sort + assert_empty @transport.instance_variable_get(:@sessions) + end + + test "broadcast notification skips sessions without GET SSE stream" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + @transport.handle_request(init_request) + + # No GET SSE stream connected, only request_streams. + # Pass **{} to prevent Ruby 2.7 from converting the Hash to keyword arguments. + result = @transport.send_notification("test/notify", { message: "hello" }, **{}) + + assert_equal 0, result + end + test "sends notification to correct session with multiple active sessions" do # Create first session init_request1 = create_rack_request( @@ -653,8 +702,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase result end - # Handle request from session 1 - @transport.handle_request(request_as_session1) + # Handle request from session 1 (execute SSE proc) + response1 = @transport.handle_request(request_as_session1) + response1[2].call(StringIO.new) if response1[2].is_a?(Proc) # Make a request as session 2 request_as_session2 = create_rack_request( @@ -667,18 +717,17 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase { jsonrpc: "2.0", method: "ping", id: "890" }.to_json, ) - # Handle request from session 2 - @transport.handle_request(request_as_session2) + # Handle request from session 2 (execute SSE proc) + response2_post = @transport.handle_request(request_as_session2) + response2_post[2].call(StringIO.new) if response2_post[2].is_a?(Proc) - # Check that each session received one notification + # Broadcast notifications are sent to GET SSE streams (no related_request_id) io1.rewind output1 = io1.read - # Session 1 should have received two notifications (one from each request since we broadcast) assert_equal 2, output1.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count io2.rewind output2 = io2.read - # Session 2 should have received two notifications (one from each request since we broadcast) assert_equal 2, output2.scan(/data: {"jsonrpc":"2.0","method":"test_notification","params":{"session":"current"}}/).count end @@ -888,6 +937,85 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_not @transport.instance_variable_get(:@sessions).key?(session_id) end + test "send_notification on broken request_stream removes only that stream, not the session" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + sleep(0.1) + + # Simulate a broken request_stream. + broken_stream = Object.new + broken_stream.define_singleton_method(:write) { |_data| raise Errno::EPIPE } + broken_stream.define_singleton_method(:close) {} + related_id = "req-1" + @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams] = { related_id => broken_stream } + + result = @transport.send_notification("test", { msg: "hello" }, session_id: session_id, related_request_id: related_id) + + refute result + # Session should still exist. + assert @transport.instance_variable_get(:@sessions).key?(session_id) + # The broken request_stream should be removed. + refute @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams].key?(related_id) + # GET SSE stream should still be intact. + assert @transport.instance_variable_get(:@sessions)[session_id][:stream] + end + + test "active_stream does not fall back to GET SSE when related_request_id is given but request_stream is missing" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = @transport.handle_request(get_request) + response[2].call(io) if response[2].is_a?(Proc) + sleep(0.1) + + # Send notification with a related_request_id that has no matching request_stream. + result = @transport.send_notification( + "test/notify", + { message: "should not arrive" }, + session_id: session_id, + related_request_id: "nonexistent-request-id", + ) + + # Should return false because no matching request_stream exists. + refute result + + # Session should still exist (not cleaned up). + assert @transport.instance_variable_get(:@sessions).key?(session_id) + + # GET SSE stream should NOT have received the notification. + io.rewind + refute_includes io.read, "should not arrive" + end + test "send_notification broadcast continues when one session raises Errno::ECONNRESET" do # Create two sessions. init_request1 = create_rack_request( @@ -1334,8 +1462,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_nil(body) end - test "send_response_to_stream returns 202 when message is sent to stream" do - # Create and initialize a session + test "POST request returns SSE response even with GET SSE connected" do init_request = create_rack_request( "POST", "/", @@ -1345,7 +1472,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase init_response = @transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - # Connect with SSE + # Connect with GET SSE io = StringIO.new get_request = create_rack_request( "GET", @@ -1354,11 +1481,9 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(get_request) response[2].call(io) if response[2].is_a?(Proc) - - # Give the stream time to set up sleep(0.1) - # Make a regular request that will be routed through send_response_to_stream + # POST request should return SSE, not 202 request = create_rack_request( "POST", "/", @@ -1370,9 +1495,13 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase ) response = @transport.handle_request(request) - assert_equal 202, response[0] - assert_empty response[1] - assert_empty response[2] + assert_equal 200, response[0] + assert_equal "text/event-stream", response[1]["Content-Type"] + + post_io = StringIO.new + response[2].call(post_io) + body = JSON.parse(post_io.string.match(/^data: (.+)$/)[1]) + assert_equal "456", body["id"] end test "handle post request with a standard error" do @@ -1436,7 +1565,7 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase @transport.send_request("sampling/createMessage", { "messages" => [] }, session_id: session_id) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) end test "send_request sends via SSE and waits for response" do @@ -1683,6 +1812,300 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal("SSE session closed while waiting for sampling/createMessage response.", error.message) end + test "send_request sends via POST response stream even with GET SSE connected" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Connect GET SSE. + get_io = StringIO.new + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + get_response = @transport.handle_request(get_request) + get_response[2].call(get_io) if get_response[2].is_a?(Proc) + sleep(0.1) + + # Set up sampling capability for the session. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .store_client_info(client: { name: "test" }, capabilities: { sampling: {} }) + + # Define a tool that calls create_sampling_message. + sampling_tool = MCP::Tool.define( + name: "sampling_tool", + input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] }, + ) do |prompt:, server_context:| + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: prompt } }], + max_tokens: 100, + ) + MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }]) + end + @server.tools[sampling_tool.name_value] = sampling_tool + + # Send tools/call via POST (GET SSE is connected). + tool_request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "sampling_tool", arguments: { prompt: "Hello" } }, + }.to_json, + ) + + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + response = @transport.handle_request(tool_request) + response[2].call(post_stream) + result_queue.push(:done) + end + + sleep(0.2) + + # Sampling request should be in POST response stream, not GET SSE. + output = post_stream.string + data_lines = output.lines.select { |line| line.start_with?("data: ") } + sampling_request = JSON.parse(data_lines.first.sub("data: ", "")) + assert_equal "sampling/createMessage", sampling_request["method"] + + # GET SSE should NOT have the sampling request. + get_io.rewind + refute_includes get_io.read, "sampling/createMessage" + + # Simulate client sending sampling result via POST. + client_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: sampling_request["id"], + result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } }, + }.to_json, + ) + @transport.handle_request(client_response) + + result_queue.pop + + tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") } + tool_response = JSON.parse(tool_response_lines.last.sub("data: ", "")) + assert_equal "tool-1", tool_response["id"] + assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM" + end + + test "send_request sends via POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Set up sampling capability for the session. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .store_client_info(client: { name: "test" }, capabilities: { sampling: {} }) + + # Define a tool that calls create_sampling_message. + sampling_tool = MCP::Tool.define( + name: "sampling_tool", + input_schema: { properties: { prompt: { type: "string" } }, required: ["prompt"] }, + ) do |prompt:, server_context:| + result = server_context.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: prompt } }], + max_tokens: 100, + ) + MCP::Tool::Response.new([{ type: "text", text: result[:content][:text] }]) + end + @server.tools[sampling_tool.name_value] = sampling_tool + + # Send tools/call via POST (no GET SSE stream). + tool_request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "sampling_tool", arguments: { prompt: "Hello" } }, + }.to_json, + ) + + # Process in background since handle_request blocks until tool completes. + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + response = @transport.handle_request(tool_request) + response[2].call(post_stream) + result_queue.push(:done) + end + + sleep(0.2) # Wait for the tool to start and send sampling request. + + # Read the sampling request from the POST response stream. + output = post_stream.string + data_lines = output.lines.select { |line| line.start_with?("data: ") } + sampling_request = JSON.parse(data_lines.first.sub("data: ", "")) + assert_equal "sampling/createMessage", sampling_request["method"] + + # Simulate client sending sampling result via POST. + client_response = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: sampling_request["id"], + result: { role: "assistant", content: { type: "text", text: "Hi from LLM" } }, + }.to_json, + ) + @transport.handle_request(client_response) + + result_queue.pop # Wait for tool to complete. + + # Verify the tool result was written to the POST response stream. + tool_response_lines = post_stream.string.lines.select { |line| line.start_with?("data: ") } + tool_response = JSON.parse(tool_response_lines.last.sub("data: ", "")) + assert_equal "tool-1", tool_response["id"] + assert_includes tool_response["result"]["content"].first["text"], "Hi from LLM" + end + + test "send_notification uses POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Define a tool that sends a notification during execution. + notification_sent = Queue.new + slow_tool = MCP::Tool.define( + name: "slow_tool", + ) do |server_context:| + server_context.notify_log_message(data: "test log", level: "info") + notification_sent.push(true) + MCP::Tool::Response.new([{ type: "text", text: "done" }]) + end + @server.tools[slow_tool.name_value] = slow_tool + + # Configure logging so notifications are sent. + @transport.instance_variable_get(:@sessions)[session_id][:server_session] + .configure_logging(MCP::LoggingMessageNotification.new(level: "debug")) + + # Send tools/call via POST (no GET SSE stream). + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "slow_tool", arguments: {} }, + }.to_json, + ) + response = @transport.handle_request(request) + response[2].call(post_stream) + result_queue.push(:done) + end + + notification_sent.pop # Wait for tool to send notification. + result_queue.pop + + # Verify notification was written to the POST response stream. + assert_includes post_stream.string, "notifications/message" + assert_includes post_stream.string, "test log" + end + + test "progress notification uses POST response stream when no GET SSE stream" do + # Create session without connecting GET SSE. + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Define a tool that reports progress during execution. + progress_reported = Queue.new + progress_tool = MCP::Tool.define( + name: "progress_tool", + ) do |server_context:| + server_context.report_progress(50, total: 100, message: "halfway") + progress_reported.push(true) + MCP::Tool::Response.new([{ type: "text", text: "done" }]) + end + @server.tools[progress_tool.name_value] = progress_tool + + # Send tools/call via POST (no GET SSE stream) with a progress token. + post_stream = TestStream.new + result_queue = Queue.new + Thread.new do + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { + jsonrpc: "2.0", + id: "tool-1", + method: "tools/call", + params: { name: "progress_tool", arguments: {}, _meta: { progressToken: "token-1" } }, + }.to_json, + ) + response = @transport.handle_request(request) + response[2].call(post_stream) + result_queue.push(:done) + end + + progress_reported.pop + result_queue.pop + + # Verify progress notification was written to the POST response stream. + assert_includes post_stream.string, "notifications/progress" + assert_includes post_stream.string, "token-1" + end + test "POST notifications/initialized returns 202 with no body" do # Create a session first (optional for notification, but keep consistent with flow) init_request = create_rack_request( @@ -2228,13 +2651,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { name: "log_tool", arguments: {} }, }.to_json, ) - transport.handle_request(tool_request) + tool_response = transport.handle_request(tool_request) + post_io = StringIO.new + tool_response[2].call(post_io) - # Session 1 should receive the log notification. - io1.rewind - assert_includes io1.read, "secret" + # Session 1's POST response stream should contain the log notification. + assert_includes post_io.string, "secret" - # Session 2 should NOT receive the log notification. + # GET SSE streams should NOT receive the log notification. + io1.rewind + refute_includes io1.read, "secret" io2.rewind refute_includes io2.read, "secret" end @@ -2306,13 +2732,16 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase }, }.to_json, ) - transport.handle_request(tool_request) + tool_response = transport.handle_request(tool_request) + post_io = StringIO.new + tool_response[2].call(post_io) - # Session 1 should receive the progress notification. - io1.rewind - assert_includes io1.read, "halfway" + # Session 1's POST response stream should contain the progress notification. + assert_includes post_io.string, "halfway" - # Session 2 should NOT receive the progress notification. + # GET SSE streams should NOT receive the progress notification. + io1.rewind + refute_includes io1.read, "halfway" io2.rewind refute_includes io2.read, "halfway" end @@ -2406,7 +2835,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { level: "error" }, }.to_json, ) - transport.handle_request(set_level1) + response1 = transport.handle_request(set_level1) + response1[2].call(StringIO.new) # Session 2 sets log level to "debug". set_level2 = create_rack_request( @@ -2420,7 +2850,8 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase params: { level: "debug" }, }.to_json, ) - transport.handle_request(set_level2) + response2 = transport.handle_request(set_level2) + response2[2].call(StringIO.new) # Session 1 (error level) should not notify for "info", but should for "error". session1_logging = transport.instance_variable_get(:@sessions)[session1][:server_session].logging_message_notification diff --git a/test/mcp/server_context_test.rb b/test/mcp/server_context_test.rb index 605e3852..81735d61 100644 --- a/test/mcp/server_context_test.rb +++ b/test/mcp/server_context_test.rb @@ -46,6 +46,7 @@ class ServerContextTest < ActiveSupport::TestCase notification_target.expects(:create_sampling_message).with( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, + related_request_id: nil, ).returns({ role: "assistant", content: { type: "text", text: "Hi" } }) context = mock @@ -67,6 +68,7 @@ class ServerContextTest < ActiveSupport::TestCase context.expects(:create_sampling_message).with( messages: [{ role: "user", content: { type: "text", text: "Hello" } }], max_tokens: 100, + related_request_id: nil, ).returns({ role: "assistant", content: { type: "text", text: "Fallback" } }) progress = Progress.new(notification_target: notification_target, progress_token: nil) diff --git a/test/mcp/server_sampling_test.rb b/test/mcp/server_sampling_test.rb index 57c488dc..bf250e4f 100644 --- a/test/mcp/server_sampling_test.rb +++ b/test/mcp/server_sampling_test.rb @@ -260,7 +260,7 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error_with_sampling.message) + assert_equal("No active stream for sampling/createMessage request.", error_with_sampling.message) # Session without sampling capability should be rejected. session_without_sampling = ServerSession.new(server: @server, transport: transport, session_id: "s2") @@ -290,7 +290,7 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) end test "session init does not overwrite server global client_capabilities" do @@ -375,7 +375,18 @@ def close; end max_tokens: 100, ) end - assert_equal("No active SSE stream for sampling/createMessage request.", error.message) + assert_equal("No active stream for sampling/createMessage request.", error.message) + end + + test "Server#create_sampling_message accepts related_request_id without error" do + @server.create_sampling_message( + messages: [{ role: "user", content: { type: "text", text: "Hello" } }], + max_tokens: 100, + related_request_id: "req-1", + ) + + request = @mock_transport.requests.first + assert_equal "sampling/createMessage", request[:method] end test "create_sampling_message omits nil optional params" do