diff --git a/internal/broker/broker.go b/internal/broker/broker.go index a4f38528..44637f94 100644 --- a/internal/broker/broker.go +++ b/internal/broker/broker.go @@ -214,6 +214,10 @@ func (m *mcpBrokerImpl) GetVirtualSeverByHeader(namespaceName string) (config.Vi } func (m *mcpBrokerImpl) ToolAnnotations(serverID config.UpstreamMCPID, tool string) (mcp.ToolAnnotation, bool) { + // Avoid race with OnConfigChange() + m.mcpLock.RLock() + defer m.mcpLock.RUnlock() + upstream, ok := m.mcpServers[serverID] if !ok { return mcp.ToolAnnotation{}, false @@ -227,6 +231,10 @@ func (m *mcpBrokerImpl) ToolAnnotations(serverID config.UpstreamMCPID, tool stri // GetServerInfo implements MCPBroker by providing a lookup of the server that implements a tool. func (m *mcpBrokerImpl) GetServerInfo(tool string) (*config.MCPServer, error) { + // Avoid race with OnConfigChange() + m.mcpLock.RLock() + defer m.mcpLock.RUnlock() + for _, upstream := range m.mcpServers { t := upstream.GetServedManagedTool(tool) if t != nil { @@ -243,6 +251,10 @@ func (m *mcpBrokerImpl) GetServerInfo(tool string) (*config.MCPServer, error) { } func (m *mcpBrokerImpl) Shutdown(_ context.Context) error { + // Avoid race with OnConfigChange() + m.mcpLock.RLock() + defer m.mcpLock.RUnlock() + // Close the long-running notification channel for _, mcpServer := range m.mcpServers { if mcpServer != nil { @@ -265,6 +277,10 @@ func (m *mcpBrokerImpl) HandleStatusRequest(w http.ResponseWriter, r *http.Reque // ValidateAllServers performs comprehensive validation of all registered servers and returns status func (m *mcpBrokerImpl) ValidateAllServers() StatusResponse { + // The race is with len(m.mcpServers), which is not thread-safe in Go + m.mcpLock.RLock() + defer m.mcpLock.RUnlock() + response := StatusResponse{ Servers: make([]upstream.ServerValidationStatus, 0), OverallValid: true, diff --git a/internal/config/mcpservers.go b/internal/config/mcpservers.go index b4797aec..e674c39e 100644 --- a/internal/config/mcpservers.go +++ b/internal/config/mcpservers.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/url" + "sync" ) // UpstreamMCPID is used as type for identifying individual upstreams @@ -12,6 +13,8 @@ type UpstreamMCPID string // MCPServersConfig holds server configuration type MCPServersConfig struct { + lock sync.RWMutex + Servers []*MCPServer VirtualServers []*VirtualServer observers []Observer @@ -23,24 +26,33 @@ type MCPServersConfig struct { // RegisterObserver registers an observer to be notified of changes to the config func (config *MCPServersConfig) RegisterObserver(obs Observer) { + config.lock.Lock() + defer config.lock.Unlock() + config.observers = append(config.observers, obs) } // Notify notifies registered observers of config changes func (config *MCPServersConfig) Notify(ctx context.Context) { + config.lock.RLock() + defer config.lock.RUnlock() + for _, observer := range config.observers { go observer.OnConfigChange(ctx, config) } } // GetServerConfigByName get the routing config by server name -func (config *MCPServersConfig) GetServerConfigByName(serverName string) *MCPServer { +func (config *MCPServersConfig) GetServerConfigByName(serverName string) (*MCPServer, error) { + config.lock.RLock() + defer config.lock.RUnlock() + for _, server := range config.Servers { if server.Name == serverName { - return server + return server, nil } } - return nil + return nil, fmt.Errorf("unknown server") } // MCPServer represents a server diff --git a/internal/mcp-router/request_handlers.go b/internal/mcp-router/request_handlers.go index 7b04756a..6794653f 100644 --- a/internal/mcp-router/request_handlers.go +++ b/internal/mcp-router/request_handlers.go @@ -298,7 +298,10 @@ data: {"result":{"content":[{"type":"text","text":"MCP error -32602: Tool not fo // This connection is kept open for the life of the gateway session. // TODO when we receive a 404 from a backend MCP Server we should have a way to close the connection at that point also currently when we receive a 404 we remove the session from cache and will open a new connection. They will all be closed once the gateway session expires or the client sends a delete but it is a source of potential leaks func (s *ExtProcServer) initializeMCPSeverSession(ctx context.Context, mcpReq *MCPRequest) (string, error) { - mcpServerConfig := s.RoutingConfig.GetServerConfigByName(mcpReq.serverName) + mcpServerConfig, err := s.RoutingConfig.GetServerConfigByName(mcpReq.serverName) + if err != nil { + return "", NewRouterErrorf(500, "failed check for server: %w", err) + } exists, err := s.SessionCache.GetSession(ctx, mcpReq.GetSessionID()) if err != nil { return "", NewRouterErrorf(500, "failed to check for existing session: %w", err)