Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"bytes"
baseCtx "context"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -631,6 +632,60 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
return err
}

func (s *SouinBaseHandler) GetEarlyHints(cachedKey string) map[string][]string {
var wg sync.WaitGroup

wg.Add(s.storersLen)
cachedKey = fmt.Sprintf("%s_103_", cachedKey)

for _, storer := range s.Storers {
result := storer.MapKeys(cachedKey)

if len(result) != 0 {
earlyHintLinks := map[string][]string{}

for k, h := range result {
var res http.Header
_ = json.Unmarshal([]byte(h), &res)

res.Values("Link")
earlyHintLinks[k] = res.Values("Link")
}

s.Configuration.GetLogger().Debugf("Found early_hints %#v for the cachedKey %s", earlyHintLinks, cachedKey)

return earlyHintLinks
}
}

return nil
}

func (s *SouinBaseHandler) StoreEarlyHint(cachedKey string, h http.Header, iteration int) {
var wg sync.WaitGroup

wg.Add(s.storersLen)

cachedKey = fmt.Sprintf("%s_103_%d", cachedKey, iteration)

byteHeaders, _ := json.Marshal(h)

for _, storer := range s.Storers {
go func(currentStorer types.Storer, currentKey string, byteHeaders []byte) {
defer wg.Done()
if currentStorer.Set(currentKey, byteHeaders, s.DefaultMatchedUrl.TTL.Duration) == nil {
s.Configuration.GetLogger().Debugf("Stored the early_hint key %s in the %s provider for %v duration", currentKey, currentStorer.Name())
} else {
s.Configuration.GetLogger().Debugf(
"Cannot store the key early_hint key %s in the %s provider",
currentKey,
currentStorer.Name(),
)
}
}(storer, cachedKey, byteHeaders)
}
}

func (s *SouinBaseHandler) HandleInternally(r *http.Request) (bool, http.HandlerFunc) {
if s.InternalEndpointHandlers != nil {
for k, handler := range *s.InternalEndpointHandlers.Handlers {
Expand Down Expand Up @@ -783,9 +838,28 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
bufPool.Reset()
defer s.bufPool.Put(bufPool)

customWriter := NewCustomWriter(req, rw, bufPool, int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()))
earlyHintIteration := 0
customWriter := NewCustomWriter(
req,
rw,
bufPool,
int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()),
func(h http.Header) {
s.StoreEarlyHint(cachedKey, h, earlyHintIteration)
earlyHintIteration++
})
customWriter.Headers.Add("Range", req.Header.Get("Range"))
// req.Header.Del("Range")

earlyHints := s.GetEarlyHints(cachedKey)
for _, links := range earlyHints {
for _, link := range links {
rw.Header().Add("Link", link)
}

rw.WriteHeader(http.StatusEarlyHints)

rw.Header().Del("Link")
}

go func(req *http.Request, crw *CustomWriter) {
<-req.Context().Done()
Expand Down Expand Up @@ -896,7 +970,15 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
rfc.HitStaleCache(&response.Header)
_, _ = customWriter.copyToBuffer(response.Body)
_, err := customWriter.Send()
customWriter = NewCustomWriter(req, rw, bufPool, int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()))
customWriter = NewCustomWriter(
req,
rw,
bufPool,
int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()),
func(h http.Header) {
s.StoreEarlyHint(cachedKey, h, earlyHintIteration)
earlyHintIteration++
})
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri)
}(validator, customWriter, req, next, requestCc, cachedKey, uri)
Expand Down
37 changes: 34 additions & 3 deletions pkg/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ type SouinWriterInterface interface {

var _ SouinWriterInterface = (*CustomWriter)(nil)

func NewCustomWriter(rq *http.Request, rw http.ResponseWriter, b *bytes.Buffer, maxSize int) *CustomWriter {
func NewCustomWriter(
rq *http.Request,
rw http.ResponseWriter,
b *bytes.Buffer,
maxSize int,
earlyHintStore func(http.Header),
) *CustomWriter {
return &CustomWriter{
statusCode: 200,
Buf: b,
Expand All @@ -31,6 +37,7 @@ func NewCustomWriter(rq *http.Request, rw http.ResponseWriter, b *bytes.Buffer,
mutex: sync.Mutex{},
maxSize: maxSize,
maxSizeReached: false,
earlyHintStore: earlyHintStore,
}
}

Expand All @@ -45,6 +52,8 @@ type CustomWriter struct {
statusCode int
maxSize int
maxSizeReached bool

earlyHintStore func(http.Header)
}

func (r *CustomWriter) resetBuffer() {
Expand Down Expand Up @@ -92,12 +101,24 @@ func (r *CustomWriter) GetStatusCode() int {

// WriteHeader will write the response headers
func (r *CustomWriter) WriteHeader(code int) {
r.mutex.Lock()
defer r.mutex.Unlock()
defer func(h http.Header) {
r.mutex.Unlock()

if code == http.StatusEarlyHints {
r.earlyHintStore(h)
}
}(r.Header())

if r.headersSent {
return
}

r.mutex.Lock()

r.statusCode = code
if code == http.StatusEarlyHints {
r.Rw.WriteHeader(code)
}
}

// Write will write the response body
Expand Down Expand Up @@ -187,6 +208,16 @@ func parseRange(rangeHeaders []string, contentRange string) ([]rangeValue, range
return values, crv, total + 1
}

// Push implements http.Pusher
func (r *CustomWriter) Push(target string, opts *http.PushOptions) error {
pusher, ok := r.Rw.(http.Pusher)
if !ok {
return fmt.Errorf("ResponseWriter does not implement http.Pusher")
}

return pusher.Push(target, opts)
}

// Send delays the response to handle Cache-Status
func (r *CustomWriter) Send() (int, error) {
defer r.resetBuffer()
Expand Down
Loading