diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index e1b9fd78..2d595d88 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -32,9 +33,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Health check endpoint if rawPath == "/health" || rawPath == "/healthz" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + writeJSONResponse(w, http.StatusOK, map[string]string{"status": "ok"}) return } @@ -66,9 +65,7 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if match == nil { // Unknown GraphQL query — fail closed: deny rather than risk leaking unfiltered data logHandler.Printf("unknown GraphQL query, blocking request: %s", truncateForLog(string(graphQLBody), 500)) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - json.NewEncoder(w).Encode(map[string]interface{}{ + writeJSONResponse(w, http.StatusForbidden, map[string]interface{}{ "errors": []map[string]string{{"message": "access denied: unrecognized GraphQL operation"}}, "data": nil, }) @@ -78,13 +75,10 @@ func (h *proxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if match.ToolName == "graphql_introspection" { logHandler.Printf("GraphQL introspection query, passing through") clientAuth := r.Header.Get("Authorization") - resp, err := h.server.forwardToGitHub(r.Context(), http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) - if err != nil { - http.Error(w, "upstream request failed", http.StatusBadGateway) + resp, respBody := h.forwardAndReadBody(r.Context(), w, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) + if resp == nil { return } - defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) h.writeResponse(w, resp, respBody) return } @@ -146,9 +140,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa } else { // Write blocked logHandler.Printf("[DIFC] Phase 2: BLOCKED %s %s — %s", r.Method, path, evalResult.Reason) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - json.NewEncoder(w).Encode(map[string]string{ + writeJSONResponse(w, http.StatusForbidden, map[string]string{ "message": fmt.Sprintf("DIFC policy violation: %s", evalResult.Reason), }) return @@ -158,22 +150,13 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa // **Phase 3: Forward to upstream GitHub API** clientAuth := r.Header.Get("Authorization") var resp *http.Response + var respBody []byte if graphQLBody != nil { - resp, err = s.forwardToGitHub(ctx, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) + resp, respBody = h.forwardAndReadBody(ctx, w, http.MethodPost, "/graphql", bytes.NewReader(graphQLBody), "application/json", clientAuth) } else { - resp, err = s.forwardToGitHub(ctx, r.Method, path, nil, "", clientAuth) - } - if err != nil { - logHandler.Printf("[DIFC] Phase 3 failed: %v", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return + resp, respBody = h.forwardAndReadBody(ctx, w, r.Method, path, nil, "", clientAuth) } - defer resp.Body.Close() - - // Read the response body - respBody, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, "failed to read upstream response", http.StatusBadGateway) + if resp == nil { return } @@ -226,9 +209,7 @@ func (h *proxyHandler) handleWithDIFC(w http.ResponseWriter, r *http.Request, pa // Strict mode: block entire response if any item filtered if s.enforcementMode == difc.EnforcementStrict && filtered.GetFilteredCount() > 0 { logHandler.Printf("[DIFC] STRICT: blocking response — %d filtered items", filtered.GetFilteredCount()) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - json.NewEncoder(w).Encode(map[string]string{ + writeJSONResponse(w, http.StatusForbidden, map[string]string{ "message": fmt.Sprintf("DIFC policy violation: %d of %d items not accessible", filtered.GetFilteredCount(), filtered.TotalCount), }) @@ -312,16 +293,8 @@ func (h *proxyHandler) passthrough(w http.ResponseWriter, r *http.Request, path defer r.Body.Close() } - resp, err := h.server.forwardToGitHub(r.Context(), r.Method, path, body, r.Header.Get("Content-Type"), r.Header.Get("Authorization")) - if err != nil { - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, "failed to read upstream response", http.StatusBadGateway) + resp, respBody := h.forwardAndReadBody(r.Context(), w, r.Method, path, body, r.Header.Get("Content-Type"), r.Header.Get("Authorization")) + if resp == nil { return } @@ -376,6 +349,35 @@ func copyResponseHeaders(w http.ResponseWriter, resp *http.Response) { } } +// forwardAndReadBody forwards a request to the upstream GitHub API and reads the full +// response body. On success it returns the response (body already drained and closed) +// and the body bytes. On error it writes an appropriate HTTP error to w and returns +// (nil, nil); callers should return immediately on a nil response. +func (h *proxyHandler) forwardAndReadBody(ctx context.Context, w http.ResponseWriter, method, path string, body io.Reader, contentType, clientAuth string) (*http.Response, []byte) { + resp, err := h.server.forwardToGitHub(ctx, method, path, body, contentType, clientAuth) + if err != nil { + logHandler.Printf("upstream request failed: %s %s: %v", method, path, err) + http.Error(w, "upstream request failed", http.StatusBadGateway) + return nil, nil + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + logHandler.Printf("failed to read upstream response: %s %s: %v", method, path, err) + http.Error(w, "failed to read upstream response", http.StatusBadGateway) + return nil, nil + } + return resp, respBody +} + +// writeJSONResponse sets the Content-Type header, writes the status code, and encodes +// body as JSON. It centralises the three-line pattern used across HTTP handlers. +func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(body) //nolint:errcheck +} + func truncateForLog(s string, maxLen int) string { if len(s) <= maxLen { return s