diff --git a/Makefile b/Makefile index 56056f0fb..aeb168818 100644 --- a/Makefile +++ b/Makefile @@ -134,6 +134,7 @@ env: @echo export CARTESI_TEST_DATABASE_CONNECTION="postgres://test_user:password@localhost:5432/test_rollupsdb?sslmode=disable" @echo export CARTESI_TEST_MACHINE_IMAGES_PATH=\"$(CARTESI_TEST_MACHINE_IMAGES_PATH)\" @echo export PATH=\"$(CURDIR):$$PATH\" + @$(if $(MACOSX_DEPLOYMENT_TARGET),echo export MACOSX_DEPLOYMENT_TARGET=\"$(MACOSX_DEPLOYMENT_TARGET)\") # ============================================================================= # Artifacts diff --git a/cmd/cartesi-rollups-cli/root/read/read.go b/cmd/cartesi-rollups-cli/root/read/read.go index 8fe5fd33a..bd9053ece 100644 --- a/cmd/cartesi-rollups-cli/root/read/read.go +++ b/cmd/cartesi-rollups-cli/root/read/read.go @@ -23,7 +23,7 @@ var Cmd = &cobra.Command{ Short: "Read the node state from the database", PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if !cmd.Flags().Changed("jsonrpc") && cmd.Flags().Changed("jsonrpc-api-url") { - if err:= cmd.Flags().Set("jsonrpc", "true"); err != nil { + if err := cmd.Flags().Set("jsonrpc", "true"); err != nil { return err } } diff --git a/internal/advancer/advancer.go b/internal/advancer/advancer.go index 22ef67c08..4b309b932 100644 --- a/internal/advancer/advancer.go +++ b/internal/advancer/advancer.go @@ -8,7 +8,7 @@ import ( "errors" "fmt" "os" - "path" + "path/filepath" "strings" "github.com/cartesi/rollups-node/internal/manager" @@ -45,10 +45,16 @@ func getUnprocessedEpochs(ctx context.Context, er AdvancerRepository, address st return er.ListEpochs(ctx, address, f, repository.Pagination{}, false) } -// getUnprocessedInputs retrieves inputs that haven't been processed yet -func getUnprocessedInputs(ctx context.Context, repo AdvancerRepository, appAddress string, epochIndex uint64) ([]*Input, uint64, error) { +// getUnprocessedInputs retrieves inputs that haven't been processed yet with pagination support. +func getUnprocessedInputs( + ctx context.Context, + repo AdvancerRepository, + appAddress string, + epochIndex uint64, + batchSize uint64, +) ([]*Input, uint64, error) { f := repository.InputFilter{Status: Pointer(InputCompletionStatus_None), EpochIndex: &epochIndex} - return repo.ListInputs(ctx, appAddress, f, repository.Pagination{}, false) + return repo.ListInputs(ctx, appAddress, f, repository.Pagination{Limit: batchSize}, false) } // Step performs one processing cycle of the advancer @@ -78,17 +84,7 @@ func (s *Service) Step(ctx context.Context) error { } for _, epoch := range epochs { - // Get unprocessed inputs for this application - s.Logger.Debug("Querying for unprocessed inputs", "application", app.Name, "epoch_index", epoch.Index) - inputs, _, err := getUnprocessedInputs(ctx, s.repository, appAddress, epoch.Index) - if err != nil { - return err - } - - // Process the inputs - s.Logger.Debug("Processing inputs", "application", app.Name, "epoch_index", epoch.Index, "count", len(inputs)) - err = s.processInputs(ctx, app, inputs) - if err != nil { + if err := s.processEpochInputs(ctx, app, epoch.Index); err != nil { return err } @@ -117,6 +113,26 @@ func (s *Service) Step(ctx context.Context) error { return nil } +// processEpochInputs fetches and processes unprocessed inputs for an epoch in batches. +// Processed inputs change status and drop out of the filter, so each batch fetches from offset 0. +func (s *Service) processEpochInputs(ctx context.Context, app *Application, epochIndex uint64) error { + appAddress := app.IApplicationAddress.String() + for { + inputs, _, err := getUnprocessedInputs(ctx, s.repository, appAddress, epochIndex, s.inputBatchSize) + if err != nil { + return err + } + if len(inputs) == 0 { + return nil + } + s.Logger.Debug("Processing inputs", + "application", app.Name, "epoch_index", epochIndex, "count", len(inputs)) + if err := s.processInputs(ctx, app, inputs); err != nil { + return err + } + } +} + func (s *Service) isAllEpochInputsProcessed(app *Application, epoch *Epoch) (bool, error) { // epoch has no inputs if epoch.InputIndexLowerBound == epoch.InputIndexUpperBound { @@ -159,6 +175,7 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] // Advance the machine with this input result, err := machine.Advance(ctx, input.RawData, input.EpochIndex, input.Index, app.IsDaveConsensus()) + input.RawData = nil // allow GC to collect payload while batch continues if err != nil { // If there's an error, mark the application as inoperable s.Logger.Error("Error executing advance", @@ -179,6 +196,17 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] "error", updateErr) } + // Eagerly close the machine to release the child process. + // The app is already inoperable, so no further operations will succeed. + // Skip if the runtime was already destroyed inside the manager. + if !errors.Is(err, manager.ErrMachineClosed) { + if closeErr := machine.Close(); closeErr != nil { + s.Logger.Warn("Failed to close machine after advance error", + "application", app.Name, + "error", closeErr) + } + } + return err } // log advance result hashes @@ -196,10 +224,18 @@ func (s *Service) processInputs(ctx context.Context, app *Application, inputs [] // Store the result in the database err = s.repository.StoreAdvanceResult(ctx, input.EpochApplicationID, result) if err != nil { - s.Logger.Error("Failed to store advance result", + // Machine state is now ahead of the database. This desync is + // unrecoverable without a restart — regardless of whether the + // failure was a DB error or a context timeout. Shut down the + // node so it can restart cleanly from the last snapshot. + s.Logger.Error( + "FATAL: failed to store advance result after machine state "+ + "was already updated — shutting down to prevent permanent desync", "application", app.Name, + "epoch", input.EpochIndex, "index", input.Index, "error", err) + s.Cancel() // triggers graceful shutdown of all services return err } @@ -281,8 +317,15 @@ func (s *Service) handleEpochAfterInputsProcessed(ctx context.Context, app *Appl if !exists { return fmt.Errorf("%w: %d", ErrNoApp, app.ID) } - outputsProof, err := machine.OutputsProof(ctx, 0) + outputsProof, err := machine.OutputsProof(ctx) if err != nil { + // If the runtime was destroyed (e.g., child process crashed), + // mark the app inoperable to avoid an infinite retry loop. + if errors.Is(err, manager.ErrMachineClosed) { + reason := err.Error() + _ = s.repository.UpdateApplicationState(ctx, app.ID, + ApplicationState_Inoperable, &reason) + } return fmt.Errorf("failed to get outputs proof from machine: %w", err) } err = s.repository.UpdateEpochOutputsProof(ctx, app.ID, epoch.Index, outputsProof) @@ -342,7 +385,7 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine // Generate a snapshot path with a simpler structure // Use app name and input index only, avoiding deep directory nesting snapshotName := fmt.Sprintf("%s_epoch%d_input%d", app.Name, input.EpochIndex, input.Index) - snapshotPath := path.Join(s.snapshotsDir, snapshotName) + snapshotPath := filepath.Join(s.snapshotsDir, snapshotName) s.Logger.Info("Creating snapshot", "application", app.Name, @@ -351,10 +394,8 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine "path", snapshotPath) // Ensure the parent directory exists - if _, err := os.Stat(s.snapshotsDir); os.IsNotExist(err) { - if err := os.MkdirAll(s.snapshotsDir, 0755); err != nil { //nolint: mnd - return fmt.Errorf("failed to create snapshots directory: %w", err) - } + if err := os.MkdirAll(s.snapshotsDir, 0755); err != nil { //nolint: mnd + return fmt.Errorf("failed to create snapshots directory: %w", err) } // Create the snapshot @@ -363,6 +404,16 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine return err } + // Get previous snapshot BEFORE writing the new one so the query does not + // return the snapshot we just created — that would cause self-deletion. + previousSnapshot, err := s.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) + if err != nil { + s.Logger.Error("Failed to get previous snapshot", + "application", app.Name, + "error", err) + // Continue even if we can't get the previous snapshot + } + // Update the input record with the snapshot URI input.SnapshotURI = &snapshotPath @@ -372,18 +423,8 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine return fmt.Errorf("failed to update input snapshot URI: %w", err) } - // Get previous snapshot if it exists - previousSnapshot, err := s.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) - if err != nil { - s.Logger.Error("Failed to get previous snapshot", - "application", app.Name, - "error", err) - // Continue even if we can't get the previous snapshot - } - // Remove previous snapshot if it exists - if previousSnapshot != nil && previousSnapshot.Index != input.Index && previousSnapshot.SnapshotURI != nil { - // Only remove if it's a different snapshot than the one we just created + if previousSnapshot != nil && previousSnapshot.SnapshotURI != nil { if err := s.removeSnapshot(*previousSnapshot.SnapshotURI, app.Name); err != nil { s.Logger.Error("Failed to remove previous snapshot", "application", app.Name, @@ -398,8 +439,11 @@ func (s *Service) createSnapshot(ctx context.Context, app *Application, machine // removeSnapshot safely removes a previous snapshot func (s *Service) removeSnapshot(snapshotPath string, appName string) error { - // Safety check: ensure the path contains the application name and is in the snapshots directory - if !strings.HasPrefix(snapshotPath, s.snapshotsDir) || !strings.Contains(snapshotPath, appName) { + // Safety check: canonicalize paths to prevent directory traversal via ".." sequences + cleanPath := filepath.Clean(snapshotPath) + cleanDir := filepath.Clean(s.snapshotsDir) + if !strings.HasPrefix(cleanPath, cleanDir+string(filepath.Separator)) || + !strings.HasPrefix(filepath.Base(cleanPath), appName+"_") { return fmt.Errorf("invalid snapshot path: %s", snapshotPath) } diff --git a/internal/advancer/advancer_test.go b/internal/advancer/advancer_test.go index f1699bdcb..eb390e705 100644 --- a/internal/advancer/advancer_test.go +++ b/internal/advancer/advancer_test.go @@ -10,6 +10,8 @@ import ( "errors" "fmt" mrand "math/rand" + "os" + "path/filepath" "sync" "testing" "time" @@ -17,6 +19,7 @@ import ( "github.com/cartesi/rollups-node/internal/manager" . "github.com/cartesi/rollups-node/internal/model" "github.com/cartesi/rollups-node/internal/repository" + "github.com/cartesi/rollups-node/internal/repository/repotest" "github.com/cartesi/rollups-node/pkg/service" "github.com/ethereum/go-ethereum/common" @@ -31,6 +34,7 @@ type AdvancerSuite struct{ suite.Suite } func newMockAdvancerService(machineManager *MockMachineManager, repo *MockRepository) (*Service, error) { s := &Service{ + inputBatchSize: 500, machineManager: machineManager, repository: repo, } @@ -42,6 +46,27 @@ func newMockAdvancerService(machineManager *MockMachineManager, repo *MockReposi return s, nil } +// testEnv bundles the components most tests need: a service, a single app's mock, +// the mock machine manager, and the mock repository. +type testEnv struct { + service *Service + app *MockMachineImpl + mm *MockMachineManager + repo *MockRepository +} + +// setupOneApp creates a standard test environment with one application. +// The repository is empty; callers can configure it after the call. +func (s *AdvancerSuite) setupOneApp() testEnv { + mm := newMockMachineManager() + app := newMockMachine(1) + mm.Map[1] = newMockInstance(app) + repo := &MockRepository{} + svc, err := newMockAdvancerService(mm, repo) + s.Require().NoError(err) + return testEnv{service: svc, app: app, mm: mm, repo: repo} +} + func (s *AdvancerSuite) TestServiceInterface() { s.Run("ServiceMethods", func() { require := s.Require() @@ -60,9 +85,9 @@ func (s *AdvancerSuite) TestServiceInterface() { require.Equal(advancer.Name, advancer.String()) // Test Tick method - machineManager.Map[1] = *newMockMachine(1) + machineManager.Map[1] = newMockInstance(newMockMachine(1)) repository.GetEpochsReturn = map[common.Address][]*Epoch{ - machineManager.Map[1].Application.IApplicationAddress: {}, + machineManager.Map[1].application.IApplicationAddress: {}, } tickErrors := advancer.Tick() require.Empty(tickErrors) @@ -82,11 +107,11 @@ func (s *AdvancerSuite) TestStep() { machineManager := newMockMachineManager() app1 := newMockMachine(1) app2 := newMockMachine(2) - machineManager.Map[1] = *app1 - machineManager.Map[2] = *app2 + machineManager.Map[1] = newMockInstance(app1) + machineManager.Map[2] = newMockInstance(app2) + res0 := randomAdvanceResult(0) res1 := randomAdvanceResult(1) - res2 := randomAdvanceResult(2) - res3 := randomAdvanceResult(3) + res2 := randomAdvanceResult(0) repository := &MockRepository{ GetEpochsReturn: map[common.Address][]*Epoch{ @@ -99,11 +124,11 @@ func (s *AdvancerSuite) TestStep() { }, GetInputsReturn: map[common.Address][]*Input{ app1.Application.IApplicationAddress: { - newInput(app1.Application.ID, 0, 0, marshal(res1)), - newInput(app1.Application.ID, 0, 1, marshal(res2)), + newInput(app1.Application.ID, 0, 0, marshal(res0)), + newInput(app1.Application.ID, 0, 1, marshal(res1)), }, app2.Application.IApplicationAddress: { - newInput(app2.Application.ID, 0, 0, marshal(res3)), + newInput(app2.Application.ID, 0, 0, marshal(res2)), }, }, } @@ -120,31 +145,22 @@ func (s *AdvancerSuite) TestStep() { s.Run("Error/UpdateEpochs", func() { require := s.Require() + env := s.setupOneApp() + res0 := randomAdvanceResult(0) - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - res1 := randomAdvanceResult(1) - - repository := &MockRepository{ - GetEpochsReturn: map[common.Address][]*Epoch{ - app1.Application.IApplicationAddress: { - &Epoch{Index: 0, Status: EpochStatus_Closed}, - }, + env.repo.GetEpochsReturn = map[common.Address][]*Epoch{ + env.app.Application.IApplicationAddress: { + {Index: 0, Status: EpochStatus_Closed}, }, - GetInputsReturn: map[common.Address][]*Input{ - app1.Application.IApplicationAddress: { - newInput(app1.Application.ID, 0, 0, marshal(res1)), - }, + } + env.repo.GetInputsReturn = map[common.Address][]*Input{ + env.app.Application.IApplicationAddress: { + newInput(env.app.Application.ID, 0, 0, marshal(res0)), }, - UpdateEpochsError: errors.New("update epochs error"), } + env.repo.UpdateEpochsError = errors.New("update epochs error") - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - err = advancer.Step(context.Background()) + err := env.service.Step(context.Background()) require.Error(err) require.Contains(err.Error(), "update epochs error") }) @@ -153,7 +169,7 @@ func (s *AdvancerSuite) TestStep() { require := s.Require() machineManager := &MockMachineManager{ - Map: map[int64]MockMachineImpl{}, + Map: map[int64]*MockMachineInstance{}, UpdateMachinesError: errors.New("update machines error"), } repository := &MockRepository{} @@ -169,49 +185,31 @@ func (s *AdvancerSuite) TestStep() { s.Run("Error/GetInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - repository := &MockRepository{ - GetEpochsReturn: map[common.Address][]*Epoch{ - app1.Application.IApplicationAddress: { - &Epoch{Index: 0, Status: EpochStatus_Closed}, - }, + env.repo.GetEpochsReturn = map[common.Address][]*Epoch{ + env.app.Application.IApplicationAddress: { + {Index: 0, Status: EpochStatus_Closed}, }, - GetInputsError: errors.New("get inputs error"), } + env.repo.GetInputsError = errors.New("get inputs error") - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - err = advancer.Step(context.Background()) + err := env.service.Step(context.Background()) require.Error(err) require.Contains(err.Error(), "get inputs error") }) s.Run("NoInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - repository := &MockRepository{ - GetInputsReturn: map[common.Address][]*Input{ - app1.Application.IApplicationAddress: {}, - }, + env.repo.GetInputsReturn = map[common.Address][]*Input{ + env.app.Application.IApplicationAddress: {}, } - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - err = advancer.Step(context.Background()) + err := env.service.Step(context.Background()) require.Nil(err) - require.Len(repository.StoredResults, 0) + require.Len(env.repo.StoredResults, 0) }) } @@ -225,13 +223,14 @@ func (s *AdvancerSuite) TestGetUnprocessedInputs() { newInput(app1.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), } - repository := &MockRepository{ + repo := &MockRepository{ GetInputsReturn: map[common.Address][]*Input{ app1.Application.IApplicationAddress: inputs, }, } - result, count, err := getUnprocessedInputs(context.Background(), repository, app1.Application.IApplicationAddress.String(), 0) + result, count, err := getUnprocessedInputs( + context.Background(), repo, app1.Application.IApplicationAddress.String(), 0, 500) require.Nil(err) require.Equal(uint64(2), count) require.Equal(inputs, result) @@ -241,88 +240,70 @@ func (s *AdvancerSuite) TestGetUnprocessedInputs() { require := s.Require() app1 := newMockMachine(1) - repository := &MockRepository{ + repo := &MockRepository{ GetInputsError: errors.New("list inputs error"), } - _, _, err := getUnprocessedInputs(context.Background(), repository, app1.Application.IApplicationAddress.String(), 0) + _, _, err := getUnprocessedInputs( + context.Background(), repo, app1.Application.IApplicationAddress.String(), 0, 500) require.Error(err) require.Contains(err.Error(), "list inputs error") }) } func (s *AdvancerSuite) TestProcess() { - setup := func() (*MockMachineManager, *MockRepository, *Service, *MockMachineImpl) { - require := s.Require() - - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.Nil(err) - return machineManager, repository, advancer, app1 - } - s.Run("ApplicationStateUpdate", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, []byte("advance error")), + newInput(env.app.Application.ID, 0, 0, []byte("advance error")), } - // Verify application state is updated on error - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) - require.Equal(1, repository.ApplicationStateUpdates) - require.Equal(ApplicationState_Inoperable, repository.LastApplicationState) - require.NotNil(repository.LastApplicationStateReason) - require.Equal("advance error", *repository.LastApplicationStateReason) + require.Equal(1, env.repo.ApplicationStateUpdates) + require.Equal(ApplicationState_Inoperable, env.repo.LastApplicationState) + require.NotNil(env.repo.LastApplicationStateReason) + require.Equal("advance error", *env.repo.LastApplicationStateReason) }) s.Run("ApplicationStateUpdateError", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, []byte("advance error")), + newInput(env.app.Application.ID, 0, 0, []byte("advance error")), } - repository.UpdateApplicationStateError = errors.New("update state error") + env.repo.UpdateApplicationStateError = errors.New("update state error") - // Verify error is still returned even if application state update fails - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "advance error") }) s.Run("Ok", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), - newInput(app.Application.ID, 0, 2, marshal(randomAdvanceResult(2))), - newInput(app.Application.ID, 0, 3, marshal(randomAdvanceResult(3))), - newInput(app.Application.ID, 1, 4, marshal(randomAdvanceResult(4))), - newInput(app.Application.ID, 1, 5, marshal(randomAdvanceResult(5))), - newInput(app.Application.ID, 2, 6, marshal(randomAdvanceResult(6))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), + newInput(env.app.Application.ID, 0, 2, marshal(randomAdvanceResult(2))), + newInput(env.app.Application.ID, 0, 3, marshal(randomAdvanceResult(3))), + newInput(env.app.Application.ID, 1, 4, marshal(randomAdvanceResult(4))), + newInput(env.app.Application.ID, 1, 5, marshal(randomAdvanceResult(5))), + newInput(env.app.Application.ID, 2, 6, marshal(randomAdvanceResult(6))), } - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Nil(err) - require.Len(repository.StoredResults, 7) + require.Len(env.repo.StoredResults, 7) }) s.Run("Noop", func() { s.Run("NoInputs", func() { require := s.Require() + env := s.setupOneApp() - _, _, advancer, app := setup() - inputs := []*Input{} - - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, []*Input{}) require.Nil(err) }) }) @@ -330,46 +311,46 @@ func (s *AdvancerSuite) TestProcess() { s.Run("Error", func() { s.Run("ErrApp", func() { require := s.Require() - + env := s.setupOneApp() invalidApp := Application{ID: 999} - _, _, advancer, _ := setup() inputs := randomInputs(1, 0, 3) - err := advancer.processInputs(context.Background(), &invalidApp, inputs) + err := env.service.processInputs(context.Background(), &invalidApp, inputs) expected := fmt.Sprintf("%v: %v", ErrNoApp, invalidApp.ID) require.EqualError(err, expected) }) s.Run("Advance", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, []byte("advance error")), - newInput(app.Application.ID, 0, 2, []byte("unreachable")), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, []byte("advance error")), + newInput(env.app.Application.ID, 0, 2, []byte("unreachable")), } - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "advance error") - require.Len(repository.StoredResults, 1) + require.Len(env.repo.StoredResults, 1) }) s.Run("StoreAdvance", func() { require := s.Require() - - _, repository, advancer, app := setup() + env := s.setupOneApp() inputs := []*Input{ - newInput(app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app.Application.ID, 0, 1, []byte("unreachable")), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 1, []byte("unreachable")), } - repository.StoreAdvanceError = errors.New("store-advance error") + env.repo.StoreAdvanceError = errors.New("store-advance error") - err := advancer.processInputs(context.Background(), app.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Error(err) require.Contains(err.Error(), "store-advance error") - require.Len(repository.StoredResults, 1) + require.Len(env.repo.StoredResults, 1) + + // Verify that the node shutdown was triggered (context cancelled) + require.Error(env.service.Context.Err(), "shared context should be cancelled") }) }) } @@ -378,27 +359,15 @@ func (s *AdvancerSuite) TestProcess() { func (s *AdvancerSuite) TestContextCancellation() { s.Run("CancelDuringStep", func() { require := s.Require() + env := s.setupOneApp() + env.repo.GetEpochsBlock = true - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - - // Create a repository that will block until we cancel the context - repository := &MockRepository{ - GetEpochsBlock: true, - } - - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create a context that we can cancel ctx, cancel := context.WithCancel(context.Background()) // Start the Step operation in a goroutine errCh := make(chan error) go func() { - errCh <- advancer.Step(ctx) + errCh <- env.service.Step(ctx) }() // Cancel the context after a short delay @@ -417,28 +386,17 @@ func (s *AdvancerSuite) TestContextCancellation() { s.Run("CancelDuringProcessInputs", func() { require := s.Require() + env := s.setupOneApp() + env.app.AdvanceBlock = true - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - // Create a machine that will block during Advance until we cancel the context - app1.AdvanceBlock = true - machineManager.Map[1] = *app1 - - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create inputs and a context that we can cancel inputs := []*Input{ - newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), } ctx, cancel := context.WithCancel(context.Background()) - // Start the processInputs operation in a goroutine errCh := make(chan error) go func() { - errCh <- advancer.processInputs(ctx, app1.Application, inputs) + errCh <- env.service.processInputs(ctx, env.app.Application, inputs) }() // Cancel the context after a short delay @@ -460,71 +418,633 @@ func (s *AdvancerSuite) TestContextCancellation() { func (s *AdvancerSuite) TestLargeNumberOfInputs() { s.Run("LargeNumberOfInputs", func() { require := s.Require() + env := s.setupOneApp() - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 - repository := &MockRepository{} - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) - require.Nil(err) - - // Create a large number of inputs const inputCount = 10000 inputs := make([]*Input, inputCount) for i := range inputCount { - inputs[i] = newInput(app1.Application.ID, 0, uint64(i), marshal(randomAdvanceResult(uint64(i)))) + inputs[i] = newInput(env.app.Application.ID, 0, uint64(i), marshal(randomAdvanceResult(uint64(i)))) } - // Process the inputs - err = advancer.processInputs(context.Background(), app1.Application, inputs) + err := env.service.processInputs(context.Background(), env.app.Application, inputs) require.Nil(err) - - // Verify all inputs were processed - require.Len(repository.StoredResults, inputCount) + require.Len(env.repo.StoredResults, inputCount) }) } -// TestErrorRecovery tests how the advancer recovers from temporary failures +// TestErrorRecovery verifies that any store failure after a successful Advance() +// triggers node shutdown, because the machine and DB are now out of sync. func (s *AdvancerSuite) TestErrorRecovery() { - s.Run("TemporaryRepositoryFailure", func() { + s.Run("TransientStoreFailureTriggersShutdown", func() { require := s.Require() + env := s.setupOneApp() + env.repo.StoreAdvanceFailCount = 1 - machineManager := newMockMachineManager() - app1 := newMockMachine(1) - machineManager.Map[1] = *app1 + inputs := []*Input{ + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), + } - // Repository that fails on first attempt but succeeds on second - repository := &MockRepository{ - StoreAdvanceFailCount: 1, + err := env.service.processInputs(context.Background(), env.app.Application, inputs) + require.Error(err) + require.Contains(err.Error(), "temporary failure") + require.Error(env.service.Context.Err(), "shared context should be cancelled") + }) +} + +// TestContextCancelledBeforeProcessing verifies that when the context is +// already cancelled, processInputs returns the context error immediately +// without reaching the advance or store paths. +func (s *AdvancerSuite) TestContextCancelledBeforeProcessing() { + s.Run("ContextAlreadyCancelled", func() { + require := s.Require() + env := s.setupOneApp() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inputs := []*Input{ + newInput(env.app.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), } - advancer, err := newMockAdvancerService(machineManager, repository) - require.NotNil(advancer) + err := env.service.processInputs(ctx, env.app.Application, inputs) + require.ErrorIs(err, context.Canceled) + }) +} + +// --------------------------------------------------------------------------- +// isAllEpochInputsProcessed tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestIsAllEpochInputsProcessed() { + s.Run("TrueWhenEpochHasNoInputs", func() { + require := s.Require() + env := s.setupOneApp() + + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 5} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) + require.Nil(perr) + require.True(result) + }) + + s.Run("TrueWhenMachineProcessedAllInputs", func() { + require := s.Require() + env := s.setupOneApp() + env.mm.Map[1].machineImpl.processedInputs = 10 + + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 10} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) + require.Nil(perr) + require.True(result) + }) + + s.Run("FalseWhenMoreInputsExist", func() { + require := s.Require() + env := s.setupOneApp() + env.mm.Map[1].machineImpl.processedInputs = 7 + + epoch := &Epoch{Index: 0, InputIndexLowerBound: 5, InputIndexUpperBound: 10} + result, perr := env.service.isAllEpochInputsProcessed(env.app.Application, epoch) + require.Nil(perr) + require.False(result) + }) + + s.Run("ErrorWhenNoMachineForApp", func() { + require := s.Require() + mm := newMockMachineManager() + repo := &MockRepository{} + svc, err := newMockAdvancerService(mm, repo) require.Nil(err) - // Create inputs - inputs := []*Input{ - newInput(app1.Application.ID, 0, 0, marshal(randomAdvanceResult(0))), - newInput(app1.Application.ID, 0, 1, marshal(randomAdvanceResult(1))), + app := &Application{ID: 999} + epoch := &Epoch{Index: 0, InputIndexLowerBound: 0, InputIndexUpperBound: 5} + _, perr := svc.isAllEpochInputsProcessed(app, epoch) + require.Error(perr) + require.ErrorIs(perr, ErrNoApp) + }) +} + +// --------------------------------------------------------------------------- +// isEpochLastInput tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestIsEpochLastInput() { + setupWithEpoch := func(epochStatus EpochStatus) (testEnv, *Application) { + env := s.setupOneApp() + env.repo.GetEpochReturn = &Epoch{Status: epochStatus} + return env, env.app.Application + } + + s.Run("TrueWhenLastInputInClosedEpoch", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Closed) + + lastInput := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + env.repo.GetInputsReturn = map[common.Address][]*Input{ + app.IApplicationAddress: {lastInput}, } + env.repo.GetLastInputReturn = lastInput + + input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + result, err := env.service.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.True(result) + }) + + s.Run("FalseWhenEpochIsOpen", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Open) + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + result, err := env.service.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.False(result) + }) - // First attempt should fail - err = advancer.processInputs(context.Background(), app1.Application, inputs) + s.Run("FalseWhenNotLastInput", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Closed) + + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + result, err := env.service.isEpochLastInput(context.Background(), app, input) + require.Nil(err) + require.False(result) + }) + + s.Run("ErrorWhenNilInput", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Closed) + + _, err := env.service.isEpochLastInput(context.Background(), app, nil) require.Error(err) - require.Contains(err.Error(), "temporary failure") + require.Contains(err.Error(), "must not be nil") + }) + + s.Run("ErrorWhenNilApplication", func() { + require := s.Require() + env, _ := setupWithEpoch(EpochStatus_Closed) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := env.service.isEpochLastInput(context.Background(), nil, input) + require.Error(err) + require.Contains(err.Error(), "must not be nil") + }) + + s.Run("ErrorWhenGetEpochFails", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Closed) + env.repo.GetEpochError = errors.New("get epoch error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := env.service.isEpochLastInput(context.Background(), app, input) + require.Error(err) + require.Contains(err.Error(), "get epoch error") + }) + + s.Run("ErrorWhenGetLastInputFails", func() { + require := s.Require() + env, app := setupWithEpoch(EpochStatus_Closed) + env.repo.GetLastInputError = errors.New("get last input error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + _, err := env.service.isEpochLastInput(context.Background(), app, input) + require.Error(err) + require.Contains(err.Error(), "get last input error") + }) +} + +// --------------------------------------------------------------------------- +// handleEpochAfterInputsProcessed tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestHandleEpochAfterInputsProcessed() { + s.Run("EmptyEpochIndex0GetsOutputsProofFromMachine", func() { + require := s.Require() + env := s.setupOneApp() + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Nil(err) + require.True(env.repo.OutputsProofUpdated) + }) + + s.Run("EmptyEpochIndex0ErrorOnOutputsProof", func() { + require := s.Require() + env := s.setupOneApp() + env.app.OutputsProofError = errors.New("proof error") + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "proof error") + }) + + s.Run("EmptyEpochIndexGt0RepeatsPreviousProof", func() { + require := s.Require() + env := s.setupOneApp() + + epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Nil(err) + require.True(env.repo.RepeatOutputsProofCalled) + }) + + s.Run("EmptyEpochIndexGt0RepeatError", func() { + require := s.Require() + env := s.setupOneApp() + env.repo.RepeatOutputsProofError = errors.New("repeat error") + + epoch := &Epoch{Index: 2, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 0} + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "repeat error") + }) - // Second attempt should succeed - err = advancer.processInputs(context.Background(), app1.Application, inputs) + s.Run("NonEmptyEpochWithEveryEpochSnapshotPolicy", func() { + require := s.Require() + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + env.service.snapshotsDir = s.T().TempDir() + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). + WithStatus(InputCompletionStatus_Accepted).Build() + lastInput.EpochApplicationID = env.app.Application.ID + env.repo.GetLastProcessedInputReturn = lastInput + env.repo.GetLastInputReturn = lastInput + + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Nil(err) + require.True(env.repo.SnapshotURIUpdated) + }) + + s.Run("NonEmptyEpochNoSnapshotPolicy", func() { + require := s.Require() + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_None + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + lastInput := repotest.NewInputBuilder().WithIndex(2).WithEpochIndex(0). + WithStatus(InputCompletionStatus_Accepted).Build() + lastInput.EpochApplicationID = env.app.Application.ID + env.repo.GetLastProcessedInputReturn = lastInput + + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Nil(err) + require.False(env.repo.SnapshotURIUpdated) + }) + + s.Run("NoMachineReturnsError", func() { + require := s.Require() + mm := newMockMachineManager() + svc, err := newMockAdvancerService(mm, &MockRepository{}) + require.Nil(err) + + app := repotest.NewApplicationBuilder().Build() + app.ID = 999 + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + err = svc.handleEpochAfterInputsProcessed(context.Background(), app, epoch) + require.Error(err) + require.ErrorIs(err, ErrNoApp) + }) + + s.Run("GetLastProcessedInputError", func() { + require := s.Require() + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryEpoch + env.repo.GetLastProcessedInputError = errors.New("db connection lost") + + epoch := &Epoch{Index: 0, Status: EpochStatus_Closed, InputIndexLowerBound: 0, InputIndexUpperBound: 3} + err := env.service.handleEpochAfterInputsProcessed(context.Background(), env.app.Application, epoch) + require.Error(err) + require.Contains(err.Error(), "db connection lost") + }) +} + +// --------------------------------------------------------------------------- +// handleSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestHandleSnapshot() { + setupSnapshot := func(policy SnapshotPolicy) (testEnv, *MockMachineInstance) { + env := s.setupOneApp() + env.app.Application.ExecutionParameters.SnapshotPolicy = policy + env.service.snapshotsDir = s.T().TempDir() + instance := env.mm.Map[1] + return env, instance + } + + s.Run("NonePolicy", func() { + require := s.Require() + env, machine := setupSnapshot(SnapshotPolicy_None) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.False(env.repo.SnapshotURIUpdated) + }) + + s.Run("EveryInputPolicy", func() { + require := s.Require() + env, machine := setupSnapshot(SnapshotPolicy_EveryInput) + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.True(env.repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyLastInput", func() { + require := s.Require() + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + + input := repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.True(env.repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyNotLastInput", func() { + require := s.Require() + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Closed} + env.repo.GetLastInputReturn = repotest.NewInputBuilder().WithIndex(5).WithEpochIndex(0).Build() + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.False(env.repo.SnapshotURIUpdated) + }) + + s.Run("EveryEpochPolicyOpenEpoch", func() { + require := s.Require() + env, machine := setupSnapshot(SnapshotPolicy_EveryEpoch) + env.repo.GetEpochReturn = &Epoch{Status: EpochStatus_Open} + + input := repotest.NewInputBuilder().WithIndex(0).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.handleSnapshot(context.Background(), env.app.Application, machine, input) require.Nil(err) - require.Len(repository.StoredResults, 2) + require.False(env.repo.SnapshotURIUpdated) }) } +// --------------------------------------------------------------------------- +// createSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestCreateSnapshot() { + setupCreateSnapshot := func() (testEnv, *MockMachineInstance, string) { + env := s.setupOneApp() + env.app.Application.Name = "testapp" + env.app.Application.ExecutionParameters.SnapshotPolicy = SnapshotPolicy_EveryInput + tmpDir := s.T().TempDir() + env.service.snapshotsDir = tmpDir + instance := &MockMachineInstance{ + application: env.app.Application, + machineImpl: env.app, + } + return env, instance, tmpDir + } + + s.Run("Success", func() { + require := s.Require() + env, machine, tmpDir := setupCreateSnapshot() + + input := repotest.NewInputBuilder().WithIndex(3).WithEpochIndex(1).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.True(env.repo.SnapshotURIUpdated) + + require.NotNil(input.SnapshotURI) + expectedPath := filepath.Join(tmpDir, "testapp_epoch1_input3") + require.Equal(expectedPath, *input.SnapshotURI) + }) + + s.Run("SkipsIfAlreadyHasSnapshot", func() { + require := s.Require() + env, machine, _ := setupCreateSnapshot() + + existingPath := "/existing/snapshot" + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + input.SnapshotURI = &existingPath + + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + require.False(env.repo.SnapshotURIUpdated) + }) + + s.Run("RemovesPreviousSnapshot", func() { + require := s.Require() + env, machine, tmpDir := setupCreateSnapshot() + + prevPath := filepath.Join(tmpDir, "testapp_epoch0_input0") + require.Nil(os.MkdirAll(prevPath, 0755)) + env.repo.GetLastSnapshotReturn = &Input{SnapshotURI: &prevPath} + + input := repotest.NewInputBuilder().WithIndex(1).WithEpochIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) + require.Nil(err) + + _, statErr := os.Stat(prevPath) + require.True(os.IsNotExist(statErr)) + }) + + s.Run("CreateSnapshotError", func() { + require := s.Require() + env, machine, _ := setupCreateSnapshot() + machine.createSnapshotError = errors.New("snapshot failed") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) + require.Error(err) + require.Contains(err.Error(), "snapshot failed") + require.False(env.repo.SnapshotURIUpdated) + }) + + s.Run("MkdirAllError", func() { + require := s.Require() + + machineManager := newMockMachineManager() + app := newMockMachine(1) + app.Application.Name = "testapp" + machineManager.Map[1] = newMockInstance(app) + repository := &MockRepository{} + advancer, err := newMockAdvancerService(machineManager, repository) + require.Nil(err) + + // Create a read-only parent directory so MkdirAll fails + tmpDir := s.T().TempDir() + readonlyDir := filepath.Join(tmpDir, "readonly") + require.Nil(os.MkdirAll(readonlyDir, 0755)) + require.Nil(os.Chmod(readonlyDir, 0555)) + s.T().Cleanup(func() { os.Chmod(readonlyDir, 0755) }) //nolint:errcheck + advancer.snapshotsDir = filepath.Join(readonlyDir, "snapshots") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = app.Application.ID + + err = advancer.createSnapshot(context.Background(), app.Application, machineManager.Map[1], input) + require.Error(err) + require.Contains(err.Error(), "failed to create snapshots directory") + }) + + s.Run("UpdateSnapshotURIError", func() { + require := s.Require() + env, machine, _ := setupCreateSnapshot() + + env.repo.UpdateSnapshotURIError = errors.New("db error") + + input := repotest.NewInputBuilder().WithIndex(0).Build() + input.EpochApplicationID = env.app.Application.ID + + err := env.service.createSnapshot(context.Background(), env.app.Application, machine, input) + require.Error(err) + require.Contains(err.Error(), "db error") + }) +} + +// --------------------------------------------------------------------------- +// removeSnapshot tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestRemoveSnapshot() { + s.Run("RemovesExistingSnapshot", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + // Create a snapshot directory + snapshotPath := filepath.Join(tmpDir, "myapp_epoch0_input0") + require.Nil(os.MkdirAll(snapshotPath, 0755)) + + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Nil(err) + + _, statErr := os.Stat(snapshotPath) + require.True(os.IsNotExist(statErr)) + }) + + s.Run("NonExistentPathIsNoop", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + snapshotPath := filepath.Join(tmpDir, "myapp_epoch0_input0") + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Nil(err) + }) + + s.Run("RejectsDirectoryTraversal", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + // Try to traverse outside snapshotsDir + maliciousPath := filepath.Join(tmpDir, "..", "outside", "myapp_evil") + err := advancer.removeSnapshot(maliciousPath, "myapp") + require.Error(err) + require.Contains(err.Error(), "invalid snapshot path") + }) + + s.Run("RejectsMismatchedAppName", func() { + require := s.Require() + + tmpDir := s.T().TempDir() + advancer := &Service{snapshotsDir: tmpDir} + serviceArgs := &service.CreateInfo{Name: "advancer", Impl: advancer} + require.Nil(service.Create(context.Background(), serviceArgs, &advancer.Service)) + + snapshotPath := filepath.Join(tmpDir, "otherapp_epoch0_input0") + err := advancer.removeSnapshot(snapshotPath, "myapp") + require.Error(err) + require.Contains(err.Error(), "invalid snapshot path") + }) +} + +// --------------------------------------------------------------------------- +// Service.Create tests +// --------------------------------------------------------------------------- + +func (s *AdvancerSuite) TestServiceCreate() { + s.Run("NilRepository", func() { + require := s.Require() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 500 + svc, err := Create(context.Background(), c) + require.Error(err) + require.Nil(svc) + require.Contains(err.Error(), "nil") + }) + + s.Run("ZeroBatchSize", func() { + require := s.Require() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 0 + c.Repository = &MockFullRepository{} + svc, err := Create(context.Background(), c) + require.Error(err) + require.Nil(svc) + require.Contains(err.Error(), "batch size") + }) + + s.Run("CancelledContext", func() { + require := s.Require() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c := &CreateInfo{} + c.Name = "advancer" + c.Config.AdvancerInputBatchSize = 500 + c.Repository = &MockFullRepository{} + svc, err := Create(ctx, c) + require.Error(err) + require.Nil(svc) + }) +} + +// MockFullRepository satisfies the repository.Repository interface minimally +// for Create() validation tests. It panics on any actual DB call. +type MockFullRepository struct { + repository.Repository +} + type MockMachineImpl struct { - Application *Application - AdvanceBlock bool - AdvanceError error + Application *Application + AdvanceBlock bool + AdvanceError error + OutputsProofError error + processedInputs uint64 } func (mock *MockMachineImpl) Advance( @@ -572,33 +1092,33 @@ func newMockMachine(id int64) *MockMachineImpl { } } +// newMockInstance creates a MockMachineInstance from a MockMachineImpl, ready to store in MockMachineManager.Map. +func newMockInstance(impl *MockMachineImpl) *MockMachineInstance { + return &MockMachineInstance{ + application: impl.Application, + machineImpl: impl, + } +} + // ------------------------------------------------------------------------------------------------ type MockMachineManager struct { - Map map[int64]MockMachineImpl + Map map[int64]*MockMachineInstance UpdateMachinesError error } func newMockMachineManager() *MockMachineManager { return &MockMachineManager{ - Map: map[int64]MockMachineImpl{}, + Map: map[int64]*MockMachineInstance{}, } } func (mock *MockMachineManager) GetMachine(appID int64) (manager.MachineInstance, bool) { - machine, exists := mock.Map[appID] + instance, exists := mock.Map[appID] if !exists { return nil, false } - - // For testing purposes, we'll create a mock MachineInstance - // that has the same Application but delegates the methods to our mock - mockInstance := &MockMachineInstance{ - application: machine.Application, - machineImpl: &machine, - } - - return mockInstance, true + return instance, true } func (mock *MockMachineManager) UpdateMachines(ctx context.Context) error { @@ -608,7 +1128,7 @@ func (mock *MockMachineManager) UpdateMachines(ctx context.Context) error { func (mock *MockMachineManager) Applications() []*Application { apps := make([]*Application, 0, len(mock.Map)) for _, v := range mock.Map { - apps = append(apps, v.Application) + apps = append(apps, v.application) } return apps } @@ -618,10 +1138,15 @@ func (mock *MockMachineManager) HasMachine(appID int64) bool { return exists } +func (mock *MockMachineManager) Close() error { + return nil +} + // MockMachineInstance is a test implementation of manager.MachineInstance type MockMachineInstance struct { - application *Application - machineImpl *MockMachineImpl + application *Application + machineImpl *MockMachineImpl + createSnapshotError error } // Advance implements the MachineInstance interface for testing @@ -641,23 +1166,28 @@ func (m *MockMachineInstance) Application() *Application { } func (m *MockMachineInstance) ProcessedInputs() uint64 { - return 0 + return m.machineImpl.processedInputs } -func (m *MockMachineInstance) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { - return nil, nil +func (m *MockMachineInstance) OutputsProof(ctx context.Context) (*OutputsProof, error) { + if m.machineImpl.OutputsProofError != nil { + return nil, m.machineImpl.OutputsProofError + } + return &OutputsProof{ + OutputsHash: randomHash(), + MachineHash: randomHash(), + }, nil } // Synchronize implements the MachineInstance interface for testing -func (m *MockMachineInstance) Synchronize(ctx context.Context, repo manager.MachineRepository) error { +func (m *MockMachineInstance) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { // Not used in advancer tests, but needed to satisfy the interface return nil } // CreateSnapshot implements the MachineInstance interface for testing func (m *MockMachineInstance) CreateSnapshot(ctx context.Context, processInputs uint64, path string) error { - // Not used in advancer tests, but needed to satisfy the interface - return nil + return m.createSnapshotError } // Retrieves the hash of the current machine state @@ -689,11 +1219,21 @@ type MockRepository struct { GetLastSnapshotReturn *Input GetLastSnapshotError error RepeatOutputsProofError error + GetEpochReturn *Epoch + GetEpochError error + GetLastInputReturn *Input + GetLastInputError error + GetLastProcessedInputReturn *Input + GetLastProcessedInputError error + UpdateSnapshotURIError error StoredResults []*AdvanceResult ApplicationStateUpdates int LastApplicationState ApplicationState LastApplicationStateReason *string + OutputsProofUpdated bool + RepeatOutputsProofCalled bool + SnapshotURIUpdated bool mu sync.Mutex } @@ -739,7 +1279,23 @@ func (mock *MockRepository) ListInputs( } address := common.HexToAddress(nameOrAddress) - return mock.GetInputsReturn[address], uint64(len(mock.GetInputsReturn[address])), mock.GetInputsError + inputs := mock.GetInputsReturn[address] + total := uint64(len(inputs)) + + // Apply pagination if a limit is set + if p.Limit > 0 { + start := p.Offset + if start >= total { + return nil, total, mock.GetInputsError + } + end := start + p.Limit + if end > total { + end = total + } + inputs = inputs[start:end] + } + + return inputs, total, mock.GetInputsError } func (mock *MockRepository) StoreAdvanceResult( @@ -763,15 +1319,32 @@ func (mock *MockRepository) StoreAdvanceResult( } mock.StoredResults = append(mock.StoredResults, res) + + // Simulate real behavior: processed inputs change status and are no longer + // returned by queries filtering for unprocessed (Status_None) inputs. + // This prevents infinite loops in batched fetching. + if mock.StoreAdvanceError == nil { + for addr, inputs := range mock.GetInputsReturn { + for i, inp := range inputs { + if inp.EpochApplicationID == appID && inp.Index == res.InputIndex { + newInputs := make([]*Input, 0, len(inputs)-1) + newInputs = append(newInputs, inputs[:i]...) + newInputs = append(newInputs, inputs[i+1:]...) + mock.GetInputsReturn[addr] = newInputs + break + } + } + } + } + return mock.StoreAdvanceError } func (mock *MockRepository) UpdateEpochOutputsProof(ctx context.Context, appID int64, epochIndex uint64, proof *OutputsProof) error { - // Check for context cancellation if ctx.Err() != nil { return ctx.Err() } - + mock.OutputsProofUpdated = true return mock.UpdateOutputsProofError } @@ -797,15 +1370,28 @@ func (mock *MockRepository) UpdateApplicationState(ctx context.Context, appID in } func (mock *MockRepository) GetEpoch(ctx context.Context, nameOrAddress string, index uint64) (*Epoch, error) { - // Not used in most advancer tests, but needed to satisfy the interface + if ctx.Err() != nil { + return nil, ctx.Err() + } + if mock.GetEpochError != nil { + return nil, mock.GetEpochError + } + if mock.GetEpochReturn != nil { + return mock.GetEpochReturn, nil + } return &Epoch{Status: EpochStatus_Closed}, nil } func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, epochIndex uint64) (*Input, error) { - // Check for context cancellation if ctx.Err() != nil { return nil, ctx.Err() } + if mock.GetLastInputError != nil { + return nil, mock.GetLastInputError + } + if mock.GetLastInputReturn != nil { + return mock.GetLastInputReturn, nil + } address := common.HexToAddress(appAddress) inputs := mock.GetInputsReturn[address] @@ -813,7 +1399,6 @@ func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, return nil, nil } - // Find the last input for the given epoch var lastInput *Input for _, input := range inputs { if input.EpochIndex == epochIndex && (lastInput == nil || input.Index > lastInput.Index) { @@ -825,10 +1410,15 @@ func (mock *MockRepository) GetLastInput(ctx context.Context, appAddress string, } func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddress string) (*Input, error) { - if ctx.Err() != nil { return nil, ctx.Err() } + if mock.GetLastProcessedInputError != nil { + return nil, mock.GetLastProcessedInputError + } + if mock.GetLastProcessedInputReturn != nil { + return mock.GetLastProcessedInputReturn, nil + } address := common.HexToAddress(appAddress) inputs := mock.GetInputsReturn[address] @@ -836,7 +1426,6 @@ func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddres return nil, nil } - // Find the last input for the given epoch var lastInput *Input for _, input := range inputs { if input.Status != InputCompletionStatus_None && (lastInput == nil || input.Index > lastInput.Index) { @@ -848,8 +1437,11 @@ func (mock *MockRepository) GetLastProcessedInput(ctx context.Context, appAddres } func (mock *MockRepository) UpdateInputSnapshotURI(ctx context.Context, appId int64, inputIndex uint64, snapshotURI string) error { - // Not used in most advancer tests, but needed to satisfy the interface - return nil + if ctx.Err() != nil { + return ctx.Err() + } + mock.SnapshotURIUpdated = true + return mock.UpdateSnapshotURIError } func (mock *MockRepository) GetLastSnapshot(ctx context.Context, nameOrAddress string) (*Input, error) { @@ -862,10 +1454,10 @@ func (mock *MockRepository) GetLastSnapshot(ctx context.Context, nameOrAddress s } func (mock *MockRepository) RepeatPreviousEpochOutputsProof(ctx context.Context, appID int64, epochIndex uint64) error { - // Check for context cancellation if ctx.Err() != nil { return ctx.Err() } + mock.RepeatOutputsProofCalled = true return mock.RepeatOutputsProofError } diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 1f8399331..168703a83 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -5,8 +5,10 @@ package advancer import ( "context" + "errors" "fmt" "net/http" + "time" "github.com/cartesi/rollups-node/internal/config" "github.com/cartesi/rollups-node/internal/inspect" @@ -15,9 +17,14 @@ import ( "github.com/cartesi/rollups-node/pkg/service" ) +// httpShutdownTimeout is how long to wait for in-flight inspect HTTP requests +// to drain before forcibly closing the server during shutdown. +const httpShutdownTimeout = 10 * time.Second //nolint: mnd + // Service is the main advancer service that processes inputs through Cartesi machines type Service struct { service.Service + inputBatchSize uint64 snapshotsDir string repository AdvancerRepository machineManager manager.MachineProvider @@ -53,12 +60,17 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { return nil, fmt.Errorf("repository on advancer service Create is nil") } + s.inputBatchSize = c.Config.AdvancerInputBatchSize + if s.inputBatchSize == 0 { + return nil, fmt.Errorf("advancer input batch size must be greater than 0") + } + // Create the machine manager manager := manager.NewMachineManager( - ctx, c.Repository, s.Logger, c.Config.FeatureMachineHashCheckEnabled, + s.inputBatchSize, ) s.machineManager = manager @@ -86,14 +98,41 @@ func (s *Service) Tick() []error { if err := s.Step(s.Context); err != nil { return []error{err} } - return []error{} + return nil } func (s *Service) Stop(b bool) []error { - return nil + var errs []error + + // Shut down the inspect HTTP server gracefully. + // Use a dedicated timeout context because s.Context may already be cancelled + // when Stop is called from the context.Done path. + if s.HTTPServer != nil { + s.Logger.Info("Shutting down inspect HTTP server") + shutdownCtx, cancel := context.WithTimeout(context.Background(), httpShutdownTimeout) + defer cancel() + if err := s.HTTPServer.Shutdown(shutdownCtx); err != nil { + errs = append(errs, fmt.Errorf("failed to shutdown inspect HTTP server: %w", err)) + } + } + + // Close all machine instances to avoid orphaned emulator processes + if s.machineManager != nil { + s.Logger.Info("Closing machine manager") + if err := s.machineManager.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close machine manager: %w", err)) + } + } + + return errs } func (s *Service) Serve() error { if s.inspector != nil && s.HTTPServerFunc != nil { - go s.HTTPServerFunc() + go func() { + if err := s.HTTPServerFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Error("Inspect HTTP server failed — shutting down", "error", err) + s.Cancel() + } + }() } return s.Service.Serve() } diff --git a/internal/config/generate/Config.toml b/internal/config/generate/Config.toml index 3233aedbc..daccd1e09 100644 --- a/internal/config/generate/Config.toml +++ b/internal/config/generate/Config.toml @@ -61,6 +61,14 @@ used-by = ["advancer", "node", "cli"] # Rollups # +[rollups.CARTESI_ADVANCER_INPUT_BATCH_SIZE] +default = "500" +go-type = "uint64" +description = """ +Maximum number of inputs fetched per database query during advance processing and machine +synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs.""" +used-by = ["advancer", "node"] + [rollups.CARTESI_ADVANCER_POLLING_INTERVAL] default = "3" go-type = "Duration" diff --git a/internal/config/generated.go b/internal/config/generated.go index 1fb1db0f7..2ac966345 100644 --- a/internal/config/generated.go +++ b/internal/config/generated.go @@ -53,6 +53,7 @@ const ( LOG_COLOR = "CARTESI_LOG_COLOR" LOG_LEVEL = "CARTESI_LOG_LEVEL" JSONRPC_MACHINE_LOG_LEVEL = "CARTESI_JSONRPC_MACHINE_LOG_LEVEL" + ADVANCER_INPUT_BATCH_SIZE = "CARTESI_ADVANCER_INPUT_BATCH_SIZE" ADVANCER_POLLING_INTERVAL = "CARTESI_ADVANCER_POLLING_INTERVAL" BLOCKCHAIN_HTTP_MAX_RETRIES = "CARTESI_BLOCKCHAIN_HTTP_MAX_RETRIES" BLOCKCHAIN_HTTP_RETRY_MAX_WAIT = "CARTESI_BLOCKCHAIN_HTTP_RETRY_MAX_WAIT" @@ -143,6 +144,8 @@ func SetDefaults() { viper.SetDefault(JSONRPC_MACHINE_LOG_LEVEL, "info") + viper.SetDefault(ADVANCER_INPUT_BATCH_SIZE, "500") + viper.SetDefault(ADVANCER_POLLING_INTERVAL, "3") viper.SetDefault(BLOCKCHAIN_HTTP_MAX_RETRIES, "4") @@ -201,6 +204,10 @@ type AdvancerConfig struct { // One of "trace", "debug", "info", "warning", "error", "fatal". JsonrpcMachineLogLevel string `mapstructure:"CARTESI_JSONRPC_MACHINE_LOG_LEVEL"` + // Maximum number of inputs fetched per database query during advance processing and machine + // synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs. + AdvancerInputBatchSize uint64 `mapstructure:"CARTESI_ADVANCER_INPUT_BATCH_SIZE"` + // How many seconds the node will wait before querying the database for new inputs. AdvancerPollingInterval Duration `mapstructure:"CARTESI_ADVANCER_POLLING_INTERVAL"` @@ -283,6 +290,13 @@ func LoadAdvancerConfig() (*AdvancerConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_MACHINE_LOG_LEVEL is required for the advancer service: %w", err) } + cfg.AdvancerInputBatchSize, err = GetAdvancerInputBatchSize() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_INPUT_BATCH_SIZE: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_ADVANCER_INPUT_BATCH_SIZE is required for the advancer service: %w", err) + } + cfg.AdvancerPollingInterval, err = GetAdvancerPollingInterval() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_POLLING_INTERVAL: %w", err) @@ -815,6 +829,10 @@ type NodeConfig struct { // One of "trace", "debug", "info", "warning", "error", "fatal". JsonrpcMachineLogLevel string `mapstructure:"CARTESI_JSONRPC_MACHINE_LOG_LEVEL"` + // Maximum number of inputs fetched per database query during advance processing and machine + // synchronization. Bounds memory usage for applications with large backlogs of unprocessed inputs. + AdvancerInputBatchSize uint64 `mapstructure:"CARTESI_ADVANCER_INPUT_BATCH_SIZE"` + // How many seconds the node will wait before querying the database for new inputs. AdvancerPollingInterval Duration `mapstructure:"CARTESI_ADVANCER_POLLING_INTERVAL"` @@ -981,6 +999,13 @@ func LoadNodeConfig() (*NodeConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_MACHINE_LOG_LEVEL is required for the node service: %w", err) } + cfg.AdvancerInputBatchSize, err = GetAdvancerInputBatchSize() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_INPUT_BATCH_SIZE: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_ADVANCER_INPUT_BATCH_SIZE is required for the node service: %w", err) + } + cfg.AdvancerPollingInterval, err = GetAdvancerPollingInterval() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_ADVANCER_POLLING_INTERVAL: %w", err) @@ -1347,6 +1372,7 @@ func (c *NodeConfig) ToAdvancerConfig() *AdvancerConfig { LogColor: c.LogColor, LogLevel: c.LogLevel, JsonrpcMachineLogLevel: c.JsonrpcMachineLogLevel, + AdvancerInputBatchSize: c.AdvancerInputBatchSize, AdvancerPollingInterval: c.AdvancerPollingInterval, MaxStartupTime: c.MaxStartupTime, SnapshotsDir: c.SnapshotsDir, @@ -1891,6 +1917,19 @@ func GetJsonrpcMachineLogLevel() (string, error) { return notDefinedstring(), fmt.Errorf("%s: %w", JSONRPC_MACHINE_LOG_LEVEL, ErrNotDefined) } +// GetAdvancerInputBatchSize returns the value for the environment variable CARTESI_ADVANCER_INPUT_BATCH_SIZE. +func GetAdvancerInputBatchSize() (uint64, error) { + s := viper.GetString(ADVANCER_INPUT_BATCH_SIZE) + if s != "" { + v, err := toUint64(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", ADVANCER_INPUT_BATCH_SIZE, err) + } + return v, nil + } + return notDefineduint64(), fmt.Errorf("%s: %w", ADVANCER_INPUT_BATCH_SIZE, ErrNotDefined) +} + // GetAdvancerPollingInterval returns the value for the environment variable CARTESI_ADVANCER_POLLING_INTERVAL. func GetAdvancerPollingInterval() (Duration, error) { s := viper.GetString(ADVANCER_POLLING_INTERVAL) diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 983c951c4..2d77c9d31 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -234,11 +234,11 @@ func (mock *MockMachine) ProcessedInputs() uint64 { return 0 } -func (m *MockMachine) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { +func (m *MockMachine) OutputsProof(ctx context.Context) (*OutputsProof, error) { return nil, nil } -func (mock *MockMachine) Synchronize(ctx context.Context, repo manager.MachineRepository) error { +func (mock *MockMachine) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { // Not used in inspect tests, but needed to satisfy the interface return nil } diff --git a/internal/manager/instance.go b/internal/manager/instance.go index f1ad2000a..447dbb50a 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -10,10 +10,12 @@ import ( "fmt" "log/slog" "sync" + "sync/atomic" "time" "github.com/cartesi/rollups-node/internal/manager/pmutex" . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/pkg/machine" "github.com/ethereum/go-ethereum/common" "golang.org/x/sync/semaphore" @@ -29,13 +31,26 @@ var ( ErrInvalidConcurrentLimit = errors.New("maximum concurrent inspects must not be zero") ) -// MachineInstanceImpl represents a running Cartesi machine for an application +// MachineInstanceImpl represents a running Cartesi machine for an application. +// +// Concurrency protocol: +// - runtime: Protected by PMutex. Written under HLock, read under LLock. +// - processedInputs: atomic.Uint64. Written under HLock (together with runtime swap, +// so writers see a consistent pair). Read lock-free via Load() — +// this is safe because only one advance runs at a time (advanceMutex) +// and the atomic store is visible to all goroutines immediately. +// - advanceMutex: Serializes all Advance calls. Only one input is processed at a time. +// - mutex (PMutex): HLock for advance/snapshot/hash/proof (may destroy runtime on error). +// LLock for inspect (read-only fork). HLock starves LLock by design. +// - inspectSemaphore: Bounds concurrent inspect operations. type MachineInstanceImpl struct { application *Application runtime machine.Machine - // How many inputs were processed by the machine - processedInputs uint64 + // How many inputs were processed by the machine. + // Written under HLock (together with runtime swap — the two MUST be updated + // atomically from the perspective of readers). Read without locks via Load(). + processedInputs atomic.Uint64 // Timeouts for operations advanceTimeout time.Duration @@ -47,6 +62,9 @@ type MachineInstanceImpl struct { advanceMutex sync.Mutex inspectSemaphore *semaphore.Weighted + // Timeout for draining in-flight inspects during Close + closeTimeout time.Duration + // Factory for creating machine runtimes runtimeFactory MachineRuntimeFactory @@ -126,15 +144,16 @@ func NewMachineInstanceWithFactory( instance := &MachineInstanceImpl{ application: app, runtime: runtime, - processedInputs: processedInputs, advanceTimeout: app.ExecutionParameters.AdvanceMaxDeadline, inspectTimeout: app.ExecutionParameters.InspectMaxDeadline, maxConcurrentInspects: app.ExecutionParameters.MaxConcurrentInspects, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(int64(app.ExecutionParameters.MaxConcurrentInspects)), + closeTimeout: defaultCloseTimeout, runtimeFactory: factory, logger: logger.With("application", app.Name), } + instance.processedInputs.Store(processedInputs) return instance, nil } @@ -144,46 +163,74 @@ func (m *MachineInstanceImpl) Application() *Application { } func (m *MachineInstanceImpl) ProcessedInputs() uint64 { - return m.processedInputs + return m.processedInputs.Load() } -// Synchronize brings the machine up to date with processed inputs -func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository) error { +// Synchronize brings the machine up to date with processed inputs. +// It handles both template-based instances (processedInputs == 0, replays all) +// and snapshot-based instances (processedInputs > 0, replays only remaining). +// Inputs are fetched in batches to bound memory usage. +func (m *MachineInstanceImpl) Synchronize(ctx context.Context, repo MachineRepository, batchSize uint64) error { appAddress := m.application.IApplicationAddress.String() - m.logger.Info("Synchronizing machine processed inputs", + currentProcessed := m.processedInputs.Load() + m.logger.Info("Synchronizing machine with processed inputs", "address", appAddress, - "processed_inputs", m.application.ProcessedInputs) + "app_processed_inputs", m.application.ProcessedInputs, + "machine_processed_inputs", currentProcessed) - // Get all processed inputs for this application - inputs, _, err := getProcessedInputs(ctx, repo, appAddress) - if err != nil { - return err - } + initialProcessedInputs := currentProcessed + replayed := uint64(0) + toReplay := uint64(0) - // Verify that the number of inputs matches what's expected - if uint64(len(inputs)) != m.application.ProcessedInputs { - errorMsg := fmt.Sprintf("processed inputs count mismatch: expected %d, got %d", - m.application.ProcessedInputs, len(inputs)) - m.logger.Error(errorMsg, "address", appAddress) - return fmt.Errorf("%w: %s", ErrMachineSynchronization, errorMsg) - } + for { + p := repository.Pagination{ + Limit: batchSize, + Offset: initialProcessedInputs + replayed, + } + inputs, totalCount, err := getProcessedInputs(ctx, repo, appAddress, p) + if err != nil { + return fmt.Errorf("%w: %w", ErrMachineSynchronization, err) + } - if len(inputs) == 0 { - m.logger.Info("No previous processed inputs to synchronize", "address", appAddress) - return nil - } + // Validate count on the first batch + if replayed == 0 { + if totalCount != m.application.ProcessedInputs { + errorMsg := fmt.Sprintf( + "processed inputs count mismatch: expected %d, got %d", + m.application.ProcessedInputs, totalCount) + m.logger.Error(errorMsg, "address", appAddress) + return fmt.Errorf("%w: %s", ErrMachineSynchronization, errorMsg) + } + if currentProcessed > totalCount { + return fmt.Errorf( + "%w: machine has processed %d inputs but DB only has %d", + ErrMachineSynchronization, currentProcessed, totalCount) + } + toReplay = totalCount - currentProcessed + if toReplay == 0 { + m.logger.Info("No inputs to replay during synchronization", + "address", appAddress) + return nil + } + } - // Process each input to bring the machine to the current state - for _, input := range inputs { - m.logger.Info("Replaying input during synchronization", - "address", appAddress, - "epoch_index", input.EpochIndex, - "input_index", input.Index) + for _, input := range inputs { + m.logger.Info("Replaying input during synchronization", + "address", appAddress, + "epoch_index", input.EpochIndex, + "input_index", input.Index, + "progress", fmt.Sprintf("%d/%d", replayed+1, toReplay)) + + _, err := m.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) + if err != nil { + return fmt.Errorf("%w: failed to replay input %d: %w", + ErrMachineSynchronization, input.Index, err) + } + replayed++ + } - _, err := m.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) - if err != nil { - return fmt.Errorf("%w: failed to replay input %d: %v", - ErrMachineSynchronization, input.Index, err) + if replayed >= toReplay || len(inputs) == 0 { + break } } @@ -201,8 +248,10 @@ func (m *MachineInstanceImpl) forkForAdvance(ctx context.Context, index uint64) } // Verify input index - if m.processedInputs != index { - return nil, fmt.Errorf("%w: processed inputs is %d and index is %d", ErrInvalidInputIndex, m.processedInputs, index) + current := m.processedInputs.Load() + if current != index { + return nil, fmt.Errorf("%w: processed inputs is %d and index is %d", + ErrInvalidInputIndex, current, index) } // Fork the machine @@ -253,8 +302,8 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn } // Process the input - accepted, outputs, reports, hashes, remaining, outputsHash, err := fork.Advance(advanceCtx, input, computeHashes) - status, err := toInputStatus(accepted, err) + advanceResp, err := fork.Advance(advanceCtx, input, computeHashes) + status, err := toInputStatus(advanceResp.Accepted, err) if err != nil { return nil, errors.Join(err, fork.Close()) } @@ -264,10 +313,10 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn EpochIndex: epochIndex, InputIndex: index, Status: status, - Outputs: outputs, - Reports: reports, - Hashes: hashes, - RemainingMetaCycles: remaining, + Outputs: advanceResp.Outputs, + Reports: advanceResp.Reports, + Hashes: advanceResp.Hashes, + RemainingMetaCycles: advanceResp.RemainingCycles, IsDaveConsensus: computeHashes, } @@ -278,7 +327,7 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn if err != nil { return nil, errors.Join(err, fork.Close()) } - result.OutputsHash = outputsHash + result.OutputsHash = advanceResp.OutputsHash result.OutputsHashProof, err = fork.OutputsHashProof(ctx) if err != nil { return nil, errors.Join(err, fork.Close()) @@ -286,13 +335,14 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn // Replace the current machine with the fork m.mutex.HLock() - if err = m.runtime.Close(); err != nil { - m.mutex.Unlock() - return nil, err - } + oldRuntime := m.runtime m.runtime = fork - m.processedInputs++ + m.processedInputs.Add(1) m.mutex.Unlock() + + if err := oldRuntime.Close(); err != nil { + m.logger.Warn("Failed to close old machine runtime", "error", err) + } } else { // Use the previous state for rejected inputs result.MachineHash = prevMachineHash @@ -300,15 +350,17 @@ func (m *MachineInstanceImpl) Advance(ctx context.Context, input []byte, epochIn result.OutputsHashProof = prevOutputsHashProof // Close the fork since we're not using it - err = fork.Close() + if err := fork.Close(); err != nil { + m.logger.Warn("Failed to close fork machine runtime", "error", err) + } // Update the processed inputs counter m.mutex.HLock() - m.processedInputs++ + m.processedInputs.Add(1) m.mutex.Unlock() } - return result, err + return result, nil } // forkForInspect creates a copy of the machine for inspect operations @@ -327,7 +379,7 @@ func (m *MachineInstanceImpl) forkForInspect(ctx context.Context) (machine.Machi return nil, 0, err } - return fork, m.processedInputs, nil + return fork, m.processedInputs.Load(), nil } // Inspect queries the machine state without modifying it @@ -380,8 +432,8 @@ func (m *MachineInstanceImpl) CreateSnapshot(ctx context.Context, processedInput m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -389,21 +441,25 @@ func (m *MachineInstanceImpl) CreateSnapshot(ctx context.Context, processedInput } // Verify processed inputs - if m.processedInputs != processedInputs { - return fmt.Errorf("%w: machine processed inputs is %d and expected is %d", ErrInvalidSnapshotPoint, m.processedInputs, processedInputs) + current := m.processedInputs.Load() + if current != processedInputs { + return fmt.Errorf("%w: machine processed inputs is %d and expected is %d", + ErrInvalidSnapshotPoint, current, processedInputs) } - m.logger.Debug("Creating machine snapshot", "path", path, "processed_inputs", m.processedInputs) + m.logger.Debug("Creating machine snapshot", "path", path, "processed_inputs", current) // Create a context with a timeout for the store operation storeCtx, cancel := context.WithTimeout(ctx, m.application.ExecutionParameters.StoreDeadline) defer cancel() - // Store the machine state to the specified path + // Store the machine state to the specified path. + // A Store failure on a local child process indicates an unrecoverable + // condition (disk full, process crash, etc.) — destroy the runtime. err := m.runtime.Store(storeCtx, path) if err != nil { - m.logger.Error("Failed to create snapshot", "path", path, "error", err) - return err + m.logger.Error("Failed to create snapshot, destroying runtime", "path", path, "error", err) + return m.destroyRuntime(fmt.Errorf("failed to create snapshot: %w", err)) } m.logger.Debug("Snapshot created successfully", "path", path) @@ -415,8 +471,8 @@ func (m *MachineInstanceImpl) Hash(ctx context.Context) ([32]byte, error) { m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -430,21 +486,21 @@ func (m *MachineInstanceImpl) Hash(ctx context.Context) ([32]byte, error) { hash, err := m.runtime.Hash(storeCtx) if err != nil { - m.logger.Error("Failed to retrieve machine root hash", "error", err) - return [32]byte{}, err + m.logger.Error("Failed to retrieve machine root hash, destroying runtime", "error", err) + return [32]byte{}, m.destroyRuntime(fmt.Errorf("failed to retrieve machine root hash: %w", err)) } m.logger.Debug("Machine root hash retrieved successfully", "hash", "0x"+hex.EncodeToString(hash[:])) return hash, nil } -func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) { +func (m *MachineInstanceImpl) OutputsProof(ctx context.Context) (*OutputsProof, error) { // Acquire the advance mutex to ensure no advance operations are in progress m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - // Acquire a read lock on the machine - m.mutex.LLock() + // Acquire HLock since this operation may destroy the runtime on failure. + m.mutex.HLock() defer m.mutex.Unlock() if m.runtime == nil { @@ -456,20 +512,22 @@ func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs proofCtx, cancel := context.WithTimeout(ctx, m.application.ExecutionParameters.LoadDeadline) defer cancel() - // Get the machine state before processing + // The runtime is a local child process — errors here indicate the process + // crashed, ran out of resources, or is otherwise unrecoverable. + // Close the runtime to avoid leaving a broken process alive. machineHash, err := m.runtime.Hash(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get machine hash: %w", err)) } outputsHash, err := m.runtime.OutputsHash(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get outputs hash: %w", err)) } outputsHashProof, err := m.runtime.OutputsHashProof(proofCtx) if err != nil { - return nil, errors.Join(err, m.runtime.Close()) + return nil, m.destroyRuntime(fmt.Errorf("failed to get outputs hash proof: %w", err)) } proof := &OutputsProof{ @@ -478,22 +536,35 @@ func (m *MachineInstanceImpl) OutputsProof(ctx context.Context, processedInputs OutputsHashProof: outputsHashProof, } - m.logger.Debug("Machine machine hash, outputs merkle root and outputs merkle proof retrieved successfully", + m.logger.Debug("Machine hash, outputs merkle root and outputs merkle proof retrieved successfully", "hash", "0x"+hex.EncodeToString(machineHash[:])) return proof, nil } +// defaultCloseTimeout is how long Close waits for in-flight inspects to drain +// before forcibly closing the runtime. +const defaultCloseTimeout = 30 * time.Second + // Close shuts down the machine instance func (m *MachineInstanceImpl) Close() error { // Acquire all locks to ensure no operations are in progress m.advanceMutex.Lock() defer m.advanceMutex.Unlock() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), m.closeTimeout) + defer cancel() + + acquired := 0 for range int(m.maxConcurrentInspects) { - _ = m.inspectSemaphore.Acquire(ctx, 1) - defer m.inspectSemaphore.Release(1) + if err := m.inspectSemaphore.Acquire(ctx, 1); err != nil { + m.logger.Warn("Timed out waiting for in-flight inspects to drain; closing anyway", + "still_in_flight", int(m.maxConcurrentInspects)-acquired, + "drained", acquired) + break + } + acquired++ } + defer m.inspectSemaphore.Release(int64(acquired)) // Close the runtime m.mutex.HLock() @@ -508,6 +579,18 @@ func (m *MachineInstanceImpl) Close() error { return err } +// destroyRuntime closes the runtime and nils it out so that subsequent calls +// fail fast with ErrMachineClosed instead of talking to a broken process. +// Must be called while holding the appropriate locks. +func (m *MachineInstanceImpl) destroyRuntime(cause error) error { + if m.runtime == nil { + return cause + } + closeErr := m.runtime.Close() + m.runtime = nil + return errors.Join(cause, closeErr) +} + // MachineRuntimeFactory defines an interface for creating machine runtimes type MachineRuntimeFactory interface { CreateMachineRuntime( @@ -654,6 +737,8 @@ func toInputStatus(accepted bool, err error) (status InputCompletionStatus, _ er return InputCompletionStatus_MachineHalted, nil case errors.Is(err, machine.ErrOutputsLimitExceeded): return InputCompletionStatus_OutputsLimitExceeded, nil + case errors.Is(err, machine.ErrReportsLimitExceeded): + return InputCompletionStatus_ReportsLimitExceeded, nil case errors.Is(err, machine.ErrReachedTargetMcycle): return InputCompletionStatus_CycleLimitExceeded, nil case errors.Is(err, machine.ErrPayloadLengthLimitExceeded): diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 82f27640b..9f42fd130 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -14,6 +14,7 @@ import ( "github.com/cartesi/rollups-node/internal/manager/pmutex" "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/pkg/machine" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/suite" @@ -204,6 +205,106 @@ func (s *MachineInstanceSuite) TestNewMachineInstance() { }) } +func (s *MachineInstanceSuite) TestNewMachineInstanceFromSnapshot() { + s.Run("Ok", func() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockRuntime := &MockRollupsMachine{} + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + + // NewMachineInstanceFromSnapshot creates a SnapshotMachineRuntimeFactory + // internally, so we use NewMachineInstanceWithFactory to test the same + // logic with a controlled factory. + inputIndex := uint64(5) + + // The function sets processedInputs = inputIndex + 1 + // Use the mock factory to avoid actual machine loading + inst, err := NewMachineInstanceWithFactory( + context.Background(), + app, + inputIndex+1, + testLogger, + false, + mockFactory, + ) + require.NoError(err) + require.NotNil(inst) + require.Equal(inputIndex+1, inst.ProcessedInputs()) + + inst.Close() + }) + + s.Run("FactoryError", func() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: nil, + ErrorToReturn: errors.New("snapshot load failed"), + } + + inst, err := NewMachineInstanceWithFactory( + context.Background(), + app, + 6, + testLogger, + false, + mockFactory, + ) + require.Error(err) + require.Nil(inst) + require.ErrorIs(err, ErrMachineCreation) + require.Contains(err.Error(), "snapshot load failed") + }) +} + +func (s *MachineInstanceSuite) TestApplicationAndProcessedInputs() { + require := s.Require() + app := &model.Application{ + Name: "TestApp", + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + } + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockRuntime := &MockRollupsMachine{} + mockFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + + inst, err := NewMachineInstanceWithFactory( + context.Background(), app, 42, testLogger, false, mockFactory, + ) + require.NoError(err) + require.Same(app, inst.Application()) + require.Equal(uint64(42), inst.ProcessedInputs()) + inst.Close() +} + func (s *MachineInstanceSuite) TestAdvance() { s.Run("Ok", func() { s.Run("Accept", func() { @@ -220,7 +321,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) s.Run("Reject", func() { @@ -239,7 +340,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) testSoftError := func(name string, err error, status model.InputCompletionStatus) { @@ -258,7 +359,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Equal(expectedReports1, res.Reports) require.Equal(newHash(1), res.OutputsHash) require.Equal(newHash(2), res.MachineHash) - require.Equal(uint64(6), machine.processedInputs) + require.Equal(uint64(6), machine.processedInputs.Load()) }) } @@ -274,6 +375,10 @@ func (s *MachineInstanceSuite) TestAdvance() { machine.ErrOutputsLimitExceeded, model.InputCompletionStatus_OutputsLimitExceeded) + testSoftError("ReportsLimit", + machine.ErrReportsLimitExceeded, + model.InputCompletionStatus_ReportsLimitExceeded) + testSoftError("ReachedTargetMcycle", machine.ErrReachedTargetMcycle, model.InputCompletionStatus_CycleLimitExceeded) @@ -298,7 +403,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Error(err) require.Nil(res) require.Equal(errFork, err) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Advance", func() { @@ -313,7 +418,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Nil(res) require.ErrorIs(err, errAdvance) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("AdvanceAndClose", func() { @@ -331,7 +436,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.ErrorIs(err, errAdvance) require.ErrorIs(err, errClose) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Hash", func() { @@ -346,7 +451,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.Nil(res) require.ErrorIs(err, errHash) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("HashAndClose", func() { @@ -364,7 +469,7 @@ func (s *MachineInstanceSuite) TestAdvance() { require.ErrorIs(err, errHash) require.ErrorIs(err, errClose) require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.Equal(uint64(5), machine.processedInputs.Load()) }) s.Run("Close", func() { @@ -374,34 +479,99 @@ func (s *MachineInstanceSuite) TestAdvance() { errClose := errors.New("Close error") inner.CloseError = errClose + // Close error on old runtime is logged, not propagated. + // Advance succeeds and processedInputs is incremented. res, err := machine.Advance(context.Background(), []byte{}, 0, 5, false) - require.Error(err) - require.Nil(res) - require.ErrorIs(err, errClose) - require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(5), machine.processedInputs) + require.NoError(err) + require.NotNil(res) + require.Equal(uint64(6), machine.processedInputs.Load()) }) s.Run("Fork", func() { require := s.Require() _, fork, machineInst := s.setupAdvance() - errClose := errors.New("Close error") fork.AdvanceError = machine.ErrException - fork.CloseError = errClose + fork.CloseError = errors.New("Close error") + // Close error on fork is logged, not propagated. + // Advance succeeds and processedInputs is incremented. res, err := machineInst.Advance(context.Background(), []byte{}, 0, 5, false) - require.Error(err) + require.NoError(err) require.NotNil(res) - require.ErrorIs(err, errClose) - require.NotErrorIs(err, errUnreachable) - require.Equal(uint64(6), machineInst.processedInputs) + require.Equal(uint64(6), machineInst.processedInputs.Load()) }) }) }) - s.Run("Concurrency", func() { - // Two Advances cannot be concurrently active. - s.T().Skip("TODO") + s.Run("CollectHashes", func() { + require := s.Require() + inner, fork, machine := s.setupAdvance() + + // Set up WriteCheckpointHash to succeed + fork.CheckpointHashError = nil + + res, err := machine.Advance(context.Background(), []byte{}, 0, 5, true) + require.Nil(err) + require.NotNil(res) + + require.Same(fork, machine.runtime) + require.Equal(model.InputCompletionStatus_Accepted, res.Status) + require.True(res.IsDaveConsensus) + require.Equal(uint64(6), machine.processedInputs.Load()) + + // Verify the inner runtime was closed (accept path) + _ = inner + }) + + s.Run("CollectHashesWriteCheckpointError", func() { + require := s.Require() + _, fork, machine := s.setupAdvance() + + errCheckpoint := errors.New("checkpoint write error") + fork.CheckpointHashError = errCheckpoint + + res, err := machine.Advance(context.Background(), []byte{}, 0, 5, true) + require.Error(err) + require.Nil(res) + require.ErrorIs(err, errCheckpoint) + require.Equal(uint64(5), machine.processedInputs.Load()) + }) + + s.Run("SequentialAdvances", func() { + // Advance is serialized by advanceMutex — concurrent advance on the + // same machine never happens by design. This test verifies that two + // sequential advances correctly increment processedInputs. + require := s.Require() + inner, fork, machine := s.setupAdvance() + + // Allow inner.Close to succeed (old runtime close on accept) + inner.CloseError = nil + + // First advance: fork from inner (processedInputs=5), returns accepted. + // After accept, fork becomes the new runtime. + // Second advance: fork from fork (processedInputs=6), fork must also fork. + fork2 := &MockRollupsMachine{} + fork2.AdvanceAcceptedReturn = true + fork2.AdvanceOutputsReturn = expectedOutputs + fork2.AdvanceReportsReturn = expectedReports1 + fork2.OutputsHashReturn = newHash(1) + fork2.HashReturn = newHash(2) + fork2.CloseError = errUnreachable // old runtime close for second advance + fork2.ForkReturn = nil + fork.ForkReturn = fork2 + fork.CloseError = nil // close of fork (now old runtime) in second advance + + // First advance at index 5 + res1, err := machine.Advance(context.Background(), []byte{}, 0, 5, false) + require.Nil(err) + require.NotNil(res1) + require.Equal(uint64(6), machine.processedInputs.Load()) + + // Second advance at index 6 + res2, err := machine.Advance(context.Background(), []byte{}, 0, 6, false) + require.Nil(err) + require.NotNil(res2) + require.Equal(uint64(7), machine.processedInputs.Load()) }) } @@ -540,10 +710,29 @@ func (s *MachineInstanceSuite) TestCreateSnapshot() { inner, _, machine := s.setupAdvance() errStore := errors.New("Store error") inner.StoreError = errStore + inner.CloseError = nil + + err := machine.CreateSnapshot(context.Background(), 5, "/tmp/snapshot") + require.Error(err) + require.ErrorIs(err, errStore) + + // Runtime should be destroyed after a store error. + require.Nil(machine.runtime) + }) + + s.Run("ErrorAndCloseError", func() { + require := s.Require() + inner, _, machine := s.setupAdvance() + errStore := errors.New("Store error") + errClose := errors.New("Close error") + inner.StoreError = errStore + inner.CloseError = errClose err := machine.CreateSnapshot(context.Background(), 5, "/tmp/snapshot") require.Error(err) - require.Equal(errStore, err) + require.ErrorIs(err, errStore) + require.ErrorIs(err, errClose) + require.Nil(machine.runtime) }) s.Run("MachineClosed", func() { @@ -565,6 +754,158 @@ func (s *MachineInstanceSuite) TestCreateSnapshot() { }) } +func (s *MachineInstanceSuite) TestHash() { + s.Run("Ok", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + + hash, err := machineInst.Hash(context.Background()) + require.NoError(err) + require.Equal([32]byte(newHash(1)), hash) + + // Runtime should still be alive after a successful call. + require.Same(inner, machineInst.runtime) + }) + + s.Run("MachineClosed", func() { + require := s.Require() + _, machineInst := s.setupOutputsProof() + machineInst.runtime = nil + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.Equal(ErrMachineClosed, err) + require.Equal([32]byte{}, hash) + }) + + s.Run("Error", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + inner.HashError = errHash + inner.CloseError = nil + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.ErrorIs(err, errHash) + require.Equal([32]byte{}, hash) + + // Runtime should be destroyed after a hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("ErrorAndCloseError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + errClose := errors.New("Close error") + inner.HashError = errHash + inner.CloseError = errClose + + hash, err := machineInst.Hash(context.Background()) + require.Error(err) + require.ErrorIs(err, errHash) + require.ErrorIs(err, errClose) + require.Equal([32]byte{}, hash) + require.Nil(machineInst.runtime) + }) +} + +func (s *MachineInstanceSuite) TestOutputsProof() { + s.Run("Ok", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + + proof, err := machineInst.OutputsProof(context.Background()) + require.NoError(err) + require.NotNil(proof) + + require.Equal(newHash(1), proof.MachineHash) + require.Equal(newHash(2), proof.OutputsHash) + require.Equal(expectedOutputsHashProof, proof.OutputsHashProof) + + // Runtime should still be alive after a successful call. + require.Same(inner, machineInst.runtime) + }) + + s.Run("MachineClosed", func() { + require := s.Require() + _, machineInst := s.setupOutputsProof() + machineInst.runtime = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.Equal(ErrMachineClosed, err) + }) + + s.Run("HashError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + inner.HashError = errHash + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errHash) + + // Runtime should be destroyed after a hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("HashErrorAndCloseError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errHash := errors.New("Hash error") + errClose := errors.New("Close error") + inner.HashError = errHash + inner.CloseError = errClose + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errHash) + require.ErrorIs(err, errClose) + + // Runtime should be destroyed even when Close also fails. + require.Nil(machineInst.runtime) + }) + + s.Run("OutputsHashError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errOutputsHash := errors.New("OutputsHash error") + inner.OutputsHashError = errOutputsHash + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errOutputsHash) + + // Runtime should be destroyed after an outputs hash error. + require.Nil(machineInst.runtime) + }) + + s.Run("OutputsHashProofError", func() { + require := s.Require() + inner, machineInst := s.setupOutputsProof() + errProof := errors.New("OutputsHashProof error") + inner.OutputsHashProofError = errProof + inner.CloseError = nil + + proof, err := machineInst.OutputsProof(context.Background()) + require.Nil(proof) + require.Error(err) + require.ErrorIs(err, errProof) + + // Runtime should be destroyed after an outputs hash proof error. + require.Nil(machineInst.runtime) + }) +} + func (s *MachineInstanceSuite) TestClose() { s.Run("Ok", func() { require := s.Require() @@ -618,6 +959,38 @@ func (s *MachineInstanceSuite) TestClose() { require.Fail("Advance did not complete after Close") } }) + + s.Run("TimesOutWaitingForInspects", func() { + require := s.Require() + inner, _, machine := s.setupAdvance() + inner.CloseError = nil + + // Use a short timeout so the test runs fast + machine.closeTimeout = centisecond + + // Pre-acquire all semaphore slots to simulate stuck inspects + for range int(machine.maxConcurrentInspects) { + err := machine.inspectSemaphore.Acquire(context.Background(), 1) + require.Nil(err) + } + + // Close should not block indefinitely — it times out and closes anyway + done := make(chan error, 1) + go func() { + done <- machine.Close() + }() + + select { + case err := <-done: + require.Nil(err) + require.Nil(machine.runtime) + case <-time.After(decisecond * 5): + require.Fail("Close blocked indefinitely despite timeout") + } + + // Release slots to clean up + machine.inspectSemaphore.Release(int64(machine.maxConcurrentInspects)) + }) } // ------------------------------------------------------------------------------------------------ @@ -639,6 +1012,11 @@ var ( newBytes(33, 300), newBytes(34, 300), } + expectedOutputsHashProof = []machine.Hash{ + newHash(3), + newHash(4), + newHash(5), + } ) func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollupsMachine, *MachineInstanceImpl) { @@ -653,14 +1031,15 @@ func (s *MachineInstanceSuite) setupAdvance() (*MockRollupsMachine, *MockRollups machineInst := &MachineInstanceImpl{ application: app, runtime: inner, - processedInputs: 5, advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + machineInst.processedInputs.Store(5) fork := &MockRollupsMachine{} @@ -709,14 +1088,15 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups machineInst := &MachineInstanceImpl{ application: app, runtime: inner, - processedInputs: 55, advanceTimeout: decisecond, inspectTimeout: centisecond, maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, mutex: pmutex.New(), inspectSemaphore: semaphore.NewWeighted(3), logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } + machineInst.processedInputs.Store(55) fork := &MockRollupsMachine{} @@ -740,6 +1120,44 @@ func (s *MachineInstanceSuite) setupInspect() (*MockRollupsMachine, *MockRollups return inner, fork, machineInst } +func (s *MachineInstanceSuite) setupOutputsProof() (*MockRollupsMachine, *MachineInstanceImpl) { + app := &model.Application{ + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + LoadDeadline: decisecond, + MaxConcurrentInspects: 3, + }, + } + inner := &MockRollupsMachine{} + machineInst := &MachineInstanceImpl{ + application: app, + runtime: inner, + advanceTimeout: decisecond, + inspectTimeout: centisecond, + maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, + mutex: pmutex.New(), + inspectSemaphore: semaphore.NewWeighted(3), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + machineInst.processedInputs.Store(5) + + inner.HashReturn = newHash(1) + inner.HashError = nil + inner.OutputsHashReturn = newHash(2) + inner.OutputsHashError = nil + inner.OutputsHashProofReturn = []machine.Hash{ + newHash(3), + newHash(4), + newHash(5), + } + inner.OutputsHashProofError = nil + inner.CloseError = errUnreachable + + return inner, machineInst +} + // ------------------------------------------------------------------------------------------------ const ( @@ -763,10 +1181,269 @@ func newBytes(n byte, size int) []byte { return bytes } +// ------------------------------------------------------------------------------------------------ +// Synchronize tests +// ------------------------------------------------------------------------------------------------ + +// mockSyncRepository is a lightweight mock for Synchronize tests. +// It simulates pagination over a slice of inputs. +type mockSyncRepository struct { + inputs []*model.Input + totalCount uint64 + listErr error +} + +func (r *mockSyncRepository) ListApplications( + _ context.Context, + _ repository.ApplicationFilter, + _ repository.Pagination, + _ bool, +) ([]*model.Application, uint64, error) { + return nil, 0, nil +} + +func (r *mockSyncRepository) ListInputs( + ctx context.Context, + _ string, + _ repository.InputFilter, + p repository.Pagination, + _ bool, +) ([]*model.Input, uint64, error) { + if err := ctx.Err(); err != nil { + return nil, 0, err + } + if r.listErr != nil { + return nil, 0, r.listErr + } + start := p.Offset + if start >= uint64(len(r.inputs)) { + return nil, r.totalCount, nil + } + end := start + p.Limit + if p.Limit == 0 || end > uint64(len(r.inputs)) { + end = uint64(len(r.inputs)) + } + return r.inputs[start:end], r.totalCount, nil +} + +func (r *mockSyncRepository) GetLastSnapshot( + _ context.Context, + _ string, +) (*model.Input, error) { + return nil, nil +} + +// newForkableMock creates a mock where Fork returns a fresh mock each time, +// properly exercising the fork/replace lifecycle in Synchronize tests. +func newForkableMock() *MockRollupsMachine { + m := &MockRollupsMachine{} + m.CloseError = nil + m.AdvanceAcceptedReturn = true + m.HashReturn = newHash(1) + m.OutputsHashReturn = newHash(2) + m.ForkFunc = func(_ context.Context) (machine.Machine, error) { + return newForkableMock(), nil + } + return m +} + +func (s *MachineInstanceSuite) newSyncMachine(processedInputs uint64, appProcessedInputs uint64) *MachineInstanceImpl { + runtime := newForkableMock() + + inst := &MachineInstanceImpl{ + application: &model.Application{ + ProcessedInputs: appProcessedInputs, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: decisecond, + InspectMaxDeadline: centisecond, + MaxConcurrentInspects: 3, + }, + }, + runtime: runtime, + advanceTimeout: decisecond, + inspectTimeout: centisecond, + maxConcurrentInspects: 3, + closeTimeout: defaultCloseTimeout, + mutex: pmutex.New(), + inspectSemaphore: semaphore.NewWeighted(3), + logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + } + inst.processedInputs.Store(processedInputs) + return inst +} + +func makeInputs(startIndex, count uint64) []*model.Input { + inputs := make([]*model.Input, count) + for i := uint64(0); i < count; i++ { + inputs[i] = &model.Input{ + Index: startIndex + i, + EpochIndex: 0, + RawData: []byte{byte(startIndex + i)}, + } + } + return inputs +} + +func (s *MachineInstanceSuite) TestSynchronize() { + s.Run("TemplateSyncAllInputs", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + originalRuntime := inst.runtime + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.NoError(err) + require.Equal(uint64(3), inst.processedInputs.Load()) + // Verify the runtime was actually replaced (not self-fork) + require.NotSame(originalRuntime, inst.runtime) + }) + + s.Run("SnapshotSyncRemainingInputs", func() { + require := s.Require() + // Snapshot was at index 2, so processedInputs=3, but app has 5 total + inst := s.newSyncMachine(3, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 5), + totalCount: 5, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.NoError(err) + require.Equal(uint64(5), inst.processedInputs.Load()) + }) + + s.Run("NoInputsToReplay", func() { + require := s.Require() + inst := s.newSyncMachine(0, 0) + repo := &mockSyncRepository{ + inputs: nil, + totalCount: 0, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.NoError(err) + require.Equal(uint64(0), inst.processedInputs.Load()) + }) + + s.Run("SnapshotAlreadyCaughtUp", func() { + require := s.Require() + // Snapshot at last input — nothing to replay + inst := s.newSyncMachine(5, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 5), + totalCount: 5, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.NoError(err) + require.Equal(uint64(5), inst.processedInputs.Load()) + }) + + s.Run("MachineAheadOfDB", func() { + require := s.Require() + // Machine has processed 5 inputs but DB only has 3 + inst := s.newSyncMachine(5, 3) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "machine has processed 5 inputs but DB only has 3") + }) + + s.Run("CountMismatch", func() { + require := s.Require() + inst := s.newSyncMachine(0, 5) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, // DB says 3 but app expects 5 + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "count mismatch") + }) + + s.Run("ListInputsError", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + listErr := errors.New("database connection lost") + repo := &mockSyncRepository{ + listErr: listErr, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "database connection lost") + }) + + s.Run("AdvanceErrorMidReplay", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + // Make each fork return a hard error on Advance. + runtime := inst.runtime.(*MockRollupsMachine) + runtime.ForkFunc = func(_ context.Context) (machine.Machine, error) { + fork := newForkableMock() + fork.AdvanceError = errors.New("advance failed during replay") + return fork, nil + } + + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo, 1000) + require.Error(err) + require.ErrorIs(err, ErrMachineSynchronization) + require.Contains(err.Error(), "failed to replay input") + }) + + s.Run("BatchBoundaryCrossing", func() { + require := s.Require() + // Use batchSize=2 with 3 inputs so the loop must fetch two batches + // (batch 1: inputs 0-1, batch 2: input 2), exercising pagination. + inst := s.newSyncMachine(0, 3) + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(context.Background(), repo, 2) + require.NoError(err) + require.Equal(uint64(3), inst.processedInputs.Load()) + }) + + s.Run("ContextCancellation", func() { + require := s.Require() + inst := s.newSyncMachine(0, 3) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + repo := &mockSyncRepository{ + inputs: makeInputs(0, 3), + totalCount: 3, + } + + err := inst.Synchronize(ctx, repo, 1000) + require.Error(err) + }) +} + // ------------------------------------------------------------------------------------------------ type MockRollupsMachine struct { ForkReturn machine.Machine + ForkFunc func(context.Context) (machine.Machine, error) ForkError error HashReturn machine.Hash @@ -794,7 +1471,10 @@ type MockRollupsMachine struct { CloseError error } -func (m *MockRollupsMachine) Fork(_ context.Context) (machine.Machine, error) { +func (m *MockRollupsMachine) Fork(ctx context.Context) (machine.Machine, error) { + if m.ForkFunc != nil { + return m.ForkFunc(ctx) + } return m.ForkReturn, m.ForkError } @@ -803,7 +1483,7 @@ func (m *MockRollupsMachine) Hash(_ context.Context) (machine.Hash, error) { } func (m *MockRollupsMachine) OutputsHash(_ context.Context) (machine.Hash, error) { - return m.OutputsHashReturn, m.HashError + return m.OutputsHashReturn, m.OutputsHashError } func (m *MockRollupsMachine) OutputsHashProof(_ context.Context) ([]machine.Hash, error) { @@ -814,16 +1494,15 @@ func (m *MockRollupsMachine) WriteCheckpointHash(_ context.Context, _ machine.Ha return m.CheckpointHashError } -func (m *MockRollupsMachine) Advance(_ context.Context, _ []byte, _ bool) ( - bool, []machine.Output, []machine.Report, []machine.Hash, uint64, machine.Hash, error, -) { - return m.AdvanceAcceptedReturn, - m.AdvanceOutputsReturn, - m.AdvanceReportsReturn, - m.AdvanceLeafsReturn, - m.AdvanceRemainingReturn, - m.OutputsHashReturn, - m.AdvanceError +func (m *MockRollupsMachine) Advance(_ context.Context, _ []byte, _ bool) (*machine.AdvanceResponse, error) { + return &machine.AdvanceResponse{ + Accepted: m.AdvanceAcceptedReturn, + Outputs: m.AdvanceOutputsReturn, + Reports: m.AdvanceReportsReturn, + Hashes: m.AdvanceLeafsReturn, + RemainingCycles: m.AdvanceRemainingReturn, + OutputsHash: m.OutputsHashReturn, + }, m.AdvanceError } func (m *MockRollupsMachine) Inspect(_ context.Context, _ []byte) (bool, []machine.Report, error) { diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 29379a8cf..efb0cadfc 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -13,6 +13,7 @@ import ( . "github.com/cartesi/rollups-node/internal/model" "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" ) var ( @@ -33,28 +34,70 @@ type MachineRepository interface { GetLastSnapshot(ctx context.Context, nameOrAddress string) (*Input, error) } +// MachineInstanceFactory creates MachineInstance values from applications. +// Implementations decide whether to load from a template or snapshot. +type MachineInstanceFactory interface { + NewFromTemplate(ctx context.Context, app *Application, logger *slog.Logger, checkHash bool) (MachineInstance, error) + NewFromSnapshot(ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, + snapshotPath string, machineHash *common.Hash, inputIndex uint64) (MachineInstance, error) +} + +// DefaultMachineInstanceFactory delegates to NewMachineInstance / NewMachineInstanceFromSnapshot. +type DefaultMachineInstanceFactory struct{} + +func (f *DefaultMachineInstanceFactory) NewFromTemplate( + ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, +) (MachineInstance, error) { + return NewMachineInstance(ctx, app, logger, checkHash) +} + +func (f *DefaultMachineInstanceFactory) NewFromSnapshot( + ctx context.Context, app *Application, logger *slog.Logger, checkHash bool, + snapshotPath string, machineHash *common.Hash, inputIndex uint64, +) (MachineInstance, error) { + return NewMachineInstanceFromSnapshot(ctx, app, logger, checkHash, snapshotPath, machineHash, inputIndex) +} + // MachineManager manages the lifecycle of machine instances for applications type MachineManager struct { - mutex sync.RWMutex - machines map[int64]MachineInstance - repository MachineRepository - checkHash bool - logger *slog.Logger + mutex sync.RWMutex + machines map[int64]MachineInstance + closed bool + repository MachineRepository + checkHash bool + inputBatchSize uint64 + logger *slog.Logger + instanceFactory MachineInstanceFactory } -// NewMachineManager creates a new machine manager +// Option configures a MachineManager. +type Option func(*MachineManager) + +// WithInstanceFactory overrides the default MachineInstanceFactory. +func WithInstanceFactory(f MachineInstanceFactory) Option { + return func(m *MachineManager) { m.instanceFactory = f } +} + +// NewMachineManager creates a new machine manager. func NewMachineManager( - ctx context.Context, repo MachineRepository, logger *slog.Logger, checkHash bool, + inputBatchSize uint64, + opts ...Option, ) *MachineManager { - return &MachineManager{ - machines: map[int64]MachineInstance{}, - repository: repo, - checkHash: checkHash, - logger: logger, + m := &MachineManager{ + machines: map[int64]MachineInstance{}, + repository: repo, + checkHash: checkHash, + inputBatchSize: inputBatchSize, + logger: logger, + instanceFactory: &DefaultMachineInstanceFactory{}, + } + for _, opt := range opts { + opt(m) } + return m } // UpdateMachines refreshes the list of machines based on enabled applications @@ -73,11 +116,10 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { m.logger.Info("Creating new machine instance", "application", app.Name, - "address", app.IApplicationAddress.String()) + "address", app.IApplicationAddress) // Check if we have a snapshot to load from var instance MachineInstance - var err error // Find the latest snapshot for this application snapshot, err := m.repository.GetLastSnapshot(ctx, app.IApplicationAddress.String()) @@ -89,88 +131,68 @@ func (m *MachineManager) UpdateMachines(ctx context.Context) error { } if snapshot != nil && snapshot.SnapshotURI != nil { - // Create a machine instance from the snapshot - m.logger.Info("Creating machine instance from snapshot", - "application", app.Name, - "snapshot", *snapshot.SnapshotURI) - // Verify the snapshot path exists - if _, err := os.Stat(*snapshot.SnapshotURI); os.IsNotExist(err) { - m.logger.Error("Snapshot path does not exist", + if _, statErr := os.Stat(*snapshot.SnapshotURI); statErr == nil { + m.logger.Info("Creating machine instance from snapshot", "application", app.Name, - "snapshot", *snapshot.SnapshotURI, - "error", err) - // Fall back to template-based initialization - } else { - // Create a factory with the snapshot path and machine hash - instance, err = NewMachineInstanceFromSnapshot( - ctx, app, m.logger, m.checkHash, *snapshot.SnapshotURI, snapshot.MachineHash, snapshot.Index) + "snapshot", *snapshot.SnapshotURI) + instance, err = m.instanceFactory.NewFromSnapshot( + ctx, app, m.logger, m.checkHash, + *snapshot.SnapshotURI, snapshot.MachineHash, snapshot.Index) if err != nil { m.logger.Error("Failed to create machine instance from snapshot", "application", app.Name, "snapshot", *snapshot.SnapshotURI, "error", err) - // Fall back to template-based initialization - } else { - // If we loaded from a snapshot, we need to synchronize from the snapshot point - // Get the inputs after the snapshot - inputsAfterSnapshot, err := getInputsAfterSnapshot(ctx, m.repository, app, snapshot.Index) - if err != nil { - m.logger.Error("Failed to get inputs after snapshot", - "application", app.Name, - "snapshot_input_index", snapshot.Index, - "error", err) - instance.Close() - continue - } - - // Process each input to bring the machine to the current state - for _, input := range inputsAfterSnapshot { - m.logger.Info("Replaying input after snapshot", - "application", app.Name, - "epoch_index", input.EpochIndex, - "input_index", input.Index) - - _, err := instance.Advance(ctx, input.RawData, input.EpochIndex, input.Index, false) - if err != nil { - m.logger.Error("Failed to replay input after snapshot", - "application", app.Name, - "input_index", input.Index, - "error", err) - instance.Close() - continue - } - } - - // Add the machine to the manager - m.addMachine(app.ID, instance) - continue + // Fall back to template-based initialization below } + } else if errors.Is(statErr, os.ErrNotExist) { + m.logger.Warn("Snapshot path does not exist", + "application", app.Name, + "snapshot", *snapshot.SnapshotURI) + } else { + m.logger.Error("Failed to access snapshot path", + "application", app.Name, + "snapshot", *snapshot.SnapshotURI, + "error", statErr) } } - // If we didn't load from a snapshot, create a new machine instance from the template - instance, err = NewMachineInstance(ctx, app, m.logger, m.checkHash) - if err != nil { - m.logger.Error("Failed to create machine instance", - "application", app.IApplicationAddress, - "error", err) - continue + // Fall back to template if snapshot loading failed or was unavailable + if instance == nil { + instance, err = m.instanceFactory.NewFromTemplate(ctx, app, m.logger, m.checkHash) + if err != nil { + m.logger.Error("Failed to create machine instance", + "application", app.IApplicationAddress, + "error", err) + continue + } } - // Synchronize the machine with processed inputs - err = instance.Synchronize(ctx, m.repository) + // Synchronize the machine with processed inputs. + // For template instances (processedInputs=0) this replays all inputs. + // For snapshot instances (processedInputs=snapshotIndex+1) this replays + // only inputs after the snapshot. + err = instance.Synchronize(ctx, m.repository, m.inputBatchSize) if err != nil { m.logger.Error("Failed to synchronize machine", "application", app.IApplicationAddress, "error", err) - instance.Close() + if err := instance.Close(); err != nil { + m.logger.Warn("Failed to close machine after synchronization failure", + "application", app.Name, "error", err) + } continue } - // Add the machine to the manager - m.addMachine(app.ID, instance) + // Add the machine to the manager; close if it fails + if !m.addMachine(app.ID, instance) { + if err := instance.Close(); err != nil { + m.logger.Warn("Failed to close duplicate machine instance", + "application", app.Name, "error", err) + } + } } // Remove machines for disabled applications @@ -195,11 +217,16 @@ func (m *MachineManager) HasMachine(appID int64) bool { return exists } -// AddMachine adds a machine to the manager +// addMachine adds a machine to the manager. +// Returns false if the manager is closed or the appID already exists. func (m *MachineManager) addMachine(appID int64, machine MachineInstance) bool { m.mutex.Lock() defer m.mutex.Unlock() + if m.closed { + return false + } + if _, exists := m.machines[appID]; exists { return false } @@ -226,7 +253,10 @@ func (m *MachineManager) removeMachines(apps []*Application) { m.logger.Info("Application was disabled, shutting down machine", "application", machine.Application().Name) } - machine.Close() + if err := machine.Close(); err != nil && m.logger != nil { + m.logger.Warn("Failed to close machine for disabled application", + "application", machine.Application().Name, "error", err) + } delete(m.machines, id) } } @@ -244,23 +274,43 @@ func (m *MachineManager) Applications() []*Application { return apps } -// Close shuts down all machine instances +// Close shuts down all machine instances in parallel. +// After Close returns, no new machines can be added. func (m *MachineManager) Close() error { + // Mark as closed and take ownership of the machines map under the lock, + // then release it so readers (GetMachine, HasMachine, Applications) + // aren't blocked during the potentially slow parallel shutdown. m.mutex.Lock() - defer m.mutex.Unlock() + m.closed = true + machines := m.machines + m.machines = make(map[int64]MachineInstance) + m.mutex.Unlock() + + type closeResult struct { + id int64 + err error + } + + var wg sync.WaitGroup + results := make(chan closeResult, len(machines)) + + for id, machine := range machines { + wg.Go(func() { + results <- closeResult{id: id, err: machine.Close()} + }) + } + + wg.Wait() + close(results) var errs []error - for id, machine := range m.machines { - if err := machine.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close machine for app %d: %w", id, err)) + for r := range results { + if r.err != nil { + errs = append(errs, fmt.Errorf("failed to close machine for app %d: %w", r.id, r.err)) } - delete(m.machines, id) } - if len(errs) > 0 { - return errors.Join(errs...) - } - return nil + return errors.Join(errs...) } // Helper function to get enabled applications @@ -269,25 +319,13 @@ func getEnabledApplications(ctx context.Context, repo MachineRepository) ([]*App return repo.ListApplications(ctx, f, repository.Pagination{}, false) } -// Helper function to get processed inputs -func getProcessedInputs(ctx context.Context, repo MachineRepository, appAddress string) ([]*Input, uint64, error) { +// getProcessedInputs retrieves processed inputs with pagination support. +func getProcessedInputs( + ctx context.Context, + repo MachineRepository, + appAddress string, + p repository.Pagination, +) ([]*Input, uint64, error) { f := repository.InputFilter{NotStatus: Pointer(InputCompletionStatus_None)} - return repo.ListInputs(ctx, appAddress, f, repository.Pagination{}, false) -} - -// Helper function to get inputs after a specific index -func getInputsAfterSnapshot(ctx context.Context, repo MachineRepository, app *Application, snapshotInputIndex uint64) ([]*Input, error) { - // Get all processed inputs for this application - inputs, _, err := getProcessedInputs(ctx, repo, app.IApplicationAddress.String()) - if err != nil { - return nil, err - } - - // Filter inputs to only include those after the snapshot - for i, input := range inputs { - if input.Index > snapshotInputIndex { - return inputs[i:], nil - } - } - return []*Input{}, nil + return repo.ListInputs(ctx, appAddress, f, p, false) } diff --git a/internal/manager/manager_test.go b/internal/manager/manager_test.go index f5ef13e3e..660bde832 100644 --- a/internal/manager/manager_test.go +++ b/internal/manager/manager_test.go @@ -5,6 +5,7 @@ package manager import ( "context" + "errors" "io" "log/slog" "testing" @@ -28,7 +29,7 @@ func (s *MachineManagerSuite) TestNewMachineManager() { require := s.Require() repo := &MockMachineRepository{} testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + manager := NewMachineManager(repo, testLogger, false, 500) require.NotNil(manager) require.Empty(manager.machines) require.Equal(repo, manager.repository) @@ -63,25 +64,15 @@ func (s *MachineManagerSuite) TestUpdateMachines() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - // Create manager with a test logger + // Create manager with a mock instance factory testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + mockInstance := &DummyMachineInstanceMock{application: app1} + factory := &MockMachineInstanceFactory{Instance: mockInstance} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) - // Create a mock factory for testing - mockRuntime := &MockRollupsMachine{} - mockFactory := &MockMachineRuntimeFactory{ - RuntimeToReturn: mockRuntime, - ErrorToReturn: nil, - } - - // Replace the default factory with our mock - originalFactory := defaultFactory - defaultFactory = mockFactory - defer func() { defaultFactory = originalFactory }() - - // This test should now succeed with our mock err := manager.UpdateMachines(context.Background()) require.NoError(err) + require.True(manager.HasMachine(1)) repo.AssertCalled(s.T(), "ListApplications", mock.Anything, mock.Anything, mock.Anything, false) }) @@ -96,7 +87,7 @@ func (s *MachineManagerSuite) TestUpdateMachines() { // Create a test logger testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - manager := NewMachineManager(context.Background(), repo, testLogger, false) + manager := NewMachineManager(repo, testLogger, false, 500) // Add mock machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -129,7 +120,7 @@ func (s *MachineManagerSuite) TestGetMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -152,7 +143,7 @@ func (s *MachineManagerSuite) TestHasMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(repo, nil, false, 500) machine := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} // Add a machine @@ -172,7 +163,7 @@ func (s *MachineManagerSuite) TestAddMachine() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(repo, nil, false, 500) machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} machine2 := &DummyMachineInstanceMock{application: &model.Application{ID: 2}} @@ -190,12 +181,20 @@ func (s *MachineManagerSuite) TestAddMachine() { added = manager.addMachine(1, machine1) require.False(added) require.Len(manager.machines, 2) + + // Close the manager and try to add a new machine + err := manager.Close() + require.NoError(err) + + machine3 := &DummyMachineInstanceMock{application: &model.Application{ID: 3}} + added = manager.addMachine(3, machine3) + require.False(added, "addMachine must reject additions after Close") } func (s *MachineManagerSuite) TestRemoveDisabledMachines() { require := s.Require() - manager := NewMachineManager(context.Background(), nil, nil, false) + manager := NewMachineManager(nil, nil, false, 500) // Add machines app1 := &model.Application{ID: 1} @@ -220,6 +219,163 @@ func (s *MachineManagerSuite) TestRemoveDisabledMachines() { require.True(manager.HasMachine(3)) } +func (s *MachineManagerSuite) TestUpdateMachinesErrors() { + s.Run("GetEnabledApplicationsError", func() { + require := s.Require() + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return(([]*model.Application)(nil), uint64(0), errors.New("db error")) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + manager := NewMachineManager(repo, testLogger, false, 500) + + err := manager.UpdateMachines(context.Background()) + require.Error(err) + require.Contains(err.Error(), "db error") + }) + + s.Run("SnapshotCreationFailureFallsBackToTemplate", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + snapshotPath := "/fake/snapshot/path" + snapshotInput := &model.Input{ + Index: 2, + SnapshotURI: &snapshotPath, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(snapshotInput, nil) + // ListInputs for synchronization (no inputs to replay) + repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Input{}, uint64(0), nil) + + // The snapshot path doesn't exist, so it should fall back to template + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + mockInstance := &DummyMachineInstanceMock{application: app} + factory := &MockMachineInstanceFactory{Instance: mockInstance} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) + + err := manager.UpdateMachines(context.Background()) + require.NoError(err) + require.True(manager.HasMachine(1)) + }) + + s.Run("TemplateCreationFailureSkipsApp", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(nil, nil) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + factory := &MockMachineInstanceFactory{Err: errors.New("machine creation failed")} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(factory)) + + err := manager.UpdateMachines(context.Background()) + // UpdateMachines should not return an error; it logs and skips + require.NoError(err) + require.False(manager.HasMachine(1)) + }) + + s.Run("SynchronizeFailureClosesAndSkipsApp", func() { + require := s.Require() + + app := &model.Application{ + ID: 1, + Name: "App1", + IApplicationAddress: common.HexToAddress("0x1"), + State: model.ApplicationState_Enabled, + ProcessedInputs: 3, + ExecutionParameters: model.ExecutionParameters{ + AdvanceMaxDeadline: 100, + InspectMaxDeadline: 100, + MaxConcurrentInspects: 3, + }, + } + + repo := &MockMachineRepository{} + repo.On("ListApplications", mock.Anything, mock.Anything, mock.Anything, false). + Return([]*model.Application{app}, uint64(1), nil) + repo.On("GetLastSnapshot", mock.Anything, mock.Anything). + Return(nil, nil) + // ListInputs returns an error so the real Synchronize method propagates the failure. + repo.On("ListInputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false). + Return(([]*model.Input)(nil), uint64(0), errors.New("db connection lost")) + + testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Use a factory that builds a real MachineInstanceImpl (with a mock runtime) + // so that Synchronize actually runs and hits the repo. + mockRuntime := &MockRollupsMachine{} + runtimeFactory := &MockMachineRuntimeFactory{ + RuntimeToReturn: mockRuntime, + ErrorToReturn: nil, + } + realFactory := &realMachineInstanceFactory{runtimeFactory: runtimeFactory} + manager := NewMachineManager(repo, testLogger, false, 500, WithInstanceFactory(realFactory)) + + err := manager.UpdateMachines(context.Background()) + require.NoError(err) + // Machine should NOT have been added due to sync failure + require.False(manager.HasMachine(1)) + }) +} + +func (s *MachineManagerSuite) TestCloseAggregatesErrors() { + require := s.Require() + + manager := NewMachineManager(nil, nil, false, 500) + + machine1 := &DummyMachineInstanceMock{application: &model.Application{ID: 1}} + machine2 := &DummyMachineInstanceMock{ + application: &model.Application{ID: 2}, + closeError: errors.New("close error 2"), + } + machine3 := &DummyMachineInstanceMock{ + application: &model.Application{ID: 3}, + closeError: errors.New("close error 3"), + } + + manager.addMachine(1, machine1) + manager.addMachine(2, machine2) + manager.addMachine(3, machine3) + + err := manager.Close() + require.Error(err) + require.Contains(err.Error(), "close error") + require.Empty(manager.machines) +} + func (s *MachineManagerSuite) TestApplications() { require := s.Require() @@ -227,7 +383,7 @@ func (s *MachineManagerSuite) TestApplications() { repo.On("GetLastSnapshot", mock.Anything, mock.Anything). Return(nil, nil) - manager := NewMachineManager(context.Background(), repo, nil, false) + manager := NewMachineManager(repo, nil, false, 500) // Add machines app1 := &model.Application{ID: 1, Name: "App1"} @@ -293,9 +449,54 @@ func (m *MockMachineRepository) GetLastSnapshot( // ------------------------------------------------------------------------------------------------ +// MockMachineInstanceFactory implements MachineInstanceFactory for testing. +// It returns the same instance for every call, ignoring the app/path arguments. +type MockMachineInstanceFactory struct { + Instance MachineInstance + Err error +} + +func (f *MockMachineInstanceFactory) NewFromTemplate( + _ context.Context, _ *model.Application, _ *slog.Logger, _ bool, +) (MachineInstance, error) { + return f.Instance, f.Err +} + +func (f *MockMachineInstanceFactory) NewFromSnapshot( + _ context.Context, _ *model.Application, _ *slog.Logger, _ bool, + _ string, _ *common.Hash, _ uint64, +) (MachineInstance, error) { + return f.Instance, f.Err +} + +// realMachineInstanceFactory builds real MachineInstanceImpl values using the +// provided MachineRuntimeFactory. This lets tests exercise the real Synchronize +// path while still mocking the machine runtime. Unlike MockMachineInstanceFactory, +// snapshot path and hash are ignored — it always creates from the runtime factory. +type realMachineInstanceFactory struct { + runtimeFactory MachineRuntimeFactory +} + +func (f *realMachineInstanceFactory) NewFromTemplate( + ctx context.Context, app *model.Application, logger *slog.Logger, checkHash bool, +) (MachineInstance, error) { + return NewMachineInstanceWithFactory(ctx, app, 0, logger, checkHash, f.runtimeFactory) +} + +func (f *realMachineInstanceFactory) NewFromSnapshot( + ctx context.Context, app *model.Application, logger *slog.Logger, checkHash bool, + _ string, _ *common.Hash, inputIndex uint64, +) (MachineInstance, error) { + return NewMachineInstanceWithFactory(ctx, app, inputIndex+1, logger, checkHash, f.runtimeFactory) +} + +// ------------------------------------------------------------------------------------------------ + // DummyMachineInstanceMock implements the MachineInstance interface for testing type DummyMachineInstanceMock struct { - application *model.Application + application *model.Application + closeError error + synchronizeErr error } func (m *DummyMachineInstanceMock) Application() *model.Application { @@ -306,7 +507,7 @@ func (m *DummyMachineInstanceMock) ProcessedInputs() uint64 { return 0 } -func (m *DummyMachineInstanceMock) OutputsProof(ctx context.Context, processedInputs uint64) (*model.OutputsProof, error) { +func (m *DummyMachineInstanceMock) OutputsProof(ctx context.Context) (*model.OutputsProof, error) { return nil, nil } @@ -318,8 +519,8 @@ func (m *DummyMachineInstanceMock) Inspect(_ context.Context, _ []byte) (*model. return nil, nil } -func (m *DummyMachineInstanceMock) Synchronize(_ context.Context, _ MachineRepository) error { - return nil +func (m *DummyMachineInstanceMock) Synchronize(_ context.Context, _ MachineRepository, _ uint64) error { + return m.synchronizeErr } func (m *DummyMachineInstanceMock) CreateSnapshot(_ context.Context, _ uint64, _ string) error { @@ -331,5 +532,5 @@ func (m *DummyMachineInstanceMock) Hash(_ context.Context) ([32]byte, error) { } func (m *DummyMachineInstanceMock) Close() error { - return nil + return m.closeError } diff --git a/internal/manager/pmutex/pmutex_test.go b/internal/manager/pmutex/pmutex_test.go index 12fa1f367..4305ac085 100644 --- a/internal/manager/pmutex/pmutex_test.go +++ b/internal/manager/pmutex/pmutex_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -41,15 +40,39 @@ func (s *PMutexSuite) TestSingleLLock() { } func (s *PMutexSuite) TestContestedHLock() { - require := s.Require() s.mutex.LLock() - never(require, func() bool { s.mutex.HLock(); return true }) + acquired := make(chan struct{}) + go func() { + s.mutex.HLock() + close(acquired) + }() + select { + case <-acquired: + s.Fail("HLock should not be acquired while LLock is held") + case <-time.After(decisecond): + // Expected: HLock is blocked. + } + // Clean up: unlock so the goroutine can finish. + s.mutex.Unlock() + <-acquired } func (s *PMutexSuite) TestContestedLLock() { - require := s.Require() s.mutex.HLock() - never(require, func() bool { s.mutex.LLock(); return true }) + acquired := make(chan struct{}) + go func() { + s.mutex.LLock() + close(acquired) + }() + select { + case <-acquired: + s.Fail("LLock should not be acquired while HLock is held") + case <-time.After(decisecond): + // Expected: LLock is blocked. + } + // Clean up: unlock so the goroutine can finish. + s.mutex.Unlock() + <-acquired } func (s *PMutexSuite) TestPriority() { @@ -105,11 +128,4 @@ func (s *PMutexSuite) TestPriority() { // ------------------------------------------------------------------------------------------------ -const ( - centisecond = 10 * time.Millisecond - decisecond = 100 * time.Millisecond -) - -func never(require *require.Assertions, f func() bool) { - require.Never(f, decisecond, centisecond) -} +const decisecond = 100 * time.Millisecond diff --git a/internal/manager/types.go b/internal/manager/types.go index a6ad9a0b7..d72e1f97a 100644 --- a/internal/manager/types.go +++ b/internal/manager/types.go @@ -14,11 +14,11 @@ type MachineInstance interface { Application() *Application Advance(ctx context.Context, input []byte, epochIndex uint64, inputIndex uint64, computeHashes bool) (*AdvanceResult, error) Inspect(ctx context.Context, query []byte) (*InspectResult, error) - Synchronize(ctx context.Context, repo MachineRepository) error + Synchronize(ctx context.Context, repo MachineRepository, batchSize uint64) error CreateSnapshot(ctx context.Context, processedInputs uint64, path string) error ProcessedInputs() uint64 Hash(ctx context.Context) ([32]byte, error) - OutputsProof(ctx context.Context, processedInputs uint64) (*OutputsProof, error) + OutputsProof(ctx context.Context) (*OutputsProof, error) Close() error } @@ -35,4 +35,7 @@ type MachineProvider interface { // HasMachine checks if a machine exists for the given application ID HasMachine(appID int64) bool + + // Close shuts down all machine instances and releases resources + Close() error } diff --git a/internal/model/models.go b/internal/model/models.go index 82bcb6d38..d746a0a08 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -808,6 +808,7 @@ const ( InputCompletionStatus_Exception InputCompletionStatus = "EXCEPTION" InputCompletionStatus_MachineHalted InputCompletionStatus = "MACHINE_HALTED" InputCompletionStatus_OutputsLimitExceeded InputCompletionStatus = "OUTPUTS_LIMIT_EXCEEDED" + InputCompletionStatus_ReportsLimitExceeded InputCompletionStatus = "REPORTS_LIMIT_EXCEEDED" InputCompletionStatus_CycleLimitExceeded InputCompletionStatus = "CYCLE_LIMIT_EXCEEDED" InputCompletionStatus_TimeLimitExceeded InputCompletionStatus = "TIME_LIMIT_EXCEEDED" InputCompletionStatus_PayloadLengthLimitExceeded InputCompletionStatus = "PAYLOAD_LENGTH_LIMIT_EXCEEDED" @@ -819,6 +820,8 @@ var InputCompletionStatusAllValues = []InputCompletionStatus{ InputCompletionStatus_Rejected, InputCompletionStatus_Exception, InputCompletionStatus_MachineHalted, + InputCompletionStatus_OutputsLimitExceeded, + InputCompletionStatus_ReportsLimitExceeded, InputCompletionStatus_CycleLimitExceeded, InputCompletionStatus_TimeLimitExceeded, InputCompletionStatus_PayloadLengthLimitExceeded, @@ -846,6 +849,10 @@ func (e *InputCompletionStatus) Scan(value any) error { *e = InputCompletionStatus_Exception case "MACHINE_HALTED": *e = InputCompletionStatus_MachineHalted + case "OUTPUTS_LIMIT_EXCEEDED": + *e = InputCompletionStatus_OutputsLimitExceeded + case "REPORTS_LIMIT_EXCEEDED": + *e = InputCompletionStatus_ReportsLimitExceeded case "CYCLE_LIMIT_EXCEEDED": *e = InputCompletionStatus_CycleLimitExceeded case "TIME_LIMIT_EXCEEDED": diff --git a/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go b/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go index f18248333..a9624fcb2 100644 --- a/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go +++ b/internal/repository/postgres/db/rollupsdb/public/enum/inputcompletionstatus.go @@ -16,6 +16,7 @@ var InputCompletionStatus = &struct { Exception postgres.StringExpression MachineHalted postgres.StringExpression OutputsLimitExceeded postgres.StringExpression + ReportsLimitExceeded postgres.StringExpression CycleLimitExceeded postgres.StringExpression TimeLimitExceeded postgres.StringExpression PayloadLengthLimitExceeded postgres.StringExpression @@ -26,6 +27,7 @@ var InputCompletionStatus = &struct { Exception: postgres.NewEnumValue("EXCEPTION"), MachineHalted: postgres.NewEnumValue("MACHINE_HALTED"), OutputsLimitExceeded: postgres.NewEnumValue("OUTPUTS_LIMIT_EXCEEDED"), + ReportsLimitExceeded: postgres.NewEnumValue("REPORTS_LIMIT_EXCEEDED"), CycleLimitExceeded: postgres.NewEnumValue("CYCLE_LIMIT_EXCEEDED"), TimeLimitExceeded: postgres.NewEnumValue("TIME_LIMIT_EXCEEDED"), PayloadLengthLimitExceeded: postgres.NewEnumValue("PAYLOAD_LENGTH_LIMIT_EXCEEDED"), diff --git a/internal/repository/postgres/repository_error_test.go b/internal/repository/postgres/repository_error_test.go index 4394aab87..ea8d6c452 100644 --- a/internal/repository/postgres/repository_error_test.go +++ b/internal/repository/postgres/repository_error_test.go @@ -39,7 +39,7 @@ func TestNewPostgresRepository_ContextCancelledDuringRetry(t *testing.T) { _, err := postgres.NewPostgresRepository( ctx, "postgres://user:pass@localhost:1/testdb?connect_timeout=1", - 100, // many retries — we won't exhaust them + 100, // many retries — we won't exhaust them 10*time.Second, // long delay — context expires before this elapses ) require.Error(t, err) diff --git a/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql b/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql index 7680105b9..b965b7969 100644 --- a/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql +++ b/internal/repository/postgres/schema/migrations/000001_create_initial_schema.up.sql @@ -17,6 +17,7 @@ CREATE TYPE "InputCompletionStatus" AS ENUM ( 'EXCEPTION', 'MACHINE_HALTED', 'OUTPUTS_LIMIT_EXCEEDED', + 'REPORTS_LIMIT_EXCEEDED', 'CYCLE_LIMIT_EXCEEDED', 'TIME_LIMIT_EXCEEDED', 'PAYLOAD_LENGTH_LIMIT_EXCEEDED'); diff --git a/pkg/machine/implementation.go b/pkg/machine/implementation.go index c0e582107..19e867060 100644 --- a/pkg/machine/implementation.go +++ b/pkg/machine/implementation.go @@ -44,8 +44,9 @@ const ( ManualYieldReasonException manualYieldReason = 0x4 ) -// Constants +// Limits for outputs and reports per input const maxOutputs = 65536 // 2^16 +const maxReports = 65536 // 2^16 const CheckpointAddress uint64 = 0x7ffff000 const TxBufferAddress uint64 = 0x60800000 @@ -175,22 +176,34 @@ func (m *machineImpl) WriteCheckpointHash(ctx context.Context, hash Hash) error } // Advance sends an input to the machine and processes it -func (m *machineImpl) Advance(ctx context.Context, input []byte, computeHashes bool) (bool, []Output, []Report, []Hash, uint64, Hash, error) { - outputsHash := Hash{} +func (m *machineImpl) Advance(ctx context.Context, input []byte, computeHashes bool) (*AdvanceResponse, error) { // TODO: return the exception reason accepted, outputs, reports, hashes, remaining, data, err := m.process(ctx, input, AdvanceStateRequest, computeHashes) if err != nil { - return accepted, outputs, reports, hashes, remaining, outputsHash, err + return &AdvanceResponse{ + Accepted: accepted, + Outputs: outputs, + Reports: reports, + Hashes: hashes, + RemainingCycles: remaining, + }, err + } + + resp := &AdvanceResponse{ + Accepted: accepted, + Outputs: outputs, + Reports: reports, + Hashes: hashes, + RemainingCycles: remaining, } if accepted { if length := len(data); length != HashSize { - err = fmt.Errorf("%w (it has %d bytes)", ErrHashLength, length) - return accepted, outputs, reports, hashes, remaining, outputsHash, err + return resp, fmt.Errorf("%w (it has %d bytes)", ErrHashLength, length) } - copy(outputsHash[:], data) + copy(resp.OutputsHash[:], data) } - return accepted, outputs, reports, hashes, remaining, outputsHash, nil + return resp, nil } // Inspect sends a query to the machine and returns the results @@ -275,7 +288,8 @@ func (m *machineImpl) wasLastRequestAccepted(ctx context.Context) (bool, []byte, case ManualYieldReasonException: return false, data, ErrException default: - panic("unreachable code: invalid manual yield reason") + err = fmt.Errorf("invalid manual yield reason: %d: %w", yieldReason, ErrMachineInternal) + return false, nil, err } } @@ -351,8 +365,8 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe "limitCycle", limitCycle, "leftover", limitCycle-currentCycle) - outputs := []Output{} - reports := []Report{} + outputs := make([]Output, 0, 16) //nolint:mnd + reports := make([]Report, 0, 16) //nolint:mnd var hashCollectorState *HashCollectorState if computeHashes { @@ -383,6 +397,9 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe // Steps the machine as many times as needed until it manually/automatically yields. for yt == nil { + if err := checkContext(ctx); err != nil { + return outputs, reports, hashes(), remainingMetaCycles(), err + } if time.Since(startTime) > runTimeout { werr := fmt.Errorf("run operation timed out: %w", ErrDeadlineExceeded) return outputs, reports, hashes(), remainingMetaCycles(), werr @@ -400,10 +417,15 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe // Asserts the machine yielded automatically. if *yt != AutomaticYield { - panic("unreachable code: invalid yield type") + err := fmt.Errorf("invalid yield type: %d: %w", *yt, ErrMachineInternal) + return outputs, reports, hashes(), remainingMetaCycles(), err } yt = nil + if err := checkContext(ctx); err != nil { + return outputs, reports, hashes(), remainingMetaCycles(), err + } + _, yieldReason, data, err := m.backend.ReceiveCmioRequest(m.params.FastDeadline) if err != nil { werr := fmt.Errorf("could not read output/report: %w", err) @@ -420,9 +442,13 @@ func (m *machineImpl) run(ctx context.Context, reqType requestType, computeHashe } outputs = append(outputs, data) case AutomaticYieldReasonReport: + if len(reports) == maxReports { + return outputs, reports, hashes(), remainingMetaCycles(), ErrReportsLimitExceeded + } reports = append(reports, data) default: - panic("unreachable code: invalid automatic yield reason") + err := fmt.Errorf("invalid automatic yield reason: %d: %w", yieldReason, ErrMachineInternal) + return outputs, reports, hashes(), remainingMetaCycles(), err } } } @@ -483,9 +509,10 @@ func (m *machineImpl) runIncrementInterval(ctx context.Context, case Halted: return nil, currentCycle, ErrHalted case Failed: - fallthrough // covered by backend.Run() err + return nil, currentCycle, ErrMachineInternal default: - panic("unreachable code: invalid break reason") + err := fmt.Errorf("invalid break reason: %d: %w", breakReason, ErrMachineInternal) + return nil, currentCycle, err } } diff --git a/pkg/machine/implementation_test.go b/pkg/machine/implementation_test.go index dae8aa1de..60cf9a567 100644 --- a/pkg/machine/implementation_test.go +++ b/pkg/machine/implementation_test.go @@ -201,6 +201,100 @@ func (s *ImplementationSuite) TestOutputsHash() { mockBackend4.AssertExpectations(s.T()) } +// Test OutputsHashProof method +func (s *ImplementationSuite) TestOutputsHashProof() { + require := s.Require() + ctx := context.Background() + + // Test successful outputs hash proof retrieval + mockBackend := NewMockBackend() + expectedProof := []Hash{randomFakeHash(), randomFakeHash(), randomFakeHash()} + mockBackend.On("GetProof", TxBufferAddress, int32(HashLog2Size), mock.AnythingOfType("time.Duration")). + Return(expectedProof, nil) + + machine := &machineImpl{ + backend: mockBackend, + logger: s.logger, + params: model.ExecutionParameters{ + LoadDeadline: time.Second * 5, + }, + } + + proof, err := machine.OutputsHashProof(ctx) + require.NoError(err) + require.Equal(expectedProof, proof) + mockBackend.AssertExpectations(s.T()) + + // Test outputs hash proof with backend error + mockBackend2 := NewMockBackend() + mockBackend2.On("GetProof", TxBufferAddress, int32(HashLog2Size), mock.AnythingOfType("time.Duration")). + Return([]Hash(nil), errors.New("proof failed")) + machine2 := &machineImpl{ + backend: mockBackend2, + logger: s.logger, + params: model.ExecutionParameters{ + LoadDeadline: time.Second * 5, + }, + } + _, err = machine2.OutputsHashProof(ctx) + require.Error(err) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "could not get outputs hash machine proof") + mockBackend2.AssertExpectations(s.T()) + + // Test outputs hash proof with canceled context + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + _, err = machine.OutputsHashProof(canceledCtx) + require.ErrorIs(err, ErrCanceled) +} + +// Test WriteCheckpointHash method +func (s *ImplementationSuite) TestWriteCheckpointHash() { + require := s.Require() + ctx := context.Background() + + // Test successful write + mockBackend := NewMockBackend() + hash := randomFakeHash() + mockBackend.On("WriteMemory", CheckpointAddress, hash[:], mock.AnythingOfType("time.Duration")). + Return(nil) + machine := &machineImpl{ + backend: mockBackend, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + }, + } + + err := machine.WriteCheckpointHash(ctx, hash) + require.NoError(err) + mockBackend.AssertExpectations(s.T()) + + // Test write with backend error + mockBackend2 := NewMockBackend() + mockBackend2.On("WriteMemory", CheckpointAddress, hash[:], mock.AnythingOfType("time.Duration")). + Return(errors.New("write failed")) + machine2 := &machineImpl{ + backend: mockBackend2, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + }, + } + err = machine2.WriteCheckpointHash(ctx, hash) + require.Error(err) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "could not write checkpoint hash") + mockBackend2.AssertExpectations(s.T()) + + // Test write with canceled context + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + err = machine.WriteCheckpointHash(canceledCtx, hash) + require.ErrorIs(err, ErrCanceled) +} + // Test Advance method func (s *ImplementationSuite) TestAdvance() { require := s.Require() @@ -223,12 +317,12 @@ func (s *ImplementationSuite) TestAdvance() { } input := []byte("test input") - accepted, outputs, reports, _, _, hash, err := machine.Advance(ctx, input, false) + resp, err := machine.Advance(ctx, input, false) require.NoError(err) - require.True(accepted) - require.Empty(outputs) - require.Empty(reports) - require.NotEqual(Hash{}, hash) + require.True(resp.Accepted) + require.Empty(resp.Outputs) + require.Empty(resp.Reports) + require.NotEqual(Hash{}, resp.OutputsHash) mockBackend.AssertExpectations(s.T()) // Test advance with rejection @@ -245,12 +339,12 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - accepted, outputs, reports, _, _, hash, err = machine2.Advance(ctx, input, false) + resp, err = machine2.Advance(ctx, input, false) require.NoError(err) - require.False(accepted) - require.Empty(outputs) - require.Empty(reports) - require.Equal(Hash{}, hash) + require.False(resp.Accepted) + require.Empty(resp.Outputs) + require.Empty(resp.Reports) + require.Equal(Hash{}, resp.OutputsHash) mockBackend2.AssertExpectations(s.T()) // Test advance with exception @@ -267,10 +361,10 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - accepted, outputs, reports, _, _, hash, err = machine3.Advance(ctx, input, false) + resp, err = machine3.Advance(ctx, input, false) require.ErrorIs(err, ErrException) - require.False(accepted) - require.Equal(Hash{}, hash) + require.False(resp.Accepted) + require.Equal(Hash{}, resp.OutputsHash) mockBackend3.AssertExpectations(s.T()) // Test advance with payload too large @@ -288,7 +382,7 @@ func (s *ImplementationSuite) TestAdvance() { }, } largeInput := make([]byte, 10) - _, _, _, _, _, _, err = machine4.Advance(ctx, largeInput, false) + _, err = machine4.Advance(ctx, largeInput, false) require.ErrorIs(err, ErrPayloadLengthLimitExceeded) mockBackend4.AssertExpectations(s.T()) @@ -311,7 +405,7 @@ func (s *ImplementationSuite) TestAdvance() { AdvanceMaxDeadline: time.Second * 10, }, } - _, _, _, _, _, _, err = machine5.Advance(ctx, input, false) + _, err = machine5.Advance(ctx, input, false) require.Error(err) require.ErrorIs(err, ErrHashLength) mockBackend5.AssertExpectations(s.T()) @@ -705,6 +799,89 @@ func (s *ImplementationSuite) TestRun() { require.NoError(err) mockBackend3.AssertExpectations(s.T()) + // Test run with ReceiveCmioRequest error during automatic yield + mockBackend4 := NewMockBackend() + mockBackend4.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend4.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")).Return(YieldedAutomatically, nil) + mockBackend4.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(0), []byte(nil), errors.New("cmio request failed")) + + machine4 := &machineImpl{ + backend: mockBackend4, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + _, _, _, _, err = machine4.run(ctx, AdvanceStateRequest, false) + require.Error(err) + require.Contains(err.Error(), "could not read output/report") + require.Contains(err.Error(), "cmio request failed") + mockBackend4.AssertExpectations(s.T()) + + // Test run with automatic yield producing output then manual yield + mockBackend5 := NewMockBackend() + mockBackend5.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend5.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedAutomatically, nil).Once() + mockBackend5.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedManually, nil).Once() + mockBackend5.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(AutomaticYieldReasonOutput), []byte("output data"), nil).Once() + + machine5 := &machineImpl{ + backend: mockBackend5, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + outputs5, reports5, _, _, err := machine5.run(ctx, AdvanceStateRequest, false) + require.NoError(err) + require.Len(outputs5, 1) + require.Equal([]byte("output data"), []byte(outputs5[0])) + require.Empty(reports5) + mockBackend5.AssertExpectations(s.T()) + + // Test run with automatic yield producing report then manual yield + mockBackend6 := NewMockBackend() + mockBackend6.On("ReadMCycle", mock.AnythingOfType("time.Duration")).Return(uint64(0), nil) + mockBackend6.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedAutomatically, nil).Once() + mockBackend6.On("Run", mock.AnythingOfType("uint64"), mock.AnythingOfType("time.Duration")). + Return(YieldedManually, nil).Once() + mockBackend6.On("ReceiveCmioRequest", mock.AnythingOfType("time.Duration")).Return( + uint8(0), uint16(AutomaticYieldReasonReport), []byte("report data"), nil).Once() + + machine6 := &machineImpl{ + backend: mockBackend6, + logger: s.logger, + params: model.ExecutionParameters{ + FastDeadline: time.Second * 5, + AdvanceMaxCycles: 1000, + AdvanceIncCycles: 100, + AdvanceIncDeadline: time.Second * 1, + AdvanceMaxDeadline: time.Second * 10, + }, + } + + outputs6, reports6, _, _, err := machine6.run(ctx, AdvanceStateRequest, false) + require.NoError(err) + require.Empty(outputs6) + require.Len(reports6, 1) + require.Equal([]byte("report data"), []byte(reports6[0])) + mockBackend6.AssertExpectations(s.T()) + } // Test step method diff --git a/pkg/machine/libcartesi.go b/pkg/machine/libcartesi.go index 5def485ad..7d396f1de 100644 --- a/pkg/machine/libcartesi.go +++ b/pkg/machine/libcartesi.go @@ -140,7 +140,6 @@ func (e *LibCartesiBackend) GetProof(address uint64, log2size int32, timeout tim proof := &proofJson{} err = json.Unmarshal([]byte(jsonMessage), proof) if err != nil { - println("Failed to unmarshal proof JSON:", err.Error()) return nil, fmt.Errorf("failed to unmarshal proof JSON: %w", err) } return proof.Siblings, nil diff --git a/pkg/machine/machine.go b/pkg/machine/machine.go index 2e29e3164..ccf8cb237 100644 --- a/pkg/machine/machine.go +++ b/pkg/machine/machine.go @@ -27,17 +27,28 @@ type ( Hash = [HashSize]byte ) +// AdvanceResponse contains the result of an advance operation. +type AdvanceResponse struct { + Accepted bool + Outputs []Output + Reports []Report + Hashes []Hash + RemainingCycles uint64 + OutputsHash Hash +} + // Common errors var ( ErrMachineInternal = errors.New("machine internal error") - ErrDeadlineExceeded = errors.New("machine operation deadline exceeded") - ErrCanceled = errors.New("machine operation canceled") + ErrDeadlineExceeded = fmt.Errorf("machine operation deadline exceeded: %w", context.DeadlineExceeded) + ErrCanceled = fmt.Errorf("machine operation canceled: %w", context.Canceled) ErrOrphanServer = errors.New("machine server was left orphan") ErrNotAtManualYield = errors.New("not at manual yield") ErrException = errors.New("last request yielded an exception") ErrRejected = errors.New("last request yielded as rejected") ErrHalted = errors.New("machine halted") ErrOutputsLimitExceeded = errors.New("outputs limit exceeded") + ErrReportsLimitExceeded = errors.New("reports limit exceeded") ErrReachedTargetMcycle = errors.New("machine reached target mcycle") ErrPayloadLengthLimitExceeded = errors.New("payload length limit exceeded") ErrHashLength = errors.New("hash does not have the exactly number of bytes") @@ -60,10 +71,11 @@ type Machine interface { WriteCheckpointHash(ctx context.Context, hash Hash) error // Advance sends an input to the machine. - // It returns a boolean indicating whether or not the request was accepted. - // It also returns the corresponding outputs, reports, and the hash of the outputs. - // In case the request is not accepted, the function does not return outputs. - Advance(ctx context.Context, input []byte, computeHashes bool) (bool, []Output, []Report, []Hash, uint64, Hash, error) + // It always returns a non-nil AdvanceResponse, even on error paths. + // The response contains whether the request was accepted, + // the corresponding outputs, reports, and the hash of the outputs. + // In case the request is not accepted, the response does not contain outputs. + Advance(ctx context.Context, input []byte, computeHashes bool) (*AdvanceResponse, error) // Inspect sends a query to the machine. // It returns a boolean indicating whether or not the request was accepted diff --git a/pkg/machine/machine_test.go b/pkg/machine/machine_test.go index 868acd8cc..1d613d9ae 100644 --- a/pkg/machine/machine_test.go +++ b/pkg/machine/machine_test.go @@ -115,6 +115,25 @@ func (s *MachineSuite) TestLoad() { require.ErrorIs(err, ErrNotAtManualYield) mockBackend.AssertExpectations(s.T()) + // Test with IsAtManualYield returning error + mockBackend = NewMockBackend() + mockBackend.On("NewMachineRuntimeConfig").Return(`{"concurrency":{"update_merkle_tree":1}}`, nil) + mockBackend.On("Load", + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + mock.AnythingOfType("time.Duration"), + ).Return(nil).Once() + mockBackend.On("IsAtManualYield", mock.AnythingOfType("time.Duration")).Return(false, errors.New("yield check failed")) + mockBackend.SetupForCleanup() + config = DefaultConfig("some/path") + config.BackendFactoryFn = MockBackendFactory(mockBackend) + machine, err = Load(ctx, s.logger, config) + require.Error(err) + require.Nil(machine) + require.ErrorIs(err, ErrMachineInternal) + require.Contains(err.Error(), "yield check failed") + mockBackend.AssertExpectations(s.T()) + // Test with machine last request not accepted mockBackend = NewMockBackend() mockBackend.On("NewMachineRuntimeConfig").Return(`{"concurrency":{"update_merkle_tree":1}}`, nil) @@ -214,15 +233,15 @@ func (s *MachineSuite) TestMachineInterface() { require.Equal(Hash{6, 7, 8, 9, 10}, outputsHash) // Test Advance - accepted, outputs, reports, _, _, advanceHash, err := machine.Advance(ctx, []byte("input"), false) + advanceResp, err := machine.Advance(ctx, []byte("input"), false) require.NoError(err) - require.True(accepted) - require.Len(outputs, 2) - require.Equal([]byte("output1"), outputs[0]) - require.Equal([]byte("output2"), outputs[1]) - require.Len(reports, 1) - require.Equal([]byte("report1"), reports[0]) - require.Equal(Hash{11, 12, 13, 14, 15}, advanceHash) + require.True(advanceResp.Accepted) + require.Len(advanceResp.Outputs, 2) + require.Equal([]byte("output1"), advanceResp.Outputs[0]) + require.Equal([]byte("output2"), advanceResp.Outputs[1]) + require.Len(advanceResp.Reports, 1) + require.Equal([]byte("report1"), advanceResp.Reports[0]) + require.Equal(Hash{11, 12, 13, 14, 15}, advanceResp.OutputsHash) // Test Inspect accepted, inspectReports, err := machine.Inspect(ctx, []byte("query")) @@ -279,7 +298,7 @@ func (s *MachineSuite) TestMachineInterfaceErrors() { require.Contains(err.Error(), "outputs hash error") // Test Advance error - _, _, _, _, _, _, err = machine.Advance(ctx, []byte("input"), false) + _, err = machine.Advance(ctx, []byte("input"), false) require.Error(err) require.Contains(err.Error(), "advance error") @@ -354,16 +373,15 @@ func (m *MockMachine) WriteCheckpointHash(_ context.Context, _ Hash) error { return m.CheckpointHashError } -func (m *MockMachine) Advance(_ context.Context, _ []byte, _ bool) ( - bool, []Output, []Report, []Hash, uint64, Hash, error, -) { - return m.AdvanceAcceptedReturn, - m.AdvanceOutputsReturn, - m.AdvanceReportsReturn, - m.AdvanceHashesReturn, - m.AdvanceRemainingReturn, - m.AdvanceHashReturn, - m.AdvanceError +func (m *MockMachine) Advance(_ context.Context, _ []byte, _ bool) (*AdvanceResponse, error) { + return &AdvanceResponse{ + Accepted: m.AdvanceAcceptedReturn, + Outputs: m.AdvanceOutputsReturn, + Reports: m.AdvanceReportsReturn, + Hashes: m.AdvanceHashesReturn, + RemainingCycles: m.AdvanceRemainingReturn, + OutputsHash: m.AdvanceHashReturn, + }, m.AdvanceError } func (m *MockMachine) Inspect(_ context.Context,