Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@ type AWSBedrock struct {
BaseURL string
}

// OpenAI contains provider-level configuration for the OpenAI provider.
type OpenAI struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
SendActorHeaders bool
}

// OpenAIInterceptor contains configuration for interceptors that speak the
// OpenAI wire format. Used by any provider that uses OpenAI-compatible APIs.
type OpenAIInterceptor struct {
Key string
SendActorHeaders bool
ExtraHeaders map[string]string
}

Expand Down
8 changes: 5 additions & 3 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ import (
type interceptionBase struct {
id uuid.UUID
providerName string
baseURL string
apiDumpDir string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI
cfg config.OpenAIInterceptor

// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
Expand All @@ -43,7 +45,7 @@ type interceptionBase struct {
}

func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.baseURL)}

// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
Expand All @@ -63,7 +65,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.apiDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand Down
6 changes: 5 additions & 1 deletion intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
baseURL string,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
apiDumpDir: apiDumpDir,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
6 changes: 5 additions & 1 deletion intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ func NewStreamingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
baseURL string,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
apiDumpDir: apiDumpDir,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
7 changes: 3 additions & 4 deletions intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
t.Cleanup(mockServer.Close)

// Create interceptor with mock server URL
cfg := config.OpenAI{
BaseURL: mockServer.URL,
Key: "test-key",
cfg := config.OpenAIInterceptor{
Key: "test-key",
}

req := &ChatCompletionNewParamsWrapper{
Expand All @@ -86,7 +85,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, mockServer.URL, "", cfg, httpReq.Header, "Authorization", tracer)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
8 changes: 5 additions & 3 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ const (
type responsesInterceptionBase struct {
id uuid.UUID
providerName string
baseURL string
apiDumpDir string
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
reqPayload ResponsesRequestPayload

cfg config.OpenAI
cfg config.OpenAIInterceptor
recorder recorder.Recorder
mcpProxy mcp.ServerProxier

Expand All @@ -52,7 +54,7 @@ type responsesInterceptionBase struct {
}

func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)}
opts := []option.RequestOption{option.WithBaseURL(i.baseURL), option.WithAPIKey(i.cfg.Key)}

// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
Expand All @@ -72,7 +74,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.apiDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand Down
6 changes: 5 additions & 1 deletion intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func NewBlockingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
providerName string,
cfg config.OpenAI,
baseURL string,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
Expand All @@ -38,6 +40,8 @@ func NewBlockingInterceptor(
responsesInterceptionBase: responsesInterceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
apiDumpDir: apiDumpDir,
reqPayload: reqPayload,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
6 changes: 5 additions & 1 deletion intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ func NewStreamingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
providerName string,
cfg config.OpenAI,
baseURL string,
apiDumpDir string,
cfg config.OpenAIInterceptor,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
Expand All @@ -45,6 +47,8 @@ func NewStreamingInterceptor(
responsesInterceptionBase: responsesInterceptionBase{
id: id,
providerName: providerName,
baseURL: baseURL,
apiDumpDir: apiDumpDir,
reqPayload: reqPayload,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
22 changes: 9 additions & 13 deletions provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,11 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac

id := uuid.New()

// Build config for the interceptor using the per-request key.
// Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors
// that require a config.OpenAI.
cfg := config.OpenAI{
BaseURL: p.cfg.BaseURL,
Key: key,
APIDumpDir: p.cfg.APIDumpDir,
CircuitBreaker: p.cfg.CircuitBreaker,
ExtraHeaders: extractCopilotHeaders(r),
// Build interceptor config using the per-request key.
// Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors.
interceptorCfg := config.OpenAIInterceptor{
Key: key,
ExtraHeaders: extractCopilotHeaders(r),
}

var interceptor intercept.Interceptor
Expand All @@ -149,9 +145,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
}

case routeCopilotResponses:
Expand All @@ -165,9 +161,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
}

default:
Expand Down
17 changes: 11 additions & 6 deletions provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace

var interceptor intercept.Interceptor

cfg := p.cfg
// At this point the request contains only LLM provider headers. Any
// Coder-specific authentication has already been stripped.
//
Expand All @@ -106,8 +105,14 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
//
// In BYOK mode the user's credential is in Authorization. Replace
// the centralized key with it so it is forwarded upstream.
key := p.cfg.Key
if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.Key = token
key = token
}

interceptorCfg := config.OpenAIInterceptor{
Key: key,
SendActorHeaders: p.cfg.SendActorHeaders,
}

path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
Expand All @@ -119,9 +124,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
}

case routeResponses:
Expand All @@ -134,9 +139,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
return nil, fmt.Errorf("unmarshal request body: %w", err)
}
if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer)
}

default:
Expand Down
Loading