diff --git a/build.gradle.kts b/build.gradle.kts index 4fe64dd1a..72c55539b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -24,7 +24,7 @@ subprojects { apply(plugin = "org.jlleitschuh.gradle.ktlint") apply(plugin = "org.jetbrains.kotlinx.kover") - if (name != "conformance-test") { + if (name != "conformance-test" && name != "docs") { apply(plugin = "dev.detekt") detekt { diff --git a/conformance-test/README.md b/conformance-test/README.md index 0f08d751f..58d0d1a3a 100644 --- a/conformance-test/README.md +++ b/conformance-test/README.md @@ -110,7 +110,7 @@ Tests the conformance server against all server scenarios: ## Known SDK Limitations -9 scenarios are expected to fail due to current SDK limitations (tracked in [ +8 scenarios are expected to fail due to current SDK limitations (tracked in [ `conformance-baseline.yml`](conformance-baseline.yml). | Scenario | Suite | Root Cause | @@ -123,6 +123,5 @@ Tests the conformance server against all server scenarios: | `elicitation-sep1330-enums` | server | *(same as above)* | | `resources-templates-read` | server | SDK does not implement `addResourceTemplate()` with URI pattern matching; resources are looked up by exact URI | | `elicitation-sep1034-client-defaults` | client | SDK does not fill in `default` values from the elicitation request schema before sending the response | -| `sse-retry` | client | Transport does not respect the SSE `retry` field timing or send `Last-Event-ID` on reconnection | These failures reveal SDK gaps and are intentionally not fixed in this module. diff --git a/conformance-test/conformance-baseline.yml b/conformance-test/conformance-baseline.yml index 9126f0d34..cc06a389b 100644 --- a/conformance-test/conformance-baseline.yml +++ b/conformance-test/conformance-baseline.yml @@ -11,4 +11,3 @@ server: client: - elicitation-sep1034-client-defaults - - sse-retry diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index 268a2b8df..da0e752a7 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -62,6 +62,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/KtorClientKt { public static synthetic fun mcpSseTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/SseClientTransport; } +public final class io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions { + public synthetic fun (JJDIILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (JJDILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun equals (Ljava/lang/Object;)Z + public final fun getInitialReconnectionDelay-UwyO8pc ()J + public final fun getMaxReconnectionDelay-UwyO8pc ()J + public final fun getMaxRetries ()I + public final fun getReconnectionDelayMultiplier ()D + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + public final class io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport { public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -88,6 +100,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTranspor } public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport { + public fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getProtocolVersion ()Ljava/lang/String; @@ -106,8 +120,12 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpError } public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensionsKt { + public static final fun mcpStreamableHttp (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public static final fun mcpStreamableHttp-BZiP2OM (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun mcpStreamableHttp-BZiP2OM$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static final fun mcpStreamableHttpTransport (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; + public static synthetic fun mcpStreamableHttpTransport$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; public static final fun mcpStreamableHttpTransport-5_5nbZA (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; public static synthetic fun mcpStreamableHttpTransport-5_5nbZA$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport; } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt new file mode 100644 index 000000000..95c5bdefa --- /dev/null +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions.kt @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +/** + * Options for controlling SSE reconnection behavior. + * + * @property initialReconnectionDelay The initial delay before the first reconnection attempt. + * @property maxReconnectionDelay The maximum delay between reconnection attempts. + * @property reconnectionDelayMultiplier The factor by which the delay grows on each attempt. + * @property maxRetries The maximum number of reconnection attempts per disconnect. + */ +public class ReconnectionOptions( + public val initialReconnectionDelay: Duration = 1.seconds, + public val maxReconnectionDelay: Duration = 30.seconds, + public val reconnectionDelayMultiplier: Double = 1.5, + public val maxRetries: Int = 2, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as ReconnectionOptions + + if (reconnectionDelayMultiplier != other.reconnectionDelayMultiplier) return false + if (maxRetries != other.maxRetries) return false + if (initialReconnectionDelay != other.initialReconnectionDelay) return false + if (maxReconnectionDelay != other.maxReconnectionDelay) return false + + return true + } + + override fun hashCode(): Int { + var result = reconnectionDelayMultiplier.hashCode() + result = 31 * result + maxRetries + result = 31 * result + initialReconnectionDelay.hashCode() + result = 31 * result + maxReconnectionDelay.hashCode() + return result + } + + override fun toString(): String = + "ReconnectionOptions(initialReconnectionDelay=$initialReconnectionDelay, maxReconnectionDelay=$maxReconnectionDelay, reconnectionDelayMultiplier=$reconnectionDelayMultiplier, maxRetries=$maxRetries)" +} diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 8526042cf..80f37aedc 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -38,9 +38,13 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.math.pow import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds private const val MCP_SESSION_ID_HEADER = "mcp-session-id" private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" @@ -52,31 +56,58 @@ private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" public class StreamableHttpError(public val code: Int? = null, message: String? = null) : Exception("Streamable HTTP error: $message") +private sealed interface ConnectResult { + data class Success(val session: ClientSSESession) : ConnectResult + data object NonRetryable : ConnectResult + data object Failed : ConnectResult +} + /** * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events * for receiving messages. */ -@OptIn(ExperimentalAtomicApi::class) +@Suppress("TooManyFunctions") public class StreamableHttpClientTransport( private val client: HttpClient, private val url: String, - private val reconnectionTime: Duration? = null, + private val reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractClientTransport() { + @Deprecated( + "Use constructor with ReconnectionOptions", + replaceWith = ReplaceWith( + "StreamableHttpClientTransport(client, url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + "kotlin.time.Duration.Companion.seconds", + "io.modelcontextprotocol.kotlin.sdk.client.ReconnectionOptions", + ), + ) + public constructor( + client: HttpClient, + url: String, + reconnectionTime: Duration?, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, + ) : this(client, url, ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder) + override val logger: KLogger = KotlinLogging.logger {} public var sessionId: String? = null private set public var protocolVersion: String? = null - private var sseSession: ClientSSESession? = null private var sseJob: Job? = null private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) } - private var lastEventId: String? = null + /** Result of an SSE stream collection. Reconnect when [hasPrimingEvent] is true and [receivedResponse] is false. */ + private data class SseStreamResult( + val hasPrimingEvent: Boolean, + val receivedResponse: Boolean, + val lastEventId: String? = null, + val serverRetryDelay: Duration? = null, + ) override suspend fun initialize() { logger.debug { "Client transport is starting..." } @@ -85,7 +116,7 @@ public class StreamableHttpClientTransport( /** * Sends a single message with optional resumption support */ - @Suppress("ReturnCount", "CyclomaticComplexMethod") + @Suppress("ReturnCount", "CyclomaticComplexMethod", "LongMethod", "TooGenericExceptionCaught", "ThrowsCount") override suspend fun performSend(message: JSONRPCMessage, options: TransportSendOptions?) { logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" } @@ -133,18 +164,25 @@ public class StreamableHttpClientTransport( } } - ContentType.Text.EventStream -> handleInlineSse( - response, - onResumptionToken = options?.onResumptionToken, - replayMessageId = if (message is JSONRPCRequest) message.id else null, - ) + ContentType.Text.EventStream -> { + val replayMessageId = if (message is JSONRPCRequest) message.id else null + val result = handleInlineSse(response, replayMessageId, options?.onResumptionToken) + if (result.hasPrimingEvent && !result.receivedResponse) { + startSseSession( + resumptionToken = result.lastEventId, + replayMessageId = replayMessageId, + onResumptionToken = options?.onResumptionToken, + initialServerRetryDelay = result.serverRetryDelay, + ) + } + } else -> { val body = response.bodyAsText() if (response.contentType() == null && body.isBlank()) return val ct = response.contentType()?.toString() ?: "" - val error = StreamableHttpError(-1, "Unexpected content type: $$ct") + val error = StreamableHttpError(-1, "Unexpected content type: $ct") _onError(error) throw error } @@ -169,11 +207,6 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } - - // Try to terminate session if we have one - terminateSession() - - sseSession?.cancel() sseJob?.cancelAndJoin() scope.cancel() } @@ -201,55 +234,120 @@ public class StreamableHttpClientTransport( } sessionId = null - lastEventId = null logger.debug { "Session terminated successfully" } } - private suspend fun startSseSession( + private fun startSseSession( resumptionToken: String? = null, replayMessageId: RequestId? = null, onResumptionToken: ((String) -> Unit)? = null, + initialServerRetryDelay: Duration? = null, ) { - sseSession?.cancel() - sseJob?.cancelAndJoin() + // Cancel-and-replace: cancel() signals the previous job, join() inside + // the new coroutine ensures it completes before we start collecting. + // This is intentionally non-suspend to avoid blocking performSend. + val previousJob = sseJob + previousJob?.cancel() + sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { + previousJob?.join() + var lastEventId = resumptionToken + var serverRetryDelay = initialServerRetryDelay + var attempt = 0 + var needsDelay = initialServerRetryDelay != null + + @Suppress("LoopWithTooManyJumpStatements") + while (isActive) { + // Delay before (re)connection: skip only for first fresh SSE connection + if (needsDelay) { + delay(getNextReconnectionDelay(attempt, serverRetryDelay)) + } + needsDelay = true + + // Connect + val session = when (val cr = connectSse(lastEventId)) { + is ConnectResult.Success -> { + attempt = 0 + cr.session + } + + ConnectResult.NonRetryable -> return@launch + + ConnectResult.Failed -> { + // Give up after maxRetries consecutive failed connection attempts + if (++attempt >= reconnectionOptions.maxRetries) { + _onError(StreamableHttpError(null, "Maximum reconnection attempts exceeded")) + return@launch + } + continue + } + } + // Collect + val result = collectSse(session, replayMessageId, onResumptionToken) + lastEventId = result.lastEventId ?: lastEventId + serverRetryDelay = result.serverRetryDelay ?: serverRetryDelay + if (result.receivedResponse) break + } + } + } + + @Suppress("TooGenericExceptionCaught") + private suspend fun connectSse(lastEventId: String?): ConnectResult { logger.debug { "Client attempting to start SSE session at url: $url" } - try { - sseSession = client.sseSession( - urlString = url, - reconnectionTime = reconnectionTime, - ) { + return try { + val session = client.sseSession(urlString = url, showRetryEvents = true) { method = HttpMethod.Get applyCommonHeaders(this) - // sseSession will add ContentType.Text.EventStream automatically accept(ContentType.Application.Json) - (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } + lastEventId?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) } requestBuilder() } logger.debug { "Client SSE session started successfully." } + ConnectResult.Success(session) + } catch (e: CancellationException) { + throw e } catch (e: SSEClientException) { - val responseStatus = e.response?.status - val responseContentType = e.response?.contentType() + if (isNonRetryableSseError(e)) { + ConnectResult.NonRetryable + } else { + logger.debug { "SSE connection failed: ${e.message}" } + ConnectResult.Failed + } + } catch (e: Exception) { + logger.debug { "SSE connection failed: ${e.message}" } + ConnectResult.Failed + } + } + + private fun getNextReconnectionDelay(attempt: Int, serverRetryDelay: Duration?): Duration { + // Per SSE specification, the server-sent `retry` field sets the reconnection time + // for all subsequent attempts, taking priority over exponential backoff. + serverRetryDelay?.let { return it } + val delay = reconnectionOptions.initialReconnectionDelay * + reconnectionOptions.reconnectionDelayMultiplier.pow(attempt) + return delay.coerceAtMost(reconnectionOptions.maxReconnectionDelay) + } + + /** + * Checks if an SSE session error is non-retryable (404, 405, JSON-only). + * Returns `true` if non-retryable (should stop trying), `false` otherwise. + */ + private fun isNonRetryableSseError(e: SSEClientException): Boolean { + val responseStatus = e.response?.status + val responseContentType = e.response?.contentType() - // 404 or 405 means server doesn't support SSE at GET endpoint - this is expected and valid - if (responseStatus == HttpStatusCode.NotFound || responseStatus == HttpStatusCode.MethodNotAllowed) { + return when { + responseStatus == HttpStatusCode.NotFound || responseStatus == HttpStatusCode.MethodNotAllowed -> { logger.info { "Server returned ${responseStatus.value} for GET/SSE, stream disabled." } - return + true } - // If server returns application/json, it means it doesn't support SSE for this session - // This is valid per spec - server can choose to only use JSON responses - if (responseContentType?.match(ContentType.Application.Json) == true) { + responseContentType?.match(ContentType.Application.Json) == true -> { logger.info { "Server returned application/json for GET/SSE, using JSON-only mode." } - return + true } - _onError(e) - throw e - } - - sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) { - sseSession?.let { collectSse(it, replayMessageId, onResumptionToken) } + else -> false } } @@ -265,11 +363,17 @@ public class StreamableHttpClientTransport( session: ClientSSESession, replayMessageId: RequestId?, onResumptionToken: ((String) -> Unit)?, - ) { + ): SseStreamResult { + var hasPrimingEvent = false + var receivedResponse = false + var localLastEventId: String? = null + var localServerRetryDelay: Duration? = null try { session.incoming.collect { event -> + event.retry?.let { localServerRetryDelay = it.milliseconds } event.id?.let { - lastEventId = it + localLastEventId = it + hasPrimingEvent = true onResumptionToken?.invoke(it) } logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}, id=${event.id}" } @@ -278,6 +382,7 @@ public class StreamableHttpClientTransport( event.data?.takeIf { it.isNotEmpty() }?.let { json -> runCatching { McpJson.decodeFromString(json) } .onSuccess { msg -> + if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { _onMessage(msg.copy(id = replayMessageId)) } else { @@ -295,6 +400,7 @@ public class StreamableHttpClientTransport( } catch (t: Throwable) { _onError(t) } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } @Suppress("CyclomaticComplexMethod") @@ -302,17 +408,22 @@ public class StreamableHttpClientTransport( response: HttpResponse, replayMessageId: RequestId?, onResumptionToken: ((String) -> Unit)?, - ) { + ): SseStreamResult { logger.trace { "Handling inline SSE from POST response" } val channel = response.bodyAsChannel() + var hasPrimingEvent = false + var receivedResponse = false + var localLastEventId: String? = null + var localServerRetryDelay: Duration? = null val sb = StringBuilder() var id: String? = null var eventName: String? = null suspend fun dispatch(id: String?, eventName: String?, data: String) { id?.let { - lastEventId = it + localLastEventId = it + hasPrimingEvent = true onResumptionToken?.invoke(it) } if (data.isBlank()) { @@ -321,6 +432,7 @@ public class StreamableHttpClientTransport( if (eventName == null || eventName == "message") { runCatching { McpJson.decodeFromString(data) } .onSuccess { msg -> + if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { _onMessage(msg.copy(id = replayMessageId)) } else { @@ -351,9 +463,16 @@ public class StreamableHttpClientTransport( } when { line.startsWith("id:") -> id = line.substringAfter("id:").trim() + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim()) + + line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let { + localServerRetryDelay = it.milliseconds + } } } + return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt index b64a22062..a618f1823 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpMcpKtorClientExtensions.kt @@ -6,21 +6,67 @@ import io.modelcontextprotocol.kotlin.sdk.LIB_VERSION import io.modelcontextprotocol.kotlin.sdk.shared.IMPLEMENTATION_NAME import io.modelcontextprotocol.kotlin.sdk.types.Implementation import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds /** * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. * * @param url URL of the MCP server. - * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param reconnectionOptions Options for controlling SSE reconnection behavior. * @param requestBuilder Optional lambda to configure the HTTP request. * @return A [StreamableHttpClientTransport] configured for MCP communication. */ public fun HttpClient.mcpStreamableHttpTransport( url: String, - reconnectionTime: Duration? = null, + reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), requestBuilder: HttpRequestBuilder.() -> Unit = {}, ): StreamableHttpClientTransport = - StreamableHttpClientTransport(this, url, reconnectionTime, requestBuilder = requestBuilder) + StreamableHttpClientTransport(this, url, reconnectionOptions, requestBuilder = requestBuilder) + +/** + * Returns a new Streamable HTTP transport for the Model Context Protocol using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionTime Optional duration to wait before attempting to reconnect. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A [StreamableHttpClientTransport] configured for MCP communication. + */ +@Deprecated( + "Use overload with ReconnectionOptions", + replaceWith = ReplaceWith( + "mcpStreamableHttpTransport(url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + ), +) +public fun HttpClient.mcpStreamableHttpTransport( + url: String, + reconnectionTime: Duration?, + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): StreamableHttpClientTransport = StreamableHttpClientTransport( + this, + url, + ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), + requestBuilder = requestBuilder, +) + +/** + * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. + * + * @param url URL of the MCP server. + * @param reconnectionOptions Options for controlling SSE reconnection behavior. + * @param requestBuilder Optional lambda to configure the HTTP request. + * @return A connected [Client] ready for MCP communication. + */ +public suspend fun HttpClient.mcpStreamableHttp( + url: String, + reconnectionOptions: ReconnectionOptions = ReconnectionOptions(), + requestBuilder: HttpRequestBuilder.() -> Unit = {}, +): Client { + val transport = mcpStreamableHttpTransport(url, reconnectionOptions, requestBuilder) + val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) + client.connect(transport) + return client +} /** * Creates and connects an MCP client over Streamable HTTP using the provided HttpClient. @@ -30,12 +76,23 @@ public fun HttpClient.mcpStreamableHttpTransport( * @param requestBuilder Optional lambda to configure the HTTP request. * @return A connected [Client] ready for MCP communication. */ +@Deprecated( + "Use overload with ReconnectionOptions", + replaceWith = ReplaceWith( + "mcpStreamableHttp(url, " + + "ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)", + ), +) public suspend fun HttpClient.mcpStreamableHttp( url: String, - reconnectionTime: Duration? = null, + reconnectionTime: Duration?, requestBuilder: HttpRequestBuilder.() -> Unit = {}, ): Client { - val transport = mcpStreamableHttpTransport(url, reconnectionTime, requestBuilder) + val transport = mcpStreamableHttpTransport( + url, + ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), + requestBuilder, + ) val client = Client(Implementation(name = IMPLEMENTATION_NAME, version = LIB_VERSION)) client.connect(transport) return client diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt index e303326ff..55df3e06e 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt @@ -23,9 +23,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay @@ -657,6 +659,203 @@ class StreamableHttpClientTransportTest { receivedErrors shouldHaveSize 0 } + @Test + fun testInlineSseRetryParsing() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + appendLine("retry: 5000") + appendLine("id: ev-1") + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","id":"req-1","result":{"tools":[]}}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val responseReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "req-1", + method = "test", + params = buildJsonObject { }, + ), + ) + + eventually { + responseReceived.await() + } + + receivedMessages shouldHaveSize 1 + val response = receivedMessages[0] as JSONRPCResponse + response.id shouldBe RequestId.StringId("req-1") + + transport.close() + } + + @Test + fun testInlineSseHasPrimingEventTracking() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + // Event with id = priming event + appendLine("id: priming-1") + appendLine("event: message") + appendLine( + """data: {"jsonrpc":"2.0","method":"notifications/progress",""" + + """"params":{"progressToken":"t1","progress":50}}""", + ) + appendLine() + // Notification without id + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","method":"notifications/tools/list_changed"}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val twoMessagesReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { + twoMessagesReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "test-1", + method = "test", + params = buildJsonObject { }, + ), + ) + + eventually { + twoMessagesReceived.await() + } + + receivedMessages shouldHaveSize 2 + // Both should be notifications (no JSONRPCResponse → POST-to-GET reconnect would be triggered) + receivedMessages[0].shouldBeInstanceOf() + receivedMessages[1].shouldBeInstanceOf() + + transport.close() + } + + @Test + fun testInlineSseResponseStopsReconnection() = runTest { + val transport = createTransport { request -> + if (request.method == HttpMethod.Post) { + val sseContent = buildString { + appendLine("id: ev-1") + appendLine("event: message") + appendLine("""data: {"jsonrpc":"2.0","id":"req-1","result":{"tools":[]}}""") + appendLine() + } + + respond( + content = ByteReadChannel(sseContent), + status = HttpStatusCode.OK, + headers = headersOf( + HttpHeaders.ContentType, + ContentType.Text.EventStream.toString(), + ), + ) + } else { + respond("", HttpStatusCode.OK) + } + } + + val receivedMessages = mutableListOf() + val responseReceived = CompletableDeferred() + + transport.onMessage { message -> + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } + } + + transport.start() + + transport.send( + JSONRPCRequest( + id = "req-1", + method = "tools/list", + params = buildJsonObject { }, + ), + ) + + eventually { + responseReceived.await() + } + + receivedMessages shouldHaveSize 1 + // Response received → no reconnection triggered (hasPrimingEvent=true, receivedResponse=true) + val response = receivedMessages[0] as JSONRPCResponse + response.id shouldBe RequestId.StringId("req-1") + + transport.close() + } + + @Suppress("DEPRECATION") + @Test + fun testDeprecatedConstructorStillWorks() = runTest { + val mockEngine = MockEngine { _ -> + respond( + content = "", + status = HttpStatusCode.Accepted, + ) + } + val httpClient = HttpClient(mockEngine) { + install(SSE) + } + + val transport = + StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp", reconnectionTime = 2.seconds) + + transport.start() + transport.send(JSONRPCNotification(method = "test")) + transport.close() + } + private suspend fun setupTransportAndCollectMessages( transport: StreamableHttpClientTransport, ): Pair, MutableList> { diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt index 6513b469b..a1fe89e66 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.client import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.shouldBe import io.ktor.http.ContentType import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode @@ -203,7 +204,31 @@ internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { meta = EmptyJsonObject, ) + client.close() + } + + @Test + fun `terminateSession sends DELETE request`() = runBlocking { + val client = Client( + clientInfo = Implementation(name = "client1", version = "1.0.0"), + options = ClientOptions(capabilities = ClientCapabilities()), + ) + val sessionId = Uuid.random().toString() + + mockMcp.onInitialize(clientName = "client1", sessionId = sessionId) + mockMcp.handleJSONRPCRequest( + jsonRpcMethod = "notifications/initialized", + expectedSessionId = sessionId, + sessionId = sessionId, + statusCode = HttpStatusCode.Accepted, + ) + mockMcp.handleSubscribeWithGet(sessionId) { emptyFlow() } + + connect(client) + mockMcp.mockUnsubscribeRequest(sessionId = sessionId) + (client.transport as StreamableHttpClientTransport).terminateSession() + (client.transport as StreamableHttpClientTransport).sessionId shouldBe null client.close() } @@ -257,8 +282,6 @@ internal class StreamableHttpClientTest : AbstractStreamableHttpClientTest() { buildJsonObject {} } - mockMcp.mockUnsubscribeRequest(sessionId = sessionId) - connect(client) delay(1.seconds) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 580fcc0e5..9074c4c26 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -45,6 +45,7 @@ internal const val MCP_SESSION_ID_HEADER = "mcp-session-id" private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB +private const val MIN_PRIMING_EVENT_PROTOCOL_VERSION = "2025-11-25" /** * A holder for an active request call. @@ -388,7 +389,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat if (!configuration.enableJsonResponse) { call.appendSseHeaders() flushSse(session) // flush headers immediately - maybeSendPrimingEvent(streamId, session) + maybeSendPrimingEvent(streamId, session, call.request.header(MCP_PROTOCOL_VERSION_HEADER)) } streamMutex.withLock { @@ -451,7 +452,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat call.appendSseHeaders() flushSse(sseSession) // flush headers immediately streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call) - maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession) + maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession, call.request.header(MCP_PROTOCOL_VERSION_HEADER)) sseSession.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(STANDALONE_SSE_STREAM_ID) } @@ -702,12 +703,20 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } @Suppress("TooGenericExceptionCaught") - private suspend fun maybeSendPrimingEvent(streamId: String, session: ServerSSESession?) { - val store = configuration.eventStore ?: return - val sseSession = session ?: return + private suspend fun maybeSendPrimingEvent( + streamId: String, + session: ServerSSESession?, + clientProtocolVersion: String? = null, + ) { + val store = configuration.eventStore + if (store == null || session == null) return + // Priming events have empty data which older clients cannot handle. + // Only send priming events to clients with protocol version >= 2025-11-25 + // which includes the fix for handling empty SSE data. + if (clientProtocolVersion != null && clientProtocolVersion < MIN_PRIMING_EVENT_PROTOCOL_VERSION) return try { val primingEventId = store.storeEvent(streamId, JSONRPCEmptyMessage) - sseSession.send( + session.send( id = primingEventId, retry = configuration.retryInterval?.inWholeMilliseconds, data = "",