diff --git a/chat/src/app/layout.tsx b/chat/src/app/layout.tsx index 830124ed..7c44c440 100644 --- a/chat/src/app/layout.tsx +++ b/chat/src/app/layout.tsx @@ -29,7 +29,7 @@ export default function RootLayout({ disableTransitionOnChange > {children} - + diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 21a2ee3f..e8789ac0 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -36,6 +36,12 @@ interface StatusChangeEvent { agent_type: string; } +interface ErrorEventData { + message: string; + level: string; + time: string; +} + interface APIErrorDetail { location: string; message: string; @@ -215,6 +221,25 @@ export function ChatProvider({ children }: PropsWithChildren) { setAgentType(data.agent_type === "" ? "unknown" : data.agent_type as AgentType); }); + // Handle agent error events + eventSource.addEventListener("agent_error", (event) => { + const messageEvent = event as MessageEvent; + try { + const data: ErrorEventData = JSON.parse(messageEvent.data); + + // Display error as toast notification that persists until manually dismissed + if (data.level === "error") { + toast.error(data.message, { duration: Infinity }); + } else if (data.level === "warning") { + toast.warning(data.message, { duration: Infinity }); + } else { + toast.info(data.message, { duration: Infinity }); + } + } catch (e) { + console.error("Failed to parse agent_error event data:", e); + } + }); + // Handle connection open (server is online) eventSource.onopen = () => { // Connection is established, but we'll wait for status_change event diff --git a/cmd/server/server.go b/cmd/server/server.go index 6d5cdec3..e578a17d 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,9 +8,14 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" + "strconv" "strings" + "syscall" + "time" + "github.com/coder/agentapi/lib/screentracker" "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -103,6 +108,43 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(FlagStateFile) + loadState := false + saveState := false + + // Validate state file configuration + if stateFile != "" { + if !viper.IsSet(FlagLoadState) { + loadState = true + } else { + loadState = viper.GetBool(FlagLoadState) + } + + if !viper.IsSet(FlagSaveState) { + saveState = true + } else { + saveState = viper.GetBool(FlagSaveState) + } + } else { + if viper.IsSet(FlagLoadState) && viper.GetBool(FlagLoadState) { + return xerrors.Errorf("--load-state requires --state-file to be set") + } + if viper.IsSet(FlagSaveState) && viper.GetBool(FlagSaveState) { + return xerrors.Errorf("--save-state requires --state-file to be set") + } + } + + pidFile := viper.GetString(FlagPidFile) + + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + defer cleanupPIDFile(pidFile, logger) + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +170,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceConfig: screentracker.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -136,10 +184,21 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } + + // Create a context for graceful shutdown + gracefulCtx, gracefulCancel := context.WithCancel(ctx) + defer gracefulCancel() + + // Setup signal handlers (they will call gracefulCancel) + handleSignals(gracefulCtx, gracefulCancel, logger, srv) + logger.Info("Starting server on port", "port", port) + + // Monitor process exit processExitCh := make(chan error, 1) go func() { defer close(processExitCh) + defer gracefulCancel() if err := process.Wait(); err != nil { if errors.Is(err, termexec.ErrNonZeroExitCode) { processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) @@ -147,17 +206,46 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) + }() + + // Start the server + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { + serverErrCh <- err } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { - return xerrors.Errorf("failed to start server: %w", err) + + select { + case err := <-serverErrCh: + if err != nil { + return xerrors.Errorf("failed to start server: %w", err) + } + case <-gracefulCtx.Done(): } + + if err := srv.SaveState("shutdown"); err != nil { + logger.Error("Failed to save state during shutdown", "error", err) + } + + // Stop the HTTP server + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "error", err) + } + select { case err := <-processExitCh: - return xerrors.Errorf("agent exited with error: %w", err) + if err != nil { + return xerrors.Errorf("agent exited with error: %w", err) + } default: + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "error", err) + } } return nil } @@ -171,6 +259,58 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Check if PID file already exists + if existingPIDData, err := os.ReadFile(pidFile); err == nil { + existingPIDStr := strings.TrimSpace(string(existingPIDData)) + if existingPID, err := strconv.Atoi(existingPIDStr); err == nil { + if isProcessRunning(existingPID) { + return xerrors.Errorf("another instance is already running with PID %d (PID file: %s)", existingPID, pidFile) + } + logger.Warn("Found stale PID file, will overwrite", "pidFile", pidFile, "stalePID", existingPID) + } + } else if !os.IsNotExist(err) { + return xerrors.Errorf("failed to read existing PID file: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o600); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + +// isProcessRunning checks if a process with the given PID is running +func isProcessRunning(pid int) bool { + process, err := os.FindProcess(pid) + if err != nil { + return false + } + err = process.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) +} + type flagSpec struct { name string shorthand string @@ -190,6 +330,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + FlagStateFile = "state-file" + FlagLoadState = "load-state" + FlagSaveState = "save-state" + FlagPidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -228,6 +372,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {FlagStateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {FlagLoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {FlagSaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {FlagPidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index bd07fc63..7b9372c1 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -2,6 +2,8 @@ package server import ( "fmt" + "io" + "log/slog" "os" "strings" "testing" @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) { } } +func TestServerCmd_StatePersistenceFlags(t *testing.T) { + // NOTE: These tests use --exit flag to test flag parsing and defaults. + // Runtime validation that happens in runServer (e.g., "--load-state requires --state-file") + // would call os.Exit(1) which terminates the test process, so those validations + // are tested through integration/E2E tests instead. + + t.Run("state-file with defaults", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + // load-state and save-state default to true when state-file is set (validated in runServer) + }) + + t.Run("state-file with explicit load-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagLoadState)) + }) + + t.Run("state-file with explicit save-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagSaveState)) + }) + + t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--load-state=true", + "--save-state=true", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) + }) + + t.Run("load-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) + }) + + t.Run("save-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) + }) + + t.Run("pid-file can be set independently", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) + }) + + t.Run("state-file and pid-file can be set together", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--pid-file", "/tmp/server.pid", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) + }) +} + +func TestPIDFileOperations(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("writePIDFile creates file with process ID", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify content contains current PID + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("writePIDFile creates directory if not exists", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nested/deep/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify directory was created + _, err = os.Stat(tmpDir + "/nested/deep") + require.NoError(t, err) + }) + + t.Run("writePIDFile overwrites existing file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write initial PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Overwrite with current PID + err = writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify content is updated + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("cleanupPIDFile removes file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Create PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Cleanup + cleanupPIDFile(pidFile, discardLogger) + + // Verify file is removed + _, err = os.Stat(pidFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nonexistent.pid" + + // Should not panic or error + cleanupPIDFile(pidFile, discardLogger) + }) + + t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) { + // Create a file in a protected directory (this is system-dependent) + // Just verify it doesn't panic when it can't remove the file + pidFile := "/this/should/not/exist/test.pid" + + // Should not panic + cleanupPIDFile(pidFile, discardLogger) + }) +} + func TestServerCmd_AllowedOrigins(t *testing.T) { tests := []struct { name string diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go new file mode 100644 index 00000000..b15b5b2b --- /dev/null +++ b/cmd/server/signals_unix.go @@ -0,0 +1,46 @@ +//go:build unix + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/httpapi" +) + +// handleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: trigger graceful shutdown by canceling the context +// - SIGUSR1: save conversation state without exiting +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + logger.Info("Received shutdown signal", "signal", sig) + cancel() + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + defer signal.Stop(saveOnlyCh) + for { + select { + case <-saveOnlyCh: + logger.Info("Received SIGUSR1, saving state without exiting") + if err := srv.SaveState("SIGUSR1"); err != nil { + logger.Error("Failed to save state on SIGUSR1", "error", err) + } + case <-ctx.Done(): + return + } + } + }() +} diff --git a/cmd/server/signals_windows.go b/cmd/server/signals_windows.go new file mode 100644 index 00000000..4ada79c3 --- /dev/null +++ b/cmd/server/signals_windows.go @@ -0,0 +1,24 @@ +//go:build windows + +package server + +import ( + "context" + "log/slog" + "os" + "os/signal" + + "github.com/coder/agentapi/lib/httpapi" +) + +// handleSignals sets up signal handlers for Windows. +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + logger.Info("Received shutdown signal", "signal", sig) + cancel() + }() +} diff --git a/e2e/echo_test.go b/e2e/echo_test.go index 765521cf..d529a4bc 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -40,7 +40,8 @@ func TestE2E(t *testing.T) { t.Run("basic", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, nil) + script, apiClient, cleanup := setup(ctx, t, nil, true) + defer cleanup() messageReq := agentapisdk.PostMessageParams{ Content: "This is a test message.", Type: agentapisdk.MessageTypeUser, @@ -60,7 +61,8 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, nil) + script, apiClient, cleanup := setup(ctx, t, nil, true) + defer cleanup() messageReq := agentapisdk.PostMessageParams{ Content: "What is the answer to life, the universe, and everything?", Type: agentapisdk.MessageTypeUser, @@ -86,13 +88,14 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, ¶ms{ + script, apiClient, cleanup := setup(ctx, t, ¶ms{ cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { defCmd, defArgs := defaultCmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath) script := fmt.Sprintf(`echo "hello agent" | %s %s`, defCmd, strings.Join(defArgs, " ")) return "/bin/sh", []string{"-c", script} }, - }) + }, true) + defer cleanup() require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "stdin")) msgResp, err := apiClient.GetMessages(ctx) require.NoError(t, err, "Failed to get messages via SDK") @@ -100,27 +103,218 @@ func TestE2E(t *testing.T) { require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) }) + + t.Run("state_persistence", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + scriptFilePath := filepath.Join("testdata", "state_persistence.json") + + // Step 1: Start server with state persistence enabled and send first message + script, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, true) + + // Send first message + messageReq := agentapisdk.PostMessageParams{ + Content: "First message before state save.", + Type: agentapisdk.MessageTypeUser, + } + _, err := apiClient.PostMessage(ctx, messageReq) + require.NoError(t, err, "Failed to send first message") + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "first message")) + + // Verify messages before shutdown + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages before shutdown") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages before shutdown") + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, script[1].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Stop server (triggers state save) + cleanup() + + // Verify state file was created + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Start new server instance and load state + // Note: We don't wait for stable here because the echo agent will try to replay + // from the beginning, which conflicts with restored state. We just verify the + // state was loaded and messages are present. + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, false) + defer cleanup2() + + // Step 4: Wait for state to be restored by retrying until we get expected messages + msgResp2, err := waitForMessagesWithCount(ctx, t, apiClient2, 3, operationTimeout, "state restore") + require.NoError(t, err, "Failed to get messages after state restore") + require.Len(t, msgResp2.Messages, 3, "Expected 3 messages after state restore") + + // Verify all messages match the state before shutdown + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp2.Messages[0].Content)) + require.Equal(t, script[1].ExpectMessage, strings.TrimSpace(msgResp2.Messages[1].Content)) + require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(msgResp2.Messages[2].Content)) + }) + + t.Run("state_persistence_initial_prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + scriptFilePath := filepath.Join("testdata", "state_persistence_initial_prompt.json") + + // Step 1: Start server with initial prompt + initialPrompt1 := "Test initial prompt" + _, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + initialPrompt: initialPrompt1, + }, true) + + // Verify initial prompt was sent (should have 3 messages: agent greeting + initial prompt + response) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after initial prompt") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages after initial prompt") + require.Equal(t, "Hello! I'm ready to help you.", strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, "Echo: Test initial prompt", strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Close server + cleanup() + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Restart WITHOUT an initial prompt + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, false) + defer cleanup2() + + // Step 4: Wait for state to be restored and verify initial prompt was NOT sent again + msgResp2, err := waitForMessagesWithCount(ctx, t, apiClient2, 3, operationTimeout, "restart without initial prompt") + require.NoError(t, err, "Failed to get messages after restart without initial prompt") + require.Len(t, msgResp2.Messages, 3, "Expected 3 messages (initial prompt should not be sent again)") + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp2.Messages[1].Content)) + + // Step 5: Close server + cleanup2() + + // Step 6: Restart with same initial prompt + _, apiClient3, cleanup3 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + initialPrompt: initialPrompt1, + }, false) + defer cleanup3() + + // Step 7: Wait for state to be restored and verify same initial prompt was NOT sent again + msgResp3, err := waitForMessagesWithCount(ctx, t, apiClient3, 3, operationTimeout, "restart with same initial prompt") + require.NoError(t, err, "Failed to get messages after restart with same initial prompt") + require.Len(t, msgResp3.Messages, 3, "Expected 3 messages (same initial prompt should not be sent again)") + + }) + + t.Run("state_persistence_different_initial_prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + + // Step 1: Start server with initial prompt "Test initial prompt" using phase1 script + initialPrompt1 := "Test initial prompt" + _, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: filepath.Join("testdata", "state_persistence_different_initial_prompt_phase1.json"), + initialPrompt: initialPrompt1, + }, true) + + // Verify initial prompt was sent (3 messages: greeting + prompt + response) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after initial prompt") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages after initial prompt") + require.Equal(t, "Hello! I'm ready to help you.", strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, "Echo: Test initial prompt", strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Close server + cleanup() + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Restart with DIFFERENT initial prompt using a different script + initialPrompt2 := "Different initial prompt" + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: filepath.Join("testdata", "state_persistence_different_initial_prompt.json"), + initialPrompt: initialPrompt2, + }, false) + defer cleanup2() + + // Wait for initial prompt to be processed and state to stabilize + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient2, operationTimeout, "after different initial prompt")) + + // Step 4: Verify new initial prompt WAS sent (5 messages: 3 previous + 2 new) + msgResp2, err := waitForMessagesWithCount(ctx, t, apiClient2, 5, operationTimeout, "different initial prompt processed") + require.NoError(t, err, "Failed to get messages after different initial prompt") + require.Len(t, msgResp2.Messages, 5, "Expected 5 messages after different initial prompt (3 previous + 2 new)") + // Verify the new initial prompt and response were added + require.Equal(t, initialPrompt2, strings.TrimSpace(msgResp2.Messages[3].Content)) + require.Equal(t, "Echo: Different initial prompt", strings.TrimSpace(msgResp2.Messages[4].Content)) + + }) } type params struct { - cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) + cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) + stateFile string + scriptFilePath string + initialPrompt string } func defaultCmdFn(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { return binaryPath, []string{"server", fmt.Sprintf("--port=%d", serverPort), "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath} } -func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agentapisdk.Client) { +func stateCmdFn(stateFile, initialPrompt string) func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + return func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + args := []string{ + "server", + fmt.Sprintf("--port=%d", serverPort), + fmt.Sprintf("--state-file=%s", stateFile), + } + if initialPrompt != "" { + args = append(args, fmt.Sprintf("--initial-prompt=%s", initialPrompt)) + } + args = append(args, "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath) + return binaryPath, args + } +} + +func setup(ctx context.Context, t testing.TB, p *params, waitForStable bool) ([]ScriptEntry, *agentapisdk.Client, func()) { t.Helper() if p == nil { p = ¶ms{} } if p.cmdFn == nil { - p.cmdFn = defaultCmdFn + if p.stateFile != "" { + p.cmdFn = stateCmdFn(p.stateFile, p.initialPrompt) + } else { + p.cmdFn = defaultCmdFn + } } - scriptFilePath := filepath.Join("testdata", filepath.Base(t.Name())+".json") + scriptFilePath := p.scriptFilePath + if scriptFilePath == "" { + scriptFilePath = filepath.Join("testdata", filepath.Base(t.Name())+".json") + } data, err := os.ReadFile(scriptFilePath) require.NoError(t, err, "Failed to read test script file: %s", scriptFilePath) @@ -175,22 +369,37 @@ func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agenta logOutput(t, "SERVER-STDERR", stderr) }() - // Clean up process - t.Cleanup(func() { + // Create cleanup function + cleanup := func() { if cmd.Process != nil { - _ = cmd.Process.Kill() - _ = cmd.Wait() + // Send SIGINT to allow graceful shutdown and state save + _ = cmd.Process.Signal(os.Interrupt) + // Wait for process to exit gracefully (with timeout) + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + select { + case <-done: + // Process exited gracefully + case <-time.After(10 * time.Second): + // Timeout, force kill + _ = cmd.Process.Kill() + <-done + } } wg.Wait() - }) + } serverURL := fmt.Sprintf("http://localhost:%d", serverPort) require.NoError(t, waitForServer(ctx, t, serverURL, healthCheckTimeout), "Server not ready") apiClient, err := agentapisdk.NewClient(serverURL) require.NoError(t, err, "Failed to create agentapi SDK client") - require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup")) - return script, apiClient + if waitForStable { + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup")) + } + return script, apiClient, cleanup } // logOutput logs process output with prefix @@ -263,6 +472,46 @@ func waitAgentAPIStable(ctx context.Context, t testing.TB, apiClient *agentapisd } } +// waitForMessagesWithCount retries GetMessages until it returns the expected number of messages or the timeout is reached. +func waitForMessagesWithCount(ctx context.Context, t testing.TB, apiClient *agentapisdk.Client, expectedCount int, timeout time.Duration, msg string) (*agentapisdk.GetMessagesResponse, error) { + t.Helper() + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + start := time.Now() + var lastErr error + var lastCount int + + for { + select { + case <-waitCtx.Done(): + if lastErr != nil { + return nil, fmt.Errorf("%s: failed to get %d messages after %v (last error: %w, last count: %d)", + msg, expectedCount, time.Since(start).Round(100*time.Millisecond), lastErr, lastCount) + } + return nil, fmt.Errorf("%s: timeout waiting for %d messages after %v (last count: %d)", + msg, expectedCount, time.Since(start).Round(100*time.Millisecond), lastCount) + case <-ticker.C: + resp, err := apiClient.GetMessages(ctx) + if err != nil { + lastErr = err + t.Logf("%s: GetMessages failed (will retry): %v", msg, err) + continue + } + lastCount = len(resp.Messages) + if lastCount == expectedCount { + elapsed := time.Since(start) + t.Logf("%s: got expected %d messages (elapsed: %s)", msg, expectedCount, elapsed.Round(100*time.Millisecond)) + return resp, nil + } + t.Logf("%s: got %d messages, expecting %d (will retry)", msg, lastCount, expectedCount) + } + } +} + // getFreePort returns a free TCP port func getFreePort() (int, error) { addr, err := net.ResolveTCPAddr("tcp", "localhost:0") diff --git a/e2e/testdata/state_persistence.json b/e2e/testdata/state_persistence.json new file mode 100644 index 00000000..b7fe071d --- /dev/null +++ b/e2e/testdata/state_persistence.json @@ -0,0 +1,18 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "First message before state save.", + "responseMessage": "Echo: First message before state save." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_different_initial_prompt.json b/e2e/testdata/state_persistence_different_initial_prompt.json new file mode 100644 index 00000000..60c610fd --- /dev/null +++ b/e2e/testdata/state_persistence_different_initial_prompt.json @@ -0,0 +1,10 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_different_initial_prompt_phase1.json b/e2e/testdata/state_persistence_different_initial_prompt_phase1.json new file mode 100644 index 00000000..685ac187 --- /dev/null +++ b/e2e/testdata/state_persistence_different_initial_prompt_phase1.json @@ -0,0 +1,10 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_initial_prompt.json b/e2e/testdata/state_persistence_initial_prompt.json new file mode 100644 index 00000000..cdd8d767 --- /dev/null +++ b/e2e/testdata/state_persistence_initial_prompt.json @@ -0,0 +1,14 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 906a3a42..c92bb48f 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/coder/quartz" + mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/agentapi/lib/util" @@ -18,6 +20,7 @@ const ( EventTypeMessageUpdate EventType = "message_update" EventTypeStatusChange EventType = "status_change" EventTypeScreenUpdate EventType = "screen_update" + EventTypeError EventType = "agent_error" ) type AgentStatus string @@ -52,6 +55,12 @@ type ScreenUpdateBody struct { Screen string `json:"screen"` } +type ErrorBody struct { + Message string `json:"message" doc:"Error message"` + Level st.ErrorLevel `json:"level" doc:"Error level"` + Time time.Time `json:"time" doc:"Timestamp when the error occurred"` +} + type Event struct { Type EventType Payload any @@ -66,6 +75,8 @@ type EventEmitter struct { chanIdx int subscriptionBufSize uint screen string + errors []ErrorBody + clock quartz.Clock } func convertStatus(status st.ConversationStatus) AgentStatus { @@ -101,6 +112,12 @@ func WithAgentType(agentType mf.AgentType) EventEmitterOption { } } +func WithClock(clock quartz.Clock) EventEmitterOption { + return func(e *EventEmitter) { + e.clock = clock + } +} + func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter { e := &EventEmitter{ messages: make([]st.ConversationMessage, 0), @@ -111,6 +128,9 @@ func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter { for _, opt := range opts { opt(e) } + if e.clock == nil { + e.clock = quartz.NewReal() + } return e } @@ -137,7 +157,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// EmitMessages assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() @@ -194,6 +214,22 @@ func (e *EventEmitter) EmitScreen(newScreen string) { e.screen = newScreen } +func (e *EventEmitter) EmitError(message string, level st.ErrorLevel) { + e.mu.Lock() + defer e.mu.Unlock() + + errorBody := ErrorBody{ + Message: message, + Level: level, + Time: e.clock.Now(), + } + + // Store the error so new subscribers can receive all errors + e.errors = append(e.errors, errorBody) + + e.notifyChannels(EventTypeError, errorBody) +} + // Assumes the caller holds the lock. func (e *EventEmitter) currentStateAsEvents() []Event { events := make([]Event, 0, len(e.messages)+2) @@ -211,6 +247,15 @@ func (e *EventEmitter) currentStateAsEvents() []Event { Type: EventTypeScreenUpdate, Payload: ScreenUpdateBody{Screen: strings.TrimRight(e.screen, mf.WhiteSpaceChars)}, }) + + // Include all error events + for _, err := range e.errors { + events = append(events, Event{ + Type: EventTypeError, + Payload: err, + }) + } + return events } diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index a1d024c4..106766af 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -6,6 +6,7 @@ import ( "time" st "github.com/coder/agentapi/lib/screentracker" + "github.com/coder/quartz" "github.com/stretchr/testify/assert" ) @@ -97,4 +98,40 @@ func TestEventEmitter(t *testing.T) { t.Fatalf("read should not block") } }) + + t.Run("clock-injection", func(t *testing.T) { + mockClock := quartz.NewMock(t) + fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + mockClock.Set(fixedTime) + + emitter := NewEventEmitter(WithClock(mockClock), WithSubscriptionBufSize(10)) + _, ch, stateEvents := emitter.Subscribe() + + // Verify initial state events + assert.Len(t, stateEvents, 2) + + // Emit an error and verify it uses the mock clock time + emitter.EmitError("test error", st.ErrorLevelError) + + event := <-ch + assert.Equal(t, EventTypeError, event.Type) + errorBody, ok := event.Payload.(ErrorBody) + assert.True(t, ok) + assert.Equal(t, "test error", errorBody.Message) + assert.Equal(t, st.ErrorLevelError, errorBody.Level) + assert.Equal(t, fixedTime, errorBody.Time) + + // Advance the clock and emit another error + newTime := fixedTime.Add(1 * time.Hour) + mockClock.Set(newTime) + emitter.EmitError("another error", st.ErrorLevelWarning) + + event = <-ch + assert.Equal(t, EventTypeError, event.Type) + errorBody, ok = event.Payload.(ErrorBody) + assert.True(t, ok) + assert.Equal(t, "another error", errorBody.Message) + assert.Equal(t, st.ErrorLevelWarning, errorBody.Level) + assert.Equal(t, newTime, errorBody.Time) + }) } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 956cfb8a..4fb102fe 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -40,6 +41,7 @@ type Server struct { port int srv *http.Server mu sync.RWMutex + stopOnce sync.Once logger *slog.Logger conversation st.Conversation agentio *termexec.Process @@ -48,6 +50,8 @@ type Server struct { chatBasePath string tempDir string clock quartz.Clock + shutdownCtx context.Context + shutdown context.CancelFunc } func (s *Server) NormalizeSchema(schema any) any { @@ -97,14 +101,15 @@ func (s *Server) GetOpenAPI() string { const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - Clock quartz.Clock + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + Clock quartz.Clock + StatePersistenceConfig st.StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -253,16 +258,17 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, - Logger: logger, + AgentType: config.AgentType, + AgentIO: config.Process, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, + Logger: logger, + StatePersistenceConfig: config.StatePersistenceConfig, }, emitter) // Create temporary directory for uploads @@ -272,6 +278,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } logger.Info("Created temporary directory for uploads", "tempDir", tempDir) + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + s := &Server{ router: router, api: api, @@ -284,6 +292,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, + shutdownCtx: shutdownCtx, + shutdown: shutdownCancel, } // Register API routes @@ -387,6 +397,7 @@ func (s *Server) registerRoutes() { // Mapping of event type name to Go struct for that event. "message_update": MessageUpdateBody{}, "status_change": StatusChangeBody{}, + "agent_error": ErrorBody{}, }, s.subscribeEvents) sse.Register(s.api, huma.Operation{ @@ -511,6 +522,7 @@ func (s *Server) uploadFiles(ctx context.Context, input *struct { func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) { subscriberId, ch, stateEvents := s.emitter.Subscribe() defer s.emitter.Unsubscribe(subscriberId) + s.logger.Info("New subscriber", "subscriberId", subscriberId) for _, event := range stateEvents { if event.Type == EventTypeScreenUpdate { @@ -536,6 +548,9 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Context done", "subscriberId", subscriberId) return @@ -570,6 +585,9 @@ func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send screen event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Screen context done", "subscriberId", subscriberId) return @@ -588,15 +606,22 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server +// Stop gracefully stops the HTTP server. It is safe to call multiple times. func (s *Server) Stop(ctx context.Context) error { - // Clean up temporary directory - s.cleanupTempDir() + var err error + s.stopOnce.Do(func() { + s.shutdown() - if s.srv != nil { - return s.srv.Shutdown(ctx) - } - return nil + // Clean up temporary directory + s.cleanupTempDir() + + if s.srv != nil { + if err = s.srv.Shutdown(ctx); errors.Is(err, http.ErrServerClosed) { + err = nil + } + } + }) + return err } // cleanupTempDir removes the temporary directory and all its contents @@ -608,6 +633,14 @@ func (s *Server) cleanupTempDir() { } } +func (s *Server) SaveState(source string) error { + if err := s.conversation.SaveState(); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + return err + } + return nil +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..82fc6713 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" @@ -956,3 +957,36 @@ func TestServer_UploadFiles_Errors(t *testing.T) { require.Contains(t, string(body), "file size exceeds 10MB limit") }) } + +func TestServer_Stop_Idempotency(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + + // First call to Stop should succeed + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = srv.Stop(stopCtx) + require.NoError(t, err) + + // Second call to Stop should also succeed (no-op) + stopCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = srv.Stop(stopCtx2) + require.NoError(t, err) + + // Third call to Stop should also succeed (no-op) + stopCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + err = srv.Stop(stopCtx3) + require.NoError(t, err) +} diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..c8d95b6e 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 8299faa1..e11415a7 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,6 +2,7 @@ package screentracker import ( "context" + "strings" "time" "github.com/coder/agentapi/lib/util" @@ -33,6 +34,22 @@ var ConversationRoleValues = []ConversationRole{ ConversationRoleAgent, } +type ErrorLevel string + +func (e ErrorLevel) Schema(r huma.Registry) *huma.Schema { + return util.OpenAPISchema(r, "ErrorLevel", ErrorLevelValues) +} + +const ( + ErrorLevelWarning ErrorLevel = "warning" + ErrorLevelError ErrorLevel = "error" +) + +var ErrorLevelValues = []ErrorLevel{ + ErrorLevelWarning, + ErrorLevelError, +} + var ( ErrMessageValidationWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") ErrMessageValidationEmpty = xerrors.New("message must not be empty") @@ -49,6 +66,14 @@ type MessagePart interface { String() string } +func buildStringFromMessageParts(parts []MessagePart) string { + var sb strings.Builder + for _, part := range parts { + sb.WriteString(part.String()) + } + return sb.String() +} + // Conversation represents a conversation between a user and an agent. // It is intended as the primary interface for interacting with a session. // Implementations must support the following capabilities: @@ -63,6 +88,7 @@ type Conversation interface { Start(context.Context) Status() ConversationStatus Text() string + SaveState() error } // Emitter receives conversation state updates. @@ -70,11 +96,18 @@ type Emitter interface { EmitMessages([]ConversationMessage) EmitStatus(ConversationStatus) EmitScreen(string) + EmitError(message string, level ErrorLevel) } type ConversationMessage struct { - Id int - Message string - Role ConversationRole - Time time.Time + Id int `json:"id"` + Message string `json:"message"` + Role ConversationRole `json:"role"` + Time time.Time `json:"time"` +} + +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool } diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 27283775..37e5c374 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -2,9 +2,11 @@ package screentracker import ( "context" + "encoding/json" "fmt" "log/slog" - "strings" + "os" + "path/filepath" "sync" "time" @@ -26,6 +28,25 @@ type MessagePartText struct { Hidden bool } +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` +} + +// LoadStateStatus represents the state of loading persisted conversation state. +type LoadStateStatus int + +const ( + // LoadStatePending indicates state loading has not been attempted yet. + LoadStatePending LoadStateStatus = iota + // LoadStateSucceeded indicates state was successfully loaded. + LoadStateSucceeded + // LoadStateFailed indicates state loading was attempted but failed. + LoadStateFailed +) + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -67,8 +88,9 @@ type PTYConversationConfig struct { // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready - InitialPrompt []MessagePart - Logger *slog.Logger + InitialPrompt []MessagePart + Logger *slog.Logger + StatePersistenceConfig StatePersistenceConfig } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -107,9 +129,17 @@ type PTYConversation struct { stableSignal chan struct{} // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool + // loadStateStatus tracks the status of loading conversation state from file. + loadStateStatus LoadStateStatus // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool + // initialPromptSent is set to true when the initial prompt has been enqueued to the outbound queue. + initialPromptSent bool } var _ Conversation = &PTYConversation{} @@ -119,6 +149,7 @@ type noopEmitter struct{} func (noopEmitter) EmitMessages([]ConversationMessage) {} func (noopEmitter) EmitStatus(ConversationStatus) {} func (noopEmitter) EmitScreen(string) {} +func (noopEmitter) EmitError(_ string, _ ErrorLevel) {} func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { @@ -140,13 +171,12 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT Time: cfg.Clock.Now(), }, }, - outboundQueue: make(chan outboundMessage, 1), - stableSignal: make(chan struct{}, 1), - toolCallMessageSet: make(map[string]bool), - } - // If we have an initial prompt, enqueue it - if len(cfg.InitialPrompt) > 0 { - c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), + toolCallMessageSet: make(map[string]bool), + dirty: false, + userSentMessageAfterLoadState: false, + loadStateStatus: LoadStatePending, } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } @@ -169,6 +199,23 @@ func (c *PTYConversation) Start(ctx context.Context) { if !c.initialPromptReady && c.cfg.ReadyForInitialPrompt(screen) { c.initialPromptReady = true } + + if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState { + if err := c.loadStateLocked(); err != nil { + c.cfg.Logger.Error("Failed to load state", "error", err) + c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), ErrorLevelWarning) + c.loadStateStatus = LoadStateFailed + } else { + c.loadStateStatus = LoadStateSucceeded + } + } + + if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent { + c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil} + c.initialPromptSent = true + c.dirty = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -245,6 +292,9 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } + if c.loadStateStatus == LoadStateSucceeded && !c.userSentMessageAfterLoadState && len(c.messages) > 0 { + agentMessage = c.messages[len(c.messages)-1].Message + } if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -274,6 +324,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + + c.dirty = true } // caller MUST hold c.lock @@ -288,11 +340,7 @@ func (c *PTYConversation) snapshotLocked(screen string) { func (c *PTYConversation) Send(messageParts ...MessagePart) error { // Validate message content before enqueueing - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) if message != msgfmt.TrimWhitespace(message) { return ErrMessageValidationWhitespace } @@ -316,11 +364,7 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { // around the parts that access shared state, but releases it during // writeStabilize to avoid blocking the snapshot loop. func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...MessagePart) error { - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) c.lock.Lock() screenBeforeMessage := c.cfg.AgentIO.ReadScreen() @@ -350,6 +394,8 @@ func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...Messa Role: ConversationRoleUser, Time: now, }) + c.userSentMessageAfterLoadState = true + c.lock.Unlock() return nil } @@ -497,3 +543,144 @@ func (c *PTYConversation) Text() string { } return snapshots[len(snapshots)-1].screen } + +func (c *PTYConversation) SaveState() error { + c.lock.Lock() + defer c.lock.Unlock() + + stateFile := c.cfg.StatePersistenceConfig.StateFile + saveState := c.cfg.StatePersistenceConfig.SaveState + + if !saveState { + c.cfg.Logger.Info("State persistence is disabled") + return nil + } + + // Skip if not dirty + if !c.dirty { + c.cfg.Logger.Info("Skipping state save: no changes since last save") + return nil + } + + conversation := c.messagesLocked() + + // Serialize initial prompt from message parts + var initialPromptStr string + if len(c.cfg.InitialPrompt) > 0 { + initialPromptStr = buildStringFromMessageParts(c.cfg.InitialPrompt) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Use atomic write: write to temp file, then rename to target path + tempFile := stateFile + ".tmp" + f, err := os.OpenFile(tempFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return xerrors.Errorf("failed to create temp state file: %w", err) + } + + // Clean up temp file on error (before successful rename) + var renamed bool + defer func() { + if !renamed { + if removeErr := os.Remove(tempFile); removeErr != nil && !os.IsNotExist(removeErr) { + c.cfg.Logger.Warn("Failed to clean up temp state file", "path", tempFile, "err", removeErr) + } + } + }() + + // Encode directly to file to avoid loading entire JSON into memory + encoder := json.NewEncoder(f) + if err := encoder.Encode(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, + InitialPromptSent: c.initialPromptSent, + }); err != nil { + _ = f.Close() + return xerrors.Errorf("failed to encode state: %w", err) + } + + // Close file before rename + if err := f.Close(); err != nil { + return xerrors.Errorf("failed to close temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + renamed = true + + // Clear dirty flag after successful save + c.dirty = false + + c.cfg.Logger.Info("State saved successfully", "path", stateFile) + + return nil +} + +// loadStateLocked loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadStateLocked() error { + stateFile := c.cfg.StatePersistenceConfig.StateFile + loadState := c.cfg.StatePersistenceConfig.LoadState + + if !loadState || c.loadStateStatus != LoadStatePending { + return nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) + return nil + } + + // Open state file + f, err := os.Open(stateFile) + if err != nil { + return xerrors.Errorf("failed to open state file: %w", err) + } + defer func() { + if closeErr := f.Close(); closeErr != nil { + c.cfg.Logger.Warn("Failed to close state file", "path", stateFile, "err", closeErr) + } + }() + + var agentState AgentState + decoder := json.NewDecoder(f) + if err := decoder.Decode(&agentState); err != nil { + return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + } + + // Validate version + if agentState.Version != 1 { + return xerrors.Errorf("unsupported state file version %d (expected 1)", agentState.Version) + } + + // Handle initial prompt restoration: + // - If a new initial prompt was provided via flags, check if it differs from the saved one. + // If different, mark as not sent (will be sent). If same, preserve sent status. + // - If no new prompt provided, restore the saved prompt and its sent status. + c.initialPromptSent = agentState.InitialPromptSent + if len(c.cfg.InitialPrompt) > 0 { + isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt + c.initialPromptSent = !isDifferent + } else if agentState.InitialPrompt != "" { + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} + } + + c.messages = agentState.Messages + + c.dirty = false + + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) + return nil +} diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 19b4511b..6342bd74 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,9 +2,11 @@ package screentracker_test import ( "context" + "encoding/json" "fmt" "io" "log/slog" + "os" "sync" "testing" "time" @@ -54,6 +56,7 @@ type testEmitter struct{} func (testEmitter) EmitMessages([]st.ConversationMessage) {} func (testEmitter) EmitStatus(st.ConversationStatus) {} func (testEmitter) EmitScreen(string) {} +func (testEmitter) EmitError(_ string, _ st.ErrorLevel) {} // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { @@ -446,10 +449,500 @@ func TestMessages(t *testing.T) { }) } +func TestStatePersistence(t *testing.T) { + t.Run("SaveState creates file with correct structure", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + // Create temp directory for state file + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate some conversation + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state + err := c.SaveState() + require.NoError(t, err) + + // Read and verify the saved file + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.Equal(t, 1, agentState.Version) + assert.Equal(t, "test prompt", agentState.InitialPrompt) + assert.NotEmpty(t, agentState.Messages) + }) + + t.Run("SaveState creates valid JSON", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + fixedTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + mClock.Set(fixedTime) + + agent := &testAgent{screen: ""} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + // Change screen on each write so writeStabilize can detect changes + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + ReadyForInitialPrompt: func(message string) bool { + return message == "Hello! Ready to help." + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Step 1: Agent shows initial greeting + agent.setScreen("Hello! Ready to help.") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Step 2: Wait for initial prompt to be sent (uses advanceUntil like TestInitialPromptReadiness) + advanceUntil(ctx, t, mClock, func() bool { + return len(c.Messages()) >= 2 // greeting + user prompt + }) + + // Step 3: Agent shows response + agent.setScreen("Response to test prompt") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state - this creates state.json + err := c.SaveState() + require.NoError(t, err) + + // Read the saved state.json + actualData, err := os.ReadFile(stateFile) + require.NoError(t, err) + + // Read the expected golden file + expectedData, err := os.ReadFile("testdata/expected_saved_state.json") + require.NoError(t, err) + + // Parse both JSON files + var actualState, expectedState st.AgentState + err = json.Unmarshal(actualData, &actualState) + require.NoError(t, err) + err = json.Unmarshal(expectedData, &expectedState) + require.NoError(t, err) + + // Compare the state files field by field + assert.Equal(t, expectedState.Version, actualState.Version, "version should match") + assert.Equal(t, expectedState.InitialPrompt, actualState.InitialPrompt, "initial_prompt should match") + assert.Equal(t, expectedState.InitialPromptSent, actualState.InitialPromptSent, "initial_prompt_sent should match") + assert.Equal(t, len(expectedState.Messages), len(actualState.Messages), "message count should match") + + // Compare each message + for i := range expectedState.Messages { + if i >= len(actualState.Messages) { + break + } + assert.Equal(t, expectedState.Messages[i].Id, actualState.Messages[i].Id, "message %d: id should match", i) + assert.Equal(t, expectedState.Messages[i].Message, actualState.Messages[i].Message, "message %d: message should match", i) + assert.Equal(t, expectedState.Messages[i].Role, actualState.Messages[i].Role, "message %d: role should match", i) + } + }) + + t.Run("SaveState skips when not configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + err := c.SaveState() + require.NoError(t, err) + + // File should not be created + _, err = os.Stat(stateFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("SaveState honors dirty flag", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate conversation and save + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + err := c.SaveState() + require.NoError(t, err) + + // Get file modification time + info1, err := os.Stat(stateFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Save again without changes - file should not be modified + err = c.SaveState() + require.NoError(t, err) + + info2, err := os.Stat(stateFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // File modification time should be the same (dirty flag prevents save) + assert.Equal(t, modTime1, modTime2) + }) + + t.Run("SaveState creates directory if not exists", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nested/deep/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + // Verify file and directory were created + _, err = os.Stat(stateFile) + assert.NoError(t, err) + }) + + t.Run("LoadState restores conversation from file", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with test data + testState := st.AgentState{ + Version: 1, + InitialPrompt: "restored prompt", + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message 1", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user message 1", Role: st.ConversationRoleUser, Time: time.Now()}, + {Id: 2, Message: "agent message 2", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with LoadState enabled + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until agent is ready and state is loaded + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Verify messages were restored + messages := c.Messages() + assert.Len(t, messages, 3) + assert.Equal(t, "agent message 1", messages[0].Message) + assert.Equal(t, "user message 1", messages[1].Message) + // The last agent message may have adjustments from adjustScreenAfterStateLoad + assert.Contains(t, messages[2].Message, "agent message 2") + }) + + t.Run("LoadState handles missing file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nonexistent.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles empty file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/empty.json" + + // Create empty file + err := os.WriteFile(stateFile, []byte(""), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles corrupted JSON gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/corrupted.json" + + // Create corrupted JSON file + err := os.WriteFile(stateFile, []byte("{invalid json}"), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs warning and continues + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState rejects unsupported version", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/unsupported_version.json" + + // Create state file with unsupported version + unsupportedState := map[string]interface{}{ + "version": 999, // Unsupported version + "messages": []interface{}{}, + "initial_prompt": "", + "initial_prompt_sent": false, + } + stateBytes, err := json.Marshal(unsupportedState) + require.NoError(t, err) + err = os.WriteFile(stateFile, stateBytes, 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs error and continues with empty state + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message (version error causes fallback to empty state) + messages := c.Messages() + assert.Len(t, messages, 1) + }) +} + func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - t.Run("agent not ready - status remains changing", func(t *testing.T) { + t.Run("agent not ready - status is stable until agent becomes ready", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -472,12 +965,12 @@ func TestInitialPromptReadiness(t *testing.T) { // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). advanceFor(ctx, t, mClock, 1*time.Second) - // Even though screen is stable, status should be changing because - // the initial prompt is still in the outbound queue. - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + // Screen is stable and agent is not ready, so initial prompt hasn't been enqueued yet. + // Status should be stable. + assert.Equal(t, st.ConversationStatusStable, c.Status()) }) - t.Run("agent becomes ready - status stays changing until initial prompt sent", func(t *testing.T) { + t.Run("agent becomes ready - prompt enqueued and status changes to changing", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -497,12 +990,11 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Agent not ready initially. + // Agent not ready initially, status should be stable advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready, but status stays "changing" because the - // initial prompt is still in the outbound queue. + // Agent becomes ready, prompt gets enqueued, status becomes "changing" agent.setScreen("ready") advanceFor(ctx, t, mClock, 1*time.Second) assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -533,12 +1025,12 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Status is "changing" while waiting for readiness. + // Status is "stable" while waiting for readiness (prompt not yet enqueued). advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready. The readiness loop detects this, the snapshot - // loop sees queue + stable + ready and signals the send loop. + // Agent becomes ready. The snapshot loop detects this, enqueues the prompt, + // then sees queue + stable + ready and signals the send loop. // writeStabilize runs with onWrite changing the screen, so it completes. agent.setScreen("ready") // Drive clock until the initial prompt is sent (queue drains). @@ -611,3 +1103,385 @@ func TestInitialPromptReadiness(t *testing.T) { assert.Equal(t, st.ConversationStatusStable, c.Status()) }) } + +func TestInitialPromptSent(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("initialPromptSent is set when initial prompt is sent", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready and initial prompt is sent + agent.setScreen("ready") + advanceUntil(ctx, t, mClock, func() bool { + return len(c.Messages()) >= 2 + }) + + // Save state and verify initialPromptSent is persisted + agent.setScreen("response") + advanceFor(ctx, t, mClock, 2*time.Second) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.True(t, agentState.InitialPromptSent, "initialPromptSent should be true after initial prompt is sent") + assert.Equal(t, "test prompt", agentState.InitialPrompt) + }) + + t.Run("initialPromptSent prevents re-sending prompt after state load", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with initialPromptSent=true + testState := st.AgentState{ + Version: 1, + InitialPrompt: "test prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "test prompt", Role: st.ConversationRoleUser, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with same initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + writeCount := 0 + agent.onWrite = func(data []byte) { + writeCount++ + agent.screen = "after_write" + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until ready and state is loaded + advanceFor(ctx, t, mClock, 500*time.Millisecond) + + // Verify the prompt was NOT re-sent (no writes occurred) + assert.Equal(t, 0, writeCount, "initial prompt should not be re-sent when already sent") + + // Messages should be restored from state (at minimum, the original 2) + messages := c.Messages() + assert.GreaterOrEqual(t, len(messages), 2, "messages should be restored from state") + // Verify the first two messages match what we saved + assert.Equal(t, "agent message", messages[0].Message) + assert.Equal(t, st.ConversationRoleAgent, messages[0].Role) + assert.Equal(t, "test prompt", messages[1].Message) + assert.Equal(t, st.ConversationRoleUser, messages[1].Role) + }) + + t.Run("new initial prompt is sent if different from saved prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with old prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "old prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with different initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "new prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the new prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + // Look for the new prompt in messages + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + return true + } + } + return false + }) + + // Verify the new prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + found = true + break + } + } + assert.True(t, found, "new prompt should be sent when different from saved prompt") + }) + + t.Run("initialPromptSent not set when no initial prompt configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.False(t, agentState.InitialPromptSent, "initialPromptSent should be false when no initial prompt configured") + }) + + t.Run("restored prompt used when no new prompt provided", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with a prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "saved prompt", + InitialPromptSent: false, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation without providing an initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the saved prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + return true + } + } + return false + }) + + // Verify the saved prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + found = true + break + } + } + assert.True(t, found, "saved prompt should be sent when no new prompt provided") + }) + + t.Run("empty prompt from state is not restored", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create state file with empty prompt + emptyPromptState := st.AgentState{ + Version: 1, + Messages: []st.ConversationMessage{}, + InitialPrompt: "", // Empty prompt + InitialPromptSent: false, + } + stateBytes, err := json.Marshal(emptyPromptState) + require.NoError(t, err) + err = os.WriteFile(stateFile, stateBytes, 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: discardLogger, + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance time to ensure any prompt would be sent + advanceFor(ctx, t, mClock, 500*time.Millisecond) + + // Verify no prompt was sent (should only have the initial screen message) + messages := c.Messages() + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser { + t.Errorf("Unexpected user message sent: %q (empty prompt should not be restored)", msg.Message) + } + } + }) +} diff --git a/lib/screentracker/testdata/expected_saved_state.json b/lib/screentracker/testdata/expected_saved_state.json new file mode 100644 index 00000000..fb41d16b --- /dev/null +++ b/lib/screentracker/testdata/expected_saved_state.json @@ -0,0 +1 @@ +{"version":1,"messages":[{"id":0,"message":"Hello! Ready to help.","role":"agent","time":"2025-01-01T00:00:00.5Z"},{"id":1,"message":"test prompt","role":"user","time":"2025-01-01T00:00:00.5Z"},{"id":2,"message":"Response to test prompt","role":"agent","time":"2025-01-01T00:00:01.9Z"}],"initial_prompt":"test prompt","initial_prompt_sent":true} diff --git a/lib/termexec/termexec.go b/lib/termexec/termexec.go index edad9b13..05403690 100644 --- a/lib/termexec/termexec.go +++ b/lib/termexec/termexec.go @@ -163,7 +163,7 @@ func (p *Process) Close(logger *slog.Logger, timeout time.Duration) error { case err := <-exited: var pathErr *os.SyscallError // ECHILD is expected if the process has already exited - if err != nil && !(errors.As(err, &pathErr) && pathErr.Err == syscall.ECHILD) { + if err != nil && !(errors.As(err, &pathErr) && errors.Is(pathErr.Err, syscall.ECHILD)) { exitErr = xerrors.Errorf("process exited with error: %w", err) } } diff --git a/openapi.json b/openapi.json index dda817cc..e77c6d4f 100644 --- a/openapi.json +++ b/openapi.json @@ -19,6 +19,30 @@ "title": "ConversationRole", "type": "string" }, + "ErrorBody": { + "additionalProperties": false, + "properties": { + "level": { + "$ref": "#/components/schemas/ErrorLevel", + "description": "Error level" + }, + "message": { + "description": "Error message", + "type": "string" + }, + "time": { + "description": "Timestamp when the error occurred", + "format": "date-time", + "type": "string" + } + }, + "required": [ + "level", + "message", + "time" + ], + "type": "object" + }, "ErrorDetail": { "additionalProperties": false, "properties": { @@ -36,6 +60,15 @@ }, "type": "object" }, + "ErrorLevel": { + "enum": [ + "error", + "warning" + ], + "example": "warning", + "title": "ErrorLevel", + "type": "string" + }, "ErrorModel": { "additionalProperties": false, "properties": { @@ -326,6 +359,32 @@ "description": "Each oneOf object in the array represents one possible Server Sent Events (SSE) message, serialized as UTF-8 text according to the SSE specification.", "items": { "oneOf": [ + { + "properties": { + "data": { + "$ref": "#/components/schemas/ErrorBody" + }, + "event": { + "const": "agent_error", + "description": "The event name.", + "type": "string" + }, + "id": { + "description": "The event ID.", + "type": "integer" + }, + "retry": { + "description": "The retry time in milliseconds.", + "type": "integer" + } + }, + "required": [ + "data", + "event" + ], + "title": "Event agent_error", + "type": "object" + }, { "properties": { "data": {