diff --git a/config/config.go b/config/config.go index 8a8af2f6e..ea9d51f57 100644 --- a/config/config.go +++ b/config/config.go @@ -61,19 +61,14 @@ type DiceDBConfig struct { Engine string `mapstructure:"engine" default:"ironhawk" description:"the engine to use, values: ironhawk"` - EnableWAL bool `mapstructure:"enable-wal" default:"false" description:"enable write-ahead logging"` - WALDir string `mapstructure:"wal-dir" default:"logs" description:"the directory to store WAL segments"` - WALMode string `mapstructure:"wal-mode" default:"buffered" description:"wal mode to use, values: buffered, unbuffered"` - WALWriteMode string `mapstructure:"wal-write-mode" default:"default" description:"wal file write mode to use, values: default, fsync"` - WALBufferSizeMB int `mapstructure:"wal-buffer-size-mb" default:"1" description:"the size of the wal write buffer in megabytes"` - WALRotationMode string `mapstructure:"wal-rotation-mode" default:"segment-size" description:"wal rotation mode to use, values: segment-size, time"` - WALMaxSegmentSizeMB int `mapstructure:"wal-max-segment-size-mb" default:"16" description:"the maximum size of a wal segment file in megabytes before rotation"` - WALMaxSegmentRotationTimeSec int `mapstructure:"wal-max-segment-rotation-time-sec" default:"60" description:"the time interval (in seconds) after which wal a segment is rotated"` - WALBufferSyncIntervalMillis int `mapstructure:"wal-buffer-sync-interval-ms" default:"200" description:"the interval (in milliseconds) at which the wal write buffer is synced to disk"` - WALRetentionMode string `mapstructure:"wal-retention-mode" default:"num-segments" description:"the new horizon for wal segment post cleanup. values: num-segments, time, checkpoint"` - WALMaxSegmentCount int `mapstructure:"wal-max-segment-count" default:"10" description:"the maximum number of segments to retain, if the retention mode is 'num-segments'"` - WALMaxSegmentRetentionDurationSec int `mapstructure:"wal-max-segment-retention-duration-sec" default:"600" description:"the maximum duration (in seconds) for wal segments retention"` - WALRecoveryMode string `mapstructure:"wal-recovery-mode" default:"strict" description:"wal recovery mode in case of a corruption, values: strict, truncate, ignore"` + EnableWAL bool `mapstructure:"enable-wal" default:"false" description:"enable write-ahead logging"` + WALVariant string `mapstructure:"wal-variant" default:"forge" description:"wal variant to use, values: forge"` + WALDir string `mapstructure:"wal-dir" default:"logs" description:"the directory to store WAL segments"` + WALBufferSizeMB int `mapstructure:"wal-buffer-size-mb" default:"1" description:"the size of the wal write buffer in megabytes"` + WALRotationMode string `mapstructure:"wal-rotation-mode" default:"time" description:"wal rotation mode to use, values: segment-size, time"` + WALMaxSegmentSizeMB int `mapstructure:"wal-max-segment-size-mb" default:"16" description:"the maximum size of a wal segment file in megabytes before rotation"` + WALSegmentRotationTimeSec int `mapstructure:"wal-max-segment-rotation-time-sec" default:"60" description:"the time interval (in seconds) after which wal a segment is rotated"` + WALBufferSyncIntervalMillis int `mapstructure:"wal-buffer-sync-interval-ms" default:"200" description:"the interval (in milliseconds) at which the wal write buffer is synced to disk"` } func Load(flags *pflag.FlagSet) { diff --git a/internal/server/ironhawk/iothread.go b/internal/server/ironhawk/iothread.go index d3cd98222..f130821e2 100644 --- a/internal/server/ironhawk/iothread.go +++ b/internal/server/ironhawk/iothread.go @@ -60,7 +60,7 @@ func (t *IOThread) Start(ctx context.Context, shardManager *shardmanager.ShardMa select { case <-ctx.Done(): - slog.Debug("io-thread context cancelled, shutting down receive loop") + slog.Debug("io-thread context canceled, shutting down receive loop") return ctx.Err() case err := <-errCh: return err @@ -95,9 +95,8 @@ func (t *IOThread) Start(ctx context.Context, shardManager *shardmanager.ShardMa } // Log command to WAL if enabled and not a replay - if err == nil && wal.GetWAL() != nil && !_c.IsReplay { - // Create WAL entry using protobuf message - if err := wal.GetWAL().LogCommand(_c.C); err != nil { + if wal.DefaultWAL != nil && !_c.IsReplay { + if err := wal.DefaultWAL.LogCommand(_c.C); err != nil { slog.Error("failed to log command to WAL", slog.Any("error", err)) } } diff --git a/internal/server/ironhawk/main.go b/internal/server/ironhawk/main.go index 4f2ae9b7c..a2df3ef6f 100644 --- a/internal/server/ironhawk/main.go +++ b/internal/server/ironhawk/main.go @@ -15,7 +15,6 @@ import ( "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/shardmanager" - "github.com/dicedb/dice/internal/wal" ) type Server struct { @@ -28,7 +27,7 @@ type Server struct { ioThreadManager *IOThreadManager } -func NewServer(shardManager *shardmanager.ShardManager, ioThreadManager *IOThreadManager, watchManager *WatchManager, wl wal.WAL) *Server { +func NewServer(shardManager *shardmanager.ShardManager, ioThreadManager *IOThreadManager, watchManager *WatchManager) *Server { return &Server{ Host: config.Config.Host, Port: config.Config.Port, diff --git a/internal/wal/wal.go b/internal/wal/wal.go index 16252e771..6b26ab702 100644 --- a/internal/wal/wal.go +++ b/internal/wal/wal.go @@ -5,76 +5,51 @@ package wal import ( "log/slog" - "sync" - "time" - w "github.com/dicedb/dicedb-go/wal" + "github.com/dicedb/dice/config" "github.com/dicedb/dicedb-go/wire" ) type WAL interface { - Init(t time.Time) error + // Init initializes the WAL. + // The WAL implementation should start all the background jobs and initialize the WAL. + Init() error + // Stop stops the WAL. + // The WAL implementation should stop all the background jobs and close the WAL. + Stop() + // LogCommand logs a command to the WAL. LogCommand(c *wire.Command) error - Close() error - Replay(c func(*w.Element) error) error - Iterate(e *w.Element, c func(*w.Element) error) error + // Replay replays the command from the WAL. + ReplayCommand(cb func(c *wire.Command) error) error } +var DefaultWAL WAL var ( - ticker *time.Ticker stopCh chan struct{} - mu sync.Mutex - wl WAL ) -// GetWAL returns the global WAL instance -func GetWAL() WAL { - mu.Lock() - defer mu.Unlock() - return wl -} - -// SetGlobalWAL sets the global WAL instance -func SetWAL(_wl WAL) { - mu.Lock() - defer mu.Unlock() - wl = _wl -} - func init() { - ticker = time.NewTicker(10 * time.Second) stopCh = make(chan struct{}) } -func rotateWAL(wl WAL) { - mu.Lock() - defer mu.Unlock() - - if err := wl.Close(); err != nil { - slog.Warn("error closing the WAL", slog.Any("error", err)) - } - - if err := wl.Init(time.Now()); err != nil { - slog.Warn("error creating a new WAL", slog.Any("error", err)) - } +// TeardownWAL stops the WAL and closes the WAL instance. +func TeardownWAL() { + close(stopCh) } -func periodicRotate(wl WAL) { - for { - select { - case <-ticker.C: - rotateWAL(wl) - case <-stopCh: - return - } +// SetupWAL initializes the WAL based on the configuration. +// It creates a new WAL instance based on the WAL variant and initializes it. +// If the initialization fails, it panics. +func SetupWAL() { + switch config.Config.WALVariant { + case "forge": + DefaultWAL = newWalForge() + default: + return } -} -func InitBG(wl WAL) { - go periodicRotate(wl) -} - -func ShutdownBG() { - close(stopCh) - ticker.Stop() + if err := DefaultWAL.Init(); err != nil { + slog.Error("could not initialize WAL", slog.Any("error", err)) + panic(err) + } } diff --git a/internal/wal/wal_aof.go b/internal/wal/wal_aof.go deleted file mode 100644 index 62203436a..000000000 --- a/internal/wal/wal_aof.go +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright (c) 2022-present, DiceDB contributors -// All rights reserved. Licensed under the BSD 3-Clause License. See LICENSE file in the project root for full license information. - -package wal - -import ( - "bufio" - "context" - "encoding/binary" - "fmt" - "hash/crc32" - "io" - "log" - "log/slog" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/dicedb/dice/config" - "google.golang.org/protobuf/proto" - - w "github.com/dicedb/dicedb-go/wal" - "github.com/dicedb/dicedb-go/wire" -) - -const ( - segmentPrefix = "seg-" - segmentSuffix = ".wal" - RotationModeTime = "time" - RetentionModeTime = "time" - WALModeUnbuffered = "unbuffered" -) - -var bb []byte - -func init() { - // TODO: Pre-allocate a buffer to avoid re-allocating it - // This will hold one WAL AOF Entry Before it is written to the buffer - bb = make([]byte, 10*1024) -} - -type WALAOFEntry struct { - Len uint32 - Crc32 uint32 - Payload []byte -} - -type WALAOF struct { - logDir string - currentSegmentFile *os.File - walMode string - writeMode string - maxSegmentSize uint32 - maxSegmentCount int - currentSegmentIndex int - currentSegmentSize uint32 - oldestSegmentIndex int - bufferSize int - retentionMode string - recoveryMode string - rotationMode string - lastSequenceNo uint64 - bufWriter *bufio.Writer - bufferSyncTicker *time.Ticker - segmentRotationTicker *time.Ticker - segmentRetentionTicker *time.Ticker - mu sync.Mutex - ctx context.Context - cancel context.CancelFunc -} - -func NewAOFWAL(directory string) (*WALAOF, error) { - ctx, cancel := context.WithCancel(context.Background()) - return &WALAOF{ - logDir: directory, - walMode: config.Config.WALMode, - bufferSyncTicker: time.NewTicker(time.Duration(config.Config.WALBufferSyncIntervalMillis) * time.Millisecond), - segmentRotationTicker: time.NewTicker(time.Duration(config.Config.WALMaxSegmentRotationTimeSec) * time.Second), - segmentRetentionTicker: time.NewTicker(time.Duration(config.Config.WALMaxSegmentRetentionDurationSec) * time.Second), - writeMode: config.Config.WALWriteMode, - maxSegmentSize: uint32(config.Config.WALMaxSegmentSizeMB) * 1024 * 1024, - maxSegmentCount: config.Config.WALMaxSegmentCount, - bufferSize: config.Config.WALBufferSizeMB * 1024 * 1024, - retentionMode: config.Config.WALRetentionMode, - recoveryMode: config.Config.WALRecoveryMode, - rotationMode: config.Config.WALRotationMode, - ctx: ctx, - cancel: cancel, - }, nil -} - -func (wl *WALAOF) Init(t time.Time) error { - // TODO - Restore existing checkpoints to memory - - // Create the directory if it doesn't exist - if err := os.MkdirAll(wl.logDir, 0755); err != nil { - return err - } - - // Get the list of log segment files in the directory - files, err := filepath.Glob(filepath.Join(wl.logDir, segmentPrefix+"*"+segmentSuffix)) - if err != nil { - return err - } - - if len(files) > 0 { - slog.Debug("Found existing log segments", slog.Any("total_files", len(files))) - // TODO - Check if we have newer WAL entries after the last checkpoint and simultaneously replay and checkpoint them - } - - sf, err := os.OpenFile( - filepath.Join(wl.logDir, segmentPrefix+"0"+segmentSuffix), - os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return err - } - - wl.currentSegmentFile = sf - wl.bufWriter = bufio.NewWriterSize(wl.currentSegmentFile, wl.bufferSize) - - go wl.keepSyncingBuffer() - switch wl.rotationMode { - case RotationModeTime: - go wl.rotateSegmentPeriodically() - go wl.deleteSegmentPeriodically() - default: - return nil - } - return nil -} - -// Log writes a command to the WAL with a monotonically increasing sequence number. -// The sequence number is assigned atomically and the command is written to the wl. -func (wl *WALAOF) LogCommand(c *wire.Command) error { - // Lock once for the entire sequence number operation - wl.mu.Lock() - defer wl.mu.Unlock() - - b, err := proto.Marshal(c) - if err != nil { - return err - } - - wl.lastSequenceNo += 1 - el := &w.Element{ - Lsn: wl.lastSequenceNo, - Timestamp: time.Now().UnixNano(), - ElementType: w.ElementType_ELEMENT_TYPE_COMMAND, - Payload: b, - } - - b, err = proto.Marshal(el) - if err != nil { - return err - } - - entrySize := uint32(4 + 4 + len(b)) - if err := wl.rotateLogIfNeeded(entrySize); err != nil { - return err - } - - // If the entry size is greater than the buffer size, we need to - // create a new buffer. - if entrySize > uint32(cap(bb)) { - // TODO: In this case, we can do a one time creation - // of a new buffer and proceed rather than using the - // existing buffer. - panic(fmt.Errorf("buffer too small, %d > %d", entrySize, len(bb))) - } - - bb = bb[:8+len(b)] - // Calculate CRC32 only on the payload - chk := crc32.ChecksumIEEE(b) - - // Write header and payload - binary.LittleEndian.PutUint32(bb[0:4], chk) - binary.LittleEndian.PutUint32(bb[4:8], uint32(len(b))) - copy(bb[8:], b) - - _, _ = wl.bufWriter.Write(bb) - - wl.currentSegmentSize += entrySize - - // if wal-mode unbuffered immediately sync to disk - if wl.walMode == WALModeUnbuffered { - if err := wl.Sync(); err != nil { - return err - } - } - - return nil -} - -// rotateLogIfNeeded is not thread safe -func (wl *WALAOF) rotateLogIfNeeded(entrySize uint32) error { - if wl.currentSegmentSize+entrySize > wl.maxSegmentSize { - if err := wl.rotateLog(); err != nil { - return err - } - } - return nil -} - -// rotateLog is not thread safe -func (wl *WALAOF) rotateLog() error { - if err := wl.Sync(); err != nil { - return err - } - - if err := wl.currentSegmentFile.Close(); err != nil { - return err - } - - wl.currentSegmentIndex++ - if wl.currentSegmentIndex-wl.oldestSegmentIndex+1 > wl.maxSegmentCount { - if err := wl.deleteOldestSegment(); err != nil { - return err - } - wl.oldestSegmentIndex++ - } - - sf, err := os.OpenFile(filepath.Join(wl.logDir, segmentPrefix+fmt.Sprintf("%d", wl.currentSegmentIndex)+segmentSuffix), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatalf("failed opening file: %s", err) - } - - wl.currentSegmentSize = 0 - wl.currentSegmentFile = sf - wl.bufWriter = bufio.NewWriter(sf) - - return nil -} - -func (wl *WALAOF) deleteOldestSegment() error { - oldestSegmentFilePath := filepath.Join(wl.logDir, segmentPrefix+fmt.Sprintf("%d", wl.oldestSegmentIndex)+segmentSuffix) - - // TODO: checkpoint before deleting the file - if err := os.Remove(oldestSegmentFilePath); err != nil { - return err - } - wl.oldestSegmentIndex++ - return nil -} - -// Close the WAL file. It also calls Sync() on the wl. -func (wl *WALAOF) Close() error { - wl.cancel() - if err := wl.Sync(); err != nil { - return err - } - return wl.currentSegmentFile.Close() -} - -// Writes out any data in the WAL's in-memory buffer to the segment file. If -// fsync is enabled, it also calls fsync on the segment file. -func (wl *WALAOF) Sync() error { - if err := wl.bufWriter.Flush(); err != nil { - return err - } - if wl.writeMode == "fsync" { - if err := wl.currentSegmentFile.Sync(); err != nil { - return err - } - } - return nil -} - -func (wl *WALAOF) keepSyncingBuffer() { - for { - select { - case <-wl.bufferSyncTicker.C: - wl.mu.Lock() - err := wl.Sync() - wl.mu.Unlock() - - if err != nil { - slog.Error("failed to sync buffer", slog.String("error", err.Error())) - } - - case <-wl.ctx.Done(): - return - } - } -} - -func (wl *WALAOF) rotateSegmentPeriodically() { - for { - select { - case <-wl.segmentRotationTicker.C: - wl.mu.Lock() - err := wl.rotateLog() - wl.mu.Unlock() - if err != nil { - slog.Error("failed to rotate segment", slog.String("error", err.Error())) - } - - case <-wl.ctx.Done(): - return - } - } -} - -func (wl *WALAOF) deleteSegmentPeriodically() { - for { - select { - case <-wl.segmentRetentionTicker.C: - wl.mu.Lock() - err := wl.deleteOldestSegment() - wl.mu.Unlock() - if err != nil { - slog.Error("failed to delete segment", slog.String("error", err.Error())) - } - case <-wl.ctx.Done(): - return - } - } -} - -func (wl *WALAOF) segmentFiles() ([]string, error) { - // Get all segment files matching the pattern - files, err := filepath.Glob(filepath.Join(wl.logDir, segmentPrefix+"*"+segmentSuffix)) - if err != nil { - return nil, err - } - - // Sort files by numeric suffix - sort.Slice(files, func(i, j int) bool { - parseSuffix := func(name string) int64 { - num, _ := strconv.ParseInt( - strings.TrimPrefix(strings.TrimSuffix(filepath.Base(name), segmentSuffix), segmentPrefix), 10, 64) - return num - } - return parseSuffix(files[i]) < parseSuffix(files[j]) - }) - - return files, nil -} - -func (wl *WALAOF) Replay(callback func(*w.Element) error) error { - var crc uint32 - var entrySize uint32 - var el w.Element - bb1h := make([]byte, 8) - bb1ElementBytes := make([]byte, 10*1024) - - // Get list of segment files sorted by timestamp - segments, err := wl.segmentFiles() - if err != nil { - return fmt.Errorf("error getting wal-segment files: %w", err) - } - - // Process each segment file in order - for _, segment := range segments { - file, err := os.Open(segment) - if err != nil { - return fmt.Errorf("error opening wal-segment file %s: %w", segment, err) - } - - reader := bufio.NewReader(file) - // Format: CRC32 (4 bytes) | Size of WAL entry (4 bytes) | WAL data - for { - // Read CRC32 (4 bytes) + entrySize (4 bytes) - if _, err := io.ReadFull(reader, bb1h); err != nil { - if err == io.EOF { - break - } - file.Close() - return fmt.Errorf("error reading CRC32: %w", err) - } - crc = binary.LittleEndian.Uint32(bb1h[0:4]) - entrySize = binary.LittleEndian.Uint32(bb1h[4:8]) - - if _, err := io.ReadFull(reader, bb1ElementBytes[:entrySize]); err != nil { - file.Close() - return fmt.Errorf("error reading WAL data: %w", err) - } - - // Calculate CRC32 only on the payload - expectedCRC := crc32.ChecksumIEEE(bb1ElementBytes[:entrySize]) - if crc != expectedCRC { - file.Close() - return fmt.Errorf("CRC32 mismatch: expected %d, got %d", crc, expectedCRC) - } - - // Unmarshal the WAL entry to get the payload - if err := proto.Unmarshal(bb1ElementBytes[:entrySize], &el); err != nil { - file.Close() - return fmt.Errorf("error unmarshaling WAL entry: %w", err) - } - - // Call provided replay function with parsed command - if err := callback(&el); err != nil { - file.Close() - return fmt.Errorf("error replaying command: %w", err) - } - } - file.Close() - } - - return nil -} - -func (wl *WALAOF) Iterate(e *w.Element, c func(*w.Element) error) error { - return c(e) -} diff --git a/internal/wal/wal_forge.go b/internal/wal/wal_forge.go new file mode 100644 index 000000000..d9a96ceb1 --- /dev/null +++ b/internal/wal/wal_forge.go @@ -0,0 +1,421 @@ +// Copyright (c) 2022-present, DiceDB contributors +// All rights reserved. Licensed under the BSD 3-Clause License. See LICENSE file in the project root for full license information. + +package wal + +import ( + "bufio" + "context" + "encoding/binary" + "fmt" + "hash/crc32" + "io" + "log/slog" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/dicedb/dice/config" + "google.golang.org/protobuf/proto" + + w "github.com/dicedb/dicedb-go/wal" + "github.com/dicedb/dicedb-go/wire" +) + +const ( + segmentPrefix = "seg-" +) + +var bb []byte + +func init() { + // Pre-allocate a buffer to avoid re-allocating it + // This will hold one WAL Forge Entry Before it is written to the buffer + bb = make([]byte, 10*1024) +} + +type walForge struct { + // Current Segment File and its writer + csf *os.File + csWriter *bufio.Writer + csIdx int + csSize uint32 + + // TODO: Come up with a way to generate a LSN that is + // monotonically increasing and even after restart it + // resumes from the last LSN and not start from 0. + lsn uint64 + + maxSegmentSizeBytes uint32 + + bufferSyncTicker *time.Ticker + segmentRotationTicker *time.Ticker + + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +func newWalForge() *walForge { + ctx, cancel := context.WithCancel(context.Background()) + return &walForge{ + ctx: ctx, + cancel: cancel, + + bufferSyncTicker: time.NewTicker(time.Duration(config.Config.WALBufferSyncIntervalMillis) * time.Millisecond), + segmentRotationTicker: time.NewTicker(time.Duration(config.Config.WALSegmentRotationTimeSec) * time.Second), + + maxSegmentSizeBytes: uint32(config.Config.WALMaxSegmentSizeMB) * 1024 * 1024, + } +} + +func (wl *walForge) Init() error { + // TODO: Once the checkpoint is implemented + // Load the initial state of the database from this checkpoint + // and then reply the WAL files that happened after this checkpoint. + + // Make sure the WAL directory exists + if err := os.MkdirAll(config.Config.WALDir, 0755); err != nil { + return err + } + + // Get the list of log segment files in the WAL directory + sfs, err := wl.segments() + if err != nil { + return err + } + slog.Debug("Loading WAL segments", slog.Any("total_segments", len(sfs))) + + // TODO: Do not assume that the first segment is always 0 + // Pick the one with the least value of the segment index + // Maintain a metadatafile that holds the latest segment index used + // and during rotation, it increments the segment index and uses it + sf, err := os.OpenFile( + filepath.Join(config.Config.WALDir, segmentPrefix+"0"+".wal"), + os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + + wl.csf = sf + wl.csWriter = bufio.NewWriterSize(wl.csf, config.Config.WALBufferSizeMB*1024*1024) + + go wl.periodicSyncBuffer() + + switch config.Config.WALRotationMode { + case "time": + go wl.periodicRotateSegment() + default: + return nil + } + return nil +} + +// LogCommand writes a command to the WAL with a monotonically increasing LSN. +func (wl *walForge) LogCommand(c *wire.Command) error { + // Lock once for the entire LSN operation + wl.mu.Lock() + defer wl.mu.Unlock() + + // marshal the command to bytes + b, err := proto.Marshal(c) + if err != nil { + return err + } + + // TODO: This logic changes as we change the LSN format + wl.lsn += 1 + el := &w.Element{ + Lsn: wl.lsn, + Timestamp: time.Now().UnixNano(), + ElementType: w.ElementType_ELEMENT_TYPE_COMMAND, + Payload: b, + } + + // marshal the WAL Element to bytes + b, err = proto.Marshal(el) + if err != nil { + return err + } + + // Wrap the element with Checksum and Size + // and keep it ready to be written to the segment file through the buffer + // We call this WAL Entry. + entrySize := uint32(4 + 4 + len(b)) + if err := wl.rotateLogIfNeeded(entrySize); err != nil { + return err + } + + // If the entry size is greater than the buffer size, we need to + // create a new buffer. + if entrySize > uint32(cap(bb)) { + // TODO: In this case, we can do a one time creation of a new buffer + // and proceed rather than using the existing buffer. + panic(fmt.Errorf("buffer too small, %d > %d", entrySize, len(bb))) + } + + bb = bb[:8+len(b)] + chk := crc32.ChecksumIEEE(b) + + // Write header and payload + binary.LittleEndian.PutUint32(bb[0:4], chk) + binary.LittleEndian.PutUint32(bb[4:8], uint32(len(b))) + copy(bb[8:], b) + + // TODO: Check if we need to handle the error here, + // from my initial understanding, we should not be + // handling the error here because it would never happen. + // Have not tested this yet. + _, _ = wl.csWriter.Write(bb) + + wl.csSize += entrySize + return nil +} + +// rotateLogIfNeeded checks if the current segment size + the entry size is +// greater than the max segment size, and if so, it rotates the log. +// This method is not thread safe and hence should be called with the lock held. +func (wl *walForge) rotateLogIfNeeded(entrySize uint32) error { + // If the current segment size + the entry size is greater than the max segment size, + // we need to rotate the log. + if wl.csSize+entrySize > wl.maxSegmentSizeBytes { + if err := wl.rotateLog(); err != nil { + return err + } + } + return nil +} + +// rotateLog rotates the log by closing the current segment file, +// incrementing the current segment index, and opening a new segment file. +// This method is thread safe. +func (wl *walForge) rotateLog() error { + fmt.Println("rotating log") + wl.mu.Lock() + defer wl.mu.Unlock() + + // TODO: Ideally this function should not return any error + // Check for the conditions where it can return an error + // and handle them gracefully. + // I fear that we will need to do some cleanup operations in case of errors. + + // Sync the current segment file to disk + if err := wl.sync(); err != nil { + return err + } + + // Close the current segment file + if err := wl.csf.Close(); err != nil { + return err + } + + // Increment the current segment index + wl.csIdx++ + + // Open a new segment file + sf, err := os.OpenFile( + filepath.Join(config.Config.WALDir, fmt.Sprintf("%s%d.wal", segmentPrefix, wl.csIdx)), + os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + // TODO: We are panicking here because we are not handling the error + // and we want to make sure that the WAL is not corrupted. + // We need to handle this error gracefully. + panic(fmt.Errorf("failed opening file: %w", err)) + } + + // Reset the trackers + wl.csf = sf + wl.csSize = 0 + wl.csWriter = bufio.NewWriter(sf) + + return nil +} + +// Writes out any data in the WAL's in-memory buffer to the segment file. +// and syncs the segment file to disk. +// This method is thread safe. +func (wl *walForge) sync() error { + wl.mu.Lock() + defer wl.mu.Unlock() + + // Flush the buffer to the segment file + if err := wl.csWriter.Flush(); err != nil { + return err + } + + // Sync the segment file to disk to make sure + // it is written to disk. + if err := wl.csf.Sync(); err != nil { + return err + } + + // TODO: Evaluate if DIRECT_IO is needed here. + // If we are using a file system that supports direct IO, + // we can use it to improve the performance. + // If we are using a file system that does not support direct IO, + // we can use the buffer to improve the performance. + return nil +} + +func (wl *walForge) periodicSyncBuffer() { + for { + select { + case <-wl.bufferSyncTicker.C: + err := wl.sync() + if err != nil { + slog.Error("failed to sync buffer", slog.String("error", err.Error())) + } + case <-wl.ctx.Done(): + return + } + } +} + +func (wl *walForge) periodicRotateSegment() { + fmt.Println("rotating segment") + for { + select { + case <-wl.segmentRotationTicker.C: + // TODO: Remove this error handling once we clean up the error handling in the rotateLog function. + if err := wl.rotateLog(); err != nil { + slog.Error("failed to rotate segment", slog.String("error", err.Error())) + } + case <-wl.ctx.Done(): + return + } + } +} + +func (wl *walForge) segments() ([]string, error) { + // Get all segment files matching the pattern + files, err := filepath.Glob(filepath.Join(config.Config.WALDir, segmentPrefix+"*"+".wal")) + if err != nil { + return nil, err + } + + sort.Slice(files, func(i, j int) bool { + s1, _ := strconv.Atoi(strings.Split(strings.TrimPrefix(files[i], segmentPrefix), ".")[0]) + s2, _ := strconv.Atoi(strings.Split(strings.TrimPrefix(files[i], segmentPrefix), ".")[0]) + return s1 < s2 + }) + + // TODO: Check that the segment files are returned in the correct order + // The order has to be in ascending order of the segment index. + return files, nil +} + +// ReplayCommand replays the commands from the WAL files. +// This method is thread safe. +func (wl *walForge) ReplayCommand(cb func(*wire.Command) error) error { + var crc, entrySize uint32 + var el w.Element + + // Buffers to hold the header and the element bytes + bb1h := make([]byte, 8) + bb1ElementBytes := make([]byte, 10*1024) + + // Get list of segment files ordered by timestamp in ascending order + segments, err := wl.segments() + if err != nil { + return fmt.Errorf("error getting wal-segment files: %w", err) + } + + // Process each segment file in order + for _, segment := range segments { + file, err := os.Open(segment) + if err != nil { + return fmt.Errorf("error opening wal-segment file %s: %w", segment, err) + } + + reader := bufio.NewReader(file) + // Format: CRC32 (4 bytes) | Size of WAL entry (4 bytes) | WAL data + + // TODO: Replace this infinite loop with a more elegant solution + for { + // Read CRC32 (4 bytes) + entrySize (4 bytes) + if _, err := io.ReadFull(reader, bb1h); err != nil { + // TODO: this terminating connection should be handled in a better way + // and the loop should not be infinite. + // Edge case: this EOF error can happen even in the next step when + // we are reading the WAL element from the file. + if err == io.EOF { + break + } + file.Close() + return fmt.Errorf("error reading WAL: %w", err) + } + crc = binary.LittleEndian.Uint32(bb1h[0:4]) + entrySize = binary.LittleEndian.Uint32(bb1h[4:8]) + + if _, err := io.ReadFull(reader, bb1ElementBytes[:entrySize]); err != nil { + file.Close() + return fmt.Errorf("error reading WAL data: %w", err) + } + + // Calculate CRC32 only on the payload + expectedCRC := crc32.ChecksumIEEE(bb1ElementBytes[:entrySize]) + if crc != expectedCRC { + // TODO: We are reprtitively closing the file here + // A better solution would be to move this logic to a function + // and use defer to close the file. + // The function. thus, in a way processes (replays) one segment at a time. + file.Close() + + // TODO: THis is where we should trigger the WAL recovery + // Recovery process is all about truncating the segment file + // till this point and ignoring the rest. + // Log appropriate messages when this happens. + // Evaluate if this recovery mode should be a command line flag + // that would suggest if we should truncate, ignore, or stop the boot process. + return fmt.Errorf("CRC32 mismatch: expected %d, got %d", crc, expectedCRC) + } + + // Unmarshal the WAL entry to get the payload + if err := proto.Unmarshal(bb1ElementBytes[:entrySize], &el); err != nil { + file.Close() + return fmt.Errorf("error unmarshaling WAL entry: %w", err) + } + + var c wire.Command + if err := proto.Unmarshal(el.Payload, &c); err != nil { + file.Close() + return fmt.Errorf("error unmarshaling command: %w", err) + } + + // Call provided replay function with parsed command + if err := cb(&c); err != nil { + file.Close() + return fmt.Errorf("error replaying command: %w", err) + } + } + } + + return nil +} + +// Stop stops the WAL Forge. +// This method is thread safe. +func (wl *walForge) Stop() { + wl.mu.Lock() + defer wl.mu.Unlock() + + // Stop the tickers + wl.bufferSyncTicker.Stop() + wl.segmentRotationTicker.Stop() + + // Cancel the context + wl.cancel() + + // Sync the current segment file to disk + if err := wl.sync(); err != nil { + slog.Error("failed to sync current segment file", slog.String("error", err.Error())) + } + + wl.csf.Close() + + // TODO: See if we are missing any other cleanup operations. +} diff --git a/internal/wal/wal_null.go b/internal/wal/wal_null.go deleted file mode 100644 index 46e030810..000000000 --- a/internal/wal/wal_null.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2022-present, DiceDB contributors -// All rights reserved. Licensed under the BSD 3-Clause License. See LICENSE file in the project root for full license information. - -package wal - -import ( - "time" - - w "github.com/dicedb/dicedb-go/wal" - "github.com/dicedb/dicedb-go/wire" -) - -type WALNull struct { -} - -func NewNullWAL() (*WALNull, error) { - return &WALNull{}, nil -} - -func (w *WALNull) Init(t time.Time) error { - return nil -} - -func (w *WALNull) LogCommand(c *wire.Command) error { - return nil -} - -func (w *WALNull) Close() error { - return nil -} - -func (w *WALNull) Replay(callback func(*w.Element) error) error { - return nil -} - -func (w *WALNull) Iterate(entry *w.Element, callback func(*w.Element) error) error { - return nil -} diff --git a/server/main.go b/server/main.go index a537b6276..35b798aae 100644 --- a/server/main.go +++ b/server/main.go @@ -16,15 +16,12 @@ import ( "runtime/trace" "sync" "syscall" - "time" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/cmd" "github.com/dicedb/dice/internal/server/ironhawk" "github.com/dicedb/dice/internal/shardmanager" - w "github.com/dicedb/dicedb-go/wal" "github.com/dicedb/dicedb-go/wire" - "google.golang.org/protobuf/proto" "github.com/dicedb/dice/internal/wal" @@ -82,33 +79,11 @@ func Start() { signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) var ( - serverErrCh = make(chan error, 2) - wl wal.WAL - walInitSuccessful = false + serverErrCh = make(chan error, 2) ) if config.Config.EnableWAL { - _wl, err := wal.NewAOFWAL(config.Config.WALDir) - if err != nil { - slog.Warn("could not create WAL at", slog.String("wal-dir", config.Config.WALDir), slog.Any("error", err)) - sigs <- syscall.SIGKILL - cancel() - return - } - wl = _wl - wal.SetWAL(wl) // Set the global WAL instance - - if err := wl.Init(time.Now()); err != nil { - slog.Warn("could not initialize WAL", slog.Any("error", err)) - slog.Warn("disabling WAL and continuing") - // TODO: Make sure that the WAL is disabled - // We should not incurring any additional cost of making LogCommand - // invocations. - } else { - go wal.InitBG(wl) - slog.Debug("WAL initialization complete") - walInitSuccessful = true - } + wal.SetupWAL() } // Get the number of available CPU cores on the machine using runtime.NumCPU(). @@ -150,21 +125,14 @@ func Start() { } ioThreadManager := ironhawk.NewIOThreadManager() - ironhawkServer := ironhawk.NewServer(shardManager, ioThreadManager, watchManager, wl) + ironhawkServer := ironhawk.NewServer(shardManager, ioThreadManager, watchManager) - serverWg.Add(1) - go runServer(ctx, &serverWg, ironhawkServer, serverErrCh) - - // Recovery from WAL logs - if config.Config.EnableWAL && walInitSuccessful { + // Restore the database from WAL logs + if config.Config.EnableWAL { slog.Info("restoring database from WAL") - callback := func(el *w.Element) error { - var cd wire.Command - if err := proto.Unmarshal(el.Payload, &cd); err != nil { - return fmt.Errorf("failed to unmarshal command: %w", err) - } + callback := func(cd *wire.Command) error { cmdTemp := cmd.Cmd{ - C: &cd, + C: cd, IsReplay: true, } _, err := cmdTemp.Execute(shardManager) @@ -173,12 +141,16 @@ func Start() { } return nil } - if err := wl.Replay(callback); err != nil { + if err := wal.DefaultWAL.ReplayCommand(callback); err != nil { slog.Error("error restoring from WAL", slog.Any("error", err)) } slog.Info("database restored from WAL") } + slog.Info("ready to accept connections") + serverWg.Add(1) + go runServer(ctx, &serverWg, ironhawkServer, serverErrCh) + wg.Add(1) go func() { defer wg.Done() @@ -202,11 +174,10 @@ func Start() { close(sigs) if config.Config.EnableWAL { - wal.ShutdownBG() + wal.TeardownWAL() } cancel() - wg.Wait() } diff --git a/tests/commands/ironhawk/setup.go b/tests/commands/ironhawk/setup.go index f0b9786b5..e1ee0a29b 100644 --- a/tests/commands/ironhawk/setup.go +++ b/tests/commands/ironhawk/setup.go @@ -112,9 +112,9 @@ func RunTestServer(wg *sync.WaitGroup) { shardManager := shardmanager.NewShardManager(1, gec) ioThreadManager := ironhawk.NewIOThreadManager() watchManager := &ironhawk.WatchManager{} - wl, _ := wal.NewAOFWAL(config.Config.WALDir) + wal.SetupWAL() - testServer := ironhawk.NewServer(shardManager, ioThreadManager, watchManager, wl) + testServer := ironhawk.NewServer(shardManager, ioThreadManager, watchManager) ctx, cancel := context.WithCancel(context.Background()) fmt.Println("Starting the test server on port", config.Config.Port)