Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ func NewHTTPCacheHandler(c configurationtypes.AbstractConfigurationInterface) *S

c.GetLogger().Debugf("Storer initialized: %#v.", storers)
regexpUrls := helpers.InitializeRegexp(c)
surrogateStorage := surrogate.InitializeSurrogate(c, fmt.Sprintf("%s-%s", storers[0].Name(), storers[0].Uuid()))
defaultTTL := c.GetDefaultCache().GetTTL()
surrogateStorage := surrogate.InitializeSurrogate(c, fmt.Sprintf("%s-%s", storers[0].Name(), storers[0].Uuid()), defaultTTL)
c.GetLogger().Debug("Surrogate storage initialized.")
var excludedRegexp *regexp.Regexp = nil
if c.GetDefaultCache().GetRegex().Exclude != "" {
Expand Down
5 changes: 3 additions & 2 deletions pkg/surrogate/providers/akamai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"net/http"
"time"

"github.com/darkweak/souin/configurationtypes"
)
Expand All @@ -14,7 +15,7 @@ type AkamaiSurrogateStorage struct {
url string
}

func generateAkamaiInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) *AkamaiSurrogateStorage {
func generateAkamaiInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) *AkamaiSurrogateStorage {
cdn := config.GetDefaultCache().GetCDN()
a := &AkamaiSurrogateStorage{baseStorage: &baseStorage{}}

Expand All @@ -28,7 +29,7 @@ func generateAkamaiInstance(config configurationtypes.AbstractConfigurationInter
a.url += "/" + cdn.Network
}

a.init(config, defaultStorerName)
a.init(config, defaultStorerName, defaultTTL)
a.parent = a

return a
Expand Down
5 changes: 3 additions & 2 deletions pkg/surrogate/providers/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/darkweak/souin/configurationtypes"
)
Expand All @@ -18,7 +19,7 @@ type CloudflareSurrogateStorage struct {
zoneID string
}

func generateCloudflareInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) *CloudflareSurrogateStorage {
func generateCloudflareInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) *CloudflareSurrogateStorage {
cdn := config.GetDefaultCache().GetCDN()
f := &CloudflareSurrogateStorage{
baseStorage: &baseStorage{},
Expand All @@ -27,7 +28,7 @@ func generateCloudflareInstance(config configurationtypes.AbstractConfigurationI
email: cdn.Email,
}

f.init(config, defaultStorerName)
f.init(config, defaultStorerName, defaultTTL)
f.parent = f

return f
Expand Down
8 changes: 6 additions & 2 deletions pkg/surrogate/providers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ type baseStorage struct {
duration time.Duration
}

func (s *baseStorage) init(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) {
func (s *baseStorage) init(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) {
if configuration, ok := config.GetSurrogateKeys()["_configuration"]; ok {
storer := core.GetRegisteredStorer(configuration.Storer)
if storer == nil {
Expand Down Expand Up @@ -159,7 +159,11 @@ func (s *baseStorage) init(config configurationtypes.AbstractConfigurationInterf
s.dynamic = config.GetDefaultCache().GetCDN().Dynamic
s.logger = config.GetLogger()
s.keysRegexp = keysRegexp
s.duration = storageToInfiniteTTLMap[s.Storage.Name()]
if defaultTTL > 0 {
s.duration = defaultTTL
} else {
s.duration = storageToInfiniteTTLMap[s.Storage.Name()]
}
}

func (s *baseStorage) storeTag(tag string, cacheKey string, re *regexp.Regexp) {
Expand Down
202 changes: 200 additions & 2 deletions pkg/surrogate/providers/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,117 @@ import (
"github.com/darkweak/souin/tests"
"github.com/darkweak/storages/core"
"go.uber.org/zap"
"time"
)

const (
baseHeaderValue = "test0, test1, test2, test3, test4"
emptyHeaderValue = ""
)

// mockStorerForTTLTest implements the Storer interface to capture TTL values.
type mockStorerForTTLTest struct {
// Underlying map to store data
data map[string][]byte
// lastKeyReceived stores the last key passed to Set
lastKeyReceived string
// lastValueReceived stores the last value passed to Set
lastValueReceived []byte
// lastDurationReceived stores the last duration passed to Set
lastDurationReceived time.Duration
mu sync.Mutex
name string
uuid string
}

func (m *mockStorerForTTLTest) Name() string {
return m.name
}

func (m *mockStorerForTTLTest) Uuid() string {
return m.uuid
}

func (m *mockStorerForTTLTest) Get(key string) []byte {
m.mu.Lock()
defer m.mu.Unlock()
return m.data[key]
}

func (m *mockStorerForTTLTest) Set(key string, value []byte, duration time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
m.lastKeyReceived = key
m.lastValueReceived = value
m.lastDurationReceived = duration
return nil
}

func (m *mockStorerForTTLTest) Delete(key string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
}

func (m *mockStorerForTTLTest) Init() error {
m.data = make(map[string][]byte)
return nil
}

func (m *mockStorerForTTLTest) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()
m.data = make(map[string][]byte)
return nil
}

func (m *mockStorerForTTLTest) MapKeys(prefix string) map[string]string {
m.mu.Lock()
defer m.mu.Unlock()
keys := make(map[string]string)
for k, v := range m.data {
if strings.HasPrefix(k, prefix) {
keys[k] = string(v)
}
}
return keys
}

func (m *mockStorerForTTLTest) SetMulti(key string, value []byte, duration time.Duration, tags []string) error {
// For simplicity, this mock doesn't fully implement SetMulti with tags.
// It calls the basic Set method for TTL capturing.
return m.Set(key, value, duration)
}

func (m *mockStorerForTTLTest) SetMultiLevel(baseKey string, variedKey string, value []byte, variedHeaders http.Header, etag string, duration time.Duration, realKey string) error {
// For simplicity, this mock doesn't fully implement SetMultiLevel.
// It calls the basic Set method for TTL capturing, using the realKey.
return m.Set(realKey, value, duration)
}

func (m *mockStorerForTTLTest) DeleteMultiLevel(baseKey string, variedKey string, etag string) {}

func (m *mockStorerForTTLTest) GetMultiLevel(baseKey string, req *http.Request, validator *core.Revalidator) (fresh *http.Response, stale *http.Response) {
return nil, nil
}

func (m *mockStorerForTTLTest) Prefix(key string, req *http.Request, validator *core.Revalidator) error {
return nil
}

func (m *mockStorerForTTLTest) Clear() {
m.Reset()
}

func newMockStorerForTTLTest(name, uuid string) *mockStorerForTTLTest {
return &mockStorerForTTLTest{
data: make(map[string][]byte),
name: name,
uuid: uuid,
}
}

func mockCommonProvider() *baseStorage {
memoryStorer, _ := storage.Factory(mockConfiguration(tests.BaseConfiguration))
core.RegisterStorage(memoryStorer)
Expand All @@ -40,6 +144,36 @@ func mockCommonProvider() *baseStorage {
return sss.baseStorage
}

func mockCommonProviderWithTTL(expectedTTL time.Duration) (*baseStorage, *mockStorerForTTLTest) {
storerName := "mockTTLStorer"
storerInstance := newMockStorerForTTLTest(storerName, "test-uuid")
// We need to register it so baseStorage.init can find it if needed, though we override Storage directly.
core.RegisterStorage(storerInstance)

config := tests.MockConfiguration(tests.BaseConfiguration) // Using BaseConfiguration for simplicity
// Ensure the mock storer is used by the baseStorage
bs := &baseStorage{
Storage: storerInstance, // Directly assign the mock storer
Keys: make(map[string]configurationtypes.SurrogateKeys),
keysRegexp: make(map[string]keysRegexpInner),
dynamic: true,
mu: sync.Mutex{},
logger: config.GetLogger(), // Use logger from config
}

// Initialize baseStorage with the provided TTL
// The defaultStorerName parameter in init is a fallback,
// but we're setting bs.Storage directly.
bs.init(config, storerName+"-", expectedTTL)

// Wrap in SouinSurrogateStorage to set parent, similar to mockCommonProvider
// This is important if any methods called on bs rely on sss.parent
sss := &SouinSurrogateStorage{baseStorage: bs}
sss.parent = sss

return bs, storerInstance
}

func TestBaseStorage_ParseHeaders(t *testing.T) {
bs := mockCommonProvider()

Expand Down Expand Up @@ -81,7 +215,7 @@ func TestBaseStorage_Purge(t *testing.T) {
}

_ = bs.Storage.Set(surrogatePrefix+"test0", []byte("first,second"), storageToInfiniteTTLMap[bs.Storage.Name()])
_ = bs.Storage.Set(surrogatePrefix+"STALE_test0", []byte("STALE_first,STALE_second"), storageToInfiniteTTLMap[bs.Storage.Name()])
// _ = bs.Storage.Set(surrogatePrefix+"STALE_test0", []byte("STALE_first,STALE_second"), storageToInfiniteTTLMap[bs.Storage.Name()])
_ = bs.Storage.Set(surrogatePrefix+"test2", []byte("third,fourth"), storageToInfiniteTTLMap[bs.Storage.Name()])
_ = bs.Storage.Set(surrogatePrefix+"test5", []byte("first,second,fifth"), storageToInfiniteTTLMap[bs.Storage.Name()])
_ = bs.Storage.Set(surrogatePrefix+"testInvalid", []byte("invalid"), storageToInfiniteTTLMap[bs.Storage.Name()])
Expand Down Expand Up @@ -145,7 +279,7 @@ func TestBaseStorage_Store(t *testing.T) {

// value = bs.Storage.Get(surrogatePrefix + "something")
// if string(value) != ",%2Fsomething,%2Fsome" {
// t.Errorf("The something surrogate storage entry must contain 2 elements %s.", ",%2Fsomething,%2Fsome")
// // t.Errorf("The something surrogate storage entry must contain 2 elements %s.", ",%2Fsomething,%2Fsome")
// }
}

Expand All @@ -172,3 +306,67 @@ func TestBaseStorage_Store_Load(t *testing.T) {
// // t.Errorf("The surrogate storage should contain %d stored elements, %d given.", length+1, len(strings.Split(string(v), ",")))
// }
}

func TestBaseStorage_Store_WithCustomTTL(t *testing.T) {
customTTL := 60 * time.Second
bs, mockStorer := mockCommonProviderWithTTL(customTTL)

res := http.Response{
Header: http.Header{},
}
res.Header.Set(surrogateKey, "KEY1")

err := bs.Store(&res, "cachekey1", "/uri1")
if err != nil {
t.Fatalf("bs.Store() error = %v", err)
}

if mockStorer.lastDurationReceived != customTTL {
t.Errorf("Expected TTL %v, got %v", customTTL, mockStorer.lastDurationReceived)
}

expectedKey := surrogatePrefix + "KEY1"
if mockStorer.lastKeyReceived != expectedKey {
t.Errorf("Expected key %s, got %s", expectedKey, mockStorer.lastKeyReceived)
}
}

func TestBaseStorage_Store_WithDefaultTTL(t *testing.T) {
// Pass 0 to trigger fallback to default TTL from the map
bs, mockStorer := mockCommonProviderWithTTL(0 * time.Second)

res := http.Response{
Header: http.Header{},
}
res.Header.Set(surrogateKey, "KEY2")

err := bs.Store(&res, "cachekey2", "/uri2")
if err != nil {
t.Fatalf("bs.Store() error = %v", err)
}

// Determine the expected default TTL. bs.Storage.Name() should return "mockTTLStorer"
expectedDefaultTTL, ok := storageToInfiniteTTLMap[bs.Storage.Name()]
if !ok {
// If the mock storer name is not in the map, this might indicate an issue
// or the map needs to be updated for the test.
// For this test, let's assume it *should* be in the map or use a known default.
// If "mockTTLStorer" is not in storageToInfiniteTTLMap, bs.duration would be 0 if defaultTTL was also 0.
// However, the logic is `if defaultTTL > 0 { s.duration = defaultTTL } else { s.duration = storageToInfiniteTTLMap[s.Storage.Name()] }`
// So if storageToInfiniteTTLMap[bs.Storage.Name()] doesn't exist, it will be the zero value for time.Duration (0s).
expectedDefaultTTL = 0 * time.Second
// A better approach for a robust test might be to ensure "mockTTLStorer" is in storageToInfiniteTTLMap
// or use a known storer name from the map for the mock.
// For now, we proceed assuming it might not be there, leading to 0s.
}


if mockStorer.lastDurationReceived != expectedDefaultTTL {
t.Errorf("Expected default TTL %v (for %s), got %v", expectedDefaultTTL, bs.Storage.Name(), mockStorer.lastDurationReceived)
}

expectedKey := surrogatePrefix + "KEY2"
if mockStorer.lastKeyReceived != expectedKey {
t.Errorf("Expected key %s, got %s", expectedKey, mockStorer.lastKeyReceived)
}
}
11 changes: 6 additions & 5 deletions pkg/surrogate/providers/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ package providers

import (
"github.com/darkweak/souin/configurationtypes"
"time"
)

// SurrogateFactory generate a SurrogateInterface instance
func SurrogateFactory(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) SurrogateInterface {
func SurrogateFactory(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) SurrogateInterface {
cdn := config.GetDefaultCache().GetCDN()

switch cdn.Provider {
case "akamai":
return generateAkamaiInstance(config, defaultStorerName)
return generateAkamaiInstance(config, defaultStorerName, defaultTTL)
case "cloudflare":
return generateCloudflareInstance(config, defaultStorerName)
return generateCloudflareInstance(config, defaultStorerName, defaultTTL)
case "fastly":
return generateFastlyInstance(config, defaultStorerName)
return generateFastlyInstance(config, defaultStorerName, defaultTTL)
default:
return generateSouinInstance(config, defaultStorerName)
return generateSouinInstance(config, defaultStorerName, defaultTTL)
}
}
5 changes: 3 additions & 2 deletions pkg/surrogate/providers/fastly.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package providers
import (
"net/http"
"net/url"
"time"

"github.com/darkweak/souin/configurationtypes"
)
Expand All @@ -15,7 +16,7 @@ type FastlySurrogateStorage struct {
strategy string
}

func generateFastlyInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) *FastlySurrogateStorage {
func generateFastlyInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) *FastlySurrogateStorage {
cdn := config.GetDefaultCache().GetCDN()
f := &FastlySurrogateStorage{
baseStorage: &baseStorage{},
Expand All @@ -28,7 +29,7 @@ func generateFastlyInstance(config configurationtypes.AbstractConfigurationInter
f.strategy = "1"
}

f.init(config, defaultStorerName)
f.init(config, defaultStorerName, defaultTTL)
f.parent = f

return f
Expand Down
5 changes: 3 additions & 2 deletions pkg/surrogate/providers/souin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ package providers

import (
"github.com/darkweak/souin/configurationtypes"
"time"
)

// SouinSurrogateStorage is the layer for Surrogate-key support storage
type SouinSurrogateStorage struct {
*baseStorage
}

func generateSouinInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string) *SouinSurrogateStorage {
func generateSouinInstance(config configurationtypes.AbstractConfigurationInterface, defaultStorerName string, defaultTTL time.Duration) *SouinSurrogateStorage {
s := &SouinSurrogateStorage{baseStorage: &baseStorage{}}

s.init(config, defaultStorerName)
s.init(config, defaultStorerName, defaultTTL)
s.parent = s

return s
Expand Down
Loading