diff --git a/sei-db/ledger_db/parquet/wal.go b/sei-db/ledger_db/parquet/wal.go index d0a4ccbd07..6c7fb299e8 100644 --- a/sei-db/ledger_db/parquet/wal.go +++ b/sei-db/ledger_db/parquet/wal.go @@ -1,6 +1,7 @@ package parquet import ( + "context" "encoding/binary" "fmt" "os" @@ -109,6 +110,7 @@ func NewWAL(logger dbLogger.Logger, dir string) (dbwal.GenericWAL[WALEntry], err return nil, err } return dbwal.NewWAL( + context.Background(), encodeWALEntry, decodeWALEntry, logger, diff --git a/sei-db/wal/changelog.go b/sei-db/wal/changelog.go index b9a44e6e58..b3ea68a3e9 100644 --- a/sei-db/wal/changelog.go +++ b/sei-db/wal/changelog.go @@ -1,6 +1,8 @@ package wal import ( + "context" + "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/proto" ) @@ -12,6 +14,7 @@ type ChangelogWAL = GenericWAL[proto.ChangelogEntry] // This is a convenience wrapper that handles serialization automatically. func NewChangelogWAL(logger logger.Logger, dir string, config Config) (ChangelogWAL, error) { return NewWAL( + context.Background(), func(e proto.ChangelogEntry) ([]byte, error) { return e.Marshal() }, func(data []byte) (proto.ChangelogEntry, error) { var e proto.ChangelogEntry diff --git a/sei-db/wal/wal.go b/sei-db/wal/wal.go index a1cd79ff84..0a8cb48dd8 100644 --- a/sei-db/wal/wal.go +++ b/sei-db/wal/wal.go @@ -1,11 +1,11 @@ package wal import ( + "context" "errors" "fmt" "os" "path/filepath" - "sync" "time" "github.com/tidwall/wal" @@ -13,26 +13,77 @@ import ( "github.com/sei-protocol/sei-chain/sei-db/common/logger" ) +// The size of internal channel buffers if the provided buffer size is less than 1. +const defaultBufferSize = 1024 + +// The size of write batches if the provided write batch size is less than 1. +const defaultWriteBatchSize = 64 + // WAL is a generic write-ahead log implementation. type WAL[T any] struct { - dir string - log *wal.Log - config Config - logger logger.Logger - marshal MarshalFn[T] - unmarshal UnmarshalFn[T] - writeChannel chan T - mtx sync.RWMutex // guards WAL state: lazy init/close of writeChannel, isClosed checks - asyncWriteErrCh chan error // buffered=1; async writer reports first error non-blocking - isClosed bool - closeCh chan struct{} // signals shutdown to background goroutines - wg sync.WaitGroup // tracks background goroutines (pruning) + ctx context.Context + cancel context.CancelFunc + + dir string + log *wal.Log + config Config + logger logger.Logger + marshal MarshalFn[T] + unmarshal UnmarshalFn[T] + + // The size of write batches. + writeBatchSize int + asyncWrites bool + + writeChan chan *writeRequest[T] + truncateChan chan *truncateRequest + closeReqChan chan struct{} + closeErrChan chan error } +// A request to truncate the log. +type truncateRequest struct { + // If true, truncate before the provided index. Otherwise, truncate after the provided index. + before bool + // The index to truncate at. + index uint64 + // Errors are returned over this channel, nil is written if completed with no error + errChan chan error +} + +// A request to write to the WAL. +type writeRequest[T any] struct { + // The data to write + entry T + // Errors are returned over this channel, nil is written if completed with no error + errChan chan error +} + +// Configuration for the WAL. type Config struct { + // The number of recent entries to keep in the log. + KeepRecent uint64 + + // The interval at which to prune the log. + PruneInterval time.Duration + + // The size of internal buffers. Also controls whether or not the Write method is asynchronous. + // + // If BufferSize is greater than 0, then the Write method is asynchronous, and the size of internal + // buffers is set to the provided value. If Buffer size is less than 1, then the Write method is synchronous, + // and any internal buffers are set to a default size. WriteBufferSize int - KeepRecent uint64 - PruneInterval time.Duration + + // The size of write batches. If less than or equal to 0, a default of 64 is used. + // If 1, no batching is done. + WriteBatchSize int + + // If true, do an fsync after each write. + FsyncEnabled bool + + // If true, make a deep copy of the data for every write. If false, then it is not safe to modify the data after + // reading/writing it. + DeepCopyEnabled bool } // NewWAL creates a new generic write-ahead log that persists entries. @@ -49,6 +100,7 @@ type Config struct { // logger, dir, config, // ) func NewWAL[T any]( + ctx context.Context, marshal MarshalFn[T], unmarshal UnmarshalFn[T], logger logger.Logger, @@ -56,27 +108,46 @@ func NewWAL[T any]( config Config, ) (*WAL[T], error) { log, err := open(dir, &wal.Options{ - NoSync: true, - NoCopy: true, + NoSync: !config.FsyncEnabled, + NoCopy: !config.DeepCopyEnabled, }) if err != nil { return nil, err } - w := &WAL[T]{ - dir: dir, - log: log, - config: config, - logger: logger, - marshal: marshal, - unmarshal: unmarshal, - closeCh: make(chan struct{}), - asyncWriteErrCh: make(chan error, 1), + + bufferSize := config.WriteBufferSize + if config.WriteBufferSize <= 0 { + bufferSize = defaultBufferSize } - // Start the auto pruning goroutine - if config.KeepRecent > 0 && config.PruneInterval > 0 { - w.startPruning(config.KeepRecent, config.PruneInterval) + asyncWrites := config.WriteBufferSize > 0 + + writeBatchSize := config.WriteBatchSize + if writeBatchSize <= 0 { + writeBatchSize = defaultWriteBatchSize + } + + ctx, cancel := context.WithCancel(ctx) + + w := &WAL[T]{ + ctx: ctx, + cancel: cancel, + dir: dir, + log: log, + config: config, + logger: logger, + marshal: marshal, + unmarshal: unmarshal, + writeBatchSize: writeBatchSize, + asyncWrites: asyncWrites, + closeReqChan: make(chan struct{}), + closeErrChan: make(chan error, 1), + writeChan: make(chan *writeRequest[T], bufferSize), + truncateChan: make(chan *truncateRequest, bufferSize), } + + go w.mainLoop() + return w, nil } @@ -85,99 +156,215 @@ func NewWAL[T any]( // Whether the writes is in blocking or async manner depends on the buffer size. // For async writes, this also checks for any previous async write errors. func (walLog *WAL[T]) Write(entry T) error { - // Never hold walLog.mtx while doing a potentially-blocking send. Close() may run concurrently. - walLog.mtx.Lock() - defer walLog.mtx.Unlock() - if walLog.isClosed { - return errors.New("wal is closed") - } - if err := walLog.getAsyncWriteErrLocked(); err != nil { - return fmt.Errorf("async WAL write failed previously: %w", err) - } - writeBufferSize := walLog.config.WriteBufferSize - if writeBufferSize > 0 { - if walLog.writeChannel == nil { - walLog.writeChannel = make(chan T, writeBufferSize) - walLog.startAsyncWriteGoroutine() - walLog.logger.Info(fmt.Sprintf("WAL async write is enabled with buffer size %d", writeBufferSize)) - } - walLog.writeChannel <- entry + + errChan := make(chan error, 1) + req := &writeRequest[T]{ + entry: entry, + errChan: errChan, + } + + err := interuptablePush(walLog.ctx, walLog.writeChan, req) + if err != nil { + return fmt.Errorf("failed to push write request: %w", err) + } + + if walLog.asyncWrites { + // Do not wait for the write to be durable + return nil + } + + err, pullErr := interuptablePull(walLog.ctx, errChan) + if pullErr != nil { + return fmt.Errorf("failed to pull write error: %w", pullErr) + } + if err != nil { + return fmt.Errorf("failed to write data: %w", err) + } + + return nil +} + +// This method is called asynchronously in response to a call to Write. +func (walLog *WAL[T]) handleWrite(req *writeRequest[T]) { + if walLog.writeBatchSize <= 1 { + walLog.handleUnbatchedWrite(req) } else { - // synchronous write - bz, err := walLog.marshal(entry) - if err != nil { - return err + walLog.handleBatchedWrite(req) + } +} + +// handleUnbatchedWrite is called when no batching is enabled. Processes a single write request. +func (walLog *WAL[T]) handleUnbatchedWrite(req *writeRequest[T]) { + + bz, err := walLog.marshal(req.entry) + if err != nil { + req.errChan <- fmt.Errorf("marshalling error: %w", err) + return + } + lastOffset, err := walLog.log.LastIndex() + if err != nil { + req.errChan <- fmt.Errorf("error fetching last index: %w", err) + return + } + if err := walLog.log.Write(lastOffset+1, bz); err != nil { + req.errChan <- fmt.Errorf("failed to write: %w", err) + return + } + + req.errChan <- nil +} + +// handleBatchedWrite is called when batching is enabled. This method may pop pending writes from the writeChan and +// include them in the batch. +func (walLog *WAL[T]) handleBatchedWrite(req *writeRequest[T]) { + + requests := walLog.gatherRequestsForBatch(req) + + lastOffset, err := walLog.log.LastIndex() + if err != nil { + err = fmt.Errorf("error fetching last index: %w", err) + for _, req := range requests { + req.errChan <- err } - lastOffset, err := walLog.log.LastIndex() - if err != nil { - return err + return + } + + binaryRequests := walLog.marshalRequests(requests) + + batch := &wal.Batch{} + for _, binaryRequest := range binaryRequests { + batch.Write(lastOffset+1, binaryRequest) + lastOffset++ + } + + if err := walLog.log.WriteBatch(batch); err != nil { + err = fmt.Errorf("failed to write batch: %w", err) + for _, r := range requests { + if r.errChan != nil { + r.errChan <- err + } } - if err := walLog.log.Write(lastOffset+1, bz); err != nil { - return err + return + } + + for _, r := range requests { + if r.errChan != nil { + r.errChan <- nil } } - return nil } -// startWriteGoroutine will start a goroutine to write entries to the log. -// This should only be called on initialization if async write is enabled -func (walLog *WAL[T]) startAsyncWriteGoroutine() { - walLog.wg.Add(1) - ch := walLog.writeChannel - go func() { - defer walLog.wg.Done() - for entry := range ch { - bz, err := walLog.marshal(entry) - if err != nil { - walLog.recordAsyncWriteErr(err) - return - } - nextOffset, err := walLog.NextOffset() - if err != nil { - walLog.recordAsyncWriteErr(err) - return - } - err = walLog.log.Write(nextOffset, bz) - if err != nil { - walLog.recordAsyncWriteErr(err) - return - } +// Gather the requests for a batch. When this method is called, we will already have the first request in the batch. +func (walLog *WAL[T]) gatherRequestsForBatch(initialRequest *writeRequest[T]) []*writeRequest[T] { + requests := make([]*writeRequest[T], 0) + requests = append(requests, initialRequest) + + keepLooking := true + for keepLooking && len(requests) < walLog.writeBatchSize { + select { + case next := <-walLog.writeChan: + requests = append(requests, next) + default: + // No more pending writes immediately available, so process the batch we have so far. + keepLooking = false + } + } + return requests +} + +// Marshal the requests for a batch. If a request can't be marshalled, an error is immediately sent +// to that request's caller. +// +// The requests slice passed into this method is modified if some requests +// are not marshalled successfully. Any request that is not marshalled successfully has its errChan +// set to nil to avoid sending more than one response to the caller. +func (walLog *WAL[T]) marshalRequests(requests []*writeRequest[T]) [][]byte { + binaryRequests := make([][]byte, 0, len(requests)) + + for _, req := range requests { + bz, err := walLog.marshal(req.entry) + if err != nil { + err = fmt.Errorf("marshalling error: %w", err) + req.errChan <- err + req.errChan = nil // signal that we have already sent a response to the caller + continue } - }() + binaryRequests = append(binaryRequests, bz) + } + + return binaryRequests } // TruncateAfter will remove all entries that are after the provided `index`. // In other words the entry at `index` becomes the last entry in the log. func (walLog *WAL[T]) TruncateAfter(index uint64) error { - return walLog.log.TruncateBack(index) + return walLog.sendTruncate(false, index) } // TruncateBefore will remove all entries that are before the provided `index`. // In other words the entry at `index` becomes the first entry in the log. -// Need to add write lock because this would change the next write offset func (walLog *WAL[T]) TruncateBefore(index uint64) error { - return walLog.log.TruncateFront(index) + return walLog.sendTruncate(true, index) } -func (walLog *WAL[T]) FirstOffset() (index uint64, err error) { - return walLog.log.FirstIndex() +// sendTruncate sends a truncate request to the main loop and waits for completion. +func (walLog *WAL[T]) sendTruncate(before bool, index uint64) error { + req := &truncateRequest{ + before: before, + index: index, + errChan: make(chan error, 1), + } + + err := interuptablePush(walLog.ctx, walLog.truncateChan, req) + if err != nil { + return fmt.Errorf("failed to push truncate request: %w", err) + } + + err, pullErr := interuptablePull(walLog.ctx, req.errChan) + if pullErr != nil { + return fmt.Errorf("failed to pull truncate error: %w", pullErr) + } + if err != nil { + return fmt.Errorf("failed to truncate: %w", err) + } + + return nil } -// LastOffset returns the last written offset/index of the log -func (walLog *WAL[T]) LastOffset() (index uint64, err error) { - return walLog.log.LastIndex() +// handleTruncate runs on the main loop and performs the truncation. +func (walLog *WAL[T]) handleTruncate(req *truncateRequest) { + var err error + if req.before { + err = walLog.log.TruncateFront(req.index) + } else { + err = walLog.log.TruncateBack(req.index) + } + if err != nil { + req.errChan <- fmt.Errorf("failed to truncate: %w", err) + return + } + req.errChan <- nil } -func (walLog *WAL[T]) NextOffset() (index uint64, err error) { - lastOffset, err := walLog.log.LastIndex() +func (walLog *WAL[T]) FirstOffset() (uint64, error) { + val, err := walLog.log.FirstIndex() + if err != nil { + return 0, fmt.Errorf("failed to get first offset: %w", err) + } + return val, nil +} + +// LastOffset returns the last written offset/index of the log. +func (walLog *WAL[T]) LastOffset() (uint64, error) { + val, err := walLog.log.LastIndex() if err != nil { - return 0, err + return 0, fmt.Errorf("failed to get last offset: %w", err) } - return lastOffset + 1, nil + return val, nil } -// ReadAt will read the log entry at the provided index +// ReadAt will read the log entry at the provided index. func (walLog *WAL[T]) ReadAt(index uint64) (T, error) { var zero T bz, err := walLog.log.Read(index) @@ -186,12 +373,12 @@ func (walLog *WAL[T]) ReadAt(index uint64) (T, error) { } entry, err := walLog.unmarshal(bz) if err != nil { - return zero, fmt.Errorf("unmarshal rlog failed, %w", err) + return zero, fmt.Errorf("unmarshal log failed, %w", err) } return entry, nil } -// Replay will read the replay log and process each log entry with the provided function +// Replay will read the replay log and process each log entry with the provided function. func (walLog *WAL[T]) Replay(start uint64, end uint64, processFn func(index uint64, entry T) error) error { for i := start; i <= end; i++ { bz, err := walLog.log.Read(i) @@ -200,98 +387,76 @@ func (walLog *WAL[T]) Replay(start uint64, end uint64, processFn func(index uint } entry, err := walLog.unmarshal(bz) if err != nil { - return fmt.Errorf("unmarshal rlog failed, %w", err) + return fmt.Errorf("unmarshal log failed, %w", err) + } err = processFn(i, entry) if err != nil { - return err + return fmt.Errorf("process log failed, %w", err) } } return nil } -func (walLog *WAL[T]) startPruning(keepRecent uint64, pruneInterval time.Duration) { - walLog.wg.Add(1) - go func() { - defer walLog.wg.Done() - ticker := time.NewTicker(pruneInterval) - defer ticker.Stop() - for { - select { - case <-walLog.closeCh: - return - case <-ticker.C: - lastIndex, err := walLog.log.LastIndex() - if err != nil { - walLog.logger.Error("failed to get last index for pruning", "err", err) - continue - } - firstIndex, err := walLog.log.FirstIndex() - if err != nil { - walLog.logger.Error("failed to get first index for pruning", "err", err) - continue - } - if lastIndex > keepRecent && (lastIndex-keepRecent) > firstIndex { - prunePos := lastIndex - keepRecent - if err := walLog.TruncateBefore(prunePos); err != nil { - walLog.logger.Error(fmt.Sprintf("failed to prune changelog till index %d", prunePos), "err", err) - } - } - } - } - }() -} - -func (walLog *WAL[T]) Close() error { - walLog.mtx.Lock() - defer walLog.mtx.Unlock() - // Close should only be executed once. - if walLog.isClosed { - return nil - } - // Signal background goroutines to stop. - close(walLog.closeCh) - if walLog.writeChannel != nil { - close(walLog.writeChannel) - walLog.writeChannel = nil +func (walLog *WAL[T]) prune() { + keepRecent := walLog.config.KeepRecent + if keepRecent <= 0 || walLog.config.PruneInterval <= 0 { + // Pruning is disabled. This is a defensive check, since + // this method should only be called if pruning is enabled. + return } - // Wait for all background goroutines (pruning + async write) to finish. - walLog.wg.Wait() - walLog.isClosed = true - return walLog.log.Close() -} -// recordAsyncWriteErr records the first async write error (non-blocking). -func (walLog *WAL[T]) recordAsyncWriteErr(err error) { - if err == nil { + lastIndex, err := walLog.log.LastIndex() + if err != nil { + walLog.logger.Error("failed to get last index for pruning", "err", err) return } - select { - case walLog.asyncWriteErrCh <- err: - default: - // already recorded + firstIndex, err := walLog.log.FirstIndex() + if err != nil { + walLog.logger.Error("failed to get first index for pruning", "err", err) + return + } + + if lastIndex > keepRecent && (lastIndex-keepRecent) > firstIndex { + prunePos := lastIndex - keepRecent + if err := walLog.log.TruncateFront(prunePos); err != nil { + walLog.logger.Error(fmt.Sprintf("failed to prune changelog till index %d", prunePos), "err", err) + } } } -// getAsyncWriteErrLocked returns the async write error if present. -// To keep the error "sticky" without an extra cached field, we implement -// a "peek" by reading once and then non-blocking re-inserting the same -// error back into the buffered channel. -// Caller must hold walLog.mtx (read lock is sufficient). -func (walLog *WAL[T]) getAsyncWriteErrLocked() error { - select { - case err := <-walLog.asyncWriteErrCh: - // Put it back so subsequent callers still observe it. +// drain processes all pending requests so in-flight work completes before shutdown. +func (walLog *WAL[T]) drain() { + for { select { - case walLog.asyncWriteErrCh <- err: + case req := <-walLog.writeChan: + walLog.handleWrite(req) + case req := <-walLog.truncateChan: + walLog.handleTruncate(req) default: + return } - return err - default: - return nil } } +// Shut down the WAL. Sends a close request to the main loop so in-flight writes (and other work) +// can complete before teardown. Idempotent. +func (walLog *WAL[T]) Close() error { + _ = interuptablePush(walLog.ctx, walLog.closeReqChan, struct{}{}) + // If error is non-nil then this is not the first call to Close(), no problem since Close() is idempotent + + err := <-walLog.closeErrChan + + // "reload" error into channel to make Close() idempotent + walLog.closeErrChan <- err + + if err != nil { + return fmt.Errorf("error encountered while shutting down: %w", err) + } + + return nil +} + // open opens the replay log, try to truncate the corrupted tail if there's any func open(dir string, opts *wal.Options) (*wal.Log, error) { if opts == nil { @@ -325,3 +490,66 @@ func open(dir string, opts *wal.Options) (*wal.Log, error) { } return rlog, err } + +// The main loop doing work in the background. +func (walLog *WAL[T]) mainLoop() { + + var pruneChan <-chan time.Time + if walLog.config.PruneInterval > 0 && walLog.config.KeepRecent > 0 { + pruneTicker := time.NewTicker(walLog.config.PruneInterval) + defer pruneTicker.Stop() + pruneChan = pruneTicker.C + } + + running := true + for running { + select { + case <-walLog.ctx.Done(): + running = false + case req := <-walLog.writeChan: + walLog.handleWrite(req) + case req := <-walLog.truncateChan: + walLog.handleTruncate(req) + case <-pruneChan: + walLog.prune() + case <-walLog.closeReqChan: + running = false + } + } + + walLog.cancel() + + // drain pending work, then tear down + walLog.drain() + + err := walLog.log.Close() + if err != nil { + walLog.closeErrChan <- fmt.Errorf("wal returned error during shutdown: %w", err) + } else { + walLog.closeErrChan <- nil + } +} + +// Push to a channel, returning an error if the context is cancelled before the value is pushed. +func interuptablePush[T any](ctx context.Context, ch chan T, value T) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w", ctx.Err()) + case ch <- value: + return nil + } +} + +// Pull from a channel, returning an error if the context is cancelled before the value is pulled. +func interuptablePull[T any](ctx context.Context, ch <-chan T) (T, error) { + var zero T + select { + case <-ctx.Done(): + return zero, fmt.Errorf("context cancelled: %w", ctx.Err()) + case value, ok := <-ch: + if !ok { + return zero, fmt.Errorf("channel closed") + } + return value, nil + } +} diff --git a/sei-db/wal/wal_bench_test.go b/sei-db/wal/wal_bench_test.go index b5714b270d..f4e27ed263 100644 --- a/sei-db/wal/wal_bench_test.go +++ b/sei-db/wal/wal_bench_test.go @@ -86,7 +86,7 @@ func BenchmarkWALWrapperWrite(b *testing.B) { b.Run(name, func(b *testing.B) { dir := b.TempDir() - w, err := NewWAL(marshal, unmarshal, logger.NewNopLogger(), dir, Config{ + w, err := NewWAL(b.Context(), marshal, unmarshal, logger.NewNopLogger(), dir, Config{ WriteBufferSize: bufSize, }) if err != nil { diff --git a/sei-db/wal/wal_test.go b/sei-db/wal/wal_test.go index 76eaaf268b..645de5278b 100644 --- a/sei-db/wal/wal_test.go +++ b/sei-db/wal/wal_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -114,7 +115,7 @@ func TestRandomRead(t *testing.T) { func prepareTestData(t *testing.T) *WAL[proto.ChangelogEntry] { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) writeTestData(t, changelog) return changelog @@ -139,12 +140,12 @@ func TestSynchronousWrite(t *testing.T) { lastIndex, err := changelog.LastOffset() require.NoError(t, err) require.Equal(t, uint64(3), lastIndex) - } func TestAsyncWrite(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 10}) require.NoError(t, err) for _, changes := range ChangeSets { cs := []*proto.NamedChangeSet{ @@ -160,7 +161,8 @@ func TestAsyncWrite(t *testing.T) { } err = changelog.Close() require.NoError(t, err) - changelog, err = NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + changelog, err = NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 10}) require.NoError(t, err) lastIndex, err := changelog.LastOffset() require.NoError(t, err) @@ -253,7 +255,7 @@ func TestTruncateBefore(t *testing.T) { func TestCloseSyncMode(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) // Write some data in sync mode @@ -263,11 +265,8 @@ func TestCloseSyncMode(t *testing.T) { err = changelog.Close() require.NoError(t, err) - // Verify isClosed is set - require.True(t, changelog.isClosed) - // Reopen and verify data persisted - changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog2, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, changelog2.Close()) }) @@ -298,21 +297,21 @@ func TestReplayWithError(t *testing.T) { return nil }) require.Error(t, err) - require.Equal(t, expectedErr, err) + require.True(t, strings.Contains(err.Error(), expectedErr.Error())) } func TestReopenAndContinueWrite(t *testing.T) { dir := t.TempDir() // Create and write initial data - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) writeTestData(t, changelog) err = changelog.Close() require.NoError(t, err) // Reopen and continue writing - changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog2, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) // Verify nextOffset is correctly set after reopen @@ -343,7 +342,7 @@ func TestReopenAndContinueWrite(t *testing.T) { func TestEmptyLog(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, changelog.Close()) }) @@ -359,7 +358,8 @@ func TestEmptyLog(t *testing.T) { func TestCheckErrorNoError(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 10}) require.NoError(t, err) // Write some data to initialize async mode @@ -389,7 +389,8 @@ func TestAsyncWriteReopenAndContinue(t *testing.T) { dir := t.TempDir() // Create with async write and write data - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 10}) require.NoError(t, err) for _, changes := range ChangeSets { @@ -403,7 +404,8 @@ func TestAsyncWriteReopenAndContinue(t *testing.T) { require.NoError(t, err) // Reopen with async write and continue - changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + changelog2, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 10}) require.NoError(t, err) // Write more entries @@ -418,7 +420,7 @@ func TestAsyncWriteReopenAndContinue(t *testing.T) { require.NoError(t, err) // Reopen and verify all 6 entries - changelog3, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog3, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, changelog3.Close()) }) @@ -441,9 +443,62 @@ func TestReplaySingleEntry(t *testing.T) { require.Equal(t, 1, count) } +// TestBatchWrite exercises the batch write path by writing many entries quickly so they +// are processed in batches, then verifies all entries were written correctly. +func TestBatchWrite(t *testing.T) { + const ( + batchSize = 8 + numWrites = 32 + ) + dir := t.TempDir() + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{ + WriteBatchSize: batchSize, + WriteBufferSize: 64, + }) + require.NoError(t, err) + + // Pump writes quickly so the main loop batches them (handleBatchedWrite drains up to batchSize). + for i := 0; i < numWrites; i++ { + entry := &proto.ChangelogEntry{} + entry.Changesets = []*proto.NamedChangeSet{{ + Name: fmt.Sprintf("batch-%d", i), + Changeset: iavl.ChangeSet{Pairs: MockKVPairs(fmt.Sprintf("key-%d", i), fmt.Sprintf("val-%d", i))}, + }} + require.NoError(t, changelog.Write(*entry)) + } + + require.NoError(t, changelog.Close()) + + // Reopen and verify all entries + changelog2, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, changelog2.Close()) }) + + first, err := changelog2.FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(1), first) + last, err := changelog2.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(numWrites), last) + + var replayed int + err = changelog2.Replay(1, uint64(numWrites), func(index uint64, entry proto.ChangelogEntry) error { + replayed++ + require.Len(t, entry.Changesets, 1) + require.Equal(t, fmt.Sprintf("batch-%d", index-1), entry.Changesets[0].Name) + require.Len(t, entry.Changesets[0].Changeset.Pairs, 1) + require.Equal(t, []byte(fmt.Sprintf("key-%d", index-1)), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte(fmt.Sprintf("val-%d", index-1)), entry.Changesets[0].Changeset.Pairs[0].Value) + return nil + }) + require.NoError(t, err) + require.Equal(t, numWrites, replayed) +} + func TestWriteMultipleChangesets(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, changelog.Close()) }) @@ -469,7 +524,8 @@ func TestWriteMultipleChangesets(t *testing.T) { func TestConcurrentCloseWithInFlightAsyncWrites(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 8}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, + Config{WriteBufferSize: 8}) require.NoError(t, err) // Intentionally avoid t.Cleanup here: we want Close() to race with in-flight async writes. @@ -535,7 +591,7 @@ func TestConcurrentCloseWithInFlightAsyncWrites(t *testing.T) { func TestConcurrentTruncateBeforeWithAsyncWrites(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{ + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{ WriteBufferSize: 10, KeepRecent: 10, PruneInterval: 1 * time.Millisecond, @@ -606,7 +662,7 @@ func TestConcurrentTruncateBeforeWithAsyncWrites(t *testing.T) { func TestGetLastIndex(t *testing.T) { dir := t.TempDir() - changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + changelog, err := NewWAL(t.Context(), marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) require.NoError(t, err) writeTestData(t, changelog) err = changelog.Close() @@ -622,3 +678,103 @@ func TestLogPath(t *testing.T) { path := LogPath("/some/dir") require.Equal(t, "/some/dir/changelog", path) } + +// batchTestEntry is a simple type for testing batch marshal failures. +type batchTestEntry struct { + value string +} + +func TestBatchWriteWithMarshalFailure(t *testing.T) { + dir := t.TempDir() + + // Marshal fails for entries with value "fail" + marshalBatchTest := func(e batchTestEntry) ([]byte, error) { + if e.value == "fail" { + return nil, fmt.Errorf("mock marshal failure") + } + return []byte(e.value), nil + } + unmarshalBatchTest := func(b []byte) (batchTestEntry, error) { + return batchTestEntry{value: string(b)}, nil + } + + // Use sync writes (WriteBufferSize 0) and batching (WriteBatchSize 4) + // so we can observe per-write errors. The channel buffer allows multiple + // goroutines to push before the handler runs, forming a batch. + config := Config{ + WriteBufferSize: 0, // sync writes + WriteBatchSize: 4, // batch up to 4 + } + + w, err := NewWAL(t.Context(), marshalBatchTest, unmarshalBatchTest, logger.NewNopLogger(), dir, config) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, w.Close()) }) + + // Write 4 entries concurrently so they get batched. The second one will fail to marshal. + entries := []batchTestEntry{ + {value: "ok1"}, + {value: "fail"}, + {value: "ok2"}, + {value: "ok3"}, + } + + var wg sync.WaitGroup + errs := make([]error, 4) + for i := range entries { + wg.Add(1) + go func(idx int) { + defer wg.Done() + errs[idx] = w.Write(entries[idx]) + }(i) + } + wg.Wait() + + // The "fail" entry should have errored + require.Error(t, errs[1]) + require.Contains(t, errs[1].Error(), "mock marshal failure") + + // The successful entries should have no error + require.NoError(t, errs[0]) + require.NoError(t, errs[2]) + require.NoError(t, errs[3]) + + // The WAL should contain exactly 3 entries (the successfully marshalled ones; "fail" is skipped) + lastOffset, err := w.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastOffset) + + // Goroutines may push in any order, so we collect the written values and verify we have ok1, ok2, ok3 + written := make(map[string]bool) + for i := uint64(1); i <= 3; i++ { + e, err := w.ReadAt(i) + require.NoError(t, err) + written[e.value] = true + } + require.True(t, written["ok1"], "expected ok1 in WAL") + require.True(t, written["ok2"], "expected ok2 in WAL") + require.True(t, written["ok3"], "expected ok3 in WAL") + require.False(t, written["fail"], "fail should not be in WAL") +} + +func TestMultipleCloseCalls(t *testing.T) { + changelog := prepareTestData(t) + entry, err := changelog.ReadAt(2) + require.NoError(t, err) + require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world1"), entry.Changesets[0].Changeset.Pairs[0].Value) + require.Equal(t, []byte("hello2"), entry.Changesets[0].Changeset.Pairs[1].Key) + require.Equal(t, []byte("world2"), entry.Changesets[0].Changeset.Pairs[1].Value) + entry, err = changelog.ReadAt(1) + require.NoError(t, err) + require.Equal(t, []byte("hello"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world"), entry.Changesets[0].Changeset.Pairs[0].Value) + entry, err = changelog.ReadAt(3) + require.NoError(t, err) + require.Equal(t, []byte("hello3"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world3"), entry.Changesets[0].Changeset.Pairs[0].Value) + + // Calling close lots of times shouldn't cause any problems. + for i := 0; i < 10; i++ { + require.NoError(t, changelog.Close()) + } +}