Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 38 additions & 57 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,15 +516,18 @@ func (s *SouinBaseHandler) Upstream(
}

err := s.Store(customWriter, rq, requestCc, cachedKey, uri)
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
defer customWriter.resetBuffer()

// Create a copy of the buffer to prevent memory retention
// when the buffer is returned to the pool
bodyCopy := make([]byte, customWriter.Buf.Len())
copy(bodyCopy, customWriter.Buf.Bytes())

return singleflightValue{
body: customWriter.Buf.Bytes(),
headers: customWriter.Header().Clone(),
requestHeaders: rq.Header,
code: statusCode,
body: bodyCopy,
headers: customWriter.Header().Clone(),
requestHeaders: rq.Header,
code: statusCode,
disableCoalescing: strings.Contains(cacheControl, "private") || customWriter.Header().Get("Set-Cookie") != "",
}, err
})
Expand Down Expand Up @@ -581,19 +584,15 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
statusCode := customWriter.GetStatusCode()
if err == nil {
if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified {
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.resetBuffer()
customWriter.Rw.WriteHeader(http.StatusPreconditionFailed)

return nil, errors.New("")
}

if validator.IfModifiedSincePresent {
if lastModified, err := time.Parse(time.RFC1123, customWriter.Header().Get("Last-Modified")); err == nil && validator.IfModifiedSince.Sub(lastModified) > 0 {
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.resetBuffer()
customWriter.Rw.WriteHeader(http.StatusNotModified)

return nil, errors.New("")
Expand All @@ -615,11 +614,15 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
),
)

defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
defer customWriter.resetBuffer()

// Create a copy of the buffer to prevent memory retention
// when the buffer is returned to the pool
bodyCopy := make([]byte, customWriter.Buf.Len())
copy(bodyCopy, customWriter.Buf.Bytes())

return singleflightValue{
body: customWriter.Buf.Bytes(),
body: bodyCopy,
headers: customWriter.Header().Clone(),
code: statusCode,
}, err
Expand Down Expand Up @@ -649,11 +652,13 @@ func (s *SouinBaseHandler) HandleInternally(r *http.Request) (bool, http.Handler
return false, nil
}

type handlerFunc = func(http.ResponseWriter, *http.Request) error
type statusCodeLogger struct {
http.ResponseWriter
statusCode int
}
type (
handlerFunc = func(http.ResponseWriter, *http.Request) error
statusCodeLogger struct {
http.ResponseWriter
statusCode int
}
)

func (s *statusCodeLogger) WriteHeader(code int) {
s.statusCode = code
Expand Down Expand Up @@ -838,18 +843,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
if validator.NotModified {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.resetBuffer()
_, _ = customWriter.Send()

return nil
}

customWriter.WriteHeader(response.StatusCode)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, _ = customWriter.Send()

return nil
Expand All @@ -875,9 +876,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
customWriter.WriteHeader(response.StatusCode)
s.Configuration.GetLogger().Debugf("Serve from cache %+v", req)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, err := customWriter.Send()
prometheus.Increment(prometheus.CachedResponseCounter)

Expand All @@ -897,9 +896,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, err := customWriter.Send()
customWriter = NewCustomWriter(req, rw, bufPool)
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
Expand All @@ -923,18 +920,13 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
maps.Copy(customWriter.Header(), response.Header)
customWriter.WriteHeader(response.StatusCode)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.resetAndCopyToBuffer(response.Body)
_, err := customWriter.Send()

return err
}
rw.WriteHeader(http.StatusGatewayTimeout)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.resetBuffer()
_, err := customWriter.Send()

return err
Expand All @@ -945,9 +937,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
rfc.SetCacheStatusHeader(response, storerName)
customWriter.WriteHeader(response.StatusCode)
maps.Copy(customWriter.Header(), response.Header)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, _ = customWriter.Send()

return err
Expand All @@ -956,9 +946,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n

if statusCode != http.StatusNotModified && validator.Matched {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.resetBuffer()
_, _ = customWriter.Send()

return err
Expand All @@ -973,9 +961,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
maps.Copy(customWriter.Header(), response.Header)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, err := customWriter.Send()

return err
Expand All @@ -989,9 +975,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
maps.Copy(customWriter.Header(), response.Header)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.copyToBuffer(response.Body)
_, err := customWriter.Send()

return err
Expand All @@ -1015,10 +999,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
maps.Copy(customWriter.Header(), response.Header)
customWriter.WriteHeader(response.StatusCode)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.resetAndCopyToBuffer(response.Body)
_, err := customWriter.Send()

return err
Expand Down
30 changes: 21 additions & 9 deletions pkg/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bytes"
"fmt"
"io"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -41,12 +42,25 @@ type CustomWriter struct {
statusCode int
}

func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) {
func (r *CustomWriter) resetBuffer() {
r.mutex.Lock()
callback(r.Buf)
r.Buf.Reset()
r.mutex.Unlock()
}

func (r *CustomWriter) copyToBuffer(src io.Reader) (int64, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
return io.Copy(r.Buf, src)
}

func (r *CustomWriter) resetAndCopyToBuffer(src io.Reader) (int64, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.Buf.Reset()
return io.Copy(r.Buf, src)
}

// Header will write the response headers
func (r *CustomWriter) Header() http.Header {
r.mutex.Lock()
Expand Down Expand Up @@ -79,10 +93,10 @@ func (r *CustomWriter) WriteHeader(code int) {

// Write will write the response body
func (r *CustomWriter) Write(b []byte) (int, error) {
r.handleBuffer(func(actual *bytes.Buffer) {
actual.Grow(len(b))
_, _ = actual.Write(b)
})
r.mutex.Lock()
r.Buf.Grow(len(b))
_, _ = r.Buf.Write(b)
r.mutex.Unlock()

return len(b), nil
}
Expand Down Expand Up @@ -142,9 +156,7 @@ func parseRange(rangeHeaders []string, contentRange string) ([]rangeValue, range

// Send delays the response to handle Cache-Status
func (r *CustomWriter) Send() (int, error) {
defer r.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
defer r.resetBuffer()
storedLength := r.Header().Get(rfc.StoredLengthHeader)
if storedLength != "" {
r.Header().Set("Content-Length", storedLength)
Expand Down
15 changes: 14 additions & 1 deletion plugins/traefik/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package traefik

import (
"github.com/darkweak/souin/configurationtypes"
"github.com/darkweak/storages/core"
)

// Configuration holder
Expand All @@ -13,7 +14,19 @@ type Configuration struct {
LogLevel string `json:"log_level" yaml:"log_level"`
Ykeys map[string]configurationtypes.SurrogateKeys `json:"ykeys" yaml:"ykeys"`
SurrogateKeys map[string]configurationtypes.SurrogateKeys `json:"surrogate_keys" yaml:"surrogate_keys"`
SurrogateKeyDisabled bool `json:"disable_surrogate_key" yaml:"disable_surrogate_key"`
SurrogateKeyDisabled bool

logger core.Logger
}

// GetLogger implements configurationtypes.AbstractConfigurationInterface.
func (c *Configuration) GetLogger() core.Logger {
return c.logger
}

// SetLogger implements configurationtypes.AbstractConfigurationInterface.
func (c *Configuration) SetLogger(logger core.Logger) {
c.logger = logger
}

// GetUrls get the urls list in the configuration
Expand Down
Loading
Loading