diff --git a/integration-test/build.gradle.kts b/integration-test/build.gradle.kts index d5b8ef81..1b763edc 100644 --- a/integration-test/build.gradle.kts +++ b/integration-test/build.gradle.kts @@ -21,6 +21,8 @@ kotlin { implementation(libs.ktor.server.cio) implementation(libs.ktor.server.websockets) implementation(libs.ktor.server.test.host) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.serialization) } } jvmTest { diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index 43b11176..1a8ffe0d 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -3,18 +3,27 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.sse.SSE +import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.install import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.routing.delete +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport import io.modelcontextprotocol.kotlin.sdk.server.mcp import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.test.utils.Retry import kotlinx.coroutines.runBlocking @@ -44,7 +53,7 @@ abstract class KotlinTestBase { protected lateinit var serverEngine: EmbeddedServer<*, *> // Transport selection - protected enum class TransportKind { SSE, STDIO } + protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP } protected open val transportKind: TransportKind = TransportKind.STDIO // STDIO-specific fields @@ -52,13 +61,16 @@ abstract class KotlinTestBase { private var stdioClientInput: Source? = null private var stdioClientOutput: Sink? = null + // StreamableHTTP-specific fields + private var streamableHttpServerTransport: StreamableHttpServerTransport? = null + protected abstract fun configureServerCapabilities(): ServerCapabilities protected abstract fun configureServer() @BeforeEach fun setUp() { setupServer() - if (transportKind == TransportKind.SSE) { + if (transportKind == TransportKind.SSE || transportKind == TransportKind.STREAMABLE_HTTP) { await .ignoreExceptions() .until { @@ -98,6 +110,19 @@ abstract class KotlinTestBase { ) client.connect(transport) } + + TransportKind.STREAMABLE_HTTP -> { + val transport = StreamableHttpClientTransport( + client = HttpClient(CIO) { + install(SSE) + }, + url = "http://$host:$port/mcp", + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } } } @@ -111,35 +136,77 @@ abstract class KotlinTestBase { configureServer() - if (transportKind == TransportKind.SSE) { - serverEngine = embeddedServer(ServerCIO, host = host, port = port) { - install(ServerSSE) - routing { - mcp { server } + when (transportKind) { + TransportKind.SSE -> { + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + install(ServerSSE) + routing { + mcp { server } + } + }.start(wait = false) + } + + TransportKind.STREAMABLE_HTTP -> { + // Create StreamableHTTP server transport + // Using JSON response mode for simpler testing (no SSE session required) + val transport = StreamableHttpServerTransport( + enableJsonResponse = true, // Use JSON response mode for testing + ) + // Use stateless mode to skip session validation for simpler testing + transport.setSessionIdGenerator(null) + streamableHttpServerTransport = transport + + // IMPORTANT: Create server session BEFORE starting the HTTP server + // This ensures message handlers are set up before any requests come in + runBlocking { + server.createSession(transport) + } + + // Start embedded server with routing for StreamableHTTP + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + // Install ContentNegotiation for JSON serialization + install(ContentNegotiation) { + json(McpJson) + } + routing { + route("/mcp") { + post { + transport.handlePostRequest(null, call) + } + get { + transport.handleGetRequest(null, call) + } + delete { + transport.handleDeleteRequest(call) + } + } + } + }.start(wait = false) + } + + TransportKind.STDIO -> { + // Create in-memory stdio pipes: client->server and server->client + val clientToServerOut = PipedOutputStream() + val clientToServerIn = PipedInputStream(clientToServerOut) + + val serverToClientOut = PipedOutputStream() + val serverToClientIn = PipedInputStream(serverToClientOut) + + // Server transport reads from client and writes to client + val serverTransport = StdioServerTransport( + inputStream = clientToServerIn.asSource().buffered(), + outputStream = serverToClientOut.asSink().buffered(), + ) + stdioServerTransport = serverTransport + + // Prepare client-side streams for later client initialization + stdioClientInput = serverToClientIn.asSource().buffered() + stdioClientOutput = clientToServerOut.asSink().buffered() + + // Start server transport by connecting the server + runBlocking { + server.createSession(serverTransport) } - }.start(wait = false) - } else { - // Create in-memory stdio pipes: client->server and server->client - val clientToServerOut = PipedOutputStream() - val clientToServerIn = PipedInputStream(clientToServerOut) - - val serverToClientOut = PipedOutputStream() - val serverToClientIn = PipedInputStream(serverToClientOut) - - // Server transport reads from client and writes to client - val serverTransport = StdioServerTransport( - inputStream = clientToServerIn.asSource().buffered(), - outputStream = serverToClientOut.asSink().buffered(), - ) - stdioServerTransport = serverTransport - - // Prepare client-side streams for later client initialization - stdioClientInput = serverToClientIn.asSource().buffered() - stdioClientOutput = clientToServerOut.asSink().buffered() - - // Start server transport by connecting the server - runBlocking { - server.createSession(serverTransport) } } } @@ -160,24 +227,39 @@ abstract class KotlinTestBase { } // stop server - if (transportKind == TransportKind.SSE) { - if (::serverEngine.isInitialized) { - try { - serverEngine.stop(500, 1000) - } catch (e: Exception) { - println("Warning: Error during server stop: ${e.message}") + when (transportKind) { + TransportKind.SSE, TransportKind.STREAMABLE_HTTP -> { + if (::serverEngine.isInitialized) { + try { + serverEngine.stop(500, 1000) + } catch (e: Exception) { + println("Warning: Error during server stop: ${e.message}") + } + } + if (transportKind == TransportKind.STREAMABLE_HTTP) { + streamableHttpServerTransport?.let { + try { + runBlocking { it.close() } + } catch (e: Exception) { + println("Warning: Error during streamable http server stop: ${e.message}") + } finally { + streamableHttpServerTransport = null + } + } } } - } else { - stdioServerTransport?.let { - try { - runBlocking { it.close() } - } catch (e: Exception) { - println("Warning: Error during stdio server stop: ${e.message}") - } finally { - stdioServerTransport = null - stdioClientInput = null - stdioClientOutput = null + + TransportKind.STDIO -> { + stdioServerTransport?.let { + try { + runBlocking { it.close() } + } catch (e: Exception) { + println("Warning: Error during stdio server stop: ${e.message}") + } finally { + stdioServerTransport = null + stdioClientInput = null + stdioClientOutput = null + } } } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractPromptSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractPromptSecurityIntegrationTest.kt new file mode 100644 index 00000000..e953ea28 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractPromptSecurityIntegrationTest.kt @@ -0,0 +1,248 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +import io.kotest.assertions.withClue +import io.kotest.matchers.shouldBe +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase +import io.modelcontextprotocol.kotlin.sdk.integration.utils.AuthorizationRules +import io.modelcontextprotocol.kotlin.sdk.integration.utils.MockAuthorizationWrapper +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractPromptSecurityIntegrationTest : KotlinTestBase() { + + private val publicPromptName = "public-prompt" + private val secretPromptName = "secret-prompt" + private val restrictedPromptName = "restricted-prompt" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + configureServerWithAuthorization( + allowedPrompts = setOf(publicPromptName, restrictedPromptName), + ) + } + + protected fun configureServerWithAuthorization( + allowedPrompts: Set? = null, + deniedPrompts: Set? = null, + ) { + val authWrapper = MockAuthorizationWrapper( + AuthorizationRules( + allowedPrompts = allowedPrompts, + deniedPrompts = deniedPrompts, + ), + ) + + server.addPrompt( + name = publicPromptName, + description = "A public prompt that authorized users can access", + arguments = listOf( + PromptArgument( + name = "query", + description = "The query to process", + required = false, + ), + ), + ) { request -> + if (!authWrapper.isAllowed("prompts", "get", mapOf("name" to publicPromptName))) { + throw authWrapper.createDeniedError("Access denied to prompt: $publicPromptName") + } + + val query = request.params.arguments?.get("query") ?: "default query" + + GetPromptResult( + description = "Public prompt response", + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "Query: $query"), + ), + ), + ) + } + + server.addPrompt( + name = secretPromptName, + description = "A secret prompt that requires special permissions", + arguments = listOf(), + ) { request -> + if (!authWrapper.isAllowed("prompts", "get", mapOf("name" to secretPromptName))) { + throw authWrapper.createDeniedError("Access denied to prompt: $secretPromptName") + } + + GetPromptResult( + description = "Secret prompt response", + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "This is secret information"), + ), + ), + ) + } + + server.addPrompt( + name = restrictedPromptName, + description = "A restricted prompt that some users can access", + arguments = listOf(), + ) { request -> + if (!authWrapper.isAllowed("prompts", "get", mapOf("name" to restrictedPromptName))) { + throw authWrapper.createDeniedError("Access denied to prompt: $restrictedPromptName") + } + + GetPromptResult( + description = "Restricted prompt response", + messages = listOf( + PromptMessage( + role = Role.User, + content = TextContent(text = "This is restricted information"), + ), + ), + ) + } + } + + @Test + fun testListPromptsAllowed() = runBlocking { + val result = client.listPrompts() + + assertNotNull(result, "List prompts result should not be null") + assertTrue(result.prompts.isNotEmpty(), "Prompts list should not be empty") + + val promptNames = result.prompts.map { it.name } + assertTrue(promptNames.contains(publicPromptName), "Public prompt should be listed") + assertTrue(promptNames.contains(secretPromptName), "Secret prompt should be listed") + assertTrue(promptNames.contains(restrictedPromptName), "Restricted prompt should be listed") + } + + @Test + fun testListPromptsDenied() { + runBlocking { + val result = client.listPrompts() + assertNotNull(result, "List prompts should succeed") + } + } + + @Test + fun testGetPromptAllowed() = runBlocking { + val result = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = publicPromptName, + arguments = mapOf("query" to "test query"), + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals("Public prompt response", result.description) + assertTrue(result.messages.isNotEmpty(), "Messages should not be empty") + + val userMessage = result.messages.first() + assertEquals(Role.User, userMessage.role) + val content = userMessage.content as TextContent + assertTrue(content.text.contains("test query"), "Response should contain the query") + } + + @Test + fun testGetPromptDenied() { + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = secretPromptName, + arguments = emptyMap(), + ), + ), + ) + } + } + + withClue("Exception message should mention access denied") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + } + + @Test + fun testGetPromptPartialAccess(): Unit = runBlocking { + val publicResult = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = publicPromptName, + arguments = emptyMap(), + ), + ), + ) + assertNotNull(publicResult, "Public prompt should be accessible") + + val restrictedResult = client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = restrictedPromptName, + arguments = emptyMap(), + ), + ), + ) + assertNotNull(restrictedResult, "Restricted prompt should be accessible") + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = secretPromptName, + arguments = emptyMap(), + ), + ), + ) + } + } + + withClue("Secret prompt should be denied") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + } + + @Test + fun testUnauthorizedAfterInitialization(): Unit = runBlocking { + assertNotNull(client, "Client should be initialized") + + val listResult = client.listPrompts() + assertNotNull(listResult, "List prompts should succeed") + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + GetPromptRequestParams( + name = secretPromptName, + arguments = emptyMap(), + ), + ), + ) + } + } + + withClue("Authorization should be checked on prompt access, not initialization") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + } +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractResourceSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractResourceSecurityIntegrationTest.kt new file mode 100644 index 00000000..38a4030f --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractResourceSecurityIntegrationTest.kt @@ -0,0 +1,226 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +import io.kotest.assertions.withClue +import io.kotest.matchers.shouldBe +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase +import io.modelcontextprotocol.kotlin.sdk.integration.utils.AuthorizationRules +import io.modelcontextprotocol.kotlin.sdk.integration.utils.MockAuthorizationWrapper +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractResourceSecurityIntegrationTest : KotlinTestBase() { + + private val publicResourceUri = "test://public-resource.txt" + private val secretResourceUri = "test://secret-resource.txt" + private val restrictedResourceUri = "test://restricted-resource.txt" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + configureServerWithAuthorization( + allowedResources = setOf(publicResourceUri, restrictedResourceUri), + ) + } + + protected fun configureServerWithAuthorization( + allowedResources: Set? = null, + deniedResources: Set? = null, + ) { + val authWrapper = MockAuthorizationWrapper( + AuthorizationRules( + allowedResources = allowedResources, + deniedResources = deniedResources, + ), + ) + + server.addResource( + uri = publicResourceUri, + name = "Public Resource", + description = "A public resource that authorized users can access", + mimeType = "text/plain", + ) { request -> + if (!authWrapper.isAllowed("resources", "read", mapOf("uri" to publicResourceUri))) { + throw authWrapper.createDeniedError("Access denied to resource: $publicResourceUri") + } + + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Public resource content", + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = secretResourceUri, + name = "Secret Resource", + description = "A secret resource that requires special permissions", + mimeType = "text/plain", + ) { request -> + if (!authWrapper.isAllowed("resources", "read", mapOf("uri" to secretResourceUri))) { + throw authWrapper.createDeniedError("Access denied to resource: $secretResourceUri") + } + + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Secret resource content", + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = restrictedResourceUri, + name = "Restricted Resource", + description = "A restricted resource", + mimeType = "text/plain", + ) { request -> + if (!authWrapper.isAllowed("resources", "read", mapOf("uri" to restrictedResourceUri))) { + throw authWrapper.createDeniedError("Access denied to resource: $restrictedResourceUri") + } + + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "Restricted resource content", + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + } + + @Test + fun testListResourcesAllowed() = runBlocking { + val result = client.listResources() + + assertNotNull(result, "List resources result should not be null") + assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") + + val publicResource = result.resources.find { it.uri == publicResourceUri } + assertNotNull(publicResource, "Public resource should be in the list") + assertEquals("Public Resource", publicResource.name) + } + + @Test + fun testListResourcesDenied() { + runBlocking { + val result = client.listResources() + assertNotNull(result, "List should still work in default configuration") + } + } + + @Test + fun testReadResourceAllowed() = runBlocking { + val result = client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = publicResourceUri, + ), + ), + ) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Contents should not be empty") + + val content = result.contents.first() as TextResourceContents + assertEquals("Public resource content", content.text) + assertEquals("text/plain", content.mimeType) + } + + @Test + fun testReadResourceDenied() { + val exception = assertThrows { + runBlocking { + client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = secretResourceUri, + ), + ), + ) + } + } + + withClue("Exception message should mention access denied") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + } + + @Test + fun testReadResourcePartialAccess(): Unit = runBlocking { + val publicResult = client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = publicResourceUri, + ), + ), + ) + assertNotNull(publicResult, "Public resource should be accessible") + + val exception = assertThrows { + runBlocking { + client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = secretResourceUri, + ), + ), + ) + } + } + withClue("Should be denied access to secret resource") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + + val restrictedResult = client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = restrictedResourceUri, + ), + ), + ) + assertNotNull(restrictedResult, "Restricted resource should be accessible with proper permissions") + } + + @Test + fun testUnauthorizedAfterInitialization(): Unit = runBlocking { + val exception = assertThrows { + runBlocking { + client.readResource( + ReadResourceRequest( + ReadResourceRequestParams( + uri = secretResourceUri, + ), + ), + ) + } + } + + withClue("Unauthorized operations should fail after successful initialization") { + exception.message?.lowercase()?.contains("access denied") shouldBe true + } + } +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractToolSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractToolSecurityIntegrationTest.kt new file mode 100644 index 00000000..8c1b09c3 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/AbstractToolSecurityIntegrationTest.kt @@ -0,0 +1,225 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +import io.kotest.assertions.withClue +import io.kotest.matchers.shouldBe +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase +import io.modelcontextprotocol.kotlin.sdk.integration.utils.AuthorizationRules +import io.modelcontextprotocol.kotlin.sdk.integration.utils.MockAuthorizationWrapper +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractToolSecurityIntegrationTest : KotlinTestBase() { + + private val publicToolName = "public-tool" + private val secretToolName = "secret-tool" + private val restrictedToolName = "restricted-tool" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + configureServerWithAuthorization( + allowedTools = setOf(publicToolName, restrictedToolName), + ) + } + + protected fun configureServerWithAuthorization( + allowedTools: Set? = null, + deniedTools: Set? = null, + ) { + val authWrapper = MockAuthorizationWrapper( + AuthorizationRules( + allowedTools = allowedTools, + deniedTools = deniedTools, + ), + ) + + server.addTool( + name = publicToolName, + description = "A public tool that authorized users can access", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Input text") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + if (!authWrapper.isAllowed("tools", "call", mapOf("name" to publicToolName))) { + throw authWrapper.createDeniedError("Access denied to tool: $publicToolName") + } + + CallToolResult( + content = listOf( + TextContent( + text = "Public tool result", + ), + ), + ) + } + + server.addTool( + name = secretToolName, + description = "A secret tool that requires special permissions", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Input text") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + if (!authWrapper.isAllowed("tools", "call", mapOf("name" to secretToolName))) { + throw authWrapper.createDeniedError("Access denied to tool: $secretToolName") + } + + CallToolResult( + content = listOf( + TextContent( + text = "Secret tool result", + ), + ), + ) + } + + server.addTool( + name = restrictedToolName, + description = "A restricted tool", + inputSchema = ToolSchema( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Input text") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + if (!authWrapper.isAllowed("tools", "call", mapOf("name" to restrictedToolName))) { + throw authWrapper.createDeniedError("Access denied to tool: $restrictedToolName") + } + + CallToolResult( + content = listOf( + TextContent( + text = "Restricted tool result", + ), + ), + ) + } + } + + @Test + fun testListToolsAllowed() = runBlocking { + val result = client.listTools() + + assertNotNull(result, "List tools result should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + val publicTool = result.tools.find { it.name == publicToolName } + assertNotNull(publicTool, "Public tool should be in the list") + assertEquals("A public tool that authorized users can access", publicTool.description) + } + + @Test + fun testListToolsDenied() { + runBlocking { + val result = client.listTools() + assertNotNull(result, "List should still work in default configuration") + } + } + + @Test + fun testCallToolAllowed() = runBlocking { + val result = client.callTool(publicToolName, mapOf("text" to "test")) + + assertNotNull(result, "Call tool result should not be null") + assertTrue(result.content.isNotEmpty(), "Contents should not be empty") + + val content = result.content.first() as TextContent + assertEquals("Public tool result", content.text) + } + + @Test + fun testCallToolDenied() = runBlocking { + val result = client.callTool(secretToolName, mapOf("text" to "test")) + + assertNotNull(result, "Call tool result should not be null") + withClue("Tool call should have isError=true for denied access") { + result.isError shouldBe true + } + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present in the result") + withClue("Error message should mention access denied") { + textContent.text.lowercase().contains("access denied") shouldBe true + } + } + + @Test + fun testCallToolPartialAccess() = runBlocking { + val publicResult = client.callTool(publicToolName, mapOf("text" to "test")) + assertNotNull(publicResult, "Public tool should be accessible") + withClue("Public tool should succeed without error") { + (publicResult.isError ?: false) shouldBe false + } + + val secretResult = client.callTool(secretToolName, mapOf("text" to "test")) + assertNotNull(secretResult, "Secret tool call should return a result") + withClue("Secret tool should have isError=true for denied access") { + secretResult.isError shouldBe true + } + val secretTextContent = secretResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(secretTextContent, "Error content should be present") + withClue("Should be denied access to secret tool") { + secretTextContent.text.lowercase().contains("access denied") shouldBe true + } + + val restrictedResult = client.callTool(restrictedToolName, mapOf("text" to "test")) + assertNotNull(restrictedResult, "Restricted tool should be accessible with proper permissions") + withClue("Restricted tool should succeed without error") { + (restrictedResult.isError ?: false) shouldBe false + } + } + + @Test + fun testUnauthorizedAfterInitialization() = runBlocking { + val result = client.callTool(secretToolName, mapOf("text" to "test")) + + assertNotNull(result, "Call tool result should not be null") + withClue("Tool call should have isError=true for unauthorized access") { + result.isError shouldBe true + } + + val textContent = result.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Error content should be present") + withClue("Unauthorized operations should fail after successful initialization") { + textContent.text.lowercase().contains("access denied") shouldBe true + } + } +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/PromptSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/PromptSecurityIntegrationTest.kt new file mode 100644 index 00000000..22a9cde4 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/PromptSecurityIntegrationTest.kt @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +class PromptSecurityIntegrationTest : AbstractPromptSecurityIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ResourceSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ResourceSecurityIntegrationTest.kt new file mode 100644 index 00000000..6e70507b --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ResourceSecurityIntegrationTest.kt @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +class ResourceSecurityIntegrationTest : AbstractResourceSecurityIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ToolSecurityIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ToolSecurityIntegrationTest.kt new file mode 100644 index 00000000..3a588251 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/security/ToolSecurityIntegrationTest.kt @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.security + +class ToolSecurityIntegrationTest : AbstractToolSecurityIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/PromptIntegrationTestStreamableHttp.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/PromptIntegrationTestStreamableHttp.kt new file mode 100644 index 00000000..8f4e3344 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/PromptIntegrationTestStreamableHttp.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractPromptIntegrationTest + +class PromptIntegrationTestStreamableHttp : AbstractPromptIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STREAMABLE_HTTP +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ResourceIntegrationTestStreamableHttp.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ResourceIntegrationTestStreamableHttp.kt new file mode 100644 index 00000000..d72bcfdc --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ResourceIntegrationTestStreamableHttp.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractResourceIntegrationTest + +class ResourceIntegrationTestStreamableHttp : AbstractResourceIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STREAMABLE_HTTP +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ToolIntegrationTestStreamableHttp.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ToolIntegrationTestStreamableHttp.kt new file mode 100644 index 00000000..288ff075 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/ToolIntegrationTestStreamableHttp.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +class ToolIntegrationTestStreamableHttp : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STREAMABLE_HTTP +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/MockAuthorizationWrapper.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/MockAuthorizationWrapper.kt new file mode 100644 index 00000000..718ab71b --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/MockAuthorizationWrapper.kt @@ -0,0 +1,173 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.modelcontextprotocol.kotlin.sdk.types.McpException + +/** + * Test-only authorization wrapper for validating security test scenarios. + * + * According to the MCP specification, authorization is an application-level responsibility. + * This class demonstrates how applications can implement authorization logic by wrapping + * server request handlers with permission checks. + * + * This is NOT production code - it's a test utility to verify that servers can implement + * authorization at the application level and properly reject unauthorized requests. + * + * Example usage: + * ```kotlin + * val authWrapper = MockAuthorizationWrapper( + * AuthorizationRules( + * allowedPrompts = setOf("public-prompt"), + * deniedPrompts = setOf("secret-prompt") + * ) + * ) + * + * server.addPrompt(...) { request -> + * if (!authWrapper.isAllowed("prompts", "get", mapOf("name" to request.params.name))) { + * throw authWrapper.createDeniedError() + * } + * // Normal handler logic + * } + * ``` + */ +class MockAuthorizationWrapper(private val rules: AuthorizationRules) { + companion object { + /** Custom error code for authorization denied (server-defined range: -32000 to -32099) */ + const val ERROR_CODE_AUTHORIZATION_DENIED = -32002 + } + + /** + * Checks if a given operation on a feature is allowed based on the configured rules. + * + * @param feature The feature being accessed (e.g., "prompts", "resources", "tools", "logging") + * @param operation The operation being performed (e.g., "list", "get", "call", "read") + * @param params Additional parameters for the operation (e.g., prompt name, resource URI) + * @return true if the operation is allowed, false otherwise + */ + fun isAllowed(feature: String, operation: String, params: Map): Boolean = when (feature) { + "prompts" -> { + val promptName = params["name"] as? String + checkPromptAccess(promptName) + } + + "resources" -> { + val resourceUri = params["uri"] as? String + checkResourceAccess(resourceUri) + } + + "tools" -> { + val toolName = params["name"] as? String + checkToolAccess(toolName) + } + + "logging" -> { + // Logging access is controlled separately + checkLoggingAccess() + } + + else -> false // Unknown feature, deny by default + } + + /** + * Creates an MCP exception for authorization denied errors. + * This exception should be thrown when a request is rejected due to insufficient permissions. + * + * @param reason Optional reason for the denial (defaults to generic message) + * @return McpException with authorization denied error code + */ + fun createDeniedError(reason: String = "Access denied: insufficient permissions"): McpException = McpException( + code = ERROR_CODE_AUTHORIZATION_DENIED, + message = reason, + data = null, + ) + + private fun checkPromptAccess(promptName: String?): Boolean { + if (promptName == null) { + // Listing prompts - check if any prompts are allowed + return rules.allowedPrompts == null || rules.allowedPrompts.isNotEmpty() + } + + // Check denied list first (explicit deny takes precedence) + if (rules.deniedPrompts?.contains(promptName) == true) { + return false + } + + // Check allowed list (null means all allowed, empty means none allowed) + return when { + rules.allowedPrompts == null -> true + + // No restrictions + rules.allowedPrompts.isEmpty() -> false + + // Empty set means deny all + else -> rules.allowedPrompts.contains(promptName) + } + } + + private fun checkResourceAccess(resourceUri: String?): Boolean { + if (resourceUri == null) { + // Listing resources - check if any resources are allowed + return rules.allowedResources == null || rules.allowedResources.isNotEmpty() + } + + // Check denied list first + if (rules.deniedResources?.contains(resourceUri) == true) { + return false + } + + // Check allowed list + return when { + rules.allowedResources == null -> true + rules.allowedResources.isEmpty() -> false + else -> rules.allowedResources.contains(resourceUri) + } + } + + private fun checkToolAccess(toolName: String?): Boolean { + if (toolName == null) { + // Listing tools - check if any tools are allowed + return rules.allowedTools == null || rules.allowedTools.isNotEmpty() + } + + // Check denied list first + if (rules.deniedTools?.contains(toolName) == true) { + return false + } + + // Check allowed list + return when { + rules.allowedTools == null -> true + rules.allowedTools.isEmpty() -> false + else -> rules.allowedTools.contains(toolName) + } + } + + private fun checkLoggingAccess(): Boolean { + // Simple boolean check for logging + return rules.allowLogging + } +} + +/** + * Configuration for authorization rules in integration tests. + * + * For each feature (prompts, resources, tools), you can specify: + * - allowedXxx: Set of allowed items (null = all allowed, empty = none allowed) + * - deniedXxx: Set of explicitly denied items (takes precedence over allowed) + * + * @property allowedPrompts Set of allowed prompt names, or null to allow all + * @property deniedPrompts Set of explicitly denied prompt names + * @property allowedResources Set of allowed resource URIs, or null to allow all + * @property deniedResources Set of explicitly denied resource URIs + * @property allowedTools Set of allowed tool names, or null to allow all + * @property deniedTools Set of explicitly denied tool names + * @property allowLogging Whether logging operations are allowed + */ +data class AuthorizationRules( + val allowedPrompts: Set? = null, + val deniedPrompts: Set? = null, + val allowedResources: Set? = null, + val deniedResources: Set? = null, + val allowedTools: Set? = null, + val deniedTools: Set? = null, + val allowLogging: Boolean = true, +) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/StreamableHttpTestUtils.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/StreamableHttpTestUtils.kt new file mode 100644 index 00000000..84c6eec3 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/StreamableHttpTestUtils.kt @@ -0,0 +1,127 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.ktor.http.Headers +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Utility functions for StreamableHTTP transport integration tests. + * + * These utilities help with: + * - Session ID validation and extraction + * - HTTP header manipulation + * - Test assertions for HTTP responses + * + * Note: Some utilities are placeholders for future advanced test scenarios + * that require more complex test infrastructure. + */ +@OptIn(ExperimentalUuidApi::class) +object StreamableHttpTestUtils { + + private const val MCP_SESSION_ID_HEADER = "mcp-session-id" + private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" + private const val MCP_RESUMPTION_TOKEN_HEADER = "last-event-id" + + /** + * Validates that a session ID is in the correct format (UUID). + * + * @param sessionId The session ID to validate + * @return true if the session ID is a valid UUID, false otherwise + */ + fun validateSessionId(sessionId: String): Boolean = try { + Uuid.parse(sessionId) + true + } catch (e: IllegalArgumentException) { + false + } + + /** + * Extracts the MCP session ID from HTTP headers. + * + * @param headers The HTTP headers to search + * @return The session ID if present, null otherwise + */ + fun extractSessionIdFromHeaders(headers: Headers): String? = headers[MCP_SESSION_ID_HEADER] + + /** + * Extracts the MCP protocol version from HTTP headers. + * + * @param headers The HTTP headers to search + * @return The protocol version if present, null otherwise + */ + fun extractProtocolVersionFromHeaders(headers: Headers): String? = headers[MCP_PROTOCOL_VERSION_HEADER] + + /** + * Extracts the Last-Event-ID (resumption token) from HTTP headers. + * + * @param headers The HTTP headers to search + * @return The last event ID if present, null otherwise + */ + fun extractLastEventIdFromHeaders(headers: Headers): String? = headers[MCP_RESUMPTION_TOKEN_HEADER] + + /** + * Generates a random session ID in UUID format. + * + * @return A new UUID string + */ + fun generateSessionId(): String = Uuid.random().toString() + + /** + * Validates that a message is within the StreamableHTTP size limit (4MB). + * + * @param messageSize The size of the message in bytes + * @return true if the message is within the limit, false otherwise + */ + fun isWithinMessageSizeLimit(messageSize: Int): Boolean = messageSize <= MAXIMUM_MESSAGE_SIZE + + /** + * Gets the maximum message size allowed by StreamableHTTP transport. + * + * @return The maximum message size in bytes (4MB) + */ + fun getMaximumMessageSize(): Int = MAXIMUM_MESSAGE_SIZE + + private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB + + // TODO: Future utilities for advanced testing scenarios + + // TODO: Creates a custom HTTP client with specific headers for testing. + // + // This would be useful for testing: + // - Invalid session IDs + // - Custom Host/Origin headers for DNS rebinding tests + // - Missing required headers + // + // Implementation requires creating a Ktor HttpClient with custom configuration. + // fun createCustomHttpClient(customHeaders: Map): HttpClient + + // TODO: Sends a direct HTTP request to the server for low-level testing. + // + // This would be useful for testing: + // - Malformed requests + // - Invalid JSON + // - HTTP-level errors (404, 403, 400, etc.) + // + // Implementation requires direct Ktor client usage outside of MCP client abstraction. + // suspend fun sendDirectHttpRequest(url: String, method: HttpMethod, headers: Map, body: String?): HttpResponse + + // TODO: Creates a StreamableHTTP client with a pre-set session ID. + // + // This would be useful for testing: + // - Session resumption + // - Invalid session ID handling + // - Session expiration scenarios + // + // Implementation requires ability to inject session ID into client transport. + // fun createClientWithSessionId(url: String, sessionId: String): Client + + // TODO: Asserts that a request was rejected with a specific HTTP status code. + // + // This would be useful for testing: + // - Authorization failures (403) + // - Invalid requests (400) + // - Not found errors (404) + // + // Implementation requires catching and analyzing transport-level errors. + // suspend fun assertRejectedWithStatus(expectedStatus: HttpStatusCode, block: suspend () -> Unit) +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt index af2110a0..e5eeba03 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockMcp.kt @@ -19,7 +19,7 @@ import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.putJsonObject -const val MCP_SESSION_ID_HEADER = "Mcp-Session-Id" +const val MCP_SESSION_ID_HEADER = "mcp-session-id" internal class MockMcp(verbose: Boolean = false) {