Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions e2e/acp_echo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//go:build ignore

package main

import (
"context"
"encoding/json"
"fmt"
"os"
"os/signal"
"strings"

acp "github.com/coder/acp-go-sdk"
)

// ScriptEntry defines a single entry in the test script.
type ScriptEntry struct {
ExpectMessage string `json:"expectMessage"`
ThinkDurationMS int64 `json:"thinkDurationMS"`
ResponseMessage string `json:"responseMessage"`
}

// acpEchoAgent implements the ACP Agent interface for testing.
type acpEchoAgent struct {
script []ScriptEntry
scriptIndex int
conn *acp.AgentSideConnection
sessionID acp.SessionId
}

var _ acp.Agent = (*acpEchoAgent)(nil)

func main() {
if len(os.Args) != 2 {
fmt.Fprintln(os.Stderr, "Usage: acp_echo <script.json>")
os.Exit(1)
}

script, err := loadScript(os.Args[1])
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading script: %v\n", err)
os.Exit(1)
}

if len(script) == 0 {
fmt.Fprintln(os.Stderr, "Script is empty")
os.Exit(1)
}

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
go func() {
<-sigCh
os.Exit(0)
}()

agent := &acpEchoAgent{
script: script,
}

conn := acp.NewAgentSideConnection(agent, os.Stdout, os.Stdin)
agent.conn = conn

<-conn.Done()
}

func (a *acpEchoAgent) Initialize(_ context.Context, _ acp.InitializeRequest) (acp.InitializeResponse, error) {
return acp.InitializeResponse{
ProtocolVersion: acp.ProtocolVersionNumber,
AgentCapabilities: acp.AgentCapabilities{},
}, nil
}

func (a *acpEchoAgent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.AuthenticateResponse, error) {
return acp.AuthenticateResponse{}, nil
}

func (a *acpEchoAgent) Cancel(_ context.Context, _ acp.CancelNotification) error {
return nil
}

func (a *acpEchoAgent) NewSession(_ context.Context, _ acp.NewSessionRequest) (acp.NewSessionResponse, error) {
a.sessionID = "test-session"
return acp.NewSessionResponse{
SessionId: a.sessionID,
}, nil
}

func (a *acpEchoAgent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.PromptResponse, error) {
// Extract text from prompt
var promptText string
for _, block := range params.Prompt {
if block.Text != nil {
promptText = block.Text.Text
break
}
}
promptText = strings.TrimSpace(promptText)

if a.scriptIndex >= len(a.script) {
return acp.PromptResponse{
StopReason: acp.StopReasonEndTurn,
}, nil
}

entry := a.script[a.scriptIndex]
expected := strings.TrimSpace(entry.ExpectMessage)

// Empty ExpectMessage matches any prompt
if expected != "" && expected != promptText {
return acp.PromptResponse{}, fmt.Errorf("expected message %q but got %q", expected, promptText)
}

a.scriptIndex++

// Send response via session update
if err := a.conn.SessionUpdate(ctx, acp.SessionNotification{
SessionId: params.SessionId,
Update: acp.UpdateAgentMessageText(entry.ResponseMessage),
}); err != nil {
return acp.PromptResponse{}, err
}

return acp.PromptResponse{
StopReason: acp.StopReasonEndTurn,
}, nil
}

func (a *acpEchoAgent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (acp.SetSessionModeResponse, error) {
return acp.SetSessionModeResponse{}, nil
}

func loadScript(scriptPath string) ([]ScriptEntry, error) {
data, err := os.ReadFile(scriptPath)
if err != nil {
return nil, fmt.Errorf("failed to read script file: %w", err)
}

var script []ScriptEntry
if err := json.Unmarshal(data, &script); err != nil {
return nil, fmt.Errorf("failed to parse script JSON: %w", err)
}

return script, nil
}
28 changes: 28 additions & 0 deletions e2e/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,34 @@ func TestE2E(t *testing.T) {
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content))
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content))
})

t.Run("acp_basic", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

script, apiClient := setup(ctx, t, &params{
cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
return binaryPath, []string{
"server",
fmt.Sprintf("--port=%d", serverPort),
"--experimental-acp",
"--", "go", "run", filepath.Join(cwd, "acp_echo.go"), scriptFilePath,
}
},
})
messageReq := agentapisdk.PostMessageParams{
Content: "This is a test message.",
Type: agentapisdk.MessageTypeUser,
}
_, err := apiClient.PostMessage(ctx, messageReq)
require.NoError(t, err, "Failed to send message via SDK")
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message"))
msgResp, err := apiClient.GetMessages(ctx)
require.NoError(t, err, "Failed to get messages via SDK")
require.Len(t, msgResp.Messages, 2)
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[0].Content))
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[1].Content))
})
}

type params struct {
Expand Down
6 changes: 6 additions & 0 deletions e2e/testdata/acp_basic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[
{
"expectMessage": "This is a test message.",
"responseMessage": "Echo: This is a test message."
}
]
26 changes: 25 additions & 1 deletion x/acpio/acp_conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"
"strings"
"sync"
"time"

st "github.com/coder/agentapi/lib/screentracker"
"github.com/coder/quartz"
Expand All @@ -31,7 +32,8 @@ type ACPConversation struct {
agentIO ChunkableAgentIO
messages []st.ConversationMessage
nextID int // monotonically increasing message ID
prompting bool // true while agent is processing
prompting bool // true while agent is processing
chunkReceived chan struct{} // signals that handleChunk has accumulated a chunk
streamingResponse strings.Builder
logger *slog.Logger
emitter st.Emitter
Expand Down Expand Up @@ -68,6 +70,7 @@ func NewACPConversation(ctx context.Context, agentIO ChunkableAgentIO, logger *s
initialPrompt: initialPrompt,
emitter: emitter,
clock: clock,
chunkReceived: make(chan struct{}, 1),
}
return c
}
Expand Down Expand Up @@ -202,13 +205,25 @@ func (c *ACPConversation) handleChunk(chunk string) {
screen := c.streamingResponse.String()
c.mu.Unlock()

// Signal that a chunk has been received (non-blocking; a pending signal is sufficient).
select {
case c.chunkReceived <- struct{}{}:
default:
}

c.emitter.EmitMessages(messages)
c.emitter.EmitStatus(status)
c.emitter.EmitScreen(screen)
}

// executePrompt runs the actual agent request and returns any error.
func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) error {
// Drain any stale signal before sending the prompt.
select {
case <-c.chunkReceived:
default:
}

var err error
for _, part := range messageParts {
if c.ctx.Err() != nil {
Expand All @@ -221,6 +236,15 @@ func (c *ACPConversation) executePrompt(messageParts []st.MessagePart) error {
}
}

// The ACP SDK dispatches SessionUpdate notifications as goroutines, so
// the chunk may arrive after conn.Prompt() returns. Wait up to 100ms.
timer := c.clock.NewTimer(100 * time.Millisecond)
select {
case <-c.chunkReceived:
case <-timer.C:
}
Comment on lines 239 to 245
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: I'm not super happy with this but it seems to do the job.

timer.Stop()

c.mu.Lock()
c.prompting = false

Expand Down
32 changes: 27 additions & 5 deletions x/acpio/acp_conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ func Test_Send_AddsUserMessage(t *testing.T) {
assert.Equal(t, "hello", messages[0].Message)
assert.Equal(t, screentracker.ConversationRoleAgent, messages[1].Role)

// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("hello response")

// Unblock the write to let Send complete
close(done)
require.NoError(t, <-errCh)
Expand Down Expand Up @@ -290,6 +293,9 @@ func Test_Send_RejectsDuplicateSend(t *testing.T) {
err := conv.Send(screentracker.MessagePartText{Content: "second"})
assert.ErrorIs(t, err, screentracker.ErrMessageValidationChanging)

// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("first response")

// Unblock the write to let the test complete cleanly
close(done)
require.NoError(t, <-errCh)
Expand Down Expand Up @@ -318,6 +324,9 @@ func Test_Status_ChangesWhileProcessing(t *testing.T) {
// Status should be changing while processing
assert.Equal(t, screentracker.ConversationStatusChanging, conv.Status())

// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("test response")

// Unblock the write
close(done)

Expand Down Expand Up @@ -428,6 +437,9 @@ func Test_InitialPrompt_SentOnStart(t *testing.T) {
assert.Equal(t, screentracker.ConversationRoleUser, messages[0].Role)
assert.Equal(t, "initial prompt", messages[0].Message)

// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("initial response")

// Unblock the write to let the test complete cleanly
close(done)
}
Expand Down Expand Up @@ -457,6 +469,9 @@ func Test_Messages_AreCopied(t *testing.T) {
originalMessages := conv.Messages()
assert.Equal(t, "test", originalMessages[0].Message)

// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("test response")

// Unblock the write to let Send complete
close(done)
require.NoError(t, <-errCh)
Expand Down Expand Up @@ -518,12 +533,15 @@ func Test_ErrorRemovesPartialMessage(t *testing.T) {
// Send a second message — IDs must not reuse the removed agent message's ID (1).
mock.mu.Lock()
mock.writeErr = nil
mock.writeBlock = nil
mock.writeStarted = nil
mock.mu.Unlock()

err := conv.Send(screentracker.MessagePartText{Content: "retry"})
require.NoError(t, err)
started2, done2 := mock.BlockWrite()
errCh2 := make(chan error, 1)
go func() { errCh2 <- conv.Send(screentracker.MessagePartText{Content: "retry"}) }()
<-started2
// Signal a chunk so executePrompt's timer wait doesn't hang on the mock clock.
mock.SimulateChunks("retry response")
close(done2)
require.NoError(t, <-errCh2)

messages = conv.Messages()
require.Len(t, messages, 3, "first user + second user + second agent")
Expand All @@ -548,6 +566,10 @@ func Test_LateChunkAfterError_DoesNotCorruptUserMessage(t *testing.T) {
mock.mu.Lock()
mock.writeErr = assert.AnError
mock.mu.Unlock()

// Signal a chunk before unblocking; the error path still waits on chunkReceived
// or the timer, so pre-signaling avoids a hang on the mock clock.
mock.SimulateChunks("unexpected chunk")
close(done)

require.ErrorIs(t, <-errCh, assert.AnError)
Expand Down