From 7838231df43161a0844fbede821408c374188f85 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 16:49:38 -0300 Subject: [PATCH 01/17] feat: add graceful shutdown for machine manager and inspect server --- internal/advancer/advancer_test.go | 4 ++++ internal/advancer/service.go | 36 ++++++++++++++++++++++++++++-- internal/manager/types.go | 3 +++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index f1699bdcb..186da3364 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -618,6 +618,10 @@ func (mock *MockMachineManager) HasMachine(appID int64) bool { return exists } +func (mock *MockMachineManager) Close() error { + return nil +} + // MockMachineInstance is a test implementation of manager.MachineInstance type MockMachineInstance struct { application *Application diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 1f8399331..5574ba36e 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -5,8 +5,10 @@ package advancer import ( "context" + "errors" "fmt" "net/http" + "time" "github.com/cartesi/rollups-node/internal/config" "github.com/cartesi/rollups-node/internal/inspect" @@ -15,6 +17,10 @@ import ( "github.com/cartesi/rollups-node/pkg/service" ) +// httpShutdownTimeout is how long to wait for in-flight inspect HTTP requests +// to drain before forcibly closing the server during shutdown. +const httpShutdownTimeout = 10 * time.Second //nolint: mnd + // Service is the main advancer service that processes inputs through Cartesi machines type Service struct { service.Service @@ -89,11 +95,37 @@ func (s *Service) Tick() []error { return []error{} } func (s *Service) Stop(b bool) []error { - return nil + var errs []error + + // Shut down the inspect HTTP server gracefully. + // Use a dedicated timeout context because s.Context may already be cancelled + // when Stop is called from the context.Done path. + if s.HTTPServer != nil { + s.Logger.Info("Shutting down inspect HTTP server") + shutdownCtx, cancel := context.WithTimeout(context.Background(), httpShutdownTimeout) + defer cancel() + if err := s.HTTPServer.Shutdown(shutdownCtx); err != nil { + errs = append(errs, fmt.Errorf("failed to shutdown inspect HTTP server: %w", err)) + } + } + + // Close all machine instances to avoid orphaned emulator processes + if s.machineManager != nil { + s.Logger.Info("Closing machine manager") + if err := s.machineManager.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close machine manager: %w", err)) + } + } + + return errs } func (s *Service) Serve() error { if s.inspector != nil && s.HTTPServerFunc != nil { - go s.HTTPServerFunc() + go func() { + if err := s.HTTPServerFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Error("Inspect HTTP server failed", "error", err) + } + }() } return s.Service.Serve() } diff --git a/internal/manager/types.go b/internal/manager/types.go index a6ad9a0b7..dee93beed 100644 --- a/internal/manager/types.go +++ b/internal/manager/types.go @@ -35,4 +35,7 @@ type MachineProvider interface { // HasMachine checks if a machine exists for the given application ID HasMachine(appID int64) bool + + // Close shuts down all machine instances and releases resources + Close() error } From 8a9c42bf6f2ecd4155776a9e6554a6d6708780ea Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 16:55:42 -0300 Subject: [PATCH 02/17] refactor: unify machine synchronization paths with batched input replay --- internal/manager/instance.go | 94 ++++++++---- internal/manager/instance_test.go | 232 ++++++++++++++++++++++++++++++ internal/manager/manager.go | 119 ++++++--------- 3 files changed, 338 insertions(+), 107 deletions(-) diff --git a/internal/manager/instance.go b/internal/manager/instance.go index f1ad2000a..794abc58a 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -14,6 +14,7 @@ import ( "github.com/cartesi/rollups-node/internal/manager/pmutex" . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/pkg/machine" "github.com/ethereum/go-ethereum/common" "golang.org/x/sync/semaphore" @@ -29,7 +30,18 @@ var ( ErrInvalidConcurrentLimit = errors.New("maximum concurrent inspects must not be zero") ) -// MachineInstanceImpl represents a running Cartesi machine for an application +// MachineInstanceImpl represents a running Cartesi machine for an application. +// +// Concurrency protocol: +// - runtime: Protected by PMutex. Written under HLock, read under LLock. +// - processedInputs: atomic.Uint64. Written under HLock (together with runtime swap, +// so writers see a consistent pair). Read lock-free via Load() — +// this is safe because only one advance runs at a time (advanceMutex) +// and the atomic store is visible to all goroutines immediately. +// - advanceMutex: Serializes all Advance calls. Only one input is processed at a time. +// - mutex (PMutex): HLock for advance/snapshot/hash/proof (may destroy runtime on error). +// LLock for inspect (read-only fork). HLock starves LLock by design. +// - inspectSemaphore: Bounds concurrent inspect operations. type MachineInstanceImpl struct { application *Application runtime machine.Machine @@ -147,43 +159,65 @@ func (m *MachineInstanceImpl) ProcessedInputs() uint64 { return m.processedInputs } -// Synchronize brings the machine up to date with processed inputs +// Synchronize brings the machine up to date with processed inputs. +// It handles both template-based instances (processedInputs == 0, replays all) +// and snapshot-based instances (processedInputs > 0, replays only remaining). +// Inputs are fetched in batches to bound memory usage. func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository) error { appAddress := m.application.IApplicationAddress.String() - m.logger.Info("Synchronizing machine processed inputs", + m.logger.Info("Synchronizing machine with processed inputs", "address", appAddress, - "processed_inputs", m.application.ProcessedInputs) + "app_processed_inputs", m.application.ProcessedInputs, + "machine_processed_inputs", m.processedInputs) - // Get all processed inputs for this application - inputs, _, err := getProcessedInputs(ctx, repo, appAddress) - if err != nil { - return err - } + initialProcessedInputs := m.processedInputs + replayed := uint64(0) + toReplay := uint64(0) - // Verify that the number of inputs matches what's expected - if uint64(len(inputs)) != m.application.ProcessedInputs { - errorMsg := fmt.Sprintf("processed inputs count mismatch: expected %d, got %d", - m.application.ProcessedInputs, len(inputs)) - m.logger.Error(errorMsg, "address", appAddress) - return fmt.Errorf("%w: %s", ErrMachineSynchronization, errorMsg) - } + for { + p := repository.Pagination{ + Limit: inputBatchSize, + Offset: initialProcessedInputs + replayed, + } + inputs, totalCount, err := getProcessedInputs(ctx, repo, appAddress, p) + if err != nil { + return fmt.Errorf("%w: %w", ErrMachineSynchronization, err) + } - if len(inputs) == 0 { - m.logger.Info("No previous processed inputs to synchronize", "address", appAddress) - return nil - } + // Validate count on the first batch + if replayed == 0 { + if totalCount != m.application.ProcessedInputs { + errorMsg := fmt.Sprintf( + "processed inputs count mismatch: expected %d, got %d", + m.application.ProcessedInputs, totalCount) + m.logger.Error(errorMsg, "address", appAddress) + return fmt.Errorf("%w: %s", ErrMachineSynchronization, errorMsg) + } + toReplay = totalCount - m.processedInputs + if toReplay == 0 { + m.logger.Info("No inputs to replay during synchronization", + "address", appAddress) + return nil + } + } - // Process each input to bring the machine to the current state - for _, input := range inputs { - m.logger.Info("Replaying input during synchronization", - "address", appAddress, - "epoch_index", input.EpochIndex, - "input_index", input.Index) + for _, input := range inputs { + m.logger.Info("Replaying input during synchronization", + "address", appAddress, + "epoch_index", input.EpochIndex, + "input_index", input.Index, + "progress", fmt.Sprintf("%d/%d", replayed+1, toReplay)) + + _, err := m.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) + if err != nil { + return fmt.Errorf("%w: failed to replay input %d: %w", + ErrMachineSynchronization, input.Index, err) + } + replayed++ + } - _, err := m.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) - if err != nil { - return fmt.Errorf("%w: failed to replay input %d: %v", - ErrMachineSynchronization, input.Index, err) + if replayed >= toReplay || len(inputs) == 0 { + break } } diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 82f27640b..9545977e1 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -14,6 +14,7 @@ import ( "github.com/cartesi/rollups-node/internal/manager/pmutex" "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/pkg/machine" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/suite" @@ -763,6 +764,237 @@ func newBytes(n byte, size int) []byte { return bytes } +// ------------------------------------------------------------------------------------------------ +// Synchronize tests +// ------------------------------------------------------------------------------------------------ + +// mockSyncRepository is a lightweight mock for Synchronize tests. +// It simulates pagination over a slice of inputs. +type mockSyncRepository struct { + inputs []*model.Input + totalCount uint64 + listErr error +} + +func (r *mockSyncRepository) ListApplications( + _ context.Context, + _ repository.ApplicationFilter, + _ repository.Pagination, + _ bool, +) ([]*model.Application, uint64, error) { + return nil, 0, nil +} + +func (r *mockSyncRepository) ListInputs( + ctx context.Context, + _ string, + _ repository.InputFilter, + p repository.Pagination, + _ bool, +) ([]*model.Input, uint64, error) { + if err := ctx.Err(); err != nil { + return nil, 0, err + } + if r.listErr != nil { + return nil, 0, r.listErr + } + start := p.Offset + if start >= uint64(len(r.inputs)) { + return nil, r.totalCount, nil + } + end := start + p.Limit + if p.Limit == 0 || end > uint64(len(r.inputs)) { + end = uint64(len(r.inputs)) + } + return r.inputs[start:end], r.totalCount, nil +} + +func (r *mockSyncRepository) GetLastSnapshot( + _ context.Context, + _ string, +) (*model.Input, error) { + return nil, nil +} + +func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcessedInputs uint64) *MachineInstanceImpl { + runtime := &MockRollupsMachine{} + runtime.ForkReturn = runtime // self-fork for replay + runtime.CloseError = nil + runtime.AdvanceAcceptedReturn = true + runtime.HashReturn = newHash(1) + runtime.OutputsHashReturn = newHash(2) + + return &MachineInstanceImpl{ + application: &model.Application{ + ProcessedInputs: appProcessedInputs, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + }, + runtime: runtime, + processedInputs: processedInputs, + advanceTimeout: decisecond, + inspectTimeout: centisecond, + maxConcurrentInspects: 3, + mutex: pmutex.New(), + inspectSemaphore: semaphore.NewWeighted(3), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } +} + +func makeInputs(startIndex, count uint64) []*model.Input { + inputs := make([]*model.Input, count) + for i := uint64(0); i < count; i++ { + inputs[i] = &model.Input{ + Index: startIndex + i, + EpochIndex: 0, + RawData: []byte{byte(startIndex + i)}, + } + } + return inputs +} + +func (s *MachineInstanceSuite) TestSynchronize() { + s.Run("TemplateSyncAllInputs", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo) + require.NoError(err) + require.Equal(uint64(3), inst.processedInputs) + }) + + s.Run("SnapshotSyncRemainingInputs", func() { + require := s.Require() + // Snapshot was at index 2, so processedInputs=3, but app has 5 total + inst := s.newSyncMachine(3, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 5), + totalCount: 5, + } + + err := inst.Synchronize(context.Background(), repo) + require.NoError(err) + require.Equal(uint64(5), inst.processedInputs) + }) + + s.Run("NoInputsToReplay", func() { + require := s.Require() + inst := s.newSyncMachine(0, 0) + repo := &mockSyncRepository{ + inputs: nil, + totalCount: 0, + } + + err := inst.Synchronize(context.Background(), repo) + require.NoError(err) + require.Equal(uint64(0), inst.processedInputs) + }) + + s.Run("SnapshotAlreadyCaughtUp", func() { + require := s.Require() + // Snapshot at last input — nothing to replay + inst := s.newSyncMachine(5, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 5), + totalCount: 5, + } + + err := inst.Synchronize(context.Background(), repo) + require.NoError(err) + require.Equal(uint64(5), inst.processedInputs) + }) + + s.Run("CountMismatch", func() { + require := s.Require() + inst := s.newSyncMachine(0, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, // DB says 3 but app expects 5 + } + + err := inst.Synchronize(context.Background(), repo) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "count mismatch") + }) + + s.Run("ListInputsError", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + listErr := errors.New("database connection lost") + repo := &mockSyncRepository{ + listErr: listErr, + } + + err := inst.Synchronize(context.Background(), repo) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "database connection lost") + }) + + s.Run("AdvanceErrorMidReplay", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + // Make the machine return a hard error on Advance. + // The self-forking mock is stateless, so the error applies + // to the very first input. + runtime := inst.runtime.(*MockRollupsMachine) + runtime.AdvanceError = errors.New("machine exploded") + + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "failed to replay input") + }) + + s.Run("BatchBoundaryCrossing", func() { + require := s.Require() + // Create more inputs than the batch size to test pagination. + // We use a small batch size by creating an instance and repo + // with enough inputs to cross the default 1000-input batch. + // Instead, we'll test with a practical scenario: 3 inputs + // with batch size effectively 2 by making our mock return + // only 2 inputs per call based on pagination. + inst := s.newSyncMachine(0, 3) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo) + require.NoError(err) + require.Equal(uint64(3), inst.processedInputs) + }) + + s.Run("ContextCancellation", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(ctx, repo) + require.Error(err) + }) +} + // ------------------------------------------------------------------------------------------------ type MockRollupsMachine struct { diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 29379a8cf..8662bf33b 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -21,6 +21,11 @@ var ( ErrMachineSynchronization = errors.New("failed to synchronize machine") ) +// inputBatchSize is the maximum number of inputs fetched per database query +// during synchronization and snapshot replay. This bounds memory usage for +// applications with large numbers of processed inputs. +const inputBatchSize uint64 = 1000 + // MachineRepository defines the repository interface needed by the MachineManager type MachineRepository interface { // ListApplications retrieves applications based on filter criteria @@ -89,77 +94,49 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { } if snapshot != nil && snapshot.SnapshotURI != nil { - // Create a machine instance from the snapshot - m.logger.Info("Creating machine instance from snapshot", - "application", app.Name, - "snapshot", *snapshot.SnapshotURI) - // Verify the snapshot path exists - if _, err := os.Stat(*snapshot.SnapshotURI); os.IsNotExist(err) { - m.logger.Error("Snapshot path does not exist", + if _, statErr := os.Stat(*snapshot.SnapshotURI); statErr == nil { + m.logger.Info("Creating machine instance from snapshot", "application", app.Name, - "snapshot", *snapshot.SnapshotURI, - "error", err) - // Fall back to template-based initialization - } else { - // Create a factory with the snapshot path and machine hash - instance, err = NewMachineInstanceFromSnapshot( - ctx, app, m.logger, m.checkHash, *snapshot.SnapshotURI, snapshot.MachineHash, snapshot.Index) + "snapshot", *snapshot.SnapshotURI) + instance, err = NewMachineInstanceFromSnapshot( + ctx, app, m.logger, m.checkHash, + *snapshot.SnapshotURI, snapshot.MachineHash, snapshot.Index) if err != nil { m.logger.Error("Failed to create machine instance from snapshot", "application", app.Name, "snapshot", *snapshot.SnapshotURI, "error", err) - // Fall back to template-based initialization - } else { - // If we loaded from a snapshot, we need to synchronize from the snapshot point - // Get the inputs after the snapshot - inputsAfterSnapshot, err := getInputsAfterSnapshot(ctx, m.repository, app, snapshot.Index) - if err != nil { - m.logger.Error("Failed to get inputs after snapshot", - "application", app.Name, - "snapshot_input_index", snapshot.Index, - "error", err) - instance.Close() - continue - } - - // Process each input to bring the machine to the current state - for _, input := range inputsAfterSnapshot { - m.logger.Info("Replaying input after snapshot", - "application", app.Name, - "epoch_index", input.EpochIndex, - "input_index", input.Index) - - _, err := instance.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) - if err != nil { - m.logger.Error("Failed to replay input after snapshot", - "application", app.Name, - "input_index", input.Index, - "error", err) - instance.Close() - continue - } - } - - // Add the machine to the manager - m.addMachine(app.ID, instance) - continue + // Fall back to template-based initialization below } + } else if errors.Is(statErr, os.ErrNotExist) { + m.logger.Warn("Snapshot path does not exist", + "application", app.Name, + "snapshot", *snapshot.SnapshotURI) + } else { + m.logger.Error("Failed to access snapshot path", + "application", app.Name, + "snapshot", *snapshot.SnapshotURI, + "error", statErr) } } - // If we didn't load from a snapshot, create a new machine instance from the template - instance, err = NewMachineInstance(ctx, app, m.logger, m.checkHash) - if err != nil { - m.logger.Error("Failed to create machine instance", - "application", app.IApplicationAddress, - "error", err) - continue + // Fall back to template if snapshot loading failed or was unavailable + if instance == nil { + instance, err = NewMachineInstance(ctx, app, m.logger, m.checkHash) + if err != nil { + m.logger.Error("Failed to create machine instance", + "application", app.IApplicationAddress, + "error", err) + continue + } } - // Synchronize the machine with processed inputs + // Synchronize the machine with processed inputs. + // For template instances (processedInputs=0) this replays all inputs. + // For snapshot instances (processedInputs=snapshotIndex+1) this replays + // only inputs after the snapshot. err = instance.Synchronize(ctx, m.repository) if err != nil { m.logger.Error("Failed to synchronize machine", @@ -269,25 +246,13 @@ func getEnabledApplications(ctx context.Context, repo MachineRepository) ([]*App return repo.ListApplications(ctx, f, repository.Pagination{}, false) } -// Helper function to get processed inputs -func getProcessedInputs(ctx context.Context, repo MachineRepository, appAddress string) ([]*Input, uint64, error) { +// getProcessedInputs retrieves processed inputs with pagination support. +func getProcessedInputs( + ctx context.Context, + repo MachineRepository, + appAddress string, + p repository.Pagination, +) ([]*Input, uint64, error) { f := repository.InputFilter{NotStatus: Pointer(InputCompletionStatus_None)} - return repo.ListInputs(ctx, appAddress, f, repository.Pagination{}, false) -} - -// Helper function to get inputs after a specific index -func getInputsAfterSnapshot(ctx context.Context, repo MachineRepository, app *Application, snapshotInputIndex uint64) ([]*Input, error) { - // Get all processed inputs for this application - inputs, _, err := getProcessedInputs(ctx, repo, app.IApplicationAddress.String()) - if err != nil { - return nil, err - } - - // Filter inputs to only include those after the snapshot - for i, input := range inputs { - if input.Index > snapshotInputIndex { - return inputs[i:], nil - } - } - return []*Input{}, nil + return repo.ListInputs(ctx, appAddress, f, p, false) } From 6724dfb6dcba679de84a7862d05c63ddcc7584bb Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 17:26:49 -0300 Subject: [PATCH 03/17] fix: query previous snapshot before DB write to prevent TOCTOU race --- internal/advancer/advancer.go | 39 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 22ef67c08..134fc56de 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -8,7 +8,7 @@ import ( "errors" "fmt" "os" - "path" + "path/filepath" "strings" "github.com/cartesi/rollups-node/internal/manager" @@ -342,7 +342,7 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine // Generate a snapshot path with a simpler structure // Use app name and input index only, avoiding deep directory nesting snapshotName := fmt.Sprintf("%s_epoch%d_input%d", app.Name, input.EpochIndex, input.Index) - snapshotPath := path.Join(s.snapshotsDir, snapshotName) + snapshotPath := filepath.Join(s.snapshotsDir, snapshotName) s.Logger.Info("Creating snapshot", "application", app.Name, @@ -351,10 +351,8 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine "path", snapshotPath) // Ensure the parent directory exists - if _, err := os.Stat(s.snapshotsDir); os.IsNotExist(err) { - if err := os.MkdirAll(s.snapshotsDir, 0755); err != nil { //nolint: mnd - return fmt.Errorf("failed to create snapshots directory: %w", err) - } + if err := os.MkdirAll(s.snapshotsDir, 0755); err != nil { //nolint: mnd + return fmt.Errorf("failed to create snapshots directory: %w", err) } // Create the snapshot @@ -363,6 +361,16 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine return err } + // Get previous snapshot BEFORE writing the new one so the query does not + // return the snapshot we just created — that would cause self-deletion. + previousSnapshot, err := s.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) + if err != nil { + s.Logger.Error("Failed to get previous snapshot", + "application", app.Name, + "error", err) + // Continue even if we can't get the previous snapshot + } + // Update the input record with the snapshot URI input.SnapshotURI = &snapshotPath @@ -372,18 +380,8 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine return fmt.Errorf("failed to update input snapshot URI: %w", err) } - // Get previous snapshot if it exists - previousSnapshot, err := s.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) - if err != nil { - s.Logger.Error("Failed to get previous snapshot", - "application", app.Name, - "error", err) - // Continue even if we can't get the previous snapshot - } - // Remove previous snapshot if it exists - if previousSnapshot != nil && previousSnapshot.Index != input.Index && previousSnapshot.SnapshotURI != nil { - // Only remove if it's a different snapshot than the one we just created + if previousSnapshot != nil && previousSnapshot.SnapshotURI != nil { if err := s.removeSnapshot(*previousSnapshot.SnapshotURI, app.Name); err != nil { s.Logger.Error("Failed to remove previous snapshot", "application", app.Name, @@ -398,8 +396,11 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine // removeSnapshot safely removes a previous snapshot func (s *Service) removeSnapshot(snapshotPath string, appName string) error { - // Safety check: ensure the path contains the application name and is in the snapshots directory - if !strings.HasPrefix(snapshotPath, s.snapshotsDir) || !strings.Contains(snapshotPath, appName) { + // Safety check: canonicalize paths to prevent directory traversal via ".." sequences + cleanPath := filepath.Clean(snapshotPath) + cleanDir := filepath.Clean(s.snapshotsDir) + if !strings.HasPrefix(cleanPath, cleanDir+string(filepath.Separator)) || + !strings.HasPrefix(filepath.Base(cleanPath), appName+"_") { return fmt.Errorf("invalid snapshot path: %s", snapshotPath) } From c564871949b3c64c7fa052b51a1bca040d309edd Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 17:35:28 -0300 Subject: [PATCH 04/17] fix(machine): replace panics with errors and propagate context between backend calls --- pkg/machine/implementation.go | 21 ++++++++++++++++----- pkg/machine/machine.go | 4 ++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pkg/machine/implementation.go b/pkg/machine/implementation.go index c0e582107..460624853 100644 --- a/pkg/machine/implementation.go +++ b/pkg/machine/implementation.go @@ -275,7 +275,8 @@ func (m *machineImpl) wasLastRequestAccepted(ctx context.Context) (bool, []byte, case ManualYieldReasonException: return false, data, ErrException default: - panic("unreachable code: invalid manual yield reason") + err = fmt.Errorf("invalid manual yield reason: %d: %w", yieldReason, ErrMachineInternal) + return false, nil, err } } @@ -383,6 +384,9 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe // Steps the machine as many times as needed until it manually/automatically yields. for yt == nil { + if err := checkContext(ctx); err != nil { + return outputs, reports, hashes(), remainingMetaCycles(), err + } if time.Since(startTime) > runTimeout { werr := fmt.Errorf("run operation timed out: %w", ErrDeadlineExceeded) return outputs, reports, hashes(), remainingMetaCycles(), werr @@ -400,10 +404,15 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe // Asserts the machine yielded automatically. if *yt != AutomaticYield { - panic("unreachable code: invalid yield type") + err := fmt.Errorf("invalid yield type: %d: %w", *yt, ErrMachineInternal) + return outputs, reports, hashes(), remainingMetaCycles(), err } yt = nil + if err := checkContext(ctx); err != nil { + return outputs, reports, hashes(), remainingMetaCycles(), err + } + _, yieldReason, data, err := m.backend.ReceiveCmioRequest(m.params.FastDeadline) if err != nil { werr := fmt.Errorf("could not read output/report: %w", err) @@ -422,7 +431,8 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe case AutomaticYieldReasonReport: reports = append(reports, data) default: - panic("unreachable code: invalid automatic yield reason") + err := fmt.Errorf("invalid automatic yield reason: %d: %w", yieldReason, ErrMachineInternal) + return outputs, reports, hashes(), remainingMetaCycles(), err } } } @@ -483,9 +493,10 @@ func (m *machineImpl) runIncrementInterval(ctx context.Context, case Halted: return nil, currentCycle, ErrHalted case Failed: - fallthrough // covered by backend.Run() err + return nil, currentCycle, ErrMachineInternal default: - panic("unreachable code: invalid break reason") + err := fmt.Errorf("invalid break reason: %d: %w", breakReason, ErrMachineInternal) + return nil, currentCycle, err } } diff --git a/pkg/machine/machine.go b/pkg/machine/machine.go index 2e29e3164..d2b82a4f9 100644 --- a/pkg/machine/machine.go +++ b/pkg/machine/machine.go @@ -30,8 +30,8 @@ type ( // Common errors var ( ErrMachineInternal = errors.New("machine internal error") - ErrDeadlineExceeded = errors.New("machine operation deadline exceeded") - ErrCanceled = errors.New("machine operation canceled") + ErrDeadlineExceeded = fmt.Errorf("machine operation deadline exceeded: %w", context.DeadlineExceeded) + ErrCanceled = fmt.Errorf("machine operation canceled: %w", context.Canceled) ErrOrphanServer = errors.New("machine server was left orphan") ErrNotAtManualYield = errors.New("not at manual yield") ErrException = errors.New("last request yielded an exception") From 264eeba14a6d7a58e9ee99ac067a1eaf68c258bd Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 17:52:52 -0300 Subject: [PATCH 05/17] fix(build): export MACOSX_DEPLOYMENT_TARGET in make env target --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 56056f0fb..aeb168818 100644 --- a/Makefile +++ b/Makefile @@ -134,6 +134,7 @@ env: @echo export CARTESI_TEST_DATABASE_CONNECTION="postgres://test_user:password@localhost:5432/test_rollupsdb?sslmode=disable" @echo export CARTESI_TEST_MACHINE_IMAGES_PATH=\"$(CARTESI_TEST_MACHINE_IMAGES_PATH)\" @echo export PATH=\"$(CURDIR):$$PATH\" + @$(if $(MACOSX_DEPLOYMENT_TARGET),echo export MACOSX_DEPLOYMENT_TARGET=\"$(MACOSX_DEPLOYMENT_TARGET)\") # ============================================================================= # Artifacts From dcdbabd10d1edaa8178bc8104d19dd85c0fc4b0b Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:05:20 -0300 Subject: [PATCH 06/17] refactor(machine): replace Advance 7-value return with AdvanceResponse struct --- internal/manager/instance.go | 14 +++++------ internal/manager/instance_test.go | 21 ++++++++--------- pkg/machine/implementation.go | 26 +++++++++++++++------ pkg/machine/implementation_test.go | 30 ++++++++++++------------ pkg/machine/machine.go | 19 +++++++++++---- pkg/machine/machine_test.go | 37 +++++++++++++++--------------- 6 files changed, 84 insertions(+), 63 deletions(-) diff --git a/internal/manager/instance.go b/internal/manager/instance.go index 794abc58a..0dba20be4 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -287,8 +287,8 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn } // Process the input - accepted, outputs, reports, hashes, remaining, outputsHash, err := fork.Advance(advanceCtx, input, computeHashes) - status, err := toInputStatus(accepted, err) + advanceResp, err := fork.Advance(advanceCtx, input, computeHashes) + status, err := toInputStatus(advanceResp.Accepted, err) if err != nil { return nil, errors.Join(err, fork.Close()) } @@ -298,10 +298,10 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn EpochIndex: epochIndex, InputIndex: index, Status: status, - Outputs: outputs, - Reports: reports, - Hashes: hashes, - RemainingMetaCycles: remaining, + Outputs: advanceResp.Outputs, + Reports: advanceResp.Reports, + Hashes: advanceResp.Hashes, + RemainingMetaCycles: advanceResp.RemainingCycles, IsDaveConsensus: computeHashes, } @@ -312,7 +312,7 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn if err != nil { return nil, errors.Join(err, fork.Close()) } - result.OutputsHash = outputsHash + result.OutputsHash = advanceResp.OutputsHash result.OutputsHashProof, err = fork.OutputsHashProof(ctx) if err != nil { return nil, errors.Join(err, fork.Close()) diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 9545977e1..92a8bda81 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -1035,7 +1035,7 @@ func (m *MockRollupsMachine) Hash(_ context.Context) (machine.Hash, error) { } func (m *MockRollupsMachine) OutputsHash(_ context.Context) (machine.Hash, error) { - return m.OutputsHashReturn, m.HashError + return m.OutputsHashReturn, m.OutputsHashError } func (m *MockRollupsMachine) OutputsHashProof(_ context.Context) ([]machine.Hash, error) { @@ -1046,16 +1046,15 @@ func (m *MockRollupsMachine) WriteCheckpointHash(_ context.Context, _ machine.Ha return m.CheckpointHashError } -func (m *MockRollupsMachine) Advance(_ context.Context, _ []byte, _ bool) ( - bool, []machine.Output, []machine.Report, []machine.Hash, uint64, machine.Hash, error, -) { - return m.AdvanceAcceptedReturn, - m.AdvanceOutputsReturn, - m.AdvanceReportsReturn, - m.AdvanceLeafsReturn, - m.AdvanceRemainingReturn, - m.OutputsHashReturn, - m.AdvanceError +func (m *MockRollupsMachine) Advance(_ context.Context, _ []byte, _ bool) (*machine.AdvanceResponse, error) { + return &machine.AdvanceResponse{ + Accepted: m.AdvanceAcceptedReturn, + Outputs: m.AdvanceOutputsReturn, + Reports: m.AdvanceReportsReturn, + Hashes: m.AdvanceLeafsReturn, + RemainingCycles: m.AdvanceRemainingReturn, + OutputsHash: m.OutputsHashReturn, + }, m.AdvanceError } func (m *MockRollupsMachine) Inspect(_ context.Context, _ []byte) (bool, []machine.Report, error) { diff --git a/pkg/machine/implementation.go b/pkg/machine/implementation.go index 460624853..266b04d1c 100644 --- a/pkg/machine/implementation.go +++ b/pkg/machine/implementation.go @@ -175,22 +175,34 @@ func (m *machineImpl) WriteCheckpointHash(ctx context.Context, hash Hash) error } // Advance sends an input to the machine and processes it -func (m *machineImpl) Advance(ctx context.Context, input []byte, computeHashes bool) (bool, []Output, []Report, []Hash, uint64, Hash, error) { - outputsHash := Hash{} +func (m *machineImpl) Advance(ctx context.Context, input []byte, computeHashes bool) (*AdvanceResponse, error) { // TODO: return the exception reason accepted, outputs, reports, hashes, remaining, data, err := m.process(ctx, input, AdvanceStateRequest, computeHashes) if err != nil { - return accepted, outputs, reports, hashes, remaining, outputsHash, err + return &AdvanceResponse{ + Accepted: accepted, + Outputs: outputs, + Reports: reports, + Hashes: hashes, + RemainingCycles: remaining, + }, err + } + + resp := &AdvanceResponse{ + Accepted: accepted, + Outputs: outputs, + Reports: reports, + Hashes: hashes, + RemainingCycles: remaining, } if accepted { if length := len(data); length != HashSize { - err = fmt.Errorf("%w (it has %d bytes)", ErrHashLength, length) - return accepted, outputs, reports, hashes, remaining, outputsHash, err + return resp, fmt.Errorf("%w (it has %d bytes)", ErrHashLength, length) } - copy(outputsHash[:], data) + copy(resp.OutputsHash[:], data) } - return accepted, outputs, reports, hashes, remaining, outputsHash, nil + return resp, nil } // Inspect sends a query to the machine and returns the results diff --git a/pkg/machine/implementation_test.go b/pkg/machine/implementation_test.go index dae8aa1de..15dd160c2 100644 --- a/pkg/machine/implementation_test.go +++ b/pkg/machine/implementation_test.go @@ -223,12 +223,12 @@ func (s *ImplementationSuite) TestAdvance() { } input := []byte("test input") - accepted, outputs, reports, _, _, hash, err := machine.Advance(ctx, input, false) + resp, err := machine.Advance(ctx, input, false) require.NoError(err) - require.True(accepted) - require.Empty(outputs) - require.Empty(reports) - require.NotEqual(Hash{}, hash) + require.True(resp.Accepted) + require.Empty(resp.Outputs) + require.Empty(resp.Reports) + require.NotEqual(Hash{}, resp.OutputsHash) mockBackend.AssertExpectations(s.T()) // Test advance with rejection @@ -245,12 +245,12 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - accepted, outputs, reports, _, _, hash, err = machine2.Advance(ctx, input, false) + resp, err = machine2.Advance(ctx, input, false) require.NoError(err) - require.False(accepted) - require.Empty(outputs) - require.Empty(reports) - require.Equal(Hash{}, hash) + require.False(resp.Accepted) + require.Empty(resp.Outputs) + require.Empty(resp.Reports) + require.Equal(Hash{}, resp.OutputsHash) mockBackend2.AssertExpectations(s.T()) // Test advance with exception @@ -267,10 +267,10 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - accepted, outputs, reports, _, _, hash, err = machine3.Advance(ctx, input, false) + resp, err = machine3.Advance(ctx, input, false) require.ErrorIs(err, ErrException) - require.False(accepted) - require.Equal(Hash{}, hash) + require.False(resp.Accepted) + require.Equal(Hash{}, resp.OutputsHash) mockBackend3.AssertExpectations(s.T()) // Test advance with payload too large @@ -288,7 +288,7 @@ func (s *ImplementationSuite) TestAdvance() { }, } largeInput := make([]byte, 10) - _, _, _, _, _, _, err = machine4.Advance(ctx, largeInput, false) + _, err = machine4.Advance(ctx, largeInput, false) require.ErrorIs(err, ErrPayloadLengthLimitExceeded) mockBackend4.AssertExpectations(s.T()) @@ -311,7 +311,7 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - _, _, _, _, _, _, err = machine5.Advance(ctx, input, false) + _, err = machine5.Advance(ctx, input, false) require.Error(err) require.ErrorIs(err, ErrHashLength) mockBackend5.AssertExpectations(s.T()) diff --git a/pkg/machine/machine.go b/pkg/machine/machine.go index d2b82a4f9..fe017bddb 100644 --- a/pkg/machine/machine.go +++ b/pkg/machine/machine.go @@ -27,6 +27,16 @@ type ( Hash = [HashSize]byte ) +// AdvanceResponse contains the result of an advance operation. +type AdvanceResponse struct { + Accepted bool + Outputs []Output + Reports []Report + Hashes []Hash + RemainingCycles uint64 + OutputsHash Hash +} + // Common errors var ( ErrMachineInternal = errors.New("machine internal error") @@ -60,10 +70,11 @@ type Machine interface { WriteCheckpointHash(ctx context.Context, hash Hash) error // Advance sends an input to the machine. - // It returns a boolean indicating whether or not the request was accepted. - // It also returns the corresponding outputs, reports, and the hash of the outputs. - // In case the request is not accepted, the function does not return outputs. - Advance(ctx context.Context, input []byte, computeHashes bool) (bool, []Output, []Report, []Hash, uint64, Hash, error) + // It always returns a non-nil AdvanceResponse, even on error paths. + // The response contains whether the request was accepted, + // the corresponding outputs, reports, and the hash of the outputs. + // In case the request is not accepted, the response does not contain outputs. + Advance(ctx context.Context, input []byte, computeHashes bool) (*AdvanceResponse, error) // Inspect sends a query to the machine. // It returns a boolean indicating whether or not the request was accepted diff --git a/pkg/machine/machine_test.go b/pkg/machine/machine_test.go index 868acd8cc..91ad8d255 100644 --- a/pkg/machine/machine_test.go +++ b/pkg/machine/machine_test.go @@ -214,15 +214,15 @@ func (s *MachineSuite) TestMachineInterface() { require.Equal(Hash{6, 7, 8, 9, 10}, outputsHash) // Test Advance - accepted, outputs, reports, _, _, advanceHash, err := machine.Advance(ctx, []byte("input"), false) + advanceResp, err := machine.Advance(ctx, []byte("input"), false) require.NoError(err) - require.True(accepted) - require.Len(outputs, 2) - require.Equal([]byte("output1"), outputs[0]) - require.Equal([]byte("output2"), outputs[1]) - require.Len(reports, 1) - require.Equal([]byte("report1"), reports[0]) - require.Equal(Hash{11, 12, 13, 14, 15}, advanceHash) + require.True(advanceResp.Accepted) + require.Len(advanceResp.Outputs, 2) + require.Equal([]byte("output1"), advanceResp.Outputs[0]) + require.Equal([]byte("output2"), advanceResp.Outputs[1]) + require.Len(advanceResp.Reports, 1) + require.Equal([]byte("report1"), advanceResp.Reports[0]) + require.Equal(Hash{11, 12, 13, 14, 15}, advanceResp.OutputsHash) // Test Inspect accepted, inspectReports, err := machine.Inspect(ctx, []byte("query")) @@ -279,7 +279,7 @@ func (s *MachineSuite) TestMachineInterfaceErrors() { require.Contains(err.Error(), "outputs hash error") // Test Advance error - _, _, _, _, _, _, err = machine.Advance(ctx, []byte("input"), false) + _, err = machine.Advance(ctx, []byte("input"), false) require.Error(err) require.Contains(err.Error(), "advance error") @@ -354,16 +354,15 @@ func (m *MockMachine) WriteCheckpointHash(_ context.Context, _ Hash) error { return m.CheckpointHashError } -func (m *MockMachine) Advance(_ context.Context, _ []byte, _ bool) ( - bool, []Output, []Report, []Hash, uint64, Hash, error, -) { - return m.AdvanceAcceptedReturn, - m.AdvanceOutputsReturn, - m.AdvanceReportsReturn, - m.AdvanceHashesReturn, - m.AdvanceRemainingReturn, - m.AdvanceHashReturn, - m.AdvanceError +func (m *MockMachine) Advance(_ context.Context, _ []byte, _ bool) (*AdvanceResponse, error) { + return &AdvanceResponse{ + Accepted: m.AdvanceAcceptedReturn, + Outputs: m.AdvanceOutputsReturn, + Reports: m.AdvanceReportsReturn, + Hashes: m.AdvanceHashesReturn, + RemainingCycles: m.AdvanceRemainingReturn, + OutputsHash: m.AdvanceHashReturn, + }, m.AdvanceError } func (m *MockMachine) Inspect(_ context.Context, From 885e62190e93f193ced9e113dc28d5532b52f713 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:10:37 -0300 Subject: [PATCH 07/17] fix(manager): prevent machine instance leak if addMachine fails --- internal/manager/manager.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 8662bf33b..ff4ff13e0 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -142,12 +142,20 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { m.logger.Error("Failed to synchronize machine", "application", app.IApplicationAddress, "error", err) - instance.Close() + if err := instance.Close(); err != nil { + m.logger.Warn("Failed to close machine after synchronization failure", + "application", app.Name, "error", err) + } continue } - // Add the machine to the manager - m.addMachine(app.ID, instance) + // Add the machine to the manager; close if it fails + if !m.addMachine(app.ID, instance) { + if err := instance.Close(); err != nil { + m.logger.Warn("Failed to close duplicate machine instance", + "application", app.Name, "error", err) + } + } } // Remove machines for disabled applications @@ -203,7 +211,10 @@ func (m *MachineManager) removeMachines(apps []*Application) { m.logger.Info("Application was disabled, shutting down machine", "application", machine.Application().Name) } - machine.Close() + if err := machine.Close(); err != nil { + m.logger.Warn("Failed to close machine for disabled application", + "application", machine.Application().Name, "error", err) + } delete(m.machines, id) } } From a625739efaaa6a3aabeffbb5e30d4fa911af24e3 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:24:10 -0300 Subject: [PATCH 08/17] fix(advancer): bound memory usage with batched input fetching --- internal/advancer/advancer.go | 44 +++++++++++++------ internal/advancer/advancer_test.go | 63 ++++++++++++++++++++------ internal/advancer/service.go | 7 +++ internal/config/generate/Config.toml | 8 ++++ internal/config/generated.go | 39 ++++++++++++++++ internal/inspect/inspect_test.go | 2 +- internal/manager/instance.go | 4 +- internal/manager/instance_test.go | 66 +++++++++++++++++----------- internal/manager/manager.go | 28 ++++++------ internal/manager/manager_test.go | 18 ++++---- internal/manager/types.go | 2 +- 11 files changed, 200 insertions(+), 81 deletions(-) diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 134fc56de..110cf70aa 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -45,10 +45,16 @@ func getUnprocessedEpochs(ctx context.Context, er AdvancerRepository, address st return er.ListEpochs(ctx, address, f, repository.Pagination{}, false) } -// getUnprocessedInputs retrieves inputs that haven't been processed yet -func getUnprocessedInputs(ctx context.Context, repo AdvancerRepository, appAddress string, epochIndex uint64) ([]*Input, uint64, error) { +// getUnprocessedInputs retrieves inputs that haven't been processed yet with pagination support. +func getUnprocessedInputs( + ctx context.Context, + repo AdvancerRepository, + appAddress string, + epochIndex uint64, + batchSize uint64, +) ([]*Input, uint64, error) { f := repository.InputFilter{Status: Pointer(InputCompletionStatus_None), EpochIndex: &epochIndex} - return repo.ListInputs(ctx, appAddress, f, repository.Pagination{}, false) + return repo.ListInputs(ctx, appAddress, f, repository.Pagination{Limit: batchSize}, false) } // Step performs one processing cycle of the advancer @@ -78,17 +84,7 @@ func (s *Service) Step(ctx context.Context) error { } for _, epoch := range epochs { - // Get unprocessed inputs for this application - s.Logger.Debug("Querying for unprocessed inputs", "application", app.Name, "epoch_index", epoch.Index) - inputs, _, err := getUnprocessedInputs(ctx, s.repository, appAddress, epoch.Index) - if err != nil { - return err - } - - // Process the inputs - s.Logger.Debug("Processing inputs", "application", app.Name, "epoch_index", epoch.Index, "count", len(inputs)) - err = s.processInputs(ctx, app, inputs) - if err != nil { + if err := s.processEpochInputs(ctx, app, epoch.Index); err != nil { return err } @@ -117,6 +113,26 @@ func (s *Service) Step(ctx context.Context) error { return nil } +// processEpochInputs fetches and processes unprocessed inputs for an epoch in batches. +// Processed inputs change status and drop out of the filter, so each batch fetches from offset 0. +func (s *Service) processEpochInputs(ctx context.Context, app *Application, epochIndex uint64) error { + appAddress := app.IApplicationAddress.String() + for { + inputs, _, err := getUnprocessedInputs(ctx, s.repository, appAddress, epochIndex, s.inputBatchSize) + if err != nil { + return err + } + if len(inputs) == 0 { + return nil + } + s.Logger.Debug("Processing inputs", + "application", app.Name, "epoch_index", epochIndex, "count", len(inputs)) + if err := s.processInputs(ctx, app, inputs); err != nil { + return err + } + } +} + func (s *Service) isAllEpochInputsProcessed(app *Application, epoch *Epoch) (bool, error) { // epoch has no inputs if epoch.InputIndexLowerBound == epoch.InputIndexUpperBound { diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index 186da3364..c879d5172 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -31,6 +31,7 @@ type AdvancerSuite struct{ suite.Suite } func newMockAdvancerService(machineManager *MockMachineManager, repo *MockRepository) (*Service, error) { s := &Service{ + inputBatchSize: 500, machineManager: machineManager, repository: repo, } @@ -84,9 +85,9 @@ func (s *AdvancerSuite) TestStep() { app2 := newMockMachine(2) machineManager.Map[1] = *app1 machineManager.Map[2] = *app2 + res0 := randomAdvanceResult(0) res1 := randomAdvanceResult(1) - res2 := randomAdvanceResult(2) - res3 := randomAdvanceResult(3) + res2 := randomAdvanceResult(0) repository := &MockRepository{ GetEpochsReturn: map[common.Address][]*Epoch{ @@ -99,11 +100,11 @@ func (s *AdvancerSuite) TestStep() { }, GetInputsReturn: map[common.Address][]*Input{ app1.Application.IApplicationAddress: { - newInput(app1.Application.ID, 0, 0, marshal(res1)), - newInput(app1.Application.ID, 0, 1, marshal(res2)), + newInput(app1.Application.ID, 0, 0, marshal(res0)), + newInput(app1.Application.ID, 0, 1, marshal(res1)), }, app2.Application.IApplicationAddress: { - newInput(app2.Application.ID, 0, 0, marshal(res3)), + newInput(app2.Application.ID, 0, 0, marshal(res2)), }, }, } @@ -124,7 +125,7 @@ func (s *AdvancerSuite) TestStep() { machineManager := newMockMachineManager() app1 := newMockMachine(1) machineManager.Map[1] = *app1 - res1 := randomAdvanceResult(1) + res0 := randomAdvanceResult(0) repository := &MockRepository{ GetEpochsReturn: map[common.Address][]*Epoch{ @@ -134,7 +135,7 @@ func (s *AdvancerSuite) TestStep() { }, GetInputsReturn: map[common.Address][]*Input{ app1.Application.IApplicationAddress: { - newInput(app1.Application.ID, 0, 0, marshal(res1)), + newInput(app1.Application.ID, 0, 0, marshal(res0)), }, }, UpdateEpochsError: errors.New("update epochs error"), @@ -225,13 +226,14 @@ func (s *AdvancerSuite) TestGetUnprocessedInputs() { newInput(app1.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), } - repository := &MockRepository{ + repo := &MockRepository{ GetInputsReturn: map[common.Address][]*Input{ app1.Application.IApplicationAddress: inputs, }, } - result, count, err := getUnprocessedInputs(context.Background(), repository, app1.Application.IApplicationAddress.String(), 0) + result, count, err := getUnprocessedInputs( + context.Background(), repo, app1.Application.IApplicationAddress.String(), 0, 500) require.Nil(err) require.Equal(uint64(2), count) require.Equal(inputs, result) @@ -241,11 +243,12 @@ func (s *AdvancerSuite) TestGetUnprocessedInputs() { require := s.Require() app1 := newMockMachine(1) - repository := &MockRepository{ + repo := &MockRepository{ GetInputsError: errors.New("list inputs error"), } - _, _, err := getUnprocessedInputs(context.Background(), repository, app1.Application.IApplicationAddress.String(), 0) + _, _, err := getUnprocessedInputs( + context.Background(), repo, app1.Application.IApplicationAddress.String(), 0, 500) require.Error(err) require.Contains(err.Error(), "list inputs error") }) @@ -653,7 +656,7 @@ func (m *MockMachineInstance) OutputsProof(ctx context.Context, processedInputs } // Synchronize implements the MachineInstance interface for testing -func (m *MockMachineInstance) Synchronize(ctx context.Context, repo manager.MachineRepository) error { +func (m *MockMachineInstance) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { // Not used in advancer tests, but needed to satisfy the interface return nil } @@ -743,7 +746,23 @@ func (mock *MockRepository) ListInputs( } address := common.HexToAddress(nameOrAddress) - return mock.GetInputsReturn[address], uint64(len(mock.GetInputsReturn[address])), mock.GetInputsError + inputs := mock.GetInputsReturn[address] + total := uint64(len(inputs)) + + // Apply pagination if a limit is set + if p.Limit > 0 { + start := p.Offset + if start >= total { + return nil, total, mock.GetInputsError + } + end := start + p.Limit + if end > total { + end = total + } + inputs = inputs[start:end] + } + + return inputs, total, mock.GetInputsError } func (mock *MockRepository) StoreAdvanceResult( @@ -767,6 +786,24 @@ func (mock *MockRepository) StoreAdvanceResult( } mock.StoredResults = append(mock.StoredResults, res) + + // Simulate real behavior: processed inputs change status and are no longer + // returned by queries filtering for unprocessed (Status_None) inputs. + // This prevents infinite loops in batched fetching. + if mock.StoreAdvanceError == nil { + for addr, inputs := range mock.GetInputsReturn { + for i, inp := range inputs { + if inp.EpochApplicationID == appID && inp.Index == res.InputIndex { + newInputs := make([]*Input, 0, len(inputs)-1) + newInputs = append(newInputs, inputs[:i]...) + newInputs = append(newInputs, inputs[i+1:]...) + mock.GetInputsReturn[addr] = newInputs + break + } + } + } + } + return mock.StoreAdvanceError } diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 5574ba36e..0fa110135 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -24,6 +24,7 @@ const httpShutdownTimeout = 10 * time.Second //nolint: mnd // Service is the main advancer service that processes inputs through Cartesi machines type Service struct { service.Service + inputBatchSize uint64 snapshotsDir string repository AdvancerRepository machineManager manager.MachineProvider @@ -59,12 +60,18 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { return nil, fmt.Errorf("repository on advancer service Create is nil") } + s.inputBatchSize = c.Config.AdvancerInputBatchSize + if s.inputBatchSize == 0 { + return nil, fmt.Errorf("advancer input batch size must be greater than 0") + } + // Create the machine manager manager := manager.NewMachineManager( ctx, c.Repository, s.Logger, c.Config.FeatureMachineHashCheckEnabled, + s.inputBatchSize, ) s.machineManager = manager diff --git a/internal/config/generate/Config.toml b/internal/config/generate/Config.toml index 3233aedbc..daccd1e09 100644 --- a/internal/config/generate/Config.toml +++ b/internal/config/generate/Config.toml @@ -61,6 +61,14 @@ used-by = ["advancer", "node", "cli"] # Rollups # +[rollups.CARTESI_ADVANCER_INPUT_BATCH_SIZE] +default = "500" +go-type = "uint64" +description = """ +Maximum number of inputs fetched per database query during advance processing and machine +synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs.""" +used-by = ["advancer", "node"] + [rollups.CARTESI_ADVANCER_POLLING_INTERVAL] default = "3" go-type = "Duration" diff --git a/internal/config/generated.go b/internal/config/generated.go index 1fb1db0f7..2ac966345 100644 --- a/internal/config/generated.go +++ b/internal/config/generated.go @@ -53,6 +53,7 @@ const ( LOG_COLOR = "CARTESI_LOG_COLOR" LOG_LEVEL = "CARTESI_LOG_LEVEL" JSONRPC_MACHINE_LOG_LEVEL = "CARTESI_JSONRPC_MACHINE_LOG_LEVEL" + ADVANCER_INPUT_BATCH_SIZE = "CARTESI_ADVANCER_INPUT_BATCH_SIZE" ADVANCER_POLLING_INTERVAL = "CARTESI_ADVANCER_POLLING_INTERVAL" BLOCKCHAIN_HTTP_MAX_RETRIES = "CARTESI_BLOCKCHAIN_HTTP_MAX_RETRIES" BLOCKCHAIN_HTTP_RETRY_MAX_WAIT = "CARTESI_BLOCKCHAIN_HTTP_RETRY_MAX_WAIT" @@ -143,6 +144,8 @@ func SetDefaults() { viper.SetDefault(JSONRPC_MACHINE_LOG_LEVEL, "info") + viper.SetDefault(ADVANCER_INPUT_BATCH_SIZE, "500") + viper.SetDefault(ADVANCER_POLLING_INTERVAL, "3") viper.SetDefault(BLOCKCHAIN_HTTP_MAX_RETRIES, "4") @@ -201,6 +204,10 @@ type AdvancerConfig struct { // One of "trace", "debug", "info", "warning", "error", "fatal". JsonrpcMachineLogLevel string `mapstructure:"CARTESI_JSONRPC_MACHINE_LOG_LEVEL"` + // Maximum number of inputs fetched per database query during advance processing and machine + // synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs. + AdvancerInputBatchSize uint64 `mapstructure:"CARTESI_ADVANCER_INPUT_BATCH_SIZE"` + // How many seconds the node will wait before querying the database for new inputs. AdvancerPollingInterval Duration `mapstructure:"CARTESI_ADVANCER_POLLING_INTERVAL"` @@ -283,6 +290,13 @@ func LoadAdvancerConfig() (*AdvancerConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_MACHINE_LOG_LEVEL is required for the advancer service: %w", err) } + cfg.AdvancerInputBatchSize, err = GetAdvancerInputBatchSize() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_INPUT_BATCH_SIZE: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_ADVANCER_INPUT_BATCH_SIZE is required for the advancer service: %w", err) + } + cfg.AdvancerPollingInterval, err = GetAdvancerPollingInterval() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_POLLING_INTERVAL: %w", err) @@ -815,6 +829,10 @@ type NodeConfig struct { // One of "trace", "debug", "info", "warning", "error", "fatal". JsonrpcMachineLogLevel string `mapstructure:"CARTESI_JSONRPC_MACHINE_LOG_LEVEL"` + // Maximum number of inputs fetched per database query during advance processing and machine + // synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs. + AdvancerInputBatchSize uint64 `mapstructure:"CARTESI_ADVANCER_INPUT_BATCH_SIZE"` + // How many seconds the node will wait before querying the database for new inputs. AdvancerPollingInterval Duration `mapstructure:"CARTESI_ADVANCER_POLLING_INTERVAL"` @@ -981,6 +999,13 @@ func LoadNodeConfig() (*NodeConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_MACHINE_LOG_LEVEL is required for the node service: %w", err) } + cfg.AdvancerInputBatchSize, err = GetAdvancerInputBatchSize() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_INPUT_BATCH_SIZE: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_ADVANCER_INPUT_BATCH_SIZE is required for the node service: %w", err) + } + cfg.AdvancerPollingInterval, err = GetAdvancerPollingInterval() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_POLLING_INTERVAL: %w", err) @@ -1347,6 +1372,7 @@ func (c *NodeConfig) ToAdvancerConfig() *AdvancerConfig { LogColor: c.LogColor, LogLevel: c.LogLevel, JsonrpcMachineLogLevel: c.JsonrpcMachineLogLevel, + AdvancerInputBatchSize: c.AdvancerInputBatchSize, AdvancerPollingInterval: c.AdvancerPollingInterval, MaxStartupTime: c.MaxStartupTime, SnapshotsDir: c.SnapshotsDir, @@ -1891,6 +1917,19 @@ func GetJsonrpcMachineLogLevel() (string, error) { return notDefinedstring(), fmt.Errorf("%s: %w", JSONRPC_MACHINE_LOG_LEVEL, ErrNotDefined) } +// GetAdvancerInputBatchSize returns the value for the environment variable CARTESI_ADVANCER_INPUT_BATCH_SIZE. +func GetAdvancerInputBatchSize() (uint64, error) { + s := viper.GetString(ADVANCER_INPUT_BATCH_SIZE) + if s != "" { + v, err := toUint64(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", ADVANCER_INPUT_BATCH_SIZE, err) + } + return v, nil + } + return notDefineduint64(), fmt.Errorf("%s: %w", ADVANCER_INPUT_BATCH_SIZE, ErrNotDefined) +} + // GetAdvancerPollingInterval returns the value for the environment variable CARTESI_ADVANCER_POLLING_INTERVAL. func GetAdvancerPollingInterval() (Duration, error) { s := viper.GetString(ADVANCER_POLLING_INTERVAL) diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 983c951c4..68aa4f80a 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -238,7 +238,7 @@ func (m *MockMachine) OutputsProof(ctx context.Context, processedInputs uint64) return nil, nil } -func (mock *MockMachine) Synchronize(ctx context.Context, repo manager.MachineRepository) error { +func (mock *MockMachine) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { // Not used in inspect tests, but needed to satisfy the interface return nil } diff --git a/internal/manager/instance.go b/internal/manager/instance.go index 0dba20be4..af679c423 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -163,7 +163,7 @@ func (m *MachineInstanceImpl) ProcessedInputs() uint64 { // It handles both template-based instances (processedInputs == 0, replays all) // and snapshot-based instances (processedInputs > 0, replays only remaining). // Inputs are fetched in batches to bound memory usage. -func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository) error { +func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository, batchSize uint64) error { appAddress := m.application.IApplicationAddress.String() m.logger.Info("Synchronizing machine with processed inputs", "address", appAddress, @@ -176,7 +176,7 @@ func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepos for { p := repository.Pagination{ - Limit: inputBatchSize, + Limit: batchSize, Offset: initialProcessedInputs + replayed, } inputs, totalCount, err := getProcessedInputs(ctx, repo, appAddress, p) diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 92a8bda81..0f51fc412 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -816,13 +816,22 @@ func (r *mockSyncRepository) GetLastSnapshot( return nil, nil } +// newForkableMock creates a mock where Fork returns a fresh mock each time, +// properly exercising the fork/replace lifecycle in Synchronize tests. +func newForkableMock() *MockRollupsMachine { + m := &MockRollupsMachine{} + m.CloseError = nil + m.AdvanceAcceptedReturn = true + m.HashReturn = newHash(1) + m.OutputsHashReturn = newHash(2) + m.ForkFunc = func(_ context.Context) (machine.Machine, error) { + return newForkableMock(), nil + } + return m +} + func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcessedInputs uint64) *MachineInstanceImpl { - runtime := &MockRollupsMachine{} - runtime.ForkReturn = runtime // self-fork for replay - runtime.CloseError = nil - runtime.AdvanceAcceptedReturn = true - runtime.HashReturn = newHash(1) - runtime.OutputsHashReturn = newHash(2) + runtime := newForkableMock() return &MachineInstanceImpl{ application: &model.Application{ @@ -860,14 +869,17 @@ func (s *MachineInstanceSuite) TestSynchronize() { s.Run("TemplateSyncAllInputs", func() { require := s.Require() inst := s.newSyncMachine(0, 3) + originalRuntime := inst.runtime repo := &mockSyncRepository{ inputs: makeInputs(0, 3), totalCount: 3, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) require.Equal(uint64(3), inst.processedInputs) + // Verify the runtime was actually replaced (not self-fork) + require.NotSame(originalRuntime, inst.runtime) }) s.Run("SnapshotSyncRemainingInputs", func() { @@ -879,7 +891,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { totalCount: 5, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) require.Equal(uint64(5), inst.processedInputs) }) @@ -892,7 +904,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { totalCount: 0, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) require.Equal(uint64(0), inst.processedInputs) }) @@ -906,7 +918,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { totalCount: 5, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) require.Equal(uint64(5), inst.processedInputs) }) @@ -919,7 +931,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { totalCount: 3, // DB says 3 but app expects 5 } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.Error(err) require.ErrorIs(err, ErrMachineSynchronization) require.Contains(err.Error(), "count mismatch") @@ -933,7 +945,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { listErr: listErr, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.Error(err) require.ErrorIs(err, ErrMachineSynchronization) require.Contains(err.Error(), "database connection lost") @@ -942,18 +954,20 @@ func (s *MachineInstanceSuite) TestSynchronize() { s.Run("AdvanceErrorMidReplay", func() { require := s.Require() inst := s.newSyncMachine(0, 3) - // Make the machine return a hard error on Advance. - // The self-forking mock is stateless, so the error applies - // to the very first input. + // Make each fork return a hard error on Advance. runtime := inst.runtime.(*MockRollupsMachine) - runtime.AdvanceError = errors.New("machine exploded") + runtime.ForkFunc = func(_ context.Context) (machine.Machine, error) { + fork := newForkableMock() + fork.AdvanceError = errors.New("machine exploded") + return fork, nil + } repo := &mockSyncRepository{ inputs: makeInputs(0, 3), totalCount: 3, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 1000) require.Error(err) require.ErrorIs(err, ErrMachineSynchronization) require.Contains(err.Error(), "failed to replay input") @@ -961,19 +975,15 @@ func (s *MachineInstanceSuite) TestSynchronize() { s.Run("BatchBoundaryCrossing", func() { require := s.Require() - // Create more inputs than the batch size to test pagination. - // We use a small batch size by creating an instance and repo - // with enough inputs to cross the default 1000-input batch. - // Instead, we'll test with a practical scenario: 3 inputs - // with batch size effectively 2 by making our mock return - // only 2 inputs per call based on pagination. + // Use batchSize=2 with 3 inputs so the loop must fetch two batches + // (batch 1: inputs 0-1, batch 2: input 2), exercising pagination. inst := s.newSyncMachine(0, 3) repo := &mockSyncRepository{ inputs: makeInputs(0, 3), totalCount: 3, } - err := inst.Synchronize(context.Background(), repo) + err := inst.Synchronize(context.Background(), repo, 2) require.NoError(err) require.Equal(uint64(3), inst.processedInputs) }) @@ -990,7 +1000,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { totalCount: 3, } - err := inst.Synchronize(ctx, repo) + err := inst.Synchronize(ctx, repo, 1000) require.Error(err) }) } @@ -999,6 +1009,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { type MockRollupsMachine struct { ForkReturn machine.Machine + ForkFunc func(context.Context) (machine.Machine, error) ForkError error HashReturn machine.Hash @@ -1026,7 +1037,10 @@ type MockRollupsMachine struct { CloseError error } -func (m *MockRollupsMachine) Fork(_ context.Context) (machine.Machine, error) { +func (m *MockRollupsMachine) Fork(ctx context.Context) (machine.Machine, error) { + if m.ForkFunc != nil { + return m.ForkFunc(ctx) + } return m.ForkReturn, m.ForkError } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index ff4ff13e0..0123ab8ff 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -21,11 +21,6 @@ var ( ErrMachineSynchronization = errors.New("failed to synchronize machine") ) -// inputBatchSize is the maximum number of inputs fetched per database query -// during synchronization and snapshot replay. This bounds memory usage for -// applications with large numbers of processed inputs. -const inputBatchSize uint64 = 1000 - // MachineRepository defines the repository interface needed by the MachineManager type MachineRepository interface { // ListApplications retrieves applications based on filter criteria @@ -40,11 +35,12 @@ type MachineRepository interface { // MachineManager manages the lifecycle of machine instances for applications type MachineManager struct { - mutex sync.RWMutex - machines map[int64]MachineInstance - repository MachineRepository - checkHash bool - logger *slog.Logger + mutex sync.RWMutex + machines map[int64]MachineInstance + repository MachineRepository + checkHash bool + inputBatchSize uint64 + logger *slog.Logger } // NewMachineManager creates a new machine manager @@ -53,12 +49,14 @@ func NewMachineManager( repo MachineRepository, logger *slog.Logger, checkHash bool, + inputBatchSize uint64, ) *MachineManager { return &MachineManager{ - machines: map[int64]MachineInstance{}, - repository: repo, - checkHash: checkHash, - logger: logger, + machines: map[int64]MachineInstance{}, + repository: repo, + checkHash: checkHash, + inputBatchSize: inputBatchSize, + logger: logger, } } @@ -137,7 +135,7 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { // For template instances (processedInputs=0) this replays all inputs. // For snapshot instances (processedInputs=snapshotIndex+1) this replays // only inputs after the snapshot. - err = instance.Synchronize(ctx, m.repository) + err = instance.Synchronize(ctx, m.repository, m.inputBatchSize) if err != nil { m.logger.Error("Failed to synchronize machine", "application", app.IApplicationAddress, diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index f5ef13e3e..ab5ecddb6 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -28,7 +28,7 @@ func (s *MachineManagerSuite) TestNewMachineManager() { require := s.Require() repo := &MockMachineRepository{} testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) require.NotNil(manager) require.Empty(manager.machines) require.Equal(repo, manager.repository) @@ -65,7 +65,7 @@ func (s *MachineManagerSuite) TestUpdateMachines() { // Create manager with a test logger testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) // Create a mock factory for testing mockRuntime := &MockRollupsMachine{} @@ -96,7 +96,7 @@ func (s *MachineManagerSuite) TestUpdateMachines() { // Create a test logger testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) // Add mock machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -129,7 +129,7 @@ func (s *MachineManagerSuite) TestGetMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(context.Background(), repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -152,7 +152,7 @@ func (s *MachineManagerSuite) TestHasMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(context.Background(), repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -172,7 +172,7 @@ func (s *MachineManagerSuite) TestAddMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(context.Background(), repo, nil, false, 500) machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} machine2 := &DummyMachineInstanceMock{application: &model.Application{ID: 2}} @@ -195,7 +195,7 @@ func (s *MachineManagerSuite) TestAddMachine() { func (s *MachineManagerSuite) TestRemoveDisabledMachines() { require := s.Require() - manager := NewMachineManager(context.Background(), nil, nil, false) + manager := NewMachineManager(context.Background(), nil, nil, false, 500) // Add machines app1 := &model.Application{ID: 1} @@ -227,7 +227,7 @@ func (s *MachineManagerSuite) TestApplications() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(context.Background(), repo, nil, false, 500) // Add machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -318,7 +318,7 @@ func (m *DummyMachineInstanceMock) Inspect(_ context.Context, _ []byte) (*model. return nil, nil } -func (m *DummyMachineInstanceMock) Synchronize(_ context.Context, _ MachineRepository) error { +func (m *DummyMachineInstanceMock) Synchronize(_ context.Context, _ MachineRepository, _ uint64) error { return nil } diff --git a/internal/manager/types.go b/internal/manager/types.go index dee93beed..c716abbdc 100644 --- a/internal/manager/types.go +++ b/internal/manager/types.go @@ -14,7 +14,7 @@ type MachineInstance interface { Application() *Application Advance(ctx context.Context, input []byte, epochIndex uint64, inputIndex uint64, computeHashes bool) (*AdvanceResult, error) Inspect(ctx context.Context, query []byte) (*InspectResult, error) - Synchronize(ctx context.Context, repo MachineRepository) error + Synchronize(ctx context.Context, repo MachineRepository, batchSize uint64) error CreateSnapshot(ctx context.Context, processedInputs uint64, path string) error ProcessedInputs() uint64 Hash(ctx context.Context) ([32]byte, error) From a031a93c9baa1a64734d0c4eb5c4c64250b289d8 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:37:26 -0300 Subject: [PATCH 09/17] fix(manager): prevent fork leak and inconsistent state on machine close failure --- internal/manager/instance.go | 58 ++++++++++++++++--------- internal/manager/instance_test.go | 71 +++++++++++++++++++------------ 2 files changed, 80 insertions(+), 49 deletions(-) diff --git a/internal/manager/instance.go b/internal/manager/instance.go index af679c423..e95f980f3 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -10,6 +10,7 @@ import ( "fmt" "log/slog" "sync" + "sync/atomic" "time" "github.com/cartesi/rollups-node/internal/manager/pmutex" @@ -46,8 +47,10 @@ type MachineInstanceImpl struct { application *Application runtime machine.Machine - // How many inputs were processed by the machine - processedInputs uint64 + // How many inputs were processed by the machine. + // Written under HLock (together with runtime swap — the two MUST be updated + // atomically from the perspective of readers). Read without locks via Load(). + processedInputs atomic.Uint64 // Timeouts for operations advanceTimeout time.Duration @@ -138,7 +141,6 @@ func NewMachineInstanceWithFactory( instance := &MachineInstanceImpl{ application: app, runtime: runtime, - processedInputs: processedInputs, advanceTimeout: app.ExecutionParameters.AdvanceMaxDeadline, inspectTimeout: app.ExecutionParameters.InspectMaxDeadline, maxConcurrentInspects: app.ExecutionParameters.MaxConcurrentInspects, @@ -147,6 +149,7 @@ func NewMachineInstanceWithFactory( runtimeFactory: factory, logger: logger.With("application", app.Name), } + instance.processedInputs.Store(processedInputs) return instance, nil } @@ -156,7 +159,7 @@ func (m *MachineInstanceImpl) Application() *Application { } func (m *MachineInstanceImpl) ProcessedInputs() uint64 { - return m.processedInputs + return m.processedInputs.Load() } // Synchronize brings the machine up to date with processed inputs. @@ -165,12 +168,13 @@ func (m *MachineInstanceImpl) ProcessedInputs() uint64 { // Inputs are fetched in batches to bound memory usage. func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository, batchSize uint64) error { appAddress := m.application.IApplicationAddress.String() + currentProcessed := m.processedInputs.Load() m.logger.Info("Synchronizing machine with processed inputs", "address", appAddress, "app_processed_inputs", m.application.ProcessedInputs, - "machine_processed_inputs", m.processedInputs) + "machine_processed_inputs", currentProcessed) - initialProcessedInputs := m.processedInputs + initialProcessedInputs := currentProcessed replayed := uint64(0) toReplay := uint64(0) @@ -193,7 +197,12 @@ func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepos m.logger.Error(errorMsg, "address", appAddress) return fmt.Errorf("%w: %s", ErrMachineSynchronization, errorMsg) } - toReplay = totalCount - m.processedInputs + if currentProcessed > totalCount { + return fmt.Errorf( + "%w: machine has processed %d inputs but DB only has %d", + ErrMachineSynchronization, currentProcessed, totalCount) + } + toReplay = totalCount - currentProcessed if toReplay == 0 { m.logger.Info("No inputs to replay during synchronization", "address", appAddress) @@ -235,8 +244,10 @@ func (m *MachineInstanceImpl) forkForAdvance(ctx context.Context, index uint64) } // Verify input index - if m.processedInputs != index { - return nil, fmt.Errorf("%w: processed inputs is %d and index is %d", ErrInvalidInputIndex, m.processedInputs, index) + current := m.processedInputs.Load() + if current != index { + return nil, fmt.Errorf("%w: processed inputs is %d and index is %d", + ErrInvalidInputIndex, current, index) } // Fork the machine @@ -320,13 +331,14 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn // Replace the current machine with the fork m.mutex.HLock() - if err = m.runtime.Close(); err != nil { - m.mutex.Unlock() - return nil, err - } + oldRuntime := m.runtime m.runtime = fork - m.processedInputs++ + m.processedInputs.Add(1) m.mutex.Unlock() + + if err := oldRuntime.Close(); err != nil { + m.logger.Warn("Failed to close old machine runtime", "error", err) + } } else { // Use the previous state for rejected inputs result.MachineHash = prevMachineHash @@ -334,15 +346,17 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn result.OutputsHashProof = prevOutputsHashProof // Close the fork since we're not using it - err = fork.Close() + if err := fork.Close(); err != nil { + m.logger.Warn("Failed to close fork machine runtime", "error", err) + } // Update the processed inputs counter m.mutex.HLock() - m.processedInputs++ + m.processedInputs.Add(1) m.mutex.Unlock() } - return result, err + return result, nil } // forkForInspect creates a copy of the machine for inspect operations @@ -361,7 +375,7 @@ func (m *MachineInstanceImpl) forkForInspect(ctx context.Context) (machine.Machi return nil, 0, err } - return fork, m.processedInputs, nil + return fork, m.processedInputs.Load(), nil } // Inspect queries the machine state without modifying it @@ -423,11 +437,13 @@ func (m *MachineInstanceImpl) CreateSnapshot(ctx context.Context, processedInput } // Verify processed inputs - if m.processedInputs != processedInputs { - return fmt.Errorf("%w: machine processed inputs is %d and expected is %d", ErrInvalidSnapshotPoint, m.processedInputs, processedInputs) + current := m.processedInputs.Load() + if current != processedInputs { + return fmt.Errorf("%w: machine processed inputs is %d and expected is %d", + ErrInvalidSnapshotPoint, current, processedInputs) } - m.logger.Debug("Creating machine snapshot", "path", path, "processed_inputs", m.processedInputs) + m.logger.Debug("Creating machine snapshot", "path", path, "processed_inputs", current) // Create a context with a timeout for the store operation storeCtx, cancel := context.WithTimeout(ctx, m.application.ExecutionParameters.StoreDeadline) diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 0f51fc412..f93c8a339 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -221,7 +221,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) s.Run("Reject", func() { @@ -240,7 +240,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) testSoftError := func(name string, err error, status model.InputCompletionStatus) { @@ -259,7 +259,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) } @@ -299,7 +299,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Error(err) require.Nil(res) require.Equal(errFork, err) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Advance", func() { @@ -314,7 +314,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Nil(res) require.ErrorIs(err, errAdvance) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("AdvanceAndClose", func() { @@ -332,7 +332,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.ErrorIs(err, errAdvance) require.ErrorIs(err, errClose) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Hash", func() { @@ -347,7 +347,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Nil(res) require.ErrorIs(err, errHash) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("HashAndClose", func() { @@ -365,7 +365,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.ErrorIs(err, errHash) require.ErrorIs(err, errClose) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Close", func() { @@ -375,27 +375,26 @@ func (s *MachineInstanceSuite) TestAdvance() { errClose := errors.New("Close error") inner.CloseError = errClose + // Close error on old runtime is logged, not propagated. + // Advance succeeds and processedInputs is incremented. res, err := machine.Advance(context.Background(), []byte{}, 0, 5, false) - require.Error(err) - require.Nil(res) - require.ErrorIs(err, errClose) - require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.NoError(err) + require.NotNil(res) + require.Equal(uint64(6), machine.processedInputs.Load()) }) s.Run("Fork", func() { require := s.Require() _, fork, machineInst := s.setupAdvance() - errClose := errors.New("Close error") fork.AdvanceError = machine.ErrException - fork.CloseError = errClose + fork.CloseError = errors.New("Close error") + // Close error on fork is logged, not propagated. + // Advance succeeds and processedInputs is incremented. res, err := machineInst.Advance(context.Background(), []byte{}, 0, 5, false) - require.Error(err) + require.NoError(err) require.NotNil(res) - require.ErrorIs(err, errClose) - require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(6), machineInst.processedInputs) + require.Equal(uint64(6), machineInst.processedInputs.Load()) }) }) }) @@ -654,7 +653,6 @@ func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollups machineInst := &MachineInstanceImpl{ application: app, runtime: inner, - processedInputs: 5, advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, @@ -662,6 +660,7 @@ func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollups inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + machineInst.processedInputs.Store(5) fork := &MockRollupsMachine{} @@ -710,7 +709,6 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups machineInst := &MachineInstanceImpl{ application: app, runtime: inner, - processedInputs: 55, advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, @@ -718,6 +716,7 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + machineInst.processedInputs.Store(55) fork := &MockRollupsMachine{} @@ -833,7 +832,7 @@ func newForkableMock() *MockRollupsMachine { func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcessedInputs uint64) *MachineInstanceImpl { runtime := newForkableMock() - return &MachineInstanceImpl{ + inst := &MachineInstanceImpl{ application: &model.Application{ ProcessedInputs: appProcessedInputs, ExecutionParameters: model.ExecutionParameters{ @@ -843,7 +842,6 @@ func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcess }, }, runtime: runtime, - processedInputs: processedInputs, advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, @@ -851,6 +849,8 @@ func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcess inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + inst.processedInputs.Store(processedInputs) + return inst } func makeInputs(startIndex, count uint64) []*model.Input { @@ -877,7 +877,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) - require.Equal(uint64(3), inst.processedInputs) + require.Equal(uint64(3), inst.processedInputs.Load()) // Verify the runtime was actually replaced (not self-fork) require.NotSame(originalRuntime, inst.runtime) }) @@ -893,7 +893,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) - require.Equal(uint64(5), inst.processedInputs) + require.Equal(uint64(5), inst.processedInputs.Load()) }) s.Run("NoInputsToReplay", func() { @@ -906,7 +906,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) - require.Equal(uint64(0), inst.processedInputs) + require.Equal(uint64(0), inst.processedInputs.Load()) }) s.Run("SnapshotAlreadyCaughtUp", func() { @@ -920,7 +920,22 @@ func (s *MachineInstanceSuite) TestSynchronize() { err := inst.Synchronize(context.Background(), repo, 1000) require.NoError(err) - require.Equal(uint64(5), inst.processedInputs) + require.Equal(uint64(5), inst.processedInputs.Load()) + }) + + s.Run("MachineAheadOfDB", func() { + require := s.Require() + // Machine has processed 5 inputs but DB only has 3 + inst := s.newSyncMachine(5, 3) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "machine has processed 5 inputs but DB only has 3") }) s.Run("CountMismatch", func() { @@ -985,7 +1000,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { err := inst.Synchronize(context.Background(), repo, 2) require.NoError(err) - require.Equal(uint64(3), inst.processedInputs) + require.Equal(uint64(3), inst.processedInputs.Load()) }) s.Run("ContextCancellation", func() { From 7351f33995ec9a49719cd0fceb02a664df0b3901 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:09:36 -0300 Subject: [PATCH 10/17] fix(manager): add timeout to Close to prevent indefinite blocking on stuck inspects --- internal/manager/instance.go | 23 +++++++++++++++++--- internal/manager/instance_test.go | 35 +++++++++++++++++++++++++++++++ internal/manager/manager.go | 32 ++++++++++++++++++++-------- 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/internal/manager/instance.go b/internal/manager/instance.go index e95f980f3..1f52013cd 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -62,6 +62,9 @@ type MachineInstanceImpl struct { advanceMutex sync.Mutex inspectSemaphore *semaphore.Weighted + // Timeout for draining in-flight inspects during Close + closeTimeout time.Duration + // Factory for creating machine runtimes runtimeFactory MachineRuntimeFactory @@ -146,6 +149,7 @@ func NewMachineInstanceWithFactory( maxConcurrentInspects: app.ExecutionParameters.MaxConcurrentInspects, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(int64(app.ExecutionParameters.MaxConcurrentInspects)), + closeTimeout: defaultCloseTimeout, runtimeFactory: factory, logger: logger.With("application", app.Name), } @@ -533,17 +537,30 @@ func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs return proof, nil } +// defaultCloseTimeout is how long Close waits for in-flight inspects to drain +// before forcibly closing the runtime. +const defaultCloseTimeout = 30 * time.Second + // Close shuts down the machine instance func (m *MachineInstanceImpl) Close() error { // Acquire all locks to ensure no operations are in progress m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), m.closeTimeout) + defer cancel() + + acquired := 0 for range int(m.maxConcurrentInspects) { - _ = m.inspectSemaphore.Acquire(ctx, 1) - defer m.inspectSemaphore.Release(1) + if err := m.inspectSemaphore.Acquire(ctx, 1); err != nil { + m.logger.Warn("Timed out waiting for in-flight inspects to drain; closing anyway", + "still_in_flight", int(m.maxConcurrentInspects)-acquired, + "drained", acquired) + break + } + acquired++ } + defer m.inspectSemaphore.Release(int64(acquired)) // Close the runtime m.mutex.HLock() diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index f93c8a339..16431cc92 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -618,6 +618,38 @@ func (s *MachineInstanceSuite) TestClose() { require.Fail("Advance did not complete after Close") } }) + + s.Run("TimesOutWaitingForInspects", func() { + require := s.Require() + inner, _, machine := s.setupAdvance() + inner.CloseError = nil + + // Use a short timeout so the test runs fast + machine.closeTimeout = centisecond + + // Pre-acquire all semaphore slots to simulate stuck inspects + for range int(machine.maxConcurrentInspects) { + err := machine.inspectSemaphore.Acquire(context.Background(), 1) + require.Nil(err) + } + + // Close should not block indefinitely — it times out and closes anyway + done := make(chan error, 1) + go func() { + done <- machine.Close() + }() + + select { + case err := <-done: + require.Nil(err) + require.Nil(machine.runtime) + case <-time.After(decisecond * 5): + require.Fail("Close blocked indefinitely despite timeout") + } + + // Release slots to clean up + machine.inspectSemaphore.Release(int64(machine.maxConcurrentInspects)) + }) } // ------------------------------------------------------------------------------------------------ @@ -656,6 +688,7 @@ func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollups advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), @@ -712,6 +745,7 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), @@ -845,6 +879,7 @@ func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcess advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 0123ab8ff..94181372c 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -230,23 +230,37 @@ func (m *MachineManager) Applications() []*Application { return apps } -// Close shuts down all machine instances +// Close shuts down all machine instances in parallel. func (m *MachineManager) Close() error { m.mutex.Lock() defer m.mutex.Unlock() - var errs []error + type closeResult struct { + id int64 + err error + } + + var wg sync.WaitGroup + results := make(chan closeResult, len(m.machines)) + for id, machine := range m.machines { - if err := machine.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close machine for app %d: %w", id, err)) - } - delete(m.machines, id) + wg.Go(func() { + results <- closeResult{id: id, err: machine.Close()} + }) } - if len(errs) > 0 { - return errors.Join(errs...) + wg.Wait() + close(results) + + var errs []error + for r := range results { + if r.err != nil { + errs = append(errs, fmt.Errorf("failed to close machine for app %d: %w", r.id, r.err)) + } } - return nil + clear(m.machines) + + return errors.Join(errs...) } // Helper function to get enabled applications From 6e50820b0174018286fac443e4347f70a0151b6d Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:24:49 -0300 Subject: [PATCH 11/17] fix(advancer): trigger node shutdown on store failure after machine advance --- internal/advancer/advancer.go | 10 +++++- internal/advancer/advancer_test.go | 51 ++++++++++++++++++++++++------ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 110cf70aa..7682d3de0 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -212,10 +212,18 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] // Store the result in the database err = s.repository.StoreAdvanceResult(ctx, input.EpochApplicationID, result) if err != nil { - s.Logger.Error("Failed to store advance result", + // Machine state is now ahead of the database. This desync is + // unrecoverable without a restart — regardless of whether the + // failure was a DB error or a context timeout. Shut down the + // node so it can restart cleanly from the last snapshot. + s.Logger.Error( + "FATAL: failed to store advance result after machine state "+ + "was already updated — shutting down to prevent permanent desync", "application", app.Name, + "epoch", input.EpochIndex, "index", input.Index, "error", err) + s.Cancel() // triggers graceful shutdown of all services return err } diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index c879d5172..e5b678a2a 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -373,6 +373,9 @@ func (s *AdvancerSuite) TestProcess() { require.Error(err) require.Contains(err.Error(), "store-advance error") require.Len(repository.StoredResults, 1) + + // Verify that the node shutdown was triggered (context cancelled) + require.Error(advancer.Context.Err(), "shared context should be cancelled") }) }) } @@ -488,16 +491,17 @@ func (s *AdvancerSuite) TestLargeNumberOfInputs() { }) } -// TestErrorRecovery tests how the advancer recovers from temporary failures +// TestErrorRecovery verifies that any store failure after a successful Advance() +// triggers node shutdown, because the machine and DB are now out of sync. func (s *AdvancerSuite) TestErrorRecovery() { - s.Run("TemporaryRepositoryFailure", func() { + s.Run("TransientStoreFailureTriggersShutdown", func() { require := s.Require() machineManager := newMockMachineManager() app1 := newMockMachine(1) machineManager.Map[1] = *app1 - // Repository that fails on first attempt but succeeds on second + // Repository that fails on the first store attempt repository := &MockRepository{ StoreAdvanceFailCount: 1, } @@ -506,21 +510,50 @@ func (s *AdvancerSuite) TestErrorRecovery() { require.NotNil(advancer) require.Nil(err) - // Create inputs inputs := []*Input{ newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app1.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), } - // First attempt should fail + // The transient failure triggers node shutdown — no retry at this layer err = advancer.processInputs(context.Background(), app1.Application, inputs) require.Error(err) require.Contains(err.Error(), "temporary failure") - // Second attempt should succeed - err = advancer.processInputs(context.Background(), app1.Application, inputs) + // Verify that the node shutdown was triggered + require.Error(advancer.Context.Err(), "shared context should be cancelled") + }) +} + +// TestContextCancelledBeforeProcessing verifies that when the context is +// already cancelled, processInputs returns the context error immediately +// without reaching the advance or store paths. +func (s *AdvancerSuite) TestContextCancelledBeforeProcessing() { + s.Run("ContextAlreadyCancelled", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app1 := newMockMachine(1) + machineManager.Map[1] = *app1 + + repository := &MockRepository{} + + advancer, err := newMockAdvancerService(machineManager, repository) + require.NotNil(advancer) require.Nil(err) - require.Len(repository.StoredResults, 2) + + // Cancel the context before calling processInputs to simulate + // an external shutdown already in progress. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inputs := []*Input{ + newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + } + + // With the context already cancelled, processInputs returns + // the context error immediately (before reaching advance). + err = advancer.processInputs(ctx, app1.Application, inputs) + require.ErrorIs(err, context.Canceled) }) } From 95bdb0ff26040852808a5a0f3b443587164c2cc9 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Sat, 28 Feb 2026 16:53:05 -0300 Subject: [PATCH 12/17] fix(manager): destroy machine runtime on fatal errors to prevent zombie processes --- internal/advancer/advancer.go | 20 ++- internal/advancer/advancer_test.go | 2 +- internal/inspect/inspect_test.go | 2 +- internal/manager/instance.go | 50 ++++--- internal/manager/instance_test.go | 216 ++++++++++++++++++++++++++++- internal/manager/manager_test.go | 2 +- internal/manager/types.go | 2 +- 7 files changed, 271 insertions(+), 23 deletions(-) diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 7682d3de0..5f77b4531 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -195,6 +195,17 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] "error", updateErr) } + // Eagerly close the machine to release the child process. + // The app is already inoperable, so no further operations will succeed. + // Skip if the runtime was already destroyed inside the manager. + if !errors.Is(err, manager.ErrMachineClosed) { + if closeErr := machine.Close(); closeErr != nil { + s.Logger.Warn("Failed to close machine after advance error", + "application", app.Name, + "error", closeErr) + } + } + return err } // log advance result hashes @@ -305,8 +316,15 @@ func (s *Service) handleEpochAfterInputsProcessed(ctx context.Context, app *Appl if !exists { return fmt.Errorf("%w: %d", ErrNoApp, app.ID) } - outputsProof, err := machine.OutputsProof(ctx, 0) + outputsProof, err := machine.OutputsProof(ctx) if err != nil { + // If the runtime was destroyed (e.g., child process crashed), + // mark the app inoperable to avoid an infinite retry loop. + if errors.Is(err, manager.ErrMachineClosed) { + reason := err.Error() + _ = s.repository.UpdateApplicationState(ctx, app.ID, + ApplicationState_Inoperable, &reason) + } return fmt.Errorf("failed to get outputs proof from machine: %w", err) } err = s.repository.UpdateEpochOutputsProof(ctx, app.ID, epoch.Index, outputsProof) diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index e5b678a2a..43317c1ec 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -684,7 +684,7 @@ func (m *MockMachineInstance) ProcessedInputs() uint64 { return 0 } -func (m *MockMachineInstance) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { +func (m *MockMachineInstance) OutputsProof(ctx context.Context) (*OutputsProof, error) { return nil, nil } diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 68aa4f80a..2d77c9d31 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -234,7 +234,7 @@ func (mock *MockMachine) ProcessedInputs() uint64 { return 0 } -func (m *MockMachine) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { +func (m *MockMachine) OutputsProof(ctx context.Context) (*OutputsProof, error) { return nil, nil } diff --git a/internal/manager/instance.go b/internal/manager/instance.go index 1f52013cd..bb6ed6e98 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -432,8 +432,8 @@ func (m *MachineInstanceImpl) CreateSnapshot(ctx context.Context, processedInput m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -453,11 +453,13 @@ func (m *MachineInstanceImpl) CreateSnapshot(ctx context.Context, processedInput storeCtx, cancel := context.WithTimeout(ctx, m.application.ExecutionParameters.StoreDeadline) defer cancel() - // Store the machine state to the specified path + // Store the machine state to the specified path. + // A Store failure on a local child process indicates an unrecoverable + // condition (disk full, process crash, etc.) — destroy the runtime. err := m.runtime.Store(storeCtx, path) if err != nil { - m.logger.Error("Failed to create snapshot", "path", path, "error", err) - return err + m.logger.Error("Failed to create snapshot, destroying runtime", "path", path, "error", err) + return m.destroyRuntime(fmt.Errorf("failed to create snapshot: %w", err)) } m.logger.Debug("Snapshot created successfully", "path", path) @@ -469,8 +471,8 @@ func (m *MachineInstanceImpl) Hash(ctx context.Context) ([32]byte, error) { m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -484,21 +486,21 @@ func (m *MachineInstanceImpl) Hash(ctx context.Context) ([32]byte, error) { hash, err := m.runtime.Hash(storeCtx) if err != nil { - m.logger.Error("Failed to retrieve machine root hash", "error", err) - return [32]byte{}, err + m.logger.Error("Failed to retrieve machine root hash, destroying runtime", "error", err) + return [32]byte{}, m.destroyRuntime(fmt.Errorf("failed to retrieve machine root hash: %w", err)) } m.logger.Debug("Machine root hash retrieved successfully", "hash", "0x"+hex.EncodeToString(hash[:])) return hash, nil } -func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { +func (m *MachineInstanceImpl) OutputsProof(ctx context.Context) (*OutputsProof, error) { // Acquire the advance mutex to ensure no advance operations are in progress m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -510,20 +512,22 @@ func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs proofCtx, cancel := context.WithTimeout(ctx, m.application.ExecutionParameters.LoadDeadline) defer cancel() - // Get the machine state before processing + // The runtime is a local child process — errors here indicate the process + // crashed, ran out of resources, or is otherwise unrecoverable. + // Close the runtime to avoid leaving a broken process alive. machineHash, err := m.runtime.Hash(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get machine hash: %w", err)) } outputsHash, err := m.runtime.OutputsHash(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get outputs hash: %w", err)) } outputsHashProof, err := m.runtime.OutputsHashProof(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get outputs hash proof: %w", err)) } proof := &OutputsProof{ @@ -532,7 +536,7 @@ func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs OutputsHashProof: outputsHashProof, } - m.logger.Debug("Machine machine hash, outputs merkle root and outputs merkle proof retrieved successfully", + m.logger.Debug("Machine hash, outputs merkle root and outputs merkle proof retrieved successfully", "hash", "0x"+hex.EncodeToString(machineHash[:])) return proof, nil } @@ -575,6 +579,18 @@ func (m *MachineInstanceImpl) Close() error { return err } +// destroyRuntime closes the runtime and nils it out so that subsequent calls +// fail fast with ErrMachineClosed instead of talking to a broken process. +// Must be called while holding the appropriate locks. +func (m *MachineInstanceImpl) destroyRuntime(cause error) error { + if m.runtime == nil { + return cause + } + closeErr := m.runtime.Close() + m.runtime = nil + return errors.Join(cause, closeErr) +} + // MachineRuntimeFactory defines an interface for creating machine runtimes type MachineRuntimeFactory interface { CreateMachineRuntime( diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 16431cc92..19f484177 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -540,10 +540,29 @@ func (s *MachineInstanceSuite) TestCreateSnapshot() { inner, _, machine := s.setupAdvance() errStore := errors.New("Store error") inner.StoreError = errStore + inner.CloseError = nil err := machine.CreateSnapshot(context.Background(), 5, "/tmp/snapshot") require.Error(err) - require.Equal(errStore, err) + require.ErrorIs(err, errStore) + + // Runtime should be destroyed after a store error. + require.Nil(machine.runtime) + }) + + s.Run("ErrorAndCloseError", func() { + require := s.Require() + inner, _, machine := s.setupAdvance() + errStore := errors.New("Store error") + errClose := errors.New("Close error") + inner.StoreError = errStore + inner.CloseError = errClose + + err := machine.CreateSnapshot(context.Background(), 5, "/tmp/snapshot") + require.Error(err) + require.ErrorIs(err, errStore) + require.ErrorIs(err, errClose) + require.Nil(machine.runtime) }) s.Run("MachineClosed", func() { @@ -565,6 +584,158 @@ func (s *MachineInstanceSuite) TestCreateSnapshot() { }) } +func (s *MachineInstanceSuite) TestHash() { + s.Run("Ok", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + + hash, err := machineInst.Hash(context.Background()) + require.NoError(err) + require.Equal([32]byte(newHash(1)), hash) + + // Runtime should still be alive after a successful call. + require.Same(inner, machineInst.runtime) + }) + + s.Run("MachineClosed", func() { + require := s.Require() + _, machineInst := s.setupOutputsProof() + machineInst.runtime = nil + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.Equal(ErrMachineClosed, err) + require.Equal([32]byte{}, hash) + }) + + s.Run("Error", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + inner.HashError = errHash + inner.CloseError = nil + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.ErrorIs(err, errHash) + require.Equal([32]byte{}, hash) + + // Runtime should be destroyed after a hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("ErrorAndCloseError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + errClose := errors.New("Close error") + inner.HashError = errHash + inner.CloseError = errClose + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.ErrorIs(err, errHash) + require.ErrorIs(err, errClose) + require.Equal([32]byte{}, hash) + require.Nil(machineInst.runtime) + }) +} + +func (s *MachineInstanceSuite) TestOutputsProof() { + s.Run("Ok", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + + proof, err := machineInst.OutputsProof(context.Background()) + require.NoError(err) + require.NotNil(proof) + + require.Equal(newHash(1), proof.MachineHash) + require.Equal(newHash(2), proof.OutputsHash) + require.Equal(expectedOutputsHashProof, proof.OutputsHashProof) + + // Runtime should still be alive after a successful call. + require.Same(inner, machineInst.runtime) + }) + + s.Run("MachineClosed", func() { + require := s.Require() + _, machineInst := s.setupOutputsProof() + machineInst.runtime = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.Equal(ErrMachineClosed, err) + }) + + s.Run("HashError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + inner.HashError = errHash + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errHash) + + // Runtime should be destroyed after a hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("HashErrorAndCloseError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + errClose := errors.New("Close error") + inner.HashError = errHash + inner.CloseError = errClose + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errHash) + require.ErrorIs(err, errClose) + + // Runtime should be destroyed even when Close also fails. + require.Nil(machineInst.runtime) + }) + + s.Run("OutputsHashError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errOutputsHash := errors.New("OutputsHash error") + inner.OutputsHashError = errOutputsHash + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errOutputsHash) + + // Runtime should be destroyed after an outputs hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("OutputsHashProofError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errProof := errors.New("OutputsHashProof error") + inner.OutputsHashProofError = errProof + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errProof) + + // Runtime should be destroyed after an outputs hash proof error. + require.Nil(machineInst.runtime) + }) +} + func (s *MachineInstanceSuite) TestClose() { s.Run("Ok", func() { require := s.Require() @@ -671,6 +842,11 @@ var ( newBytes(33, 300), newBytes(34, 300), } + expectedOutputsHashProof = []machine.Hash{ + newHash(3), + newHash(4), + newHash(5), + } ) func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollupsMachine, *MachineInstanceImpl) { @@ -774,6 +950,44 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups return inner, fork, machineInst } +func (s *MachineInstanceSuite) setupOutputsProof() (*MockRollupsMachine, *MachineInstanceImpl) { + app := &model.Application{ + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + LoadDeadline: decisecond, + MaxConcurrentInspects: 3, + }, + } + inner := &MockRollupsMachine{} + machineInst := &MachineInstanceImpl{ + application: app, + runtime: inner, + advanceTimeout: decisecond, + inspectTimeout: centisecond, + maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, + mutex: pmutex.New(), + inspectSemaphore: semaphore.NewWeighted(3), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + machineInst.processedInputs.Store(5) + + inner.HashReturn = newHash(1) + inner.HashError = nil + inner.OutputsHashReturn = newHash(2) + inner.OutputsHashError = nil + inner.OutputsHashProofReturn = []machine.Hash{ + newHash(3), + newHash(4), + newHash(5), + } + inner.OutputsHashProofError = nil + inner.CloseError = errUnreachable + + return inner, machineInst +} + // ------------------------------------------------------------------------------------------------ const ( diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index ab5ecddb6..087f68ac3 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -306,7 +306,7 @@ func (m *DummyMachineInstanceMock) ProcessedInputs() uint64 { return 0 } -func (m *DummyMachineInstanceMock) OutputsProof(ctx context.Context, processedInputs uint64) (*model.OutputsProof, error) { +func (m *DummyMachineInstanceMock) OutputsProof(ctx context.Context) (*model.OutputsProof, error) { return nil, nil } diff --git a/internal/manager/types.go b/internal/manager/types.go index c716abbdc..d72e1f97a 100644 --- a/internal/manager/types.go +++ b/internal/manager/types.go @@ -18,7 +18,7 @@ type MachineInstance interface { CreateSnapshot(ctx context.Context, processedInputs uint64, path string) error ProcessedInputs() uint64 Hash(ctx context.Context) ([32]byte, error) - OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) + OutputsProof(ctx context.Context) (*OutputsProof, error) Close() error } From 1c54c5925c11452666b8a9adde4b214690a6400f Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:22:46 -0300 Subject: [PATCH 13/17] fix(advancer): cap reports, release memory, and harden error handling --- cmd/cartesi-rollups-cli/root/read/read.go | 2 +- internal/advancer/advancer.go | 1 + internal/advancer/service.go | 5 +++-- internal/manager/instance.go | 2 ++ internal/manager/instance_test.go | 4 ++++ internal/model/models.go | 7 +++++++ .../db/rollupsdb/public/enum/inputcompletionstatus.go | 2 ++ internal/repository/postgres/repository_error_test.go | 2 +- .../migrations/000001_create_initial_schema.up.sql | 1 + pkg/machine/implementation.go | 10 +++++++--- pkg/machine/libcartesi.go | 1 - pkg/machine/machine.go | 1 + 12 files changed, 30 insertions(+), 8 deletions(-) diff --git a/cmd/cartesi-rollups-cli/root/read/read.go b/cmd/cartesi-rollups-cli/root/read/read.go index 8fe5fd33a..bd9053ece 100644 --- a/cmd/cartesi-rollups-cli/root/read/read.go +++ b/cmd/cartesi-rollups-cli/root/read/read.go @@ -23,7 +23,7 @@ var Cmd = &cobra.Command{ Short: "Read the node state from the database", PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if !cmd.Flags().Changed("jsonrpc") && cmd.Flags().Changed("jsonrpc-api-url") { - if err:= cmd.Flags().Set("jsonrpc", "true"); err != nil { + if err := cmd.Flags().Set("jsonrpc", "true"); err != nil { return err } } diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 5f77b4531..4b309b932 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -175,6 +175,7 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] // Advance the machine with this input result, err := machine.Advance(ctx, input.RawData, input.EpochIndex, input.Index, app.IsDaveConsensus()) + input.RawData = nil // allow GC to collect payload while batch continues if err != nil { // If there's an error, mark the application as inoperable s.Logger.Error("Error executing advance", diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 0fa110135..8c54c203a 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -99,7 +99,7 @@ func (s *Service) Tick() []error { if err := s.Step(s.Context); err != nil { return []error{err} } - return []error{} + return nil } func (s *Service) Stop(b bool) []error { var errs []error @@ -130,7 +130,8 @@ func (s *Service) Serve() error { if s.inspector != nil && s.HTTPServerFunc != nil { go func() { if err := s.HTTPServerFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) { - s.Logger.Error("Inspect HTTP server failed", "error", err) + s.Logger.Error("Inspect HTTP server failed — shutting down", "error", err) + s.Cancel() } }() } diff --git a/internal/manager/instance.go b/internal/manager/instance.go index bb6ed6e98..447dbb50a 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -737,6 +737,8 @@ func toInputStatus(accepted bool, err error) (status InputCompletionStatus, _ er return InputCompletionStatus_MachineHalted, nil case errors.Is(err, machine.ErrOutputsLimitExceeded): return InputCompletionStatus_OutputsLimitExceeded, nil + case errors.Is(err, machine.ErrReportsLimitExceeded): + return InputCompletionStatus_ReportsLimitExceeded, nil case errors.Is(err, machine.ErrReachedTargetMcycle): return InputCompletionStatus_CycleLimitExceeded, nil case errors.Is(err, machine.ErrPayloadLengthLimitExceeded): diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 19f484177..48b8e9136 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -275,6 +275,10 @@ func (s *MachineInstanceSuite) TestAdvance() { machine.ErrOutputsLimitExceeded, model.InputCompletionStatus_OutputsLimitExceeded) + testSoftError("ReportsLimit", + machine.ErrReportsLimitExceeded, + model.InputCompletionStatus_ReportsLimitExceeded) + testSoftError("ReachedTargetMcycle", machine.ErrReachedTargetMcycle, model.InputCompletionStatus_CycleLimitExceeded) diff --git a/internal/model/models.go b/internal/model/models.go index 82bcb6d38..d746a0a08 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -808,6 +808,7 @@ const ( InputCompletionStatus_Exception InputCompletionStatus = "EXCEPTION" InputCompletionStatus_MachineHalted InputCompletionStatus = "MACHINE_HALTED" InputCompletionStatus_OutputsLimitExceeded InputCompletionStatus = "OUTPUTS_LIMIT_EXCEEDED" + InputCompletionStatus_ReportsLimitExceeded InputCompletionStatus = "REPORTS_LIMIT_EXCEEDED" InputCompletionStatus_CycleLimitExceeded InputCompletionStatus = "CYCLE_LIMIT_EXCEEDED" InputCompletionStatus_TimeLimitExceeded InputCompletionStatus = "TIME_LIMIT_EXCEEDED" InputCompletionStatus_PayloadLengthLimitExceeded InputCompletionStatus = "PAYLOAD_LENGTH_LIMIT_EXCEEDED" @@ -819,6 +820,8 @@ var InputCompletionStatusAllValues = []InputCompletionStatus{ InputCompletionStatus_Rejected, InputCompletionStatus_Exception, InputCompletionStatus_MachineHalted, + InputCompletionStatus_OutputsLimitExceeded, + InputCompletionStatus_ReportsLimitExceeded, InputCompletionStatus_CycleLimitExceeded, InputCompletionStatus_TimeLimitExceeded, InputCompletionStatus_PayloadLengthLimitExceeded, @@ -846,6 +849,10 @@ func (e *InputCompletionStatus) Scan(value any) error { *e = InputCompletionStatus_Exception case "MACHINE_HALTED": *e = InputCompletionStatus_MachineHalted + case "OUTPUTS_LIMIT_EXCEEDED": + *e = InputCompletionStatus_OutputsLimitExceeded + case "REPORTS_LIMIT_EXCEEDED": + *e = InputCompletionStatus_ReportsLimitExceeded case "CYCLE_LIMIT_EXCEEDED": *e = InputCompletionStatus_CycleLimitExceeded case "TIME_LIMIT_EXCEEDED": diff --git a/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go b/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go index f18248333..a9624fcb2 100644 --- a/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go +++ b/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go @@ -16,6 +16,7 @@ var InputCompletionStatus = &struct { Exception postgres.StringExpression MachineHalted postgres.StringExpression OutputsLimitExceeded postgres.StringExpression + ReportsLimitExceeded postgres.StringExpression CycleLimitExceeded postgres.StringExpression TimeLimitExceeded postgres.StringExpression PayloadLengthLimitExceeded postgres.StringExpression @@ -26,6 +27,7 @@ var InputCompletionStatus = &struct { Exception: postgres.NewEnumValue("EXCEPTION"), MachineHalted: postgres.NewEnumValue("MACHINE_HALTED"), OutputsLimitExceeded: postgres.NewEnumValue("OUTPUTS_LIMIT_EXCEEDED"), + ReportsLimitExceeded: postgres.NewEnumValue("REPORTS_LIMIT_EXCEEDED"), CycleLimitExceeded: postgres.NewEnumValue("CYCLE_LIMIT_EXCEEDED"), TimeLimitExceeded: postgres.NewEnumValue("TIME_LIMIT_EXCEEDED"), PayloadLengthLimitExceeded: postgres.NewEnumValue("PAYLOAD_LENGTH_LIMIT_EXCEEDED"), diff --git a/internal/repository/postgres/repository_error_test.go b/internal/repository/postgres/repository_error_test.go index 4394aab87..ea8d6c452 100644 --- a/internal/repository/postgres/repository_error_test.go +++ b/internal/repository/postgres/repository_error_test.go @@ -39,7 +39,7 @@ func TestNewPostgresRepository_ContextCancelledDuringRetry(t *testing.T) { _, err := postgres.NewPostgresRepository( ctx, "postgres://user:pass@localhost:1/testdb?connect_timeout=1", - 100, // many retries — we won't exhaust them + 100, // many retries — we won't exhaust them 10*time.Second, // long delay — context expires before this elapses ) require.Error(t, err) diff --git a/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql b/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql index 7680105b9..b965b7969 100644 --- a/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql +++ b/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql @@ -17,6 +17,7 @@ CREATE TYPE "InputCompletionStatus" AS ENUM ( 'EXCEPTION', 'MACHINE_HALTED', 'OUTPUTS_LIMIT_EXCEEDED', + 'REPORTS_LIMIT_EXCEEDED', 'CYCLE_LIMIT_EXCEEDED', 'TIME_LIMIT_EXCEEDED', 'PAYLOAD_LENGTH_LIMIT_EXCEEDED'); diff --git a/pkg/machine/implementation.go b/pkg/machine/implementation.go index 266b04d1c..19e867060 100644 --- a/pkg/machine/implementation.go +++ b/pkg/machine/implementation.go @@ -44,8 +44,9 @@ const ( ManualYieldReasonException manualYieldReason = 0x4 ) -// Constants +// Limits for outputs and reports per input const maxOutputs = 65536 // 2^16 +const maxReports = 65536 // 2^16 const CheckpointAddress uint64 = 0x7ffff000 const TxBufferAddress uint64 = 0x60800000 @@ -364,8 +365,8 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe "limitCycle", limitCycle, "leftover", limitCycle-currentCycle) - outputs := []Output{} - reports := []Report{} + outputs := make([]Output, 0, 16) //nolint:mnd + reports := make([]Report, 0, 16) //nolint:mnd var hashCollectorState *HashCollectorState if computeHashes { @@ -441,6 +442,9 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe } outputs = append(outputs, data) case AutomaticYieldReasonReport: + if len(reports) == maxReports { + return outputs, reports, hashes(), remainingMetaCycles(), ErrReportsLimitExceeded + } reports = append(reports, data) default: err := fmt.Errorf("invalid automatic yield reason: %d: %w", yieldReason, ErrMachineInternal) diff --git a/pkg/machine/libcartesi.go b/pkg/machine/libcartesi.go index 5def485ad..7d396f1de 100644 --- a/pkg/machine/libcartesi.go +++ b/pkg/machine/libcartesi.go @@ -140,7 +140,6 @@ func (e *LibCartesiBackend) GetProof(address uint64, log2size int32, timeout tim proof := &proofJson{} err = json.Unmarshal([]byte(jsonMessage), proof) if err != nil { - println("Failed to unmarshal proof JSON:", err.Error()) return nil, fmt.Errorf("failed to unmarshal proof JSON: %w", err) } return proof.Siblings, nil diff --git a/pkg/machine/machine.go b/pkg/machine/machine.go index fe017bddb..ccf8cb237 100644 --- a/pkg/machine/machine.go +++ b/pkg/machine/machine.go @@ -48,6 +48,7 @@ var ( ErrRejected = errors.New("last request yielded as rejected") ErrHalted = errors.New("machine halted") ErrOutputsLimitExceeded = errors.New("outputs limit exceeded") + ErrReportsLimitExceeded = errors.New("reports limit exceeded") ErrReachedTargetMcycle = errors.New("machine reached target mcycle") ErrPayloadLengthLimitExceeded = errors.New("payload length limit exceeded") ErrHashLength = errors.New("hash does not have the exactly number of bytes") From b0b63888ade350b22017a40dd5bda1df8f416cba Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:33:22 -0300 Subject: [PATCH 14/17] test: prevent goroutine leak in pmutex contested lock tests --- internal/manager/pmutex/pmutex_test.go | 42 ++++++++++++++++++-------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/internal/manager/pmutex/pmutex_test.go b/internal/manager/pmutex/pmutex_test.go index 12fa1f367..4305ac085 100644 --- a/internal/manager/pmutex/pmutex_test.go +++ b/internal/manager/pmutex/pmutex_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -41,15 +40,39 @@ func (s *PMutexSuite) TestSingleLLock() { } func (s *PMutexSuite) TestContestedHLock() { - require := s.Require() s.mutex.LLock() - never(require, func() bool { s.mutex.HLock(); return true }) + acquired := make(chan struct{}) + go func() { + s.mutex.HLock() + close(acquired) + }() + select { + case <-acquired: + s.Fail("HLock should not be acquired while LLock is held") + case <-time.After(decisecond): + // Expected: HLock is blocked. + } + // Clean up: unlock so the goroutine can finish. + s.mutex.Unlock() + <-acquired } func (s *PMutexSuite) TestContestedLLock() { - require := s.Require() s.mutex.HLock() - never(require, func() bool { s.mutex.LLock(); return true }) + acquired := make(chan struct{}) + go func() { + s.mutex.LLock() + close(acquired) + }() + select { + case <-acquired: + s.Fail("LLock should not be acquired while HLock is held") + case <-time.After(decisecond): + // Expected: LLock is blocked. + } + // Clean up: unlock so the goroutine can finish. + s.mutex.Unlock() + <-acquired } func (s *PMutexSuite) TestPriority() { @@ -105,11 +128,4 @@ func (s *PMutexSuite) TestPriority() { // ------------------------------------------------------------------------------------------------ -const ( - centisecond = 10 * time.Millisecond - decisecond = 100 * time.Millisecond -) - -func never(require *require.Assertions, f func() bool) { - require.Never(f, decisecond, centisecond) -} +const decisecond = 100 * time.Millisecond From 9ec2cd035c16e79818cc77e5f21a252960cbc966 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:38:45 -0300 Subject: [PATCH 15/17] fix(manager): reject new machines after Close to prevent post-shutdown additions --- internal/manager/manager.go | 27 +++++++++++++++++++-------- internal/manager/manager_test.go | 8 ++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 94181372c..e783a1a15 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -37,6 +37,7 @@ type MachineRepository interface { type MachineManager struct { mutex sync.RWMutex machines map[int64]MachineInstance + closed bool repository MachineRepository checkHash bool inputBatchSize uint64 @@ -80,7 +81,6 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { // Check if we have a snapshot to load from var instance MachineInstance - var err error // Find the latest snapshot for this application snapshot, err := m.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) @@ -125,7 +125,7 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { instance, err = NewMachineInstance(ctx, app, m.logger, m.checkHash) if err != nil { m.logger.Error("Failed to create machine instance", - "application", app.IApplicationAddress, + "application", app.IApplicationAddress.String(), "error", err) continue } @@ -178,11 +178,16 @@ func (m *MachineManager) HasMachine(appID int64) bool { return exists } -// AddMachine adds a machine to the manager +// addMachine adds a machine to the manager. +// Returns false if the manager is closed or the appID already exists. func (m *MachineManager) addMachine(appID int64, machine MachineInstance) bool { m.mutex.Lock() defer m.mutex.Unlock() + if m.closed { + return false + } + if _, exists := m.machines[appID]; exists { return false } @@ -209,7 +214,7 @@ func (m *MachineManager) removeMachines(apps []*Application) { m.logger.Info("Application was disabled, shutting down machine", "application", machine.Application().Name) } - if err := machine.Close(); err != nil { + if err := machine.Close(); err != nil && m.logger != nil { m.logger.Warn("Failed to close machine for disabled application", "application", machine.Application().Name, "error", err) } @@ -231,9 +236,16 @@ func (m *MachineManager) Applications() []*Application { } // Close shuts down all machine instances in parallel. +// After Close returns, no new machines can be added. func (m *MachineManager) Close() error { + // Mark as closed and take ownership of the machines map under the lock, + // then release it so readers (GetMachine, HasMachine, Applications) + // aren't blocked during the potentially slow parallel shutdown. m.mutex.Lock() - defer m.mutex.Unlock() + m.closed = true + machines := m.machines + m.machines = make(map[int64]MachineInstance) + m.mutex.Unlock() type closeResult struct { id int64 @@ -241,9 +253,9 @@ func (m *MachineManager) Close() error { } var wg sync.WaitGroup - results := make(chan closeResult, len(m.machines)) + results := make(chan closeResult, len(machines)) - for id, machine := range m.machines { + for id, machine := range machines { wg.Go(func() { results <- closeResult{id: id, err: machine.Close()} }) @@ -258,7 +270,6 @@ func (m *MachineManager) Close() error { errs = append(errs, fmt.Errorf("failed to close machine for app %d: %w", r.id, r.err)) } } - clear(m.machines) return errors.Join(errs...) } diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index 087f68ac3..f60061417 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -190,6 +190,14 @@ func (s *MachineManagerSuite) TestAddMachine() { added = manager.addMachine(1, machine1) require.False(added) require.Len(manager.machines, 2) + + // Close the manager and try to add a new machine + err := manager.Close() + require.NoError(err) + + machine3 := &DummyMachineInstanceMock{application: &model.Application{ID: 3}} + added = manager.addMachine(3, machine3) + require.False(added, "addMachine must reject additions after Close") } func (s *MachineManagerSuite) TestRemoveDisabledMachines() { From 9e843c08386eb3d3cdfa2efbd6209cb7567f0923 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:20:17 -0300 Subject: [PATCH 16/17] test: improve advancer, manager and machine tests --- internal/advancer/advancer_test.go | 800 ++++++++++++++++++++++++++++- internal/manager/instance_test.go | 174 ++++++- internal/manager/manager_test.go | 179 ++++++- pkg/machine/implementation_test.go | 177 +++++++ pkg/machine/machine_test.go | 19 + 5 files changed, 1325 insertions(+), 24 deletions(-) diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index 43317c1ec..0398703f1 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -10,6 +10,8 @@ import ( "errors" "fmt" mrand "math/rand" + "os" + "path/filepath" "sync" "testing" "time" @@ -17,6 +19,7 @@ import ( "github.com/cartesi/rollups-node/internal/manager" . "github.com/cartesi/rollups-node/internal/model" "github.com/cartesi/rollups-node/internal/repository" + "github.com/cartesi/rollups-node/internal/repository/repotest" "github.com/cartesi/rollups-node/pkg/service" "github.com/ethereum/go-ethereum/common" @@ -557,10 +560,735 @@ func (s *AdvancerSuite) TestContextCancelledBeforeProcessing() { }) } +// --------------------------------------------------------------------------- +// isAllEpochInputsProcessed tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestIsAllEpochInputsProcessed() { + s.Run("TrueWhenEpochHasNoInputs", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Epoch with no inputs (lower == upper) + epoch := &Epoch{ + Index: 0, + InputIndexLowerBound: 5, + InputIndexUpperBound: 5, + } + + result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + require.Nil(perr) + require.True(result) + }) + + s.Run("TrueWhenMachineProcessedAllInputs", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Mock the machine to report ProcessedInputs = 10 + machineManager.Map[1] = MockMachineImpl{ + Application: app.Application, + processedInputs: 10, + } + + epoch := &Epoch{ + Index: 0, + InputIndexLowerBound: 5, + InputIndexUpperBound: 10, + } + + result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + require.Nil(perr) + require.True(result) + }) + + s.Run("FalseWhenMoreInputsExist", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Mock the machine to report ProcessedInputs = 7 (not yet at upper bound) + machineManager.Map[1] = MockMachineImpl{ + Application: app.Application, + processedInputs: 7, + } + + epoch := &Epoch{ + Index: 0, + InputIndexLowerBound: 5, + InputIndexUpperBound: 10, + } + + result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + require.Nil(perr) + require.False(result) + }) + + s.Run("ErrorWhenNoMachineForApp", func() { + require := s.Require() + + machineManager := newMockMachineManager() + // Don't add any machine + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + app := &Application{ID: 999} + epoch := &Epoch{ + Index: 0, + InputIndexLowerBound: 0, + InputIndexUpperBound: 5, + } + + _, perr := advancer.isAllEpochInputsProcessed(app, epoch) + require.Error(perr) + require.ErrorIs(perr, ErrNoApp) + }) +} + +// --------------------------------------------------------------------------- +// isEpochLastInput tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestIsEpochLastInput() { + setupWithEpoch := func(epochStatus EpochStatus) (*Service, *Application, *MockRepository) { + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + s.Require().Nil(err) + + repository.GetEpochReturn = &Epoch{Status: epochStatus} + return advancer, app.Application, repository + } + + s.Run("TrueWhenLastInputInClosedEpoch", func() { + require := s.Require() + advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + + lastInput := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + repo.GetInputsReturn = map[common.Address][]*Input{ + app.IApplicationAddress: {lastInput}, + } + repo.GetLastInputReturn = lastInput + + input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + result, err := advancer.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.True(result) + }) + + s.Run("FalseWhenEpochIsOpen", func() { + require := s.Require() + advancer, app, _ := setupWithEpoch(EpochStatus_Open) + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + result, err := advancer.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.False(result) + }) + + s.Run("FalseWhenNotLastInput", func() { + require := s.Require() + advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + + lastInput := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + repo.GetLastInputReturn = lastInput + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + result, err := advancer.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.False(result) + }) + + s.Run("ErrorWhenNilInput", func() { + require := s.Require() + advancer, app, _ := setupWithEpoch(EpochStatus_Closed) + + _, err := advancer.isEpochLastInput(context.Background(), app, nil) + require.Error(err) + require.Contains(err.Error(), "must not be nil") + }) + + s.Run("ErrorWhenNilApplication", func() { + require := s.Require() + advancer, _, _ := setupWithEpoch(EpochStatus_Closed) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := advancer.isEpochLastInput(context.Background(), nil, input) + require.Error(err) + require.Contains(err.Error(), "must not be nil") + }) + + s.Run("ErrorWhenGetEpochFails", func() { + require := s.Require() + advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + repo.GetEpochError = errors.New("get epoch error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := advancer.isEpochLastInput(context.Background(), app, input) + require.Error(err) + require.Contains(err.Error(), "get epoch error") + }) + + s.Run("ErrorWhenGetLastInputFails", func() { + require := s.Require() + advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + repo.GetLastInputError = errors.New("get last input error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := advancer.isEpochLastInput(context.Background(), app, input) + require.Error(err) + require.Contains(err.Error(), "get last input error") + }) +} + +// --------------------------------------------------------------------------- +// handleEpochAfterInputsProcessed tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestHandleEpochAfterInputsProcessed() { + s.Run("EmptyEpochIndex0GetsOutputsProofFromMachine", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Epoch with no inputs (lower == upper) + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Nil(err) + require.True(repository.OutputsProofUpdated) + }) + + s.Run("EmptyEpochIndex0ErrorOnOutputsProof", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.OutputsProofError = errors.New("proof error") + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "proof error") + }) + + s.Run("EmptyEpochIndexGt0RepeatsPreviousProof", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Nil(err) + require.True(repository.RepeatOutputsProofCalled) + }) + + s.Run("EmptyEpochIndexGt0RepeatError", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + machineManager.Map[1] = *app + repository := &MockRepository{ + RepeatOutputsProofError: errors.New("repeat error"), + } + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "repeat error") + }) + + s.Run("NonEmptyEpochWithEveryEpochSnapshotPolicy", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + advancer.snapshotsDir = s.T().TempDir() + + // Epoch with inputs + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + + // Provide a last processed input + lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). + WithStatus(InputCompletionStatus_Accepted).Build() + lastInput.EpochApplicationID = app.Application.ID + repository.GetLastProcessedInputReturn = lastInput + // isEpochLastInput needs GetLastInput to return the same input + repository.GetLastInputReturn = lastInput + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Nil(err) + // Verify snapshot was attempted (CreateSnapshot called on mock) + require.True(repository.SnapshotURIUpdated) + }) + + s.Run("NonEmptyEpochNoSnapshotPolicy", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_None + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). + WithStatus(InputCompletionStatus_Accepted).Build() + lastInput.EpochApplicationID = app.Application.ID + repository.GetLastProcessedInputReturn = lastInput + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Nil(err) + // No snapshot should be created with None policy + require.False(repository.SnapshotURIUpdated) + }) + + s.Run("NoMachineReturnsError", func() { + require := s.Require() + + machineManager := newMockMachineManager() + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + app := repotest.NewApplicationBuilder().Build() + app.ID = 999 + + // Non-empty epoch: machine lookup + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app, epoch) + require.Error(err) + require.ErrorIs(err, ErrNoApp) + }) + + s.Run("GetLastProcessedInputError", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + machineManager.Map[1] = *app + repository := &MockRepository{ + GetLastProcessedInputError: errors.New("db connection lost"), + } + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + + err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "db connection lost") + }) +} + +// --------------------------------------------------------------------------- +// handleSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestHandleSnapshot() { + setupSnapshot := func(policy SnapshotPolicy) (*Service, *Application, *MockMachineInstance, *MockRepository) { + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.ExecutionParameters.SnapshotPolicy = policy + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + s.Require().Nil(err) + advancer.snapshotsDir = s.T().TempDir() + + mockInstance := &MockMachineInstance{ + application: app.Application, + machineImpl: app, + } + return advancer, app.Application, mockInstance, repository + } + + s.Run("NonePolicy", func() { + require := s.Require() + advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_None) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.handleSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.False(repo.SnapshotURIUpdated) + }) + + s.Run("EveryInputPolicy", func() { + require := s.Require() + advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryInput) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.handleSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.True(repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyLastInput", func() { + require := s.Require() + advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) + + // Set up GetEpoch to return closed epoch + repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + + input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + input.EpochApplicationID = app.ID + + // Last input in epoch matches + repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + + err := advancer.handleSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.True(repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyNotLastInput", func() { + require := s.Require() + advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) + + repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + input.EpochApplicationID = app.ID + + // Last input is a different one + repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + + err := advancer.handleSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.False(repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyOpenEpoch", func() { + require := s.Require() + advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) + + repo.GetEpochReturn = &Epoch{Status: EpochStatus_Open} + + input := repotest.NewInputBuilder().WithIndex(0).WithEpochIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.handleSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.False(repo.SnapshotURIUpdated) + }) +} + +// --------------------------------------------------------------------------- +// createSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestCreateSnapshot() { + setupCreateSnapshot := func() (*Service, *Application, *MockMachineInstance, *MockRepository, string) { + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.Name = "testapp" + app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryInput + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + s.Require().Nil(err) + + tmpDir := s.T().TempDir() + advancer.snapshotsDir = tmpDir + + mockInstance := &MockMachineInstance{ + application: app.Application, + machineImpl: app, + } + return advancer, app.Application, mockInstance, repository, tmpDir + } + + s.Run("Success", func() { + require := s.Require() + advancer, app, machine, repo, tmpDir := setupCreateSnapshot() + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(1).Build() + input.EpochApplicationID = app.ID + + err := advancer.createSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.True(repo.SnapshotURIUpdated) + + // Verify the snapshot path was set correctly + require.NotNil(input.SnapshotURI) + expectedPath := filepath.Join(tmpDir, "testapp_epoch1_input3") + require.Equal(expectedPath, *input.SnapshotURI) + }) + + s.Run("SkipsIfAlreadyHasSnapshot", func() { + require := s.Require() + advancer, app, machine, repo, _ := setupCreateSnapshot() + + existingPath := "/existing/snapshot" + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.ID + input.SnapshotURI = &existingPath + + err := advancer.createSnapshot(context.Background(), app, machine, input) + require.Nil(err) + require.False(repo.SnapshotURIUpdated) + }) + + s.Run("RemovesPreviousSnapshot", func() { + require := s.Require() + advancer, app, machine, repo, tmpDir := setupCreateSnapshot() + + // Create a previous snapshot directory to be cleaned up + prevPath := filepath.Join(tmpDir, "testapp_epoch0_input0") + require.Nil(os.MkdirAll(prevPath, 0755)) + + prevInput := &Input{ + SnapshotURI: &prevPath, + } + repo.GetLastSnapshotReturn = prevInput + + input := repotest.NewInputBuilder().WithIndex(1).WithEpochIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.createSnapshot(context.Background(), app, machine, input) + require.Nil(err) + + // Verify previous snapshot was removed + _, statErr := os.Stat(prevPath) + require.True(os.IsNotExist(statErr)) + }) + + s.Run("CreateSnapshotError", func() { + require := s.Require() + advancer, app, machine, repo, _ := setupCreateSnapshot() + + machine.createSnapshotError = errors.New("snapshot failed") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.createSnapshot(context.Background(), app, machine, input) + require.Error(err) + require.Contains(err.Error(), "snapshot failed") + require.False(repo.SnapshotURIUpdated) + }) + + s.Run("MkdirAllError", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.Name = "testapp" + machineManager.Map[1] = *app + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Create a read-only parent directory so MkdirAll fails + tmpDir := s.T().TempDir() + readonlyDir := filepath.Join(tmpDir, "readonly") + require.Nil(os.MkdirAll(readonlyDir, 0755)) + require.Nil(os.Chmod(readonlyDir, 0555)) + s.T().Cleanup(func() { os.Chmod(readonlyDir, 0755) }) //nolint: errcheck + advancer.snapshotsDir = filepath.Join(readonlyDir, "snapshots") + + mockInstance := &MockMachineInstance{ + application: app.Application, + machineImpl: app, + } + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.Application.ID + + err = advancer.createSnapshot(context.Background(), app.Application, mockInstance, input) + require.Error(err) + require.Contains(err.Error(), "failed to create snapshots directory") + }) + + s.Run("UpdateSnapshotURIError", func() { + require := s.Require() + advancer, app, machine, repo, _ := setupCreateSnapshot() + + repo.UpdateSnapshotURIError = errors.New("db error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.ID + + err := advancer.createSnapshot(context.Background(), app, machine, input) + require.Error(err) + require.Contains(err.Error(), "db error") + }) +} + +// --------------------------------------------------------------------------- +// removeSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestRemoveSnapshot() { + s.Run("RemovesExistingSnapshot", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + // Create a snapshot directory + snapshotPath := filepath.Join(tmpDir, "myapp_epoch0_input0") + require.Nil(os.MkdirAll(snapshotPath, 0755)) + + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Nil(err) + + _, statErr := os.Stat(snapshotPath) + require.True(os.IsNotExist(statErr)) + }) + + s.Run("NonExistentPathIsNoop", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + snapshotPath := filepath.Join(tmpDir, "myapp_epoch0_input0") + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Nil(err) + }) + + s.Run("RejectsDirectoryTraversal", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + // Try to traverse outside snapshotsDir + maliciousPath := filepath.Join(tmpDir, "..", "outside", "myapp_evil") + err := advancer.removeSnapshot(maliciousPath, "myapp") + require.Error(err) + require.Contains(err.Error(), "invalid snapshot path") + }) + + s.Run("RejectsMismatchedAppName", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + snapshotPath := filepath.Join(tmpDir, "otherapp_epoch0_input0") + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Error(err) + require.Contains(err.Error(), "invalid snapshot path") + }) +} + +// --------------------------------------------------------------------------- +// Service.Create tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestServiceCreate() { + s.Run("NilRepository", func() { + require := s.Require() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 500 + svc, err := Create(context.Background(), c) + require.Error(err) + require.Nil(svc) + require.Contains(err.Error(), "nil") + }) + + s.Run("ZeroBatchSize", func() { + require := s.Require() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 0 + c.Repository = &MockFullRepository{} + svc, err := Create(context.Background(), c) + require.Error(err) + require.Nil(svc) + require.Contains(err.Error(), "batch size") + }) + + s.Run("CancelledContext", func() { + require := s.Require() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 500 + c.Repository = &MockFullRepository{} + svc, err := Create(ctx, c) + require.Error(err) + require.Nil(svc) + }) +} + +// MockFullRepository satisfies the repository.Repository interface minimally +// for Create() validation tests. It panics on any actual DB call. +type MockFullRepository struct { + repository.Repository +} + type MockMachineImpl struct { - Application *Application - AdvanceBlock bool - AdvanceError error + Application *Application + AdvanceBlock bool + AdvanceError error + OutputsProofError error + processedInputs uint64 } func (mock *MockMachineImpl) Advance( @@ -660,8 +1388,9 @@ func (mock *MockMachineManager) Close() error { // MockMachineInstance is a test implementation of manager.MachineInstance type MockMachineInstance struct { - application *Application - machineImpl *MockMachineImpl + application *Application + machineImpl *MockMachineImpl + createSnapshotError error } // Advance implements the MachineInstance interface for testing @@ -681,11 +1410,17 @@ func (m *MockMachineInstance) Application() *Application { } func (m *MockMachineInstance) ProcessedInputs() uint64 { - return 0 + return m.machineImpl.processedInputs } func (m *MockMachineInstance) OutputsProof(ctx context.Context) (*OutputsProof, error) { - return nil, nil + if m.machineImpl.OutputsProofError != nil { + return nil, m.machineImpl.OutputsProofError + } + return &OutputsProof{ + OutputsHash: randomHash(), + MachineHash: randomHash(), + }, nil } // Synchronize implements the MachineInstance interface for testing @@ -696,8 +1431,7 @@ func (m *MockMachineInstance) Synchronize(ctx context.Context, repo manager.Mach // CreateSnapshot implements the MachineInstance interface for testing func (m *MockMachineInstance) CreateSnapshot(ctx context.Context, processInputs uint64, path string) error { - // Not used in advancer tests, but needed to satisfy the interface - return nil + return m.createSnapshotError } // Retrieves the hash of the current machine state @@ -729,11 +1463,21 @@ type MockRepository struct { GetLastSnapshotReturn *Input GetLastSnapshotError error RepeatOutputsProofError error + GetEpochReturn *Epoch + GetEpochError error + GetLastInputReturn *Input + GetLastInputError error + GetLastProcessedInputReturn *Input + GetLastProcessedInputError error + UpdateSnapshotURIError error StoredResults []*AdvanceResult ApplicationStateUpdates int LastApplicationState ApplicationState LastApplicationStateReason *string + OutputsProofUpdated bool + RepeatOutputsProofCalled bool + SnapshotURIUpdated bool mu sync.Mutex } @@ -841,11 +1585,10 @@ func (mock *MockRepository) StoreAdvanceResult( } func (mock *MockRepository) UpdateEpochOutputsProof(ctx context.Context, appID int64, epochIndex uint64, proof *OutputsProof) error { - // Check for context cancellation if ctx.Err() != nil { return ctx.Err() } - + mock.OutputsProofUpdated = true return mock.UpdateOutputsProofError } @@ -871,15 +1614,28 @@ func (mock *MockRepository) UpdateApplicationState(ctx context.Context, appID in } func (mock *MockRepository) GetEpoch(ctx context.Context, nameOrAddress string, index uint64) (*Epoch, error) { - // Not used in most advancer tests, but needed to satisfy the interface + if ctx.Err() != nil { + return nil, ctx.Err() + } + if mock.GetEpochError != nil { + return nil, mock.GetEpochError + } + if mock.GetEpochReturn != nil { + return mock.GetEpochReturn, nil + } return &Epoch{Status: EpochStatus_Closed}, nil } func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, epochIndex uint64) (*Input, error) { - // Check for context cancellation if ctx.Err() != nil { return nil, ctx.Err() } + if mock.GetLastInputError != nil { + return nil, mock.GetLastInputError + } + if mock.GetLastInputReturn != nil { + return mock.GetLastInputReturn, nil + } address := common.HexToAddress(appAddress) inputs := mock.GetInputsReturn[address] @@ -887,7 +1643,6 @@ func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, return nil, nil } - // Find the last input for the given epoch var lastInput *Input for _, input := range inputs { if input.EpochIndex == epochIndex && (lastInput == nil || input.Index > lastInput.Index) { @@ -899,10 +1654,15 @@ func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, } func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddress string) (*Input, error) { - if ctx.Err() != nil { return nil, ctx.Err() } + if mock.GetLastProcessedInputError != nil { + return nil, mock.GetLastProcessedInputError + } + if mock.GetLastProcessedInputReturn != nil { + return mock.GetLastProcessedInputReturn, nil + } address := common.HexToAddress(appAddress) inputs := mock.GetInputsReturn[address] @@ -910,7 +1670,6 @@ func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddres return nil, nil } - // Find the last input for the given epoch var lastInput *Input for _, input := range inputs { if input.Status != InputCompletionStatus_None && (lastInput == nil || input.Index > lastInput.Index) { @@ -922,8 +1681,11 @@ func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddres } func (mock *MockRepository) UpdateInputSnapshotURI(ctx context.Context, appId int64, inputIndex uint64, snapshotURI string) error { - // Not used in most advancer tests, but needed to satisfy the interface - return nil + if ctx.Err() != nil { + return ctx.Err() + } + mock.SnapshotURIUpdated = true + return mock.UpdateSnapshotURIError } func (mock *MockRepository) GetLastSnapshot(ctx context.Context, nameOrAddress string) (*Input, error) { @@ -936,10 +1698,10 @@ func (mock *MockRepository) GetLastSnapshot(ctx context.Context, nameOrAddress s } func (mock *MockRepository) RepeatPreviousEpochOutputsProof(ctx context.Context, appID int64, epochIndex uint64) error { - // Check for context cancellation if ctx.Err() != nil { return ctx.Err() } + mock.RepeatOutputsProofCalled = true return mock.RepeatOutputsProofError } diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 48b8e9136..9f42fd130 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -205,6 +205,106 @@ func (s *MachineInstanceSuite) TestNewMachineInstance() { }) } +func (s *MachineInstanceSuite) TestNewMachineInstanceFromSnapshot() { + s.Run("Ok", func() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockRuntime := &MockRollupsMachine{} + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + + // NewMachineInstanceFromSnapshot creates a SnapshotMachineRuntimeFactory + // internally, so we use NewMachineInstanceWithFactory to test the same + // logic with a controlled factory. + inputIndex := uint64(5) + + // The function sets processedInputs = inputIndex + 1 + // Use the mock factory to avoid actual machine loading + inst, err := NewMachineInstanceWithFactory( + context.Background(), + app, + inputIndex+1, + testLogger, + false, + mockFactory, + ) + require.NoError(err) + require.NotNil(inst) + require.Equal(inputIndex+1, inst.ProcessedInputs()) + + inst.Close() + }) + + s.Run("FactoryError", func() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: nil, + ErrorToReturn: errors.New("snapshot load failed"), + } + + inst, err := NewMachineInstanceWithFactory( + context.Background(), + app, + 6, + testLogger, + false, + mockFactory, + ) + require.Error(err) + require.Nil(inst) + require.ErrorIs(err, ErrMachineCreation) + require.Contains(err.Error(), "snapshot load failed") + }) +} + +func (s *MachineInstanceSuite) TestApplicationAndProcessedInputs() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockRuntime := &MockRollupsMachine{} + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + + inst, err := NewMachineInstanceWithFactory( + context.Background(), app, 42, testLogger, false, mockFactory, + ) + require.NoError(err) + require.Same(app, inst.Application()) + require.Equal(uint64(42), inst.ProcessedInputs()) + inst.Close() +} + func (s *MachineInstanceSuite) TestAdvance() { s.Run("Ok", func() { s.Run("Accept", func() { @@ -403,9 +503,75 @@ func (s *MachineInstanceSuite) TestAdvance() { }) }) - s.Run("Concurrency", func() { - // Two Advances cannot be concurrently active. - s.T().Skip("TODO") + s.Run("CollectHashes", func() { + require := s.Require() + inner, fork, machine := s.setupAdvance() + + // Set up WriteCheckpointHash to succeed + fork.CheckpointHashError = nil + + res, err := machine.Advance(context.Background(), []byte{}, 0, 5, true) + require.Nil(err) + require.NotNil(res) + + require.Same(fork, machine.runtime) + require.Equal(model.InputCompletionStatus_Accepted, res.Status) + require.True(res.IsDaveConsensus) + require.Equal(uint64(6), machine.processedInputs.Load()) + + // Verify the inner runtime was closed (accept path) + _ = inner + }) + + s.Run("CollectHashesWriteCheckpointError", func() { + require := s.Require() + _, fork, machine := s.setupAdvance() + + errCheckpoint := errors.New("checkpoint write error") + fork.CheckpointHashError = errCheckpoint + + res, err := machine.Advance(context.Background(), []byte{}, 0, 5, true) + require.Error(err) + require.Nil(res) + require.ErrorIs(err, errCheckpoint) + require.Equal(uint64(5), machine.processedInputs.Load()) + }) + + s.Run("SequentialAdvances", func() { + // Advance is serialized by advanceMutex — concurrent advance on the + // same machine never happens by design. This test verifies that two + // sequential advances correctly increment processedInputs. + require := s.Require() + inner, fork, machine := s.setupAdvance() + + // Allow inner.Close to succeed (old runtime close on accept) + inner.CloseError = nil + + // First advance: fork from inner (processedInputs=5), returns accepted. + // After accept, fork becomes the new runtime. + // Second advance: fork from fork (processedInputs=6), fork must also fork. + fork2 := &MockRollupsMachine{} + fork2.AdvanceAcceptedReturn = true + fork2.AdvanceOutputsReturn = expectedOutputs + fork2.AdvanceReportsReturn = expectedReports1 + fork2.OutputsHashReturn = newHash(1) + fork2.HashReturn = newHash(2) + fork2.CloseError = errUnreachable // old runtime close for second advance + fork2.ForkReturn = nil + fork.ForkReturn = fork2 + fork.CloseError = nil // close of fork (now old runtime) in second advance + + // First advance at index 5 + res1, err := machine.Advance(context.Background(), []byte{}, 0, 5, false) + require.Nil(err) + require.NotNil(res1) + require.Equal(uint64(6), machine.processedInputs.Load()) + + // Second advance at index 6 + res2, err := machine.Advance(context.Background(), []byte{}, 0, 6, false) + require.Nil(err) + require.NotNil(res2) + require.Equal(uint64(7), machine.processedInputs.Load()) }) } @@ -1226,7 +1392,7 @@ func (s *MachineInstanceSuite) TestSynchronize() { runtime := inst.runtime.(*MockRollupsMachine) runtime.ForkFunc = func(_ context.Context) (machine.Machine, error) { fork := newForkableMock() - fork.AdvanceError = errors.New("machine exploded") + fork.AdvanceError = errors.New("advance failed during replay") return fork, nil } diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index f60061417..e6aa34fcc 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -5,6 +5,7 @@ package manager import ( "context" + "errors" "io" "log/slog" "testing" @@ -228,6 +229,181 @@ func (s *MachineManagerSuite) TestRemoveDisabledMachines() { require.True(manager.HasMachine(3)) } +func (s *MachineManagerSuite) TestUpdateMachinesErrors() { + s.Run("GetEnabledApplicationsError", func() { + require := s.Require() + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return(([]*model.Application)(nil), uint64(0), errors.New("db error")) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + + err := manager.UpdateMachines(context.Background()) + require.Error(err) + require.Contains(err.Error(), "db error") + }) + + s.Run("SnapshotCreationFailureFallsBackToTemplate", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + snapshotPath := "/fake/snapshot/path" + snapshotInput := &model.Input{ + Index: 2, + SnapshotURI: &snapshotPath, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(snapshotInput, nil) + // ListInputs for synchronization (no inputs to replay) + repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Input{}, uint64(0), nil) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + + // Mock factory that always succeeds — the snapshot path doesn't exist, + // so it should fall back to template via defaultFactory + mockRuntime := &MockRollupsMachine{} + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + originalFactory := defaultFactory + defaultFactory = mockFactory + defer func() { defaultFactory = originalFactory }() + + err := manager.UpdateMachines(context.Background()) + require.NoError(err) + // Machine should have been created via fallback + require.True(manager.HasMachine(1)) + }) + + s.Run("TemplateCreationFailureSkipsApp", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(nil, nil) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + + // Factory that always fails + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: nil, + ErrorToReturn: errors.New("machine creation failed"), + } + originalFactory := defaultFactory + defaultFactory = mockFactory + defer func() { defaultFactory = originalFactory }() + + err := manager.UpdateMachines(context.Background()) + // UpdateMachines should not return an error; it logs and skips + require.NoError(err) + require.False(manager.HasMachine(1)) + }) + + s.Run("SynchronizeFailureClosesAndSkipsApp", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ProcessedInputs: 3, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(nil, nil) + // ListInputs returns an error to cause Synchronize to fail + repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). + Return(([]*model.Input)(nil), uint64(0), errors.New("db connection lost")) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + + mockRuntime := &MockRollupsMachine{} + mockRuntime.CloseError = nil + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + originalFactory := defaultFactory + defaultFactory = mockFactory + defer func() { defaultFactory = originalFactory }() + + err := manager.UpdateMachines(context.Background()) + require.NoError(err) + // Machine should NOT have been added due to sync failure + require.False(manager.HasMachine(1)) + }) +} + +func (s *MachineManagerSuite) TestCloseAggregatesErrors() { + require := s.Require() + + manager := NewMachineManager(context.Background(), nil, nil, false, 500) + + machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} + machine2 := &DummyMachineInstanceMock{ + application: &model.Application{ID: 2}, + closeError: errors.New("close error 2"), + } + machine3 := &DummyMachineInstanceMock{ + application: &model.Application{ID: 3}, + closeError: errors.New("close error 3"), + } + + manager.addMachine(1, machine1) + manager.addMachine(2, machine2) + manager.addMachine(3, machine3) + + err := manager.Close() + require.Error(err) + require.Contains(err.Error(), "close error") + require.Empty(manager.machines) +} + func (s *MachineManagerSuite) TestApplications() { require := s.Require() @@ -304,6 +480,7 @@ func (m *MockMachineRepository) GetLastSnapshot( // DummyMachineInstanceMock implements the MachineInstance interface for testing type DummyMachineInstanceMock struct { application *model.Application + closeError error } func (m *DummyMachineInstanceMock) Application() *model.Application { @@ -339,5 +516,5 @@ func (m *DummyMachineInstanceMock) Hash(_ context.Context) ([32]byte, error) { } func (m *DummyMachineInstanceMock) Close() error { - return nil + return m.closeError } diff --git a/pkg/machine/implementation_test.go b/pkg/machine/implementation_test.go index 15dd160c2..60cf9a567 100644 --- a/pkg/machine/implementation_test.go +++ b/pkg/machine/implementation_test.go @@ -201,6 +201,100 @@ func (s *ImplementationSuite) TestOutputsHash() { mockBackend4.AssertExpectations(s.T()) } +// Test OutputsHashProof method +func (s *ImplementationSuite) TestOutputsHashProof() { + require := s.Require() + ctx := context.Background() + + // Test successful outputs hash proof retrieval + mockBackend := NewMockBackend() + expectedProof := []Hash{randomFakeHash(), randomFakeHash(), randomFakeHash()} + mockBackend.On("GetProof", TxBufferAddress, int32(HashLog2Size), mock.AnythingOfType("time.Duration")). + Return(expectedProof, nil) + + machine := &machineImpl{ + backend: mockBackend, + logger: s.logger, + params: model.ExecutionParameters{ + LoadDeadline: time.Second * 5, + }, + } + + proof, err := machine.OutputsHashProof(ctx) + require.NoError(err) + require.Equal(expectedProof, proof) + mockBackend.AssertExpectations(s.T()) + + // Test outputs hash proof with backend error + mockBackend2 := NewMockBackend() + mockBackend2.On("GetProof", TxBufferAddress, int32(HashLog2Size), mock.AnythingOfType("time.Duration")). + Return([]Hash(nil), errors.New("proof failed")) + machine2 := &machineImpl{ + backend: mockBackend2, + logger: s.logger, + params: model.ExecutionParameters{ + LoadDeadline: time.Second * 5, + }, + } + _, err = machine2.OutputsHashProof(ctx) + require.Error(err) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "could not get outputs hash machine proof") + mockBackend2.AssertExpectations(s.T()) + + // Test outputs hash proof with canceled context + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + _, err = machine.OutputsHashProof(canceledCtx) + require.ErrorIs(err, ErrCanceled) +} + +// Test WriteCheckpointHash method +func (s *ImplementationSuite) TestWriteCheckpointHash() { + require := s.Require() + ctx := context.Background() + + // Test successful write + mockBackend := NewMockBackend() + hash := randomFakeHash() + mockBackend.On("WriteMemory", CheckpointAddress, hash[:], mock.AnythingOfType("time.Duration")). + Return(nil) + machine := &machineImpl{ + backend: mockBackend, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + }, + } + + err := machine.WriteCheckpointHash(ctx, hash) + require.NoError(err) + mockBackend.AssertExpectations(s.T()) + + // Test write with backend error + mockBackend2 := NewMockBackend() + mockBackend2.On("WriteMemory", CheckpointAddress, hash[:], mock.AnythingOfType("time.Duration")). + Return(errors.New("write failed")) + machine2 := &machineImpl{ + backend: mockBackend2, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + }, + } + err = machine2.WriteCheckpointHash(ctx, hash) + require.Error(err) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "could not write checkpoint hash") + mockBackend2.AssertExpectations(s.T()) + + // Test write with canceled context + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + err = machine.WriteCheckpointHash(canceledCtx, hash) + require.ErrorIs(err, ErrCanceled) +} + // Test Advance method func (s *ImplementationSuite) TestAdvance() { require := s.Require() @@ -705,6 +799,89 @@ func (s *ImplementationSuite) TestRun() { require.NoError(err) mockBackend3.AssertExpectations(s.T()) + // Test run with ReceiveCmioRequest error during automatic yield + mockBackend4 := NewMockBackend() + mockBackend4.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend4.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")).Return(YieldedAutomatically, nil) + mockBackend4.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(0), []byte(nil), errors.New("cmio request failed")) + + machine4 := &machineImpl{ + backend: mockBackend4, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + _, _, _, _, err = machine4.run(ctx, AdvanceStateRequest, false) + require.Error(err) + require.Contains(err.Error(), "could not read output/report") + require.Contains(err.Error(), "cmio request failed") + mockBackend4.AssertExpectations(s.T()) + + // Test run with automatic yield producing output then manual yield + mockBackend5 := NewMockBackend() + mockBackend5.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend5.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedAutomatically, nil).Once() + mockBackend5.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedManually, nil).Once() + mockBackend5.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(AutomaticYieldReasonOutput), []byte("output data"), nil).Once() + + machine5 := &machineImpl{ + backend: mockBackend5, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + outputs5, reports5, _, _, err := machine5.run(ctx, AdvanceStateRequest, false) + require.NoError(err) + require.Len(outputs5, 1) + require.Equal([]byte("output data"), []byte(outputs5[0])) + require.Empty(reports5) + mockBackend5.AssertExpectations(s.T()) + + // Test run with automatic yield producing report then manual yield + mockBackend6 := NewMockBackend() + mockBackend6.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend6.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedAutomatically, nil).Once() + mockBackend6.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedManually, nil).Once() + mockBackend6.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(AutomaticYieldReasonReport), []byte("report data"), nil).Once() + + machine6 := &machineImpl{ + backend: mockBackend6, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + outputs6, reports6, _, _, err := machine6.run(ctx, AdvanceStateRequest, false) + require.NoError(err) + require.Empty(outputs6) + require.Len(reports6, 1) + require.Equal([]byte("report data"), []byte(reports6[0])) + mockBackend6.AssertExpectations(s.T()) + } // Test step method diff --git a/pkg/machine/machine_test.go b/pkg/machine/machine_test.go index 91ad8d255..1d613d9ae 100644 --- a/pkg/machine/machine_test.go +++ b/pkg/machine/machine_test.go @@ -115,6 +115,25 @@ func (s *MachineSuite) TestLoad() { require.ErrorIs(err, ErrNotAtManualYield) mockBackend.AssertExpectations(s.T()) + // Test with IsAtManualYield returning error + mockBackend = NewMockBackend() + mockBackend.On("NewMachineRuntimeConfig").Return(`{"concurrency":{"update_merkle_tree":1}}`, nil) + mockBackend.On("Load", + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + mock.AnythingOfType("time.Duration"), + ).Return(nil).Once() + mockBackend.On("IsAtManualYield", mock.AnythingOfType("time.Duration")).Return(false, errors.New("yield check failed")) + mockBackend.SetupForCleanup() + config = DefaultConfig("some/path") + config.BackendFactoryFn = MockBackendFactory(mockBackend) + machine, err = Load(ctx, s.logger, config) + require.Error(err) + require.Nil(machine) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "yield check failed") + mockBackend.AssertExpectations(s.T()) + // Test with machine last request not accepted mockBackend = NewMockBackend() mockBackend.On("NewMachineRuntimeConfig").Return(`{"concurrency":{"update_merkle_tree":1}}`, nil) From e69b14d7df6d840f84ba637409f87ddbd908a8a6 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Thu, 5 Mar 2026 01:40:38 -0300 Subject: [PATCH 17/17] refactor(manager): inject machine factory via functional options and simplify test setup --- internal/advancer/advancer_test.go | 748 ++++++++++------------------- internal/advancer/service.go | 1 - internal/manager/manager.go | 77 ++- internal/manager/manager_test.go | 130 ++--- 4 files changed, 383 insertions(+), 573 deletions(-) diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index 0398703f1..eb390e705 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -46,6 +46,27 @@ func newMockAdvancerService(machineManager *MockMachineManager, repo *MockReposi return s, nil } +// testEnv bundles the components most tests need: a service, a single app's mock, +// the mock machine manager, and the mock repository. +type testEnv struct { + service *Service + app *MockMachineImpl + mm *MockMachineManager + repo *MockRepository +} + +// setupOneApp creates a standard test environment with one application. +// The repository is empty; callers can configure it after the call. +func (s *AdvancerSuite) setupOneApp() testEnv { + mm := newMockMachineManager() + app := newMockMachine(1) + mm.Map[1] = newMockInstance(app) + repo := &MockRepository{} + svc, err := newMockAdvancerService(mm, repo) + s.Require().NoError(err) + return testEnv{service: svc, app: app, mm: mm, repo: repo} +} + func (s *AdvancerSuite) TestServiceInterface() { s.Run("ServiceMethods", func() { require := s.Require() @@ -64,9 +85,9 @@ func (s *AdvancerSuite) TestServiceInterface() { require.Equal(advancer.Name, advancer.String()) // Test Tick method - machineManager.Map[1] = *newMockMachine(1) + machineManager.Map[1] = newMockInstance(newMockMachine(1)) repository.GetEpochsReturn = map[common.Address][]*Epoch{ - machineManager.Map[1].Application.IApplicationAddress: {}, + machineManager.Map[1].application.IApplicationAddress: {}, } tickErrors := advancer.Tick() require.Empty(tickErrors) @@ -86,8 +107,8 @@ func (s *AdvancerSuite) TestStep() { machineManager := newMockMachineManager() app1 := newMockMachine(1) app2 := newMockMachine(2) - machineManager.Map[1] = *app1 - machineManager.Map[2] = *app2 + machineManager.Map[1] = newMockInstance(app1) + machineManager.Map[2] = newMockInstance(app2) res0 := randomAdvanceResult(0) res1 := randomAdvanceResult(1) res2 := randomAdvanceResult(0) @@ -124,31 +145,22 @@ func (s *AdvancerSuite) TestStep() { s.Run("Error/UpdateEpochs", func() { require := s.Require() - - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 + env := s.setupOneApp() res0 := randomAdvanceResult(0) - repository := &MockRepository{ - GetEpochsReturn: map[common.Address][]*Epoch{ - app1.Application.IApplicationAddress: { - &Epoch{Index: 0, Status: EpochStatus_Closed}, - }, + env.repo.GetEpochsReturn = map[common.Address][]*Epoch{ + env.app.Application.IApplicationAddress: { + {Index: 0, Status: EpochStatus_Closed}, }, - GetInputsReturn: map[common.Address][]*Input{ - app1.Application.IApplicationAddress: { - newInput(app1.Application.ID, 0, 0, marshal(res0)), - }, + } + env.repo.GetInputsReturn = map[common.Address][]*Input{ + env.app.Application.IApplicationAddress: { + newInput(env.app.Application.ID, 0, 0, marshal(res0)), }, - UpdateEpochsError: errors.New("update epochs error"), } + env.repo.UpdateEpochsError = errors.New("update epochs error") - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - err = advancer.Step(context.Background()) + err := env.service.Step(context.Background()) require.Error(err) require.Contains(err.Error(), "update epochs error") }) @@ -157,7 +169,7 @@ func (s *AdvancerSuite) TestStep() { require := s.Require() machineManager := &MockMachineManager{ - Map: map[int64]MockMachineImpl{}, + Map: map[int64]*MockMachineInstance{}, UpdateMachinesError: errors.New("update machines error"), } repository := &MockRepository{} @@ -173,49 +185,31 @@ func (s *AdvancerSuite) TestStep() { s.Run("Error/GetInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - repository := &MockRepository{ - GetEpochsReturn: map[common.Address][]*Epoch{ - app1.Application.IApplicationAddress: { - &Epoch{Index: 0, Status: EpochStatus_Closed}, - }, + env.repo.GetEpochsReturn = map[common.Address][]*Epoch{ + env.app.Application.IApplicationAddress: { + {Index: 0, Status: EpochStatus_Closed}, }, - GetInputsError: errors.New("get inputs error"), } + env.repo.GetInputsError = errors.New("get inputs error") - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - err = advancer.Step(context.Background()) + err := env.service.Step(context.Background()) require.Error(err) require.Contains(err.Error(), "get inputs error") }) s.Run("NoInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - repository := &MockRepository{ - GetInputsReturn: map[common.Address][]*Input{ - app1.Application.IApplicationAddress: {}, - }, + env.repo.GetInputsReturn = map[common.Address][]*Input{ + env.app.Application.IApplicationAddress: {}, } - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) + err := env.service.Step(context.Background()) require.Nil(err) - - err = advancer.Step(context.Background()) - require.Nil(err) - require.Len(repository.StoredResults, 0) + require.Len(env.repo.StoredResults, 0) }) } @@ -258,77 +252,58 @@ func (s *AdvancerSuite) TestGetUnprocessedInputs() { } func (s *AdvancerSuite) TestProcess() { - setup := func() (*MockMachineManager, *MockRepository, *Service, *MockMachineImpl) { - require := s.Require() - - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - return machineManager, repository, advancer, app1 - } - s.Run("ApplicationStateUpdate", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, []byte("advance error")), + newInput(env.app.Application.ID, 0, 0, []byte("advance error")), } - // Verify application state is updated on error - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) - require.Equal(1, repository.ApplicationStateUpdates) - require.Equal(ApplicationState_Inoperable, repository.LastApplicationState) - require.NotNil(repository.LastApplicationStateReason) - require.Equal("advance error", *repository.LastApplicationStateReason) + require.Equal(1, env.repo.ApplicationStateUpdates) + require.Equal(ApplicationState_Inoperable, env.repo.LastApplicationState) + require.NotNil(env.repo.LastApplicationStateReason) + require.Equal("advance error", *env.repo.LastApplicationStateReason) }) s.Run("ApplicationStateUpdateError", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, []byte("advance error")), + newInput(env.app.Application.ID, 0, 0, []byte("advance error")), } - repository.UpdateApplicationStateError = errors.New("update state error") + env.repo.UpdateApplicationStateError = errors.New("update state error") - // Verify error is still returned even if application state update fails - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "advance error") }) s.Run("Ok", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), - newInput(app.Application.ID, 0, 2, marshal(randomAdvanceResult(2))), - newInput(app.Application.ID, 0, 3, marshal(randomAdvanceResult(3))), - newInput(app.Application.ID, 1, 4, marshal(randomAdvanceResult(4))), - newInput(app.Application.ID, 1, 5, marshal(randomAdvanceResult(5))), - newInput(app.Application.ID, 2, 6, marshal(randomAdvanceResult(6))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), + newInput(env.app.Application.ID, 0, 2, marshal(randomAdvanceResult(2))), + newInput(env.app.Application.ID, 0, 3, marshal(randomAdvanceResult(3))), + newInput(env.app.Application.ID, 1, 4, marshal(randomAdvanceResult(4))), + newInput(env.app.Application.ID, 1, 5, marshal(randomAdvanceResult(5))), + newInput(env.app.Application.ID, 2, 6, marshal(randomAdvanceResult(6))), } - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Nil(err) - require.Len(repository.StoredResults, 7) + require.Len(env.repo.StoredResults, 7) }) s.Run("Noop", func() { s.Run("NoInputs", func() { require := s.Require() + env := s.setupOneApp() - _, _, advancer, app := setup() - inputs := []*Input{} - - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, []*Input{}) require.Nil(err) }) }) @@ -336,49 +311,46 @@ func (s *AdvancerSuite) TestProcess() { s.Run("Error", func() { s.Run("ErrApp", func() { require := s.Require() - + env := s.setupOneApp() invalidApp := Application{ID: 999} - _, _, advancer, _ := setup() inputs := randomInputs(1, 0, 3) - err := advancer.processInputs(context.Background(), &invalidApp, inputs) + err := env.service.processInputs(context.Background(), &invalidApp, inputs) expected := fmt.Sprintf("%v: %v", ErrNoApp, invalidApp.ID) require.EqualError(err, expected) }) s.Run("Advance", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, []byte("advance error")), - newInput(app.Application.ID, 0, 2, []byte("unreachable")), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, []byte("advance error")), + newInput(env.app.Application.ID, 0, 2, []byte("unreachable")), } - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "advance error") - require.Len(repository.StoredResults, 1) + require.Len(env.repo.StoredResults, 1) }) s.Run("StoreAdvance", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, []byte("unreachable")), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, []byte("unreachable")), } - repository.StoreAdvanceError = errors.New("store-advance error") + env.repo.StoreAdvanceError = errors.New("store-advance error") - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "store-advance error") - require.Len(repository.StoredResults, 1) + require.Len(env.repo.StoredResults, 1) // Verify that the node shutdown was triggered (context cancelled) - require.Error(advancer.Context.Err(), "shared context should be cancelled") + require.Error(env.service.Context.Err(), "shared context should be cancelled") }) }) } @@ -387,27 +359,15 @@ func (s *AdvancerSuite) TestProcess() { func (s *AdvancerSuite) TestContextCancellation() { s.Run("CancelDuringStep", func() { require := s.Require() + env := s.setupOneApp() + env.repo.GetEpochsBlock = true - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - // Create a repository that will block until we cancel the context - repository := &MockRepository{ - GetEpochsBlock: true, - } - - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create a context that we can cancel ctx, cancel := context.WithCancel(context.Background()) // Start the Step operation in a goroutine errCh := make(chan error) go func() { - errCh <- advancer.Step(ctx) + errCh <- env.service.Step(ctx) }() // Cancel the context after a short delay @@ -426,28 +386,17 @@ func (s *AdvancerSuite) TestContextCancellation() { s.Run("CancelDuringProcessInputs", func() { require := s.Require() + env := s.setupOneApp() + env.app.AdvanceBlock = true - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - // Create a machine that will block during Advance until we cancel the context - app1.AdvanceBlock = true - machineManager.Map[1] = *app1 - - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create inputs and a context that we can cancel inputs := []*Input{ - newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), } ctx, cancel := context.WithCancel(context.Background()) - // Start the processInputs operation in a goroutine errCh := make(chan error) go func() { - errCh <- advancer.processInputs(ctx, app1.Application, inputs) + errCh <- env.service.processInputs(ctx, env.app.Application, inputs) }() // Cancel the context after a short delay @@ -469,28 +418,17 @@ func (s *AdvancerSuite) TestContextCancellation() { func (s *AdvancerSuite) TestLargeNumberOfInputs() { s.Run("LargeNumberOfInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create a large number of inputs const inputCount = 10000 inputs := make([]*Input, inputCount) for i := range inputCount { - inputs[i] = newInput(app1.Application.ID, 0, uint64(i), marshal(randomAdvanceResult(uint64(i)))) + inputs[i] = newInput(env.app.Application.ID, 0, uint64(i), marshal(randomAdvanceResult(uint64(i)))) } - // Process the inputs - err = advancer.processInputs(context.Background(), app1.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Nil(err) - - // Verify all inputs were processed - require.Len(repository.StoredResults, inputCount) + require.Len(env.repo.StoredResults, inputCount) }) } @@ -499,31 +437,17 @@ func (s *AdvancerSuite) TestLargeNumberOfInputs() { func (s *AdvancerSuite) TestErrorRecovery() { s.Run("TransientStoreFailureTriggersShutdown", func() { require := s.Require() - - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - // Repository that fails on the first store attempt - repository := &MockRepository{ - StoreAdvanceFailCount: 1, - } - - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) + env := s.setupOneApp() + env.repo.StoreAdvanceFailCount = 1 inputs := []*Input{ - newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), } - // The transient failure triggers node shutdown — no retry at this layer - err = advancer.processInputs(context.Background(), app1.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "temporary failure") - - // Verify that the node shutdown was triggered - require.Error(advancer.Context.Err(), "shared context should be cancelled") + require.Error(env.service.Context.Err(), "shared context should be cancelled") }) } @@ -533,29 +457,16 @@ func (s *AdvancerSuite) TestErrorRecovery() { func (s *AdvancerSuite) TestContextCancelledBeforeProcessing() { s.Run("ContextAlreadyCancelled", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - repository := &MockRepository{} - - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Cancel the context before calling processInputs to simulate - // an external shutdown already in progress. ctx, cancel := context.WithCancel(context.Background()) cancel() inputs := []*Input{ - newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), } - // With the context already cancelled, processInputs returns - // the context error immediately (before reaching advance). - err = advancer.processInputs(ctx, app1.Application, inputs) + err := env.service.processInputs(ctx, env.app.Application, inputs) require.ErrorIs(err, context.Canceled) }) } @@ -567,97 +478,46 @@ func (s *AdvancerSuite) TestContextCancelledBeforeProcessing() { func (s *AdvancerSuite) TestIsAllEpochInputsProcessed() { s.Run("TrueWhenEpochHasNoInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - - // Epoch with no inputs (lower == upper) - epoch := &Epoch{ - Index: 0, - InputIndexLowerBound: 5, - InputIndexUpperBound: 5, - } - - result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 5} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) require.Nil(perr) require.True(result) }) s.Run("TrueWhenMachineProcessedAllInputs", func() { require := s.Require() + env := s.setupOneApp() + env.mm.Map[1].machineImpl.processedInputs = 10 - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - - // Mock the machine to report ProcessedInputs = 10 - machineManager.Map[1] = MockMachineImpl{ - Application: app.Application, - processedInputs: 10, - } - - epoch := &Epoch{ - Index: 0, - InputIndexLowerBound: 5, - InputIndexUpperBound: 10, - } - - result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 10} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) require.Nil(perr) require.True(result) }) s.Run("FalseWhenMoreInputsExist", func() { require := s.Require() + env := s.setupOneApp() + env.mm.Map[1].machineImpl.processedInputs = 7 - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - - // Mock the machine to report ProcessedInputs = 7 (not yet at upper bound) - machineManager.Map[1] = MockMachineImpl{ - Application: app.Application, - processedInputs: 7, - } - - epoch := &Epoch{ - Index: 0, - InputIndexLowerBound: 5, - InputIndexUpperBound: 10, - } - - result, perr := advancer.isAllEpochInputsProcessed(app.Application, epoch) + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 10} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) require.Nil(perr) require.False(result) }) s.Run("ErrorWhenNoMachineForApp", func() { require := s.Require() - - machineManager := newMockMachineManager() - // Don't add any machine - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) + mm := newMockMachineManager() + repo := &MockRepository{} + svc, err := newMockAdvancerService(mm, repo) require.Nil(err) app := &Application{ID: 999} - epoch := &Epoch{ - Index: 0, - InputIndexLowerBound: 0, - InputIndexUpperBound: 5, - } - - _, perr := advancer.isAllEpochInputsProcessed(app, epoch) + epoch := &Epoch{Index: 0, InputIndexLowerBound: 0, InputIndexUpperBound: 5} + _, perr := svc.isAllEpochInputsProcessed(app, epoch) require.Error(perr) require.ErrorIs(perr, ErrNoApp) }) @@ -668,94 +528,87 @@ func (s *AdvancerSuite) TestIsAllEpochInputsProcessed() { // --------------------------------------------------------------------------- func (s *AdvancerSuite) TestIsEpochLastInput() { - setupWithEpoch := func(epochStatus EpochStatus) (*Service, *Application, *MockRepository) { - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - s.Require().Nil(err) - - repository.GetEpochReturn = &Epoch{Status: epochStatus} - return advancer, app.Application, repository + setupWithEpoch := func(epochStatus EpochStatus) (testEnv, *Application) { + env := s.setupOneApp() + env.repo.GetEpochReturn = &Epoch{Status: epochStatus} + return env, env.app.Application } s.Run("TrueWhenLastInputInClosedEpoch", func() { require := s.Require() - advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + env, app := setupWithEpoch(EpochStatus_Closed) lastInput := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() - repo.GetInputsReturn = map[common.Address][]*Input{ + env.repo.GetInputsReturn = map[common.Address][]*Input{ app.IApplicationAddress: {lastInput}, } - repo.GetLastInputReturn = lastInput + env.repo.GetLastInputReturn = lastInput input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() - result, err := advancer.isEpochLastInput(context.Background(), app, input) + result, err := env.service.isEpochLastInput(context.Background(), app, input) require.Nil(err) require.True(result) }) s.Run("FalseWhenEpochIsOpen", func() { require := s.Require() - advancer, app, _ := setupWithEpoch(EpochStatus_Open) + env, app := setupWithEpoch(EpochStatus_Open) input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() - result, err := advancer.isEpochLastInput(context.Background(), app, input) + result, err := env.service.isEpochLastInput(context.Background(), app, input) require.Nil(err) require.False(result) }) s.Run("FalseWhenNotLastInput", func() { require := s.Require() - advancer, app, repo := setupWithEpoch(EpochStatus_Closed) + env, app := setupWithEpoch(EpochStatus_Closed) - lastInput := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() - repo.GetLastInputReturn = lastInput + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() - result, err := advancer.isEpochLastInput(context.Background(), app, input) + result, err := env.service.isEpochLastInput(context.Background(), app, input) require.Nil(err) require.False(result) }) s.Run("ErrorWhenNilInput", func() { require := s.Require() - advancer, app, _ := setupWithEpoch(EpochStatus_Closed) + env, app := setupWithEpoch(EpochStatus_Closed) - _, err := advancer.isEpochLastInput(context.Background(), app, nil) + _, err := env.service.isEpochLastInput(context.Background(), app, nil) require.Error(err) require.Contains(err.Error(), "must not be nil") }) s.Run("ErrorWhenNilApplication", func() { require := s.Require() - advancer, _, _ := setupWithEpoch(EpochStatus_Closed) + env, _ := setupWithEpoch(EpochStatus_Closed) input := repotest.NewInputBuilder().WithIndex(0).Build() - _, err := advancer.isEpochLastInput(context.Background(), nil, input) + _, err := env.service.isEpochLastInput(context.Background(), nil, input) require.Error(err) require.Contains(err.Error(), "must not be nil") }) s.Run("ErrorWhenGetEpochFails", func() { require := s.Require() - advancer, app, repo := setupWithEpoch(EpochStatus_Closed) - repo.GetEpochError = errors.New("get epoch error") + env, app := setupWithEpoch(EpochStatus_Closed) + env.repo.GetEpochError = errors.New("get epoch error") input := repotest.NewInputBuilder().WithIndex(0).Build() - _, err := advancer.isEpochLastInput(context.Background(), app, input) + _, err := env.service.isEpochLastInput(context.Background(), app, input) require.Error(err) require.Contains(err.Error(), "get epoch error") }) s.Run("ErrorWhenGetLastInputFails", func() { require := s.Require() - advancer, app, repo := setupWithEpoch(EpochStatus_Closed) - repo.GetLastInputError = errors.New("get last input error") + env, app := setupWithEpoch(EpochStatus_Closed) + env.repo.GetLastInputError = errors.New("get last input error") input := repotest.NewInputBuilder().WithIndex(0).Build() - _, err := advancer.isEpochLastInput(context.Background(), app, input) + _, err := env.service.isEpochLastInput(context.Background(), app, input) require.Error(err) require.Contains(err.Error(), "get last input error") }) @@ -768,163 +621,103 @@ func (s *AdvancerSuite) TestIsEpochLastInput() { func (s *AdvancerSuite) TestHandleEpochAfterInputsProcessed() { s.Run("EmptyEpochIndex0GetsOutputsProofFromMachine", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - - // Epoch with no inputs (lower == upper) epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Nil(err) - require.True(repository.OutputsProofUpdated) + require.True(env.repo.OutputsProofUpdated) }) s.Run("EmptyEpochIndex0ErrorOnOutputsProof", func() { require := s.Require() - - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.OutputsProofError = errors.New("proof error") - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) + env := s.setupOneApp() + env.app.OutputsProofError = errors.New("proof error") epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Error(err) require.Contains(err.Error(), "proof error") }) s.Run("EmptyEpochIndexGt0RepeatsPreviousProof", func() { require := s.Require() - - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) + env := s.setupOneApp() epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Nil(err) - require.True(repository.RepeatOutputsProofCalled) + require.True(env.repo.RepeatOutputsProofCalled) }) s.Run("EmptyEpochIndexGt0RepeatError", func() { require := s.Require() - - machineManager := newMockMachineManager() - app := newMockMachine(1) - machineManager.Map[1] = *app - repository := &MockRepository{ - RepeatOutputsProofError: errors.New("repeat error"), - } - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) + env := s.setupOneApp() + env.repo.RepeatOutputsProofError = errors.New("repeat error") epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Error(err) require.Contains(err.Error(), "repeat error") }) s.Run("NonEmptyEpochWithEveryEpochSnapshotPolicy", func() { require := s.Require() + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + env.service.snapshotsDir = s.T().TempDir() - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - advancer.snapshotsDir = s.T().TempDir() - - // Epoch with inputs epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} - - // Provide a last processed input lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). WithStatus(InputCompletionStatus_Accepted).Build() - lastInput.EpochApplicationID = app.Application.ID - repository.GetLastProcessedInputReturn = lastInput - // isEpochLastInput needs GetLastInput to return the same input - repository.GetLastInputReturn = lastInput + lastInput.EpochApplicationID = env.app.Application.ID + env.repo.GetLastProcessedInputReturn = lastInput + env.repo.GetLastInputReturn = lastInput - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Nil(err) - // Verify snapshot was attempted (CreateSnapshot called on mock) - require.True(repository.SnapshotURIUpdated) + require.True(env.repo.SnapshotURIUpdated) }) s.Run("NonEmptyEpochNoSnapshotPolicy", func() { require := s.Require() - - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_None - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_None epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). WithStatus(InputCompletionStatus_Accepted).Build() - lastInput.EpochApplicationID = app.Application.ID - repository.GetLastProcessedInputReturn = lastInput + lastInput.EpochApplicationID = env.app.Application.ID + env.repo.GetLastProcessedInputReturn = lastInput - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Nil(err) - // No snapshot should be created with None policy - require.False(repository.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) s.Run("NoMachineReturnsError", func() { require := s.Require() - - machineManager := newMockMachineManager() - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) + mm := newMockMachineManager() + svc, err := newMockAdvancerService(mm, &MockRepository{}) require.Nil(err) app := repotest.NewApplicationBuilder().Build() app.ID = 999 - // Non-empty epoch: machine lookup epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app, epoch) + err = svc.handleEpochAfterInputsProcessed(context.Background(), app, epoch) require.Error(err) require.ErrorIs(err, ErrNoApp) }) s.Run("GetLastProcessedInputError", func() { require := s.Require() - - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch - machineManager.Map[1] = *app - repository := &MockRepository{ - GetLastProcessedInputError: errors.New("db connection lost"), - } - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + env.repo.GetLastProcessedInputError = errors.New("db connection lost") epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} - - err = advancer.handleEpochAfterInputsProcessed(context.Background(), app.Application, epoch) + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) require.Error(err) require.Contains(err.Error(), "db connection lost") }) @@ -935,94 +728,77 @@ func (s *AdvancerSuite) TestHandleEpochAfterInputsProcessed() { // --------------------------------------------------------------------------- func (s *AdvancerSuite) TestHandleSnapshot() { - setupSnapshot := func(policy SnapshotPolicy) (*Service, *Application, *MockMachineInstance, *MockRepository) { - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.Application.ExecutionParameters.SnapshotPolicy = policy - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - s.Require().Nil(err) - advancer.snapshotsDir = s.T().TempDir() - - mockInstance := &MockMachineInstance{ - application: app.Application, - machineImpl: app, - } - return advancer, app.Application, mockInstance, repository + setupSnapshot := func(policy SnapshotPolicy) (testEnv, *MockMachineInstance) { + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = policy + env.service.snapshotsDir = s.T().TempDir() + instance := env.mm.Map[1] + return env, instance } s.Run("NonePolicy", func() { require := s.Require() - advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_None) + env, machine := setupSnapshot(SnapshotPolicy_None) input := repotest.NewInputBuilder().WithIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.handleSnapshot(context.Background(), app, machine, input) + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.False(repo.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) s.Run("EveryInputPolicy", func() { require := s.Require() - advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryInput) + env, machine := setupSnapshot(SnapshotPolicy_EveryInput) input := repotest.NewInputBuilder().WithIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.handleSnapshot(context.Background(), app, machine, input) + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.True(repo.SnapshotURIUpdated) + require.True(env.repo.SnapshotURIUpdated) }) s.Run("EveryEpochPolicyLastInput", func() { require := s.Require() - advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) - - // Set up GetEpoch to return closed epoch - repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() - input.EpochApplicationID = app.ID - - // Last input in epoch matches - repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID - err := advancer.handleSnapshot(context.Background(), app, machine, input) + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.True(repo.SnapshotURIUpdated) + require.True(env.repo.SnapshotURIUpdated) }) s.Run("EveryEpochPolicyNotLastInput", func() { require := s.Require() - advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) - - repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() - input.EpochApplicationID = app.ID - - // Last input is a different one - repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID - err := advancer.handleSnapshot(context.Background(), app, machine, input) + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.False(repo.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) s.Run("EveryEpochPolicyOpenEpoch", func() { require := s.Require() - advancer, app, machine, repo := setupSnapshot(SnapshotPolicy_EveryEpoch) - - repo.GetEpochReturn = &Epoch{Status: EpochStatus_Open} + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Open} input := repotest.NewInputBuilder().WithIndex(0).WithEpochIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.handleSnapshot(context.Background(), app, machine, input) + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.False(repo.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) } @@ -1031,38 +807,30 @@ func (s *AdvancerSuite) TestHandleSnapshot() { // --------------------------------------------------------------------------- func (s *AdvancerSuite) TestCreateSnapshot() { - setupCreateSnapshot := func() (*Service, *Application, *MockMachineInstance, *MockRepository, string) { - machineManager := newMockMachineManager() - app := newMockMachine(1) - app.Application.Name = "testapp" - app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryInput - machineManager.Map[1] = *app - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - s.Require().Nil(err) - + setupCreateSnapshot := func() (testEnv, *MockMachineInstance, string) { + env := s.setupOneApp() + env.app.Application.Name = "testapp" + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryInput tmpDir := s.T().TempDir() - advancer.snapshotsDir = tmpDir - - mockInstance := &MockMachineInstance{ - application: app.Application, - machineImpl: app, + env.service.snapshotsDir = tmpDir + instance := &MockMachineInstance{ + application: env.app.Application, + machineImpl: env.app, } - return advancer, app.Application, mockInstance, repository, tmpDir + return env, instance, tmpDir } s.Run("Success", func() { require := s.Require() - advancer, app, machine, repo, tmpDir := setupCreateSnapshot() + env, machine, tmpDir := setupCreateSnapshot() input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(1).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.createSnapshot(context.Background(), app, machine, input) + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.True(repo.SnapshotURIUpdated) + require.True(env.repo.SnapshotURIUpdated) - // Verify the snapshot path was set correctly require.NotNil(input.SnapshotURI) expectedPath := filepath.Join(tmpDir, "testapp_epoch1_input3") require.Equal(expectedPath, *input.SnapshotURI) @@ -1070,55 +838,48 @@ func (s *AdvancerSuite) TestCreateSnapshot() { s.Run("SkipsIfAlreadyHasSnapshot", func() { require := s.Require() - advancer, app, machine, repo, _ := setupCreateSnapshot() + env, machine, _ := setupCreateSnapshot() existingPath := "/existing/snapshot" input := repotest.NewInputBuilder().WithIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID input.SnapshotURI = &existingPath - err := advancer.createSnapshot(context.Background(), app, machine, input) + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.False(repo.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) s.Run("RemovesPreviousSnapshot", func() { require := s.Require() - advancer, app, machine, repo, tmpDir := setupCreateSnapshot() + env, machine, tmpDir := setupCreateSnapshot() - // Create a previous snapshot directory to be cleaned up prevPath := filepath.Join(tmpDir, "testapp_epoch0_input0") require.Nil(os.MkdirAll(prevPath, 0755)) - - prevInput := &Input{ - SnapshotURI: &prevPath, - } - repo.GetLastSnapshotReturn = prevInput + env.repo.GetLastSnapshotReturn = &Input{SnapshotURI: &prevPath} input := repotest.NewInputBuilder().WithIndex(1).WithEpochIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.createSnapshot(context.Background(), app, machine, input) + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - // Verify previous snapshot was removed _, statErr := os.Stat(prevPath) require.True(os.IsNotExist(statErr)) }) s.Run("CreateSnapshotError", func() { require := s.Require() - advancer, app, machine, repo, _ := setupCreateSnapshot() - + env, machine, _ := setupCreateSnapshot() machine.createSnapshotError = errors.New("snapshot failed") input := repotest.NewInputBuilder().WithIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.createSnapshot(context.Background(), app, machine, input) + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) require.Error(err) require.Contains(err.Error(), "snapshot failed") - require.False(repo.SnapshotURIUpdated) + require.False(env.repo.SnapshotURIUpdated) }) s.Run("MkdirAllError", func() { @@ -1127,7 +888,7 @@ func (s *AdvancerSuite) TestCreateSnapshot() { machineManager := newMockMachineManager() app := newMockMachine(1) app.Application.Name = "testapp" - machineManager.Map[1] = *app + machineManager.Map[1] = newMockInstance(app) repository := &MockRepository{} advancer, err := newMockAdvancerService(machineManager, repository) require.Nil(err) @@ -1137,32 +898,27 @@ func (s *AdvancerSuite) TestCreateSnapshot() { readonlyDir := filepath.Join(tmpDir, "readonly") require.Nil(os.MkdirAll(readonlyDir, 0755)) require.Nil(os.Chmod(readonlyDir, 0555)) - s.T().Cleanup(func() { os.Chmod(readonlyDir, 0755) }) //nolint: errcheck + s.T().Cleanup(func() { os.Chmod(readonlyDir, 0755) }) //nolint:errcheck advancer.snapshotsDir = filepath.Join(readonlyDir, "snapshots") - mockInstance := &MockMachineInstance{ - application: app.Application, - machineImpl: app, - } - input := repotest.NewInputBuilder().WithIndex(0).Build() input.EpochApplicationID = app.Application.ID - err = advancer.createSnapshot(context.Background(), app.Application, mockInstance, input) + err = advancer.createSnapshot(context.Background(), app.Application, machineManager.Map[1], input) require.Error(err) require.Contains(err.Error(), "failed to create snapshots directory") }) s.Run("UpdateSnapshotURIError", func() { require := s.Require() - advancer, app, machine, repo, _ := setupCreateSnapshot() + env, machine, _ := setupCreateSnapshot() - repo.UpdateSnapshotURIError = errors.New("db error") + env.repo.UpdateSnapshotURIError = errors.New("db error") input := repotest.NewInputBuilder().WithIndex(0).Build() - input.EpochApplicationID = app.ID + input.EpochApplicationID = env.app.Application.ID - err := advancer.createSnapshot(context.Background(), app, machine, input) + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) require.Error(err) require.Contains(err.Error(), "db error") }) @@ -1336,33 +1092,33 @@ func newMockMachine(id int64) *MockMachineImpl { } } +// newMockInstance creates a MockMachineInstance from a MockMachineImpl, ready to store in MockMachineManager.Map. +func newMockInstance(impl *MockMachineImpl) *MockMachineInstance { + return &MockMachineInstance{ + application: impl.Application, + machineImpl: impl, + } +} + // ------------------------------------------------------------------------------------------------ type MockMachineManager struct { - Map map[int64]MockMachineImpl + Map map[int64]*MockMachineInstance UpdateMachinesError error } func newMockMachineManager() *MockMachineManager { return &MockMachineManager{ - Map: map[int64]MockMachineImpl{}, + Map: map[int64]*MockMachineInstance{}, } } func (mock *MockMachineManager) GetMachine(appID int64) (manager.MachineInstance, bool) { - machine, exists := mock.Map[appID] + instance, exists := mock.Map[appID] if !exists { return nil, false } - - // For testing purposes, we'll create a mock MachineInstance - // that has the same Application but delegates the methods to our mock - mockInstance := &MockMachineInstance{ - application: machine.Application, - machineImpl: &machine, - } - - return mockInstance, true + return instance, true } func (mock *MockMachineManager) UpdateMachines(ctx context.Context) error { @@ -1372,7 +1128,7 @@ func (mock *MockMachineManager) UpdateMachines(ctx context.Context) error { func (mock *MockMachineManager) Applications() []*Application { apps := make([]*Application, 0, len(mock.Map)) for _, v := range mock.Map { - apps = append(apps, v.Application) + apps = append(apps, v.application) } return apps } diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 8c54c203a..168703a83 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -67,7 +67,6 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { // Create the machine manager manager := manager.NewMachineManager( - ctx, c.Repository, s.Logger, c.Config.FeatureMachineHashCheckEnabled, diff --git a/internal/manager/manager.go b/internal/manager/manager.go index e783a1a15..efb0cadfc 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -13,6 +13,7 @@ import ( . "github.com/cartesi/rollups-node/internal/model" "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" ) var ( @@ -33,32 +34,70 @@ type MachineRepository interface { GetLastSnapshot(ctx context.Context, nameOrAddress string) (*Input, error) } +// MachineInstanceFactory creates MachineInstance values from applications. +// Implementations decide whether to load from a template or snapshot. +type MachineInstanceFactory interface { + NewFromTemplate(ctx context.Context, app *Application, logger *slog.Logger, checkHash bool) (MachineInstance, error) + NewFromSnapshot(ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, + snapshotPath string, machineHash *common.Hash, inputIndex uint64) (MachineInstance, error) +} + +// DefaultMachineInstanceFactory delegates to NewMachineInstance / NewMachineInstanceFromSnapshot. +type DefaultMachineInstanceFactory struct{} + +func (f *DefaultMachineInstanceFactory) NewFromTemplate( + ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, +) (MachineInstance, error) { + return NewMachineInstance(ctx, app, logger, checkHash) +} + +func (f *DefaultMachineInstanceFactory) NewFromSnapshot( + ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, + snapshotPath string, machineHash *common.Hash, inputIndex uint64, +) (MachineInstance, error) { + return NewMachineInstanceFromSnapshot(ctx, app, logger, checkHash, snapshotPath, machineHash, inputIndex) +} + // MachineManager manages the lifecycle of machine instances for applications type MachineManager struct { - mutex sync.RWMutex - machines map[int64]MachineInstance - closed bool - repository MachineRepository - checkHash bool - inputBatchSize uint64 - logger *slog.Logger + mutex sync.RWMutex + machines map[int64]MachineInstance + closed bool + repository MachineRepository + checkHash bool + inputBatchSize uint64 + logger *slog.Logger + instanceFactory MachineInstanceFactory +} + +// Option configures a MachineManager. +type Option func(*MachineManager) + +// WithInstanceFactory overrides the default MachineInstanceFactory. +func WithInstanceFactory(f MachineInstanceFactory) Option { + return func(m *MachineManager) { m.instanceFactory = f } } -// NewMachineManager creates a new machine manager +// NewMachineManager creates a new machine manager. func NewMachineManager( - ctx context.Context, repo MachineRepository, logger *slog.Logger, checkHash bool, inputBatchSize uint64, + opts ...Option, ) *MachineManager { - return &MachineManager{ - machines: map[int64]MachineInstance{}, - repository: repo, - checkHash: checkHash, - inputBatchSize: inputBatchSize, - logger: logger, + m := &MachineManager{ + machines: map[int64]MachineInstance{}, + repository: repo, + checkHash: checkHash, + inputBatchSize: inputBatchSize, + logger: logger, + instanceFactory: &DefaultMachineInstanceFactory{}, + } + for _, opt := range opts { + opt(m) } + return m } // UpdateMachines refreshes the list of machines based on enabled applications @@ -77,7 +116,7 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { m.logger.Info("Creating new machine instance", "application", app.Name, - "address", app.IApplicationAddress.String()) + "address", app.IApplicationAddress) // Check if we have a snapshot to load from var instance MachineInstance @@ -98,7 +137,7 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { "application", app.Name, "snapshot", *snapshot.SnapshotURI) - instance, err = NewMachineInstanceFromSnapshot( + instance, err = m.instanceFactory.NewFromSnapshot( ctx, app, m.logger, m.checkHash, *snapshot.SnapshotURI, snapshot.MachineHash, snapshot.Index) if err != nil { @@ -122,10 +161,10 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { // Fall back to template if snapshot loading failed or was unavailable if instance == nil { - instance, err = NewMachineInstance(ctx, app, m.logger, m.checkHash) + instance, err = m.instanceFactory.NewFromTemplate(ctx, app, m.logger, m.checkHash) if err != nil { m.logger.Error("Failed to create machine instance", - "application", app.IApplicationAddress.String(), + "application", app.IApplicationAddress, "error", err) continue } diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index e6aa34fcc..660bde832 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -29,7 +29,7 @@ func (s *MachineManagerSuite) TestNewMachineManager() { require := s.Require() repo := &MockMachineRepository{} testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + manager := NewMachineManager(repo, testLogger, false, 500) require.NotNil(manager) require.Empty(manager.machines) require.Equal(repo, manager.repository) @@ -64,25 +64,15 @@ func (s *MachineManagerSuite) TestUpdateMachines() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - // Create manager with a test logger + // Create manager with a mock instance factory testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + mockInstance := &DummyMachineInstanceMock{application: app1} + factory := &MockMachineInstanceFactory{Instance: mockInstance} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) - // Create a mock factory for testing - mockRuntime := &MockRollupsMachine{} - mockFactory := &MockMachineRuntimeFactory{ - RuntimeToReturn: mockRuntime, - ErrorToReturn: nil, - } - - // Replace the default factory with our mock - originalFactory := defaultFactory - defaultFactory = mockFactory - defer func() { defaultFactory = originalFactory }() - - // This test should now succeed with our mock err := manager.UpdateMachines(context.Background()) require.NoError(err) + require.True(manager.HasMachine(1)) repo.AssertCalled(s.T(), "ListApplications", mock.Anything, mock.Anything, mock.Anything, false) }) @@ -97,7 +87,7 @@ func (s *MachineManagerSuite) TestUpdateMachines() { // Create a test logger testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + manager := NewMachineManager(repo, testLogger, false, 500) // Add mock machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -130,7 +120,7 @@ func (s *MachineManagerSuite) TestGetMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false, 500) + manager := NewMachineManager(repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -153,7 +143,7 @@ func (s *MachineManagerSuite) TestHasMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false, 500) + manager := NewMachineManager(repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -173,7 +163,7 @@ func (s *MachineManagerSuite) TestAddMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false, 500) + manager := NewMachineManager(repo, nil, false, 500) machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} machine2 := &DummyMachineInstanceMock{application: &model.Application{ID: 2}} @@ -204,7 +194,7 @@ func (s *MachineManagerSuite) TestAddMachine() { func (s *MachineManagerSuite) TestRemoveDisabledMachines() { require := s.Require() - manager := NewMachineManager(context.Background(), nil, nil, false, 500) + manager := NewMachineManager(nil, nil, false, 500) // Add machines app1 := &model.Application{ID: 1} @@ -238,7 +228,7 @@ func (s *MachineManagerSuite) TestUpdateMachinesErrors() { Return(([]*model.Application)(nil), uint64(0), errors.New("db error")) testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + manager := NewMachineManager(repo, testLogger, false, 500) err := manager.UpdateMachines(context.Background()) require.Error(err) @@ -275,23 +265,14 @@ func (s *MachineManagerSuite) TestUpdateMachinesErrors() { repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). Return([]*model.Input{}, uint64(0), nil) + // The snapshot path doesn't exist, so it should fall back to template testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) - - // Mock factory that always succeeds — the snapshot path doesn't exist, - // so it should fall back to template via defaultFactory - mockRuntime := &MockRollupsMachine{} - mockFactory := &MockMachineRuntimeFactory{ - RuntimeToReturn: mockRuntime, - ErrorToReturn: nil, - } - originalFactory := defaultFactory - defaultFactory = mockFactory - defer func() { defaultFactory = originalFactory }() + mockInstance := &DummyMachineInstanceMock{application: app} + factory := &MockMachineInstanceFactory{Instance: mockInstance} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) err := manager.UpdateMachines(context.Background()) require.NoError(err) - // Machine should have been created via fallback require.True(manager.HasMachine(1)) }) @@ -317,16 +298,8 @@ func (s *MachineManagerSuite) TestUpdateMachinesErrors() { Return(nil, nil) testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) - - // Factory that always fails - mockFactory := &MockMachineRuntimeFactory{ - RuntimeToReturn: nil, - ErrorToReturn: errors.New("machine creation failed"), - } - originalFactory := defaultFactory - defaultFactory = mockFactory - defer func() { defaultFactory = originalFactory }() + factory := &MockMachineInstanceFactory{Err: errors.New("machine creation failed")} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) err := manager.UpdateMachines(context.Background()) // UpdateMachines should not return an error; it logs and skips @@ -355,22 +328,21 @@ func (s *MachineManagerSuite) TestUpdateMachinesErrors() { Return([]*model.Application{app}, uint64(1), nil) repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - // ListInputs returns an error to cause Synchronize to fail + // ListInputs returns an error so the real Synchronize method propagates the failure. repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). Return(([]*model.Input)(nil), uint64(0), errors.New("db connection lost")) testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false, 500) + // Use a factory that builds a real MachineInstanceImpl (with a mock runtime) + // so that Synchronize actually runs and hits the repo. mockRuntime := &MockRollupsMachine{} - mockRuntime.CloseError = nil - mockFactory := &MockMachineRuntimeFactory{ + runtimeFactory := &MockMachineRuntimeFactory{ RuntimeToReturn: mockRuntime, ErrorToReturn: nil, } - originalFactory := defaultFactory - defaultFactory = mockFactory - defer func() { defaultFactory = originalFactory }() + realFactory := &realMachineInstanceFactory{runtimeFactory: runtimeFactory} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(realFactory)) err := manager.UpdateMachines(context.Background()) require.NoError(err) @@ -382,7 +354,7 @@ func (s *MachineManagerSuite) TestUpdateMachinesErrors() { func (s *MachineManagerSuite) TestCloseAggregatesErrors() { require := s.Require() - manager := NewMachineManager(context.Background(), nil, nil, false, 500) + manager := NewMachineManager(nil, nil, false, 500) machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} machine2 := &DummyMachineInstanceMock{ @@ -411,7 +383,7 @@ func (s *MachineManagerSuite) TestApplications() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false, 500) + manager := NewMachineManager(repo, nil, false, 500) // Add machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -477,10 +449,54 @@ func (m *MockMachineRepository) GetLastSnapshot( // ------------------------------------------------------------------------------------------------ +// MockMachineInstanceFactory implements MachineInstanceFactory for testing. +// It returns the same instance for every call, ignoring the app/path arguments. +type MockMachineInstanceFactory struct { + Instance MachineInstance + Err error +} + +func (f *MockMachineInstanceFactory) NewFromTemplate( + _ context.Context, _ *model.Application, _ *slog.Logger, _ bool, +) (MachineInstance, error) { + return f.Instance, f.Err +} + +func (f *MockMachineInstanceFactory) NewFromSnapshot( + _ context.Context, _ *model.Application, _ *slog.Logger, _ bool, + _ string, _ *common.Hash, _ uint64, +) (MachineInstance, error) { + return f.Instance, f.Err +} + +// realMachineInstanceFactory builds real MachineInstanceImpl values using the +// provided MachineRuntimeFactory. This lets tests exercise the real Synchronize +// path while still mocking the machine runtime. Unlike MockMachineInstanceFactory, +// snapshot path and hash are ignored — it always creates from the runtime factory. +type realMachineInstanceFactory struct { + runtimeFactory MachineRuntimeFactory +} + +func (f *realMachineInstanceFactory) NewFromTemplate( + ctx context.Context, app *model.Application, logger *slog.Logger, checkHash bool, +) (MachineInstance, error) { + return NewMachineInstanceWithFactory(ctx, app, 0, logger, checkHash, f.runtimeFactory) +} + +func (f *realMachineInstanceFactory) NewFromSnapshot( + ctx context.Context, app *model.Application, logger *slog.Logger, checkHash bool, + _ string, _ *common.Hash, inputIndex uint64, +) (MachineInstance, error) { + return NewMachineInstanceWithFactory(ctx, app, inputIndex+1, logger, checkHash, f.runtimeFactory) +} + +// ------------------------------------------------------------------------------------------------ + // DummyMachineInstanceMock implements the MachineInstance interface for testing type DummyMachineInstanceMock struct { - application *model.Application - closeError error + application *model.Application + closeError error + synchronizeErr error } func (m *DummyMachineInstanceMock) Application() *model.Application { @@ -504,7 +520,7 @@ func (m *DummyMachineInstanceMock) Inspect(_ context.Context, _ []byte) (*model. } func (m *DummyMachineInstanceMock) Synchronize(_ context.Context, _ MachineRepository, _ uint64) error { - return nil + return m.synchronizeErr } func (m *DummyMachineInstanceMock) CreateSnapshot(_ context.Context, _ uint64, _ string) error {