Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ kotlin {
implementation(libs.ktor.server.cio)
implementation(libs.ktor.server.websockets)
implementation(libs.ktor.server.test.host)
implementation(libs.ktor.server.content.negotiation)
implementation(libs.ktor.serialization)
}
}
jvmTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin
import io.ktor.client.HttpClient
import io.ktor.client.engine.cio.CIO
import io.ktor.client.plugins.sse.SSE
import io.ktor.serialization.kotlinx.json.json
import io.ktor.server.application.install
import io.ktor.server.engine.EmbeddedServer
import io.ktor.server.engine.embeddedServer
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
import io.ktor.server.routing.delete
import io.ktor.server.routing.get
import io.ktor.server.routing.post
import io.ktor.server.routing.route
import io.ktor.server.routing.routing
import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport
import io.modelcontextprotocol.kotlin.sdk.server.mcp
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.test.utils.Retry
import kotlinx.coroutines.runBlocking
Expand Down Expand Up @@ -44,21 +53,24 @@ abstract class KotlinTestBase {
protected lateinit var serverEngine: EmbeddedServer<*, *>

// Transport selection
protected enum class TransportKind { SSE, STDIO }
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP }
protected open val transportKind: TransportKind = TransportKind.STDIO

// STDIO-specific fields
private var stdioServerTransport: StdioServerTransport? = null
private var stdioClientInput: Source? = null
private var stdioClientOutput: Sink? = null

// StreamableHTTP-specific fields
private var streamableHttpServerTransport: StreamableHttpServerTransport? = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads-up: this should probably be in the streamable test subclass, but since we already have transport-specific fields here, let’s tackle it separately.


protected abstract fun configureServerCapabilities(): ServerCapabilities
protected abstract fun configureServer()

@BeforeEach
fun setUp() {
setupServer()
if (transportKind == TransportKind.SSE) {
if (transportKind == TransportKind.SSE || transportKind == TransportKind.STREAMABLE_HTTP) {
await
.ignoreExceptions()
.until {
Expand Down Expand Up @@ -98,6 +110,19 @@ abstract class KotlinTestBase {
)
client.connect(transport)
}

TransportKind.STREAMABLE_HTTP -> {
val transport = StreamableHttpClientTransport(
client = HttpClient(CIO) {
install(SSE)
},
url = "http://$host:$port/mcp",
)
client = Client(
Implementation("test", "1.0"),
)
client.connect(transport)
}
}
}

Expand All @@ -111,35 +136,77 @@ abstract class KotlinTestBase {

configureServer()

if (transportKind == TransportKind.SSE) {
serverEngine = embeddedServer(ServerCIO, host = host, port = port) {
install(ServerSSE)
routing {
mcp { server }
when (transportKind) {
TransportKind.SSE -> {
serverEngine = embeddedServer(ServerCIO, host = host, port = port) {
install(ServerSSE)
routing {
mcp { server }
}
}.start(wait = false)
}

TransportKind.STREAMABLE_HTTP -> {
// Create StreamableHTTP server transport
// Using JSON response mode for simpler testing (no SSE session required)
val transport = StreamableHttpServerTransport(
enableJsonResponse = true, // Use JSON response mode for testing
)
// Use stateless mode to skip session validation for simpler testing
transport.setSessionIdGenerator(null)
streamableHttpServerTransport = transport

// IMPORTANT: Create server session BEFORE starting the HTTP server
// This ensures message handlers are set up before any requests come in
runBlocking {
server.createSession(transport)
}

// Start embedded server with routing for StreamableHTTP
serverEngine = embeddedServer(ServerCIO, host = host, port = port) {
// Install ContentNegotiation for JSON serialization
install(ContentNegotiation) {
json(McpJson)
}
routing {
route("/mcp") {
post {
transport.handlePostRequest(null, call)
}
get {
transport.handleGetRequest(null, call)
}
delete {
transport.handleDeleteRequest(call)
}
}
}
}.start(wait = false)
}

TransportKind.STDIO -> {
// Create in-memory stdio pipes: client->server and server->client
val clientToServerOut = PipedOutputStream()
val clientToServerIn = PipedInputStream(clientToServerOut)

val serverToClientOut = PipedOutputStream()
val serverToClientIn = PipedInputStream(serverToClientOut)

// Server transport reads from client and writes to client
val serverTransport = StdioServerTransport(
inputStream = clientToServerIn.asSource().buffered(),
outputStream = serverToClientOut.asSink().buffered(),
)
stdioServerTransport = serverTransport

// Prepare client-side streams for later client initialization
stdioClientInput = serverToClientIn.asSource().buffered()
stdioClientOutput = clientToServerOut.asSink().buffered()

// Start server transport by connecting the server
runBlocking {
server.createSession(serverTransport)
}
}.start(wait = false)
} else {
// Create in-memory stdio pipes: client->server and server->client
val clientToServerOut = PipedOutputStream()
val clientToServerIn = PipedInputStream(clientToServerOut)

val serverToClientOut = PipedOutputStream()
val serverToClientIn = PipedInputStream(serverToClientOut)

// Server transport reads from client and writes to client
val serverTransport = StdioServerTransport(
inputStream = clientToServerIn.asSource().buffered(),
outputStream = serverToClientOut.asSink().buffered(),
)
stdioServerTransport = serverTransport

// Prepare client-side streams for later client initialization
stdioClientInput = serverToClientIn.asSource().buffered()
stdioClientOutput = clientToServerOut.asSink().buffered()

// Start server transport by connecting the server
runBlocking {
server.createSession(serverTransport)
}
}
}
Expand All @@ -160,24 +227,39 @@ abstract class KotlinTestBase {
}

// stop server
if (transportKind == TransportKind.SSE) {
if (::serverEngine.isInitialized) {
try {
serverEngine.stop(500, 1000)
} catch (e: Exception) {
println("Warning: Error during server stop: ${e.message}")
when (transportKind) {
TransportKind.SSE, TransportKind.STREAMABLE_HTTP -> {
if (::serverEngine.isInitialized) {
try {
serverEngine.stop(500, 1000)
} catch (e: Exception) {
println("Warning: Error during server stop: ${e.message}")
}
}
if (transportKind == TransportKind.STREAMABLE_HTTP) {
streamableHttpServerTransport?.let {
try {
runBlocking { it.close() }
} catch (e: Exception) {
println("Warning: Error during streamable http server stop: ${e.message}")
} finally {
streamableHttpServerTransport = null
}
}
}
}
} else {
stdioServerTransport?.let {
try {
runBlocking { it.close() }
} catch (e: Exception) {
println("Warning: Error during stdio server stop: ${e.message}")
} finally {
stdioServerTransport = null
stdioClientInput = null
stdioClientOutput = null

TransportKind.STDIO -> {
stdioServerTransport?.let {
try {
runBlocking { it.close() }
} catch (e: Exception) {
println("Warning: Error during stdio server stop: ${e.message}")
} finally {
stdioServerTransport = null
stdioClientInput = null
stdioClientOutput = null
}
}
}
}
Expand Down
Loading
Loading