Skip to content

Commit 3651c3c

Browse files
authored
Update http client (#142)
1 parent 372e27f commit 3651c3c

File tree

2 files changed

+100
-101
lines changed

2 files changed

+100
-101
lines changed

pkg/api/http_client.go

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/http"
77
"regexp"
88
"strings"
9+
"sync"
910
"time"
1011

1112
"wpm/pkg/asciisanitizer"
@@ -29,13 +30,23 @@ const (
2930
jsonContentType = "application/json; charset=utf-8"
3031
)
3132

32-
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
33+
var (
34+
jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
35+
zstdDecoderPool = sync.Pool{
36+
New: func() any {
37+
d, err := zstd.NewReader(nil)
38+
if err != nil {
39+
panic(fmt.Sprintf("failed to create zstd reader: %v", err))
40+
}
41+
return d
42+
},
43+
}
44+
)
3345

3446
func DefaultHTTPClient() (*http.Client, error) {
3547
return NewHTTPClient(ClientOptions{})
3648
}
3749

38-
// NewHTTPClient creates a new HTTP client with the provided options.
3950
func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
4051
if optionsNeedResolution(opts) {
4152
var err error
@@ -45,22 +56,28 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
4556
}
4657
}
4758

48-
transport := http.DefaultTransport
59+
transport := &http.Transport{
60+
MaxIdleConns: 100,
61+
MaxIdleConnsPerHost: 100,
62+
IdleConnTimeout: 90 * time.Second,
63+
ForceAttemptHTTP2: true,
64+
DisableCompression: true,
65+
}
66+
67+
var rt http.RoundTripper = transport
4968

5069
if opts.CacheDir == "" {
5170
opts.CacheDir = config.CacheDir()
5271
}
5372

5473
if opts.EnableCache && opts.CacheTTL == 0 {
5574
opts.CacheTTL = time.Hour * 24
56-
5775
c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
58-
transport = c.RoundTripper(transport)
76+
rt = c.RoundTripper(rt)
5977
}
6078

6179
if opts.Log != nil && logrus.GetLevel() == logrus.DebugLevel {
6280
opts.LogVerboseHTTP = true
63-
6481
logger := &httpretty.Logger{
6582
Time: true,
6683
TLS: false,
@@ -76,7 +93,7 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
7693
logger.SetBodyFilter(func(h http.Header) (skip bool, err error) {
7794
return !inspectableMIMEType(h.Get(contentType)), nil
7895
})
79-
transport = logger.RoundTripper(transport)
96+
rt = logger.RoundTripper(rt)
8097
}
8198

8299
if opts.Headers == nil {
@@ -87,11 +104,11 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
87104
resolveHeaders(opts.Headers)
88105
}
89106

90-
transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)
91-
transport = newDecompressingRoundTripper(transport)
92-
transport = newSanitizerRoundTripper(transport)
107+
rt = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, rt)
108+
rt = newDecompressingRoundTripper(rt)
109+
rt = newSanitizerRoundTripper(rt)
93110

94-
return &http.Client{Transport: transport, Timeout: opts.Timeout}, nil
111+
return &http.Client{Transport: rt, Timeout: opts.Timeout}, nil
95112
}
96113

97114
func inspectableMIMEType(t string) bool {
@@ -114,18 +131,14 @@ func resolveHeaders(headers map[string]string) {
114131
if _, ok := headers[contentType]; !ok {
115132
headers[contentType] = jsonContentType
116133
}
117-
118134
if _, ok := headers[userAgent]; !ok {
119135
headers[userAgent] = "wpm-cli"
120136
}
121-
122137
if _, ok := headers[timeZone]; !ok {
123-
tz := currentTimeZone()
124-
if tz != "" {
138+
if tz, err := tzlocal.RuntimeTZ(); err == nil && tz != "" {
125139
headers[timeZone] = tz
126140
}
127141
}
128-
129142
if _, ok := headers[accept]; !ok {
130143
headers[accept] = "application/json"
131144
}
@@ -136,30 +149,25 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
136149
headers[authorization] = fmt.Sprintf("Bearer %s", authToken)
137150
}
138151
if len(headers) == 0 {
139-
return rt
152+
return headerRoundTripper{host: host, headers: nil, rt: rt}
140153
}
141154
return headerRoundTripper{host: host, headers: headers, rt: rt}
142155
}
143156

144157
func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
145-
// In wpm, we always request zstd compressed responses.
146-
req.Header.Set("Accept-Encoding", "zstd")
158+
reqCopy := req.Clone(req.Context())
159+
reqCopy.Header.Set("Accept-Encoding", "zstd")
147160

148161
for k, v := range hrt.headers {
149-
// If the authorization header has been set and the request
150-
// host is not in the same domain that was specified in the ClientOptions
151-
// then do not add the authorization header to the request.
152-
if k == authorization && !isSameDomain(req.URL.Hostname(), hrt.host) {
162+
if k == authorization && !isSameDomain(reqCopy.URL.Hostname(), hrt.host) {
153163
continue
154164
}
155-
156-
// If the header is already set in the request, don't overwrite it.
157-
if req.Header.Get(k) == "" {
158-
req.Header.Set(k, v)
165+
if reqCopy.Header.Get(k) == "" {
166+
reqCopy.Header.Set(k, v)
159167
}
160168
}
161169

162-
return hrt.rt.RoundTrip(req)
170+
return hrt.rt.RoundTrip(reqCopy)
163171
}
164172

165173
type sanitizerRoundTripper struct {
@@ -172,21 +180,24 @@ func newSanitizerRoundTripper(rt http.RoundTripper) http.RoundTripper {
172180

173181
func (srt sanitizerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
174182
resp, err := srt.rt.RoundTrip(req)
175-
if err != nil || !jsonTypeRE.MatchString(resp.Header.Get(contentType)) {
183+
if err != nil {
176184
return resp, err
177185
}
178-
sanitizedReadCloser := struct {
179-
io.Reader
180-
io.Closer
181-
}{
186+
if !inspectableMIMEType(resp.Header.Get(contentType)) {
187+
return resp, nil
188+
}
189+
resp.Body = &wrappedBody{
182190
Reader: transform.NewReader(resp.Body, &asciisanitizer.Sanitizer{JSON: true}),
183191
Closer: resp.Body,
184192
}
185-
resp.Body = sanitizedReadCloser
186-
return resp, err
193+
return resp, nil
194+
}
195+
196+
type wrappedBody struct {
197+
io.Reader
198+
io.Closer
187199
}
188200

189-
// NEW RoundTripper for decompression
190201
type decompressingRoundTripper struct {
191202
rt http.RoundTripper
192203
}
@@ -201,26 +212,37 @@ func (d decompressingRoundTripper) RoundTrip(req *http.Request) (*http.Response,
201212
return nil, err
202213
}
203214

204-
// support for zstd compressed responses
205215
if resp.Header.Get("Content-Encoding") == "zstd" {
206-
reader, err := zstd.NewReader(resp.Body)
207-
if err != nil {
216+
decoder := zstdDecoderPool.Get().(*zstd.Decoder)
217+
if err := decoder.Reset(resp.Body); err != nil {
208218
resp.Body.Close()
209-
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
219+
zstdDecoderPool.Put(decoder)
220+
return nil, fmt.Errorf("failed to reset zstd reader: %w", err)
210221
}
211222

212-
resp.Body = &readCloser{Reader: reader, Closer: resp.Body}
223+
resp.Body = &zstdReadCloser{
224+
Decoder: decoder,
225+
OriginalBody: resp.Body,
226+
}
213227
resp.Header.Del("Content-Encoding")
214228
resp.Header.Del("Content-Length")
229+
resp.ContentLength = -1
215230
}
216231

217232
return resp, nil
218233
}
219234

220-
func currentTimeZone() string {
221-
tz, err := tzlocal.RuntimeTZ()
222-
if err != nil {
223-
return ""
224-
}
225-
return tz
235+
type zstdReadCloser struct {
236+
Decoder *zstd.Decoder
237+
OriginalBody io.ReadCloser
238+
}
239+
240+
func (z *zstdReadCloser) Read(p []byte) (n int, err error) {
241+
return z.Decoder.Read(p)
242+
}
243+
244+
func (z *zstdReadCloser) Close() error {
245+
err := z.OriginalBody.Close()
246+
zstdDecoderPool.Put(z.Decoder)
247+
return err
226248
}

0 commit comments

Comments
 (0)