-
Notifications
You must be signed in to change notification settings - Fork 393
Feat/expose kagent agents in mcp #1201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f80bce2
9daf77d
cc34412
69aba67
83f4f84
c6730e3
88bc88c
5c33551
5ea346f
048f3d3
700e244
190ae09
aeaa6e9
c5b4c9d
3fcef83
42c7809
ccd0f5d
5257db2
6e46d37
e820e70
cf93df5
a9928f0
e8f2d72
be7aa78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| package mcp | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "net/http" | ||
| "os" | ||
| "strings" | ||
| "sync" | ||
| "sync/atomic" | ||
| "time" | ||
|
|
||
| "github.com/kagent-dev/kagent/go/cli/internal/config" | ||
| "github.com/kagent-dev/kagent/go/internal/a2a" | ||
| "github.com/kagent-dev/kagent/go/internal/version" | ||
| "github.com/mark3labs/mcp-go/mcp" | ||
| mcpserver "github.com/mark3labs/mcp-go/server" | ||
| "github.com/spf13/cobra" | ||
| a2aclient "trpc.group/trpc-go/trpc-a2a-go/client" | ||
| "trpc.group/trpc-go/trpc-a2a-go/protocol" | ||
| ) | ||
|
|
||
| var ( | ||
| serveAgentsTransport string | ||
| serveAgentsHost string | ||
| serveAgentsPort int | ||
| ) | ||
|
|
||
| var a2aContextBySessionAndAgent sync.Map | ||
|
|
||
| var fallbackInvocationCounter uint64 | ||
|
|
||
| var ServeAgentsCmd = &cobra.Command{ | ||
| Use: "serve-mcp", | ||
| Short: "Serve kagent agents via MCP", | ||
| RunE: func(cmd *cobra.Command, args []string) error { | ||
| cfg, err := config.Get() | ||
| if err != nil { | ||
| return fmt.Errorf("config: %w", err) | ||
| } | ||
| hooks := &mcpserver.Hooks{} | ||
| hooks.AddOnUnregisterSession(func(ctx context.Context, session mcpserver.ClientSession) { | ||
| sessionID := session.SessionID() | ||
| a2aContextBySessionAndAgent.Range(func(key, _ any) bool { | ||
| keyStr, ok := key.(string) | ||
| if !ok { | ||
| return true | ||
| } | ||
| if strings.HasPrefix(keyStr, sessionID+"|") { | ||
| a2aContextBySessionAndAgent.Delete(key) | ||
| } | ||
| return true | ||
| }) | ||
| }) | ||
| s := mcpserver.NewMCPServer( | ||
| "kagent-agents", | ||
| version.Version, | ||
| mcpserver.WithToolCapabilities(false), | ||
| mcpserver.WithHooks(hooks), | ||
| ) | ||
|
|
||
| s.AddTool(mcp.NewTool("list_agents", | ||
| mcp.WithDescription("List invokable kagent agents (accepted + deploymentReady)"), | ||
| ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | ||
| resp, err := cfg.Client().Agent.ListAgents(ctx) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("list agents", err), nil | ||
| } | ||
| type agentSummary struct { | ||
| Ref string `json:"ref"` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Id doesn't seem to be useful, maybe just ref (name) and description? |
||
| Description string `json:"description,omitempty"` | ||
| } | ||
| agents := make([]agentSummary, 0) | ||
| for _, agent := range resp.Data { | ||
| if !agent.Accepted || !agent.DeploymentReady || agent.Agent == nil { | ||
| continue | ||
| } | ||
| ref := agent.Agent.Namespace + "/" + agent.Agent.Name | ||
| agents = append(agents, agentSummary{Ref: ref, Description: agent.Agent.Spec.Description}) | ||
| } | ||
| if len(agents) == 0 { | ||
| return mcp.NewToolResultStructured(agents, "No invokable agents found."), nil | ||
| } | ||
|
|
||
| var fallbackText strings.Builder | ||
| for i, agent := range agents { | ||
| if i > 0 { | ||
| fallbackText.WriteByte('\n') | ||
| } | ||
| fallbackText.WriteString(agent.Ref) | ||
| if agent.Description != "" { | ||
| fallbackText.WriteString(" - ") | ||
| fallbackText.WriteString(agent.Description) | ||
| } | ||
| } | ||
|
|
||
| return mcp.NewToolResultStructured(agents, fallbackText.String()), nil | ||
| }) | ||
|
|
||
| s.AddTool(mcp.NewTool("invoke_agent", | ||
| mcp.WithDescription("Invoke a kagent agent via A2A"), | ||
| mcp.WithString("agent", mcp.Description("Agent name (or namespace/name)"), mcp.Required()), | ||
| mcp.WithString("task", mcp.Description("Task to run"), mcp.Required()), | ||
| ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { | ||
| agentRef, err := request.RequireString("agent") | ||
| if err != nil { | ||
| return mcp.NewToolResultError(err.Error()), nil | ||
| } | ||
| task, err := request.RequireString("task") | ||
| if err != nil { | ||
| return mcp.NewToolResultError(err.Error()), nil | ||
| } | ||
| agentNS, agentName, ok := strings.Cut(agentRef, "/") | ||
| if !ok { | ||
| agentNS, agentName = cfg.Namespace, agentRef | ||
| } | ||
| agentRef = agentNS + "/" + agentName | ||
|
|
||
| sessionID := "" | ||
| if session := mcpserver.ClientSessionFromContext(ctx); session != nil { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of falling back to unknown, use a unique invocation ID per session if none is available to keep the context separate. When callers with proper session support use this they will get unknown as session and it will cause unexpected behaviour with multiple concurrent users like potentially wrong context history. |
||
| sessionID = session.SessionID() | ||
| } else if headerSessionID := request.Header.Get(mcpserver.HeaderKeySessionID); headerSessionID != "" { | ||
| sessionID = headerSessionID | ||
| } | ||
| if sessionID == "" { | ||
| sessionID = fmt.Sprintf("invocation-%d", atomic.AddUint64(&fallbackInvocationCounter, 1)) | ||
| } | ||
| contextKey := sessionID + "|" + agentRef | ||
| var contextIDPtr *string | ||
| if prior, ok := a2aContextBySessionAndAgent.Load(contextKey); ok { | ||
| if priorStr, ok := prior.(string); ok && priorStr != "" { | ||
| contextIDPtr = &priorStr | ||
| } | ||
| } | ||
|
|
||
| a2aURL := fmt.Sprintf("%s/api/a2a/%s/%s", cfg.KAgentURL, agentNS, agentName) | ||
| client, err := a2aclient.NewA2AClient(a2aURL, a2aclient.WithTimeout(cfg.Timeout)) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("a2a client", err), nil | ||
| } | ||
| result, err := client.SendMessage(ctx, protocol.SendMessageParams{Message: protocol.Message{ | ||
| Kind: protocol.KindMessage, Role: protocol.MessageRoleUser, ContextID: contextIDPtr, Parts: []protocol.Part{protocol.NewTextPart(task)}, | ||
| }}) | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("a2a send", err), nil | ||
| } | ||
|
|
||
| var responseText, newContextID string | ||
| switch a2aResult := result.Result.(type) { | ||
| case *protocol.Message: | ||
| responseText = a2a.ExtractText(*a2aResult) | ||
| if a2aResult.ContextID != nil { | ||
| newContextID = *a2aResult.ContextID | ||
| } | ||
| case *protocol.Task: | ||
| newContextID = a2aResult.ContextID | ||
| if a2aResult.Status.Message != nil { | ||
| responseText = a2a.ExtractText(*a2aResult.Status.Message) | ||
| } | ||
| for _, artifact := range a2aResult.Artifacts { | ||
| responseText += a2a.ExtractText(protocol.Message{Parts: artifact.Parts}) | ||
| } | ||
| } | ||
| if responseText == "" { | ||
| raw, err := result.MarshalJSON() | ||
| if err != nil { | ||
| return mcp.NewToolResultErrorFromErr("marshal result", err), nil | ||
| } | ||
| responseText = string(raw) | ||
| } | ||
| if newContextID != "" { | ||
| a2aContextBySessionAndAgent.Store(contextKey, newContextID) | ||
| } | ||
| return mcp.NewToolResultStructured(map[string]any{ | ||
| "agent": agentRef, | ||
| "context_id": newContextID, | ||
| "text": responseText, | ||
| }, responseText), nil | ||
| }) | ||
|
|
||
| switch strings.ToLower(serveAgentsTransport) { | ||
| case "stdio": | ||
| stdioServer := mcpserver.NewStdioServer(s) | ||
| return stdioServer.Listen(cmd.Context(), os.Stdin, os.Stdout) | ||
| case "http": | ||
| addr := fmt.Sprintf("%s:%d", serveAgentsHost, serveAgentsPort) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps some logging to indicate the server is running successfully like "MCP server listening on xxx" |
||
| cmd.PrintErrf("MCP server listening on http://%s/mcp\n", addr) | ||
| httpServer := mcpserver.NewStreamableHTTPServer(s) | ||
| go func() { | ||
| <-cmd.Context().Done() | ||
| shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
| defer cancel() | ||
| _ = httpServer.Shutdown(shutdownCtx) | ||
| }() | ||
| if err := httpServer.Start(addr); err != nil && err != http.ErrServerClosed { | ||
| return err | ||
| } | ||
| return nil | ||
| default: | ||
| return fmt.Errorf("invalid transport %q (expected stdio or http)", serveAgentsTransport) | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| func init() { | ||
| ServeAgentsCmd.Flags().StringVar(&serveAgentsTransport, "transport", "stdio", "Transport mode (stdio or http)") | ||
| ServeAgentsCmd.Flags().StringVar(&serveAgentsHost, "host", "127.0.0.1", "HTTP host to bind (when --transport http)") | ||
| ServeAgentsCmd.Flags().IntVar(&serveAgentsPort, "port", 3000, "HTTP port to bind (when --transport http)") | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| package e2e_test | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "bytes" | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "os" | ||
| "os/exec" | ||
| "path/filepath" | ||
| "runtime" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestE2EInvokeAgentThroughMCPServeAgents(t *testing.T) { | ||
| // Setup mock server (so agent responses are deterministic and don't hit real LLMs) | ||
| baseURL, stopServer := setupMockServer(t, "mocks/invoke_mcp_serve_agents.json") | ||
| defer stopServer() | ||
|
|
||
| // Setup Kubernetes resources for a known-good agent | ||
| cli := setupK8sClient(t, false) | ||
| modelCfg := setupModelConfig(t, cli, baseURL) | ||
| agent := setupAgentWithOptions(t, cli, modelCfg.Name, nil, AgentOptions{ | ||
| Name: "kebab-agent", | ||
| }) | ||
|
|
||
| kagentURL := os.Getenv("KAGENT_URL") | ||
| if kagentURL == "" { | ||
| kagentURL = "http://localhost:8083" | ||
| } | ||
|
|
||
| _, testFile, _, ok := runtime.Caller(0) | ||
| require.True(t, ok) | ||
| goModuleRoot := filepath.Clean(filepath.Join(filepath.Dir(testFile), "../..")) | ||
|
|
||
| kagentBin := filepath.Join(t.TempDir(), "kagent") | ||
| build := exec.Command("go", "build", "-o", kagentBin, "./cli/cmd/kagent") | ||
| build.Dir = goModuleRoot | ||
| buildOutput, err := build.CombinedOutput() | ||
| require.NoError(t, err, string(buildOutput)) | ||
|
|
||
| homeDir := t.TempDir() | ||
| cfgDir := filepath.Join(homeDir, ".kagent") | ||
| require.NoError(t, os.MkdirAll(cfgDir, 0755)) | ||
| cfgPath := filepath.Join(cfgDir, "config.yaml") | ||
| require.NoError(t, os.WriteFile(cfgPath, []byte(fmt.Sprintf("kagent_url: %s\nnamespace: kagent\ntimeout: 300s\n", kagentURL)), 0644)) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) | ||
| defer cancel() | ||
|
|
||
| cmd := exec.CommandContext(ctx, kagentBin, "mcp", "serve-agents") | ||
| cmd.Env = append(os.Environ(), "HOME="+homeDir) | ||
| stdout, err := cmd.StdoutPipe() | ||
| require.NoError(t, err) | ||
| stdin, err := cmd.StdinPipe() | ||
| require.NoError(t, err) | ||
| var stderr bytes.Buffer | ||
| cmd.Stderr = &stderr | ||
| require.NoError(t, cmd.Start()) | ||
| t.Cleanup(func() { | ||
| _ = stdin.Close() | ||
| _ = cmd.Process.Kill() | ||
| _ = cmd.Wait() | ||
| }) | ||
|
|
||
| lines := make(chan string, 32) | ||
| go func() { | ||
| scanner := bufio.NewScanner(stdout) | ||
| for scanner.Scan() { | ||
| lines <- scanner.Text() | ||
| } | ||
| close(lines) | ||
| }() | ||
|
|
||
| writeLine := func(line string) { | ||
| _, _ = fmt.Fprintln(stdin, line) | ||
| } | ||
|
|
||
| readResponse := func(wantID int) json.RawMessage { | ||
| deadline := time.NewTimer(15 * time.Second) | ||
| defer deadline.Stop() | ||
| for { | ||
| select { | ||
| case line, ok := <-lines: | ||
| require.True(t, ok, stderr.String()) | ||
| var msg struct { | ||
| ID int `json:"id"` | ||
| Result json.RawMessage `json:"result,omitempty"` | ||
| Error json.RawMessage `json:"error,omitempty"` | ||
| } | ||
| require.NoError(t, json.Unmarshal([]byte(line), &msg), line) | ||
| if msg.ID != wantID { | ||
| continue | ||
| } | ||
| require.Nil(t, msg.Error, line) | ||
| return msg.Result | ||
| case <-deadline.C: | ||
| t.Fatalf("timed out waiting for id=%d; stderr=%s", wantID, stderr.String()) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e","version":"0.0.0"}}}`) | ||
| _ = readResponse(1) | ||
| writeLine(`{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}`) | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`) | ||
| toolsList := readResponse(2) | ||
| var listResult struct { | ||
| Tools []struct { | ||
| Name string `json:"name"` | ||
| } `json:"tools"` | ||
| } | ||
| require.NoError(t, json.Unmarshal(toolsList, &listResult), string(toolsList)) | ||
| require.GreaterOrEqual(t, len(listResult.Tools), 2) | ||
| toolNames := make([]string, 0, len(listResult.Tools)) | ||
| for _, tool := range listResult.Tools { | ||
| toolNames = append(toolNames, tool.Name) | ||
| } | ||
| require.Contains(t, toolNames, "list_agents") | ||
| require.Contains(t, toolNames, "invoke_agent") | ||
|
|
||
| writeLine(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"list_agents"}}`) | ||
| agentsResult := readResponse(3) | ||
| var callResult struct { | ||
| Content []struct { | ||
| Type string `json:"type"` | ||
| Text string `json:"text"` | ||
| } `json:"content"` | ||
| } | ||
| require.NoError(t, json.Unmarshal(agentsResult, &callResult), string(agentsResult)) | ||
| require.NotEmpty(t, callResult.Content) | ||
| require.Contains(t, callResult.Content[0].Text, agent.Namespace+"/"+agent.Name) | ||
|
|
||
| writeLine(fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"invoke_agent","arguments":{"agent":%q,"task":"What can you do?"}}}`, agent.Name)) | ||
| invokeResult := readResponse(4) | ||
| require.NoError(t, json.Unmarshal(invokeResult, &callResult), string(invokeResult)) | ||
| require.NotEmpty(t, callResult.Content) | ||
| require.Contains(t, callResult.Content[0].Text, "kebab") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This map never cleans up old session contexts. This might be an issue for HTTP server
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quickly looked at the docs for mcp-go and seems like this hook will help:
hope it helps