66 "net/http"
77 "regexp"
88 "strings"
9+ "sync"
910 "time"
1011
1112 "wpm/pkg/asciisanitizer"
@@ -29,13 +30,23 @@ const (
2930 jsonContentType = "application/json; charset=utf-8"
3031)
3132
32- var jsonTypeRE = regexp .MustCompile (`[/+]json($|;)` )
33+ var (
34+ jsonTypeRE = regexp .MustCompile (`[/+]json($|;)` )
35+ zstdDecoderPool = sync.Pool {
36+ New : func () any {
37+ d , err := zstd .NewReader (nil )
38+ if err != nil {
39+ panic (fmt .Sprintf ("failed to create zstd reader: %v" , err ))
40+ }
41+ return d
42+ },
43+ }
44+ )
3345
3446func DefaultHTTPClient () (* http.Client , error ) {
3547 return NewHTTPClient (ClientOptions {})
3648}
3749
38- // NewHTTPClient creates a new HTTP client with the provided options.
3950func NewHTTPClient (opts ClientOptions ) (* http.Client , error ) {
4051 if optionsNeedResolution (opts ) {
4152 var err error
@@ -45,22 +56,28 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
4556 }
4657 }
4758
48- transport := http .DefaultTransport
59+ transport := & http.Transport {
60+ MaxIdleConns : 100 ,
61+ MaxIdleConnsPerHost : 100 ,
62+ IdleConnTimeout : 90 * time .Second ,
63+ ForceAttemptHTTP2 : true ,
64+ DisableCompression : true ,
65+ }
66+
67+ var rt http.RoundTripper = transport
4968
5069 if opts .CacheDir == "" {
5170 opts .CacheDir = config .CacheDir ()
5271 }
5372
5473 if opts .EnableCache && opts .CacheTTL == 0 {
5574 opts .CacheTTL = time .Hour * 24
56-
5775 c := cache {dir : opts .CacheDir , ttl : opts .CacheTTL }
58- transport = c .RoundTripper (transport )
76+ rt = c .RoundTripper (rt )
5977 }
6078
6179 if opts .Log != nil && logrus .GetLevel () == logrus .DebugLevel {
6280 opts .LogVerboseHTTP = true
63-
6481 logger := & httpretty.Logger {
6582 Time : true ,
6683 TLS : false ,
@@ -76,7 +93,7 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
7693 logger .SetBodyFilter (func (h http.Header ) (skip bool , err error ) {
7794 return ! inspectableMIMEType (h .Get (contentType )), nil
7895 })
79- transport = logger .RoundTripper (transport )
96+ rt = logger .RoundTripper (rt )
8097 }
8198
8299 if opts .Headers == nil {
@@ -87,11 +104,11 @@ func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
87104 resolveHeaders (opts .Headers )
88105 }
89106
90- transport = newHeaderRoundTripper (opts .Host , opts .AuthToken , opts .Headers , transport )
91- transport = newDecompressingRoundTripper (transport )
92- transport = newSanitizerRoundTripper (transport )
107+ rt = newHeaderRoundTripper (opts .Host , opts .AuthToken , opts .Headers , rt )
108+ rt = newDecompressingRoundTripper (rt )
109+ rt = newSanitizerRoundTripper (rt )
93110
94- return & http.Client {Transport : transport , Timeout : opts .Timeout }, nil
111+ return & http.Client {Transport : rt , Timeout : opts .Timeout }, nil
95112}
96113
97114func inspectableMIMEType (t string ) bool {
@@ -114,18 +131,14 @@ func resolveHeaders(headers map[string]string) {
114131 if _ , ok := headers [contentType ]; ! ok {
115132 headers [contentType ] = jsonContentType
116133 }
117-
118134 if _ , ok := headers [userAgent ]; ! ok {
119135 headers [userAgent ] = "wpm-cli"
120136 }
121-
122137 if _ , ok := headers [timeZone ]; ! ok {
123- tz := currentTimeZone ()
124- if tz != "" {
138+ if tz , err := tzlocal .RuntimeTZ (); err == nil && tz != "" {
125139 headers [timeZone ] = tz
126140 }
127141 }
128-
129142 if _ , ok := headers [accept ]; ! ok {
130143 headers [accept ] = "application/json"
131144 }
@@ -136,30 +149,25 @@ func newHeaderRoundTripper(host string, authToken string, headers map[string]str
136149 headers [authorization ] = fmt .Sprintf ("Bearer %s" , authToken )
137150 }
138151 if len (headers ) == 0 {
139- return rt
152+ return headerRoundTripper { host : host , headers : nil , rt : rt }
140153 }
141154 return headerRoundTripper {host : host , headers : headers , rt : rt }
142155}
143156
144157func (hrt headerRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
145- // In wpm, we always request zstd compressed responses.
146- req .Header .Set ("Accept-Encoding" , "zstd" )
158+ reqCopy := req . Clone ( req . Context ())
159+ reqCopy .Header .Set ("Accept-Encoding" , "zstd" )
147160
148161 for k , v := range hrt .headers {
149- // If the authorization header has been set and the request
150- // host is not in the same domain that was specified in the ClientOptions
151- // then do not add the authorization header to the request.
152- if k == authorization && ! isSameDomain (req .URL .Hostname (), hrt .host ) {
162+ if k == authorization && ! isSameDomain (reqCopy .URL .Hostname (), hrt .host ) {
153163 continue
154164 }
155-
156- // If the header is already set in the request, don't overwrite it.
157- if req .Header .Get (k ) == "" {
158- req .Header .Set (k , v )
165+ if reqCopy .Header .Get (k ) == "" {
166+ reqCopy .Header .Set (k , v )
159167 }
160168 }
161169
162- return hrt .rt .RoundTrip (req )
170+ return hrt .rt .RoundTrip (reqCopy )
163171}
164172
165173type sanitizerRoundTripper struct {
@@ -172,21 +180,24 @@ func newSanitizerRoundTripper(rt http.RoundTripper) http.RoundTripper {
172180
173181func (srt sanitizerRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
174182 resp , err := srt .rt .RoundTrip (req )
175- if err != nil || ! jsonTypeRE . MatchString ( resp . Header . Get ( contentType )) {
183+ if err != nil {
176184 return resp , err
177185 }
178- sanitizedReadCloser := struct {
179- io. Reader
180- io. Closer
181- } {
186+ if ! inspectableMIMEType ( resp . Header . Get ( contentType )) {
187+ return resp , nil
188+ }
189+ resp . Body = & wrappedBody {
182190 Reader : transform .NewReader (resp .Body , & asciisanitizer.Sanitizer {JSON : true }),
183191 Closer : resp .Body ,
184192 }
185- resp .Body = sanitizedReadCloser
186- return resp , err
193+ return resp , nil
194+ }
195+
196+ type wrappedBody struct {
197+ io.Reader
198+ io.Closer
187199}
188200
189- // NEW RoundTripper for decompression
190201type decompressingRoundTripper struct {
191202 rt http.RoundTripper
192203}
@@ -201,26 +212,37 @@ func (d decompressingRoundTripper) RoundTrip(req *http.Request) (*http.Response,
201212 return nil , err
202213 }
203214
204- // support for zstd compressed responses
205215 if resp .Header .Get ("Content-Encoding" ) == "zstd" {
206- reader , err := zstd . NewReader ( resp . Body )
207- if err != nil {
216+ decoder := zstdDecoderPool . Get ().( * zstd. Decoder )
217+ if err := decoder . Reset ( resp . Body ); err != nil {
208218 resp .Body .Close ()
209- return nil , fmt .Errorf ("failed to create zstd reader: %w" , err )
219+ zstdDecoderPool .Put (decoder )
220+ return nil , fmt .Errorf ("failed to reset zstd reader: %w" , err )
210221 }
211222
212- resp .Body = & readCloser {Reader : reader , Closer : resp .Body }
223+ resp .Body = & zstdReadCloser {
224+ Decoder : decoder ,
225+ OriginalBody : resp .Body ,
226+ }
213227 resp .Header .Del ("Content-Encoding" )
214228 resp .Header .Del ("Content-Length" )
229+ resp .ContentLength = - 1
215230 }
216231
217232 return resp , nil
218233}
219234
220- func currentTimeZone () string {
221- tz , err := tzlocal .RuntimeTZ ()
222- if err != nil {
223- return ""
224- }
225- return tz
235+ type zstdReadCloser struct {
236+ Decoder * zstd.Decoder
237+ OriginalBody io.ReadCloser
238+ }
239+
240+ func (z * zstdReadCloser ) Read (p []byte ) (n int , err error ) {
241+ return z .Decoder .Read (p )
242+ }
243+
244+ func (z * zstdReadCloser ) Close () error {
245+ err := z .OriginalBody .Close ()
246+ zstdDecoderPool .Put (z .Decoder )
247+ return err
226248}
0 commit comments