Skip to content

Commit b835b2d

Browse files
authored
fix: restore req body (#397)
1 parent f4daa21 commit b835b2d

File tree

4 files changed

+100
-9
lines changed

4 files changed

+100
-9
lines changed

main_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,37 @@ func TestServe(t *testing.T) {
212212
},
213213
startTLS,
214214
},
215+
{
216+
"https request body is not empty",
217+
"testdata/https.yml",
218+
func(t *testing.T) {
219+
query := "SELECT SleepTimeout"
220+
buf := bytes.NewBufferString(query)
221+
req, err := http.NewRequest("POST", "https://127.0.0.1:8443", buf)
222+
checkErr(t, err)
223+
req.SetBasicAuth("default", "qwerty")
224+
req.Close = true
225+
226+
resp, err := tlsClient.Do(req)
227+
checkErr(t, err)
228+
if resp.StatusCode != http.StatusGatewayTimeout {
229+
t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusGatewayTimeout)
230+
}
231+
232+
bodyBytes, err := io.ReadAll(resp.Body)
233+
if err != nil {
234+
t.Fatalf("error while reading body from response; err: %q", err)
235+
}
236+
237+
b := string(bodyBytes)
238+
if !strings.Contains(b, query) {
239+
t.Fatalf("expected request body: %q; got: %q", query, b)
240+
}
241+
242+
resp.Body.Close()
243+
},
244+
startTLS,
245+
},
215246
{
216247
"https cache with mix query source",
217248
"testdata/https.cache.yml",
@@ -1019,6 +1050,26 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) {
10191050
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
10201051
fmt.Fprint(w, b)
10211052
fmt.Fprint(w, "Ok.\n")
1053+
case strings.Contains(q, "SELECT SleepTimeout"):
1054+
w.WriteHeader(http.StatusGatewayTimeout)
1055+
1056+
bodyBytes, err := io.ReadAll(r.Body)
1057+
if err != nil {
1058+
fmt.Fprintf(w, "query: %s; error while reading body: %s", query, err)
1059+
return
1060+
}
1061+
1062+
b := string(bodyBytes)
1063+
// Ensure the original request body is not empty and remains unchanged
1064+
// after it is processed by getFullQuery.
1065+
if b == "" && b != q {
1066+
fmt.Fprintf(w, "got original req body: <%s>; escaped query: <%s>", b, q)
1067+
return
1068+
}
1069+
1070+
// execute sleep 1.5 sec
1071+
time.Sleep(1500 * time.Millisecond)
1072+
fmt.Fprint(w, b)
10221073
default:
10231074
if strings.Contains(string(query), killQueryPattern) {
10241075
fakeCHState.kill()

proxy.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,23 +207,21 @@ func executeWithRetry(
207207
startTime := time.Now()
208208
var since float64
209209

210-
// keep the request body
211-
body, err := io.ReadAll(req.Body)
212-
req.Body.Close()
210+
// Use readAndRestoreRequestBody to read the entire request body into a byte slice,
211+
// and to restore req.Body so that it can be reused later in the code.
212+
body, err := readAndRestoreRequestBody(req)
213213
if err != nil {
214-
since = time.Since(startTime).Seconds()
215-
214+
since := time.Since(startTime).Seconds()
216215
return since, err
217216
}
218217

219218
numRetry := 0
220219
for {
221-
// update body
222-
req.Body = io.NopCloser(bytes.NewBuffer(body))
223-
req.Body.Close()
224-
225220
rp(rw, req)
226221

222+
// Restore req.Body after it's consumed by 'rp' for potential reuse.
223+
req.Body = io.NopCloser(bytes.NewBuffer(body))
224+
227225
err := ctx.Err()
228226
if err != nil {
229227
since = time.Since(startTime).Seconds()

proxy_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,25 @@ func TestReverseProxy_ServeHTTP2(t *testing.T) {
10121012
t.Fatalf("expected response: %q; got: %q", expected, b)
10131013
}
10141014
})
1015+
1016+
t.Run("request body not empty", func(t *testing.T) {
1017+
proxy, err := getProxy(goodCfg)
1018+
if err != nil {
1019+
t.Fatalf("unexpected error: %s", err)
1020+
}
1021+
body := bytes.NewBufferString("SELECT sleep(1.5)")
1022+
expected := "SELECT sleep(1.5)"
1023+
req := httptest.NewRequest("POST", fakeServer.URL, body)
1024+
1025+
resp := makeCustomRequest(proxy, req)
1026+
b := bbToString(t, resp.Body)
1027+
resp.Body.Close()
1028+
1029+
if !strings.Contains(b, expected) {
1030+
t.Fatalf("expected response: %q; got: %q", expected, b)
1031+
}
1032+
1033+
})
10151034
}
10161035

10171036
func getNetwork(s string) *net.IPNet {

utils.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,17 @@ func getQuerySnippetFromBody(req *http.Request) string {
9898
// 'read' request body, so it traps into to crc.
9999
// Ignore any errors, since getQuerySnippet is called only
100100
// during error reporting.
101+
// Temporary solution: Quick and dirty way to work with the request body.
102+
// TODO: Create an original copy of req.Body and work with the copy to avoid altering the original request.
103+
// This current approach consumes the req.Body content with io.Copy(io.Discard, crc) to reset the internal state of crc.
104+
// However, it is not the most efficient or safest method, as it modifies the original req.Body.
101105
io.Copy(io.Discard, crc) // nolint
102106
data := crc.String()
103107

108+
// Here, we attempt to restore req.Body by wrapping the string data in a ReadCloser.
109+
// This is part of the temporary solution and should be replaced with a more robust method that does not consume the original req.Body.
110+
req.Body = io.NopCloser(strings.NewReader(data))
111+
104112
u := getDecompressor(req)
105113
if u == nil {
106114
return data
@@ -295,3 +303,18 @@ func calcCredentialHash(user string, pwd string) (uint32, error) {
295303
_, err := h.Write([]byte(user + pwd))
296304
return h.Sum32(), err
297305
}
306+
307+
// Function to read the request body and return it as a byte slice.
308+
// It also restores the req.Body to be used again.
309+
func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
310+
// Read the entire request body.
311+
body, err := io.ReadAll(req.Body)
312+
if err != nil {
313+
return nil, err
314+
}
315+
// Restore the req.Body with a new reader for the original content.
316+
req.Body = io.NopCloser(bytes.NewReader(body))
317+
318+
// Return the read body.
319+
return body, nil
320+
}

0 commit comments

Comments
 (0)