Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 42 additions & 40 deletions internal/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -32,9 +33,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Health check endpoint
if rawPath == "/health" || rawPath == "/healthz" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
writeJSONResponse(w, http.StatusOK, map[string]string{"status": "ok"})
return
}

Expand Down Expand Up @@ -66,9 +65,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if match == nil {
// Unknown GraphQL query — fail closed: deny rather than risk leaking unfiltered data
logHandler.Printf("unknown GraphQL query, blocking request: %s", truncateForLog(string(graphQLBody), 500))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]interface{}{
writeJSONResponse(w, http.StatusForbidden, map[string]interface{}{
"errors": []map[string]string{{"message": "access denied: unrecognized GraphQL operation"}},
"data": nil,
})
Expand All @@ -78,13 +75,10 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if match.ToolName == "graphql_introspection" {
logHandler.Printf("GraphQL introspection query, passing through")
clientAuth := r.Header.Get("Authorization")
resp, err := h.server.forwardToGitHub(r.Context(), http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth)
if err != nil {
http.Error(w, "upstream request failed", http.StatusBadGateway)
resp, respBody := h.forwardAndReadBody(r.Context(), w, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth)
if resp == nil {
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
h.writeResponse(w, resp, respBody)
return
}
Expand Down Expand Up @@ -146,9 +140,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
} else {
// Write blocked
logHandler.Printf("[DIFC] Phase 2: BLOCKED %s %s — %s", r.Method, path, evalResult.Reason)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{
writeJSONResponse(w, http.StatusForbidden, map[string]string{
"message": fmt.Sprintf("DIFC policy violation: %s", evalResult.Reason),
})
return
Expand All @@ -158,22 +150,13 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
// **Phase 3: Forward to upstream GitHub API**
clientAuth := r.Header.Get("Authorization")
var resp *http.Response
var respBody []byte
if graphQLBody != nil {
resp, err = s.forwardToGitHub(ctx, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth)
resp, respBody = h.forwardAndReadBody(ctx, w, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth)
} else {
resp, err = s.forwardToGitHub(ctx, r.Method, path, nil, "", clientAuth)
}
if err != nil {
logHandler.Printf("[DIFC] Phase 3 failed: %v", err)
http.Error(w, "upstream request failed", http.StatusBadGateway)
return
resp, respBody = h.forwardAndReadBody(ctx, w, r.Method, path, nil, "", clientAuth)
}
defer resp.Body.Close()

// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "failed to read upstream response", http.StatusBadGateway)
if resp == nil {
return
}

Expand Down Expand Up @@ -226,9 +209,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa
// Strict mode: block entire response if any item filtered
if s.enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 {
logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount())
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{
writeJSONResponse(w, http.StatusForbidden, map[string]string{
"message": fmt.Sprintf("DIFC policy violation: %d of %d items not accessible",
filtered.GetFilteredCount(), filtered.TotalCount),
})
Expand Down Expand Up @@ -312,16 +293,8 @@ func (h *proxyHandler) passthrough(w http.ResponseWriter, r *http.Request, path
defer r.Body.Close()
}

resp, err := h.server.forwardToGitHub(r.Context(), r.Method, path, body, r.Header.Get("Content-Type"), r.Header.Get("Authorization"))
if err != nil {
http.Error(w, "upstream request failed", http.StatusBadGateway)
return
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, "failed to read upstream response", http.StatusBadGateway)
resp, respBody := h.forwardAndReadBody(r.Context(), w, r.Method, path, body, r.Header.Get("Content-Type"), r.Header.Get("Authorization"))
if resp == nil {
return
}

Expand Down Expand Up @@ -376,6 +349,35 @@ func copyResponseHeaders(w http.ResponseWriter, resp *http.Response) {
}
}

// forwardAndReadBody forwards a request to the upstream GitHub API and reads the full
// response body. On success it returns the response (body already drained and closed)
// and the body bytes. On error it writes an appropriate HTTP error to w and returns
// (nil, nil); callers should return immediately on a nil response.
func (h *proxyHandler) forwardAndReadBody(ctx context.Context, w http.ResponseWriter, method, path string, body io.Reader, contentType, clientAuth string) (*http.Response, []byte) {
resp, err := h.server.forwardToGitHub(ctx, method, path, body, contentType, clientAuth)
if err != nil {
logHandler.Printf("upstream request failed: %s %s: %v", method, path, err)
http.Error(w, "upstream request failed", http.StatusBadGateway)
return nil, nil
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
logHandler.Printf("failed to read upstream response: %s %s: %v", method, path, err)
http.Error(w, "failed to read upstream response", http.StatusBadGateway)
return nil, nil
}
return resp, respBody
}

// writeJSONResponse sets the Content-Type header, writes the status code, and encodes
// body as JSON. It centralises the three-line pattern used across HTTP handlers.
func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(body) //nolint:errcheck
}

func truncateForLog(s string, maxLen int) string {
if len(s) <= maxLen {
return s
Expand Down
Loading