diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index c47ed4ff2..f75640640 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -7,6 +7,7 @@ + diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 3522ad60b..f314a519b 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -1329,6 +1329,7 @@ await Rpc.SessionFs.SetProviderAsync( _options.SessionFs.InitialCwd, _options.SessionFs.SessionStatePath, _options.SessionFs.Conventions, + _options.SessionFs.Capabilities, cancellationToken: cancellationToken); } @@ -1345,8 +1346,16 @@ private void ConfigureSessionFsHandlers(CopilotSession session, Func +/// Result of a SQLite query execution via . +/// Same shape as but without the Error field, +/// since providers signal errors by throwing. +/// +public class SessionFsSqliteResult +{ + /// Column names from the result set. + public IList Columns { get; set; } = []; + + /// For SELECT: rows as column-keyed dictionaries. For others: empty. + public IList> Rows { get; set; } = []; + + /// Number of rows affected (for INSERT/UPDATE/DELETE). + public long RowsAffected { get; set; } + + /// Last inserted row ID (for INSERT). + public long? LastInsertRowid { get; set; } +} + +/// +/// Optional interface for subclasses that support +/// per-session SQLite databases. Implement this interface on your provider to enable +/// the runtime's SQL tool to route queries through your SessionFs implementation. +/// +public interface ISessionFsSqliteProvider +{ + /// + /// Executes a SQLite query against the per-session database. + /// + /// How to execute: "exec" for DDL/multi-statement, "query" for SELECT, "run" for INSERT/UPDATE/DELETE. + /// SQL query to execute. + /// Optional named bind parameters. + /// Cancellation token. + /// The query result, or null for exec-type queries. + Task QueryAsync( + SessionFsSqliteQueryType queryType, + string query, + IDictionary? bindParams, + CancellationToken cancellationToken); + + /// + /// Checks whether the per-session SQLite database already exists, without creating it. + /// + /// Cancellation token. + Task ExistsAsync(CancellationToken cancellationToken); +} + /// /// Base class for session filesystem providers. Subclasses override the /// virtual methods and use normal C# patterns (return values, throw exceptions). /// The base class catches exceptions and converts them to /// results expected by the runtime. +/// To add SQLite support, also implement . /// public abstract class SessionFsProvider : ISessionFsHandler { @@ -75,24 +124,6 @@ public abstract class SessionFsProvider : ISessionFsHandler /// Cancellation token. protected abstract Task RenameAsync(string src, string dest, CancellationToken cancellationToken); - /// Executes a SQLite query against the per-session database. - /// Target session identifier. - /// SQL query to execute. - /// How to execute the query. - /// Optional named bind parameters. - /// Cancellation token. - protected abstract Task SqliteQueryAsync( - string sessionId, - string query, - SessionFsSqliteQueryType queryType, - IDictionary? parameters, - CancellationToken cancellationToken); - - /// Checks whether the per-session SQLite database already exists. - /// Target session identifier. - /// Cancellation token. - protected abstract Task SqliteExistsAsync(string sessionId, CancellationToken cancellationToken); - // ---- ISessionFsHandler implementation (private, handles error mapping) ---- async Task ISessionFsHandler.ReadFileAsync(SessionFsReadFileRequest request, CancellationToken cancellationToken) @@ -246,13 +277,27 @@ async Task ISessionFsHandler.ReaddirWithTypesAs async Task ISessionFsHandler.SqliteQueryAsync(SessionFsSqliteQueryRequest request, CancellationToken cancellationToken) { - ArgumentNullException.ThrowIfNull(request); + if (this is not ISessionFsSqliteProvider sqliteProvider) + { + return new SessionFsSqliteQueryResult + { + Error = new SessionFsError { Code = SessionFsErrorCode.UNKNOWN, Message = "SQLite is not supported by this provider." }, + }; + } try { - return await SqliteQueryAsync(request.SessionId, request.Query, request.QueryType, request.Params, cancellationToken).ConfigureAwait(false); + var result = await sqliteProvider.QueryAsync(request.QueryType, request.Query, request.Params, cancellationToken).ConfigureAwait(false); + + return new SessionFsSqliteQueryResult + { + Rows = result?.Rows ?? [], + Columns = result?.Columns ?? [], + RowsAffected = result?.RowsAffected ?? 0, + LastInsertRowid = result?.LastInsertRowid, + }; } - catch (Exception ex) when (!IsCriticalException(ex)) + catch (Exception ex) { return new SessionFsSqliteQueryResult { Error = ToSessionFsError(ex) }; } @@ -260,28 +305,22 @@ async Task ISessionFsHandler.SqliteQueryAsync(Sessio async Task ISessionFsHandler.SqliteExistsAsync(SessionFsSqliteExistsRequest request, CancellationToken cancellationToken) { - ArgumentNullException.ThrowIfNull(request); + if (this is not ISessionFsSqliteProvider sqliteProvider) + { + return new SessionFsSqliteExistsResult { Exists = false }; + } try { - var exists = await SqliteExistsAsync(request.SessionId, cancellationToken).ConfigureAwait(false); + var exists = await sqliteProvider.ExistsAsync(cancellationToken).ConfigureAwait(false); return new SessionFsSqliteExistsResult { Exists = exists }; } - catch (Exception ex) when (!IsCriticalException(ex)) + catch { return new SessionFsSqliteExistsResult { Exists = false }; } } - private static bool IsCriticalException(Exception ex) => - ex is OperationCanceledException - or OutOfMemoryException - or StackOverflowException - or AccessViolationException - or AppDomainUnloadedException - or BadImageFormatException - or CannotUnloadAppDomainException - or InvalidProgramException; private static SessionFsError ToSessionFsError(Exception ex) { diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 42747bcb1..18cab63d4 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -323,6 +323,13 @@ public sealed class SessionFsConfig /// Path conventions used by this filesystem provider. /// public required SessionFsSetProviderConventions Conventions { get; init; } + + /// + /// Optional capabilities that this filesystem provider supports. + /// When is true, + /// the runtime routes SQLite queries through the provider instead of using a local database file. + /// + public SessionFsSetProviderCapabilities? Capabilities { get; init; } } /// diff --git a/dotnet/test/E2E/InMemorySessionFsSqliteHandler.cs b/dotnet/test/E2E/InMemorySessionFsSqliteHandler.cs new file mode 100644 index 000000000..2bfc4b1d8 --- /dev/null +++ b/dotnet/test/E2E/InMemorySessionFsSqliteHandler.cs @@ -0,0 +1,205 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using GitHub.Copilot.SDK; +using GitHub.Copilot.SDK.Rpc; +using Microsoft.Data.Sqlite; + +namespace GitHub.Copilot.SDK.Test.E2E; + +internal record SqliteCall(string SessionId, string QueryType, string Query); + +/// +/// A SessionFsProvider that implements with a real +/// in-memory SQLite database, and uses a simple +/// for file operations instead of touching disk. +/// +internal sealed class InMemorySessionFsSqliteHandler(string sessionId, List sqliteCalls) + : SessionFsProvider, ISessionFsSqliteProvider +{ + internal ConcurrentDictionary Files { get; } = new(); + private readonly ConcurrentDictionary _directories = new(); + private SqliteConnection? _db; + + private SqliteConnection GetOrCreateDb() + { + if (_db is not null) + { + return _db; + } + + _db = new SqliteConnection("Data Source=:memory:"); + _db.Open(); + using var cmd = _db.CreateCommand(); + cmd.CommandText = "PRAGMA busy_timeout = 5000"; + cmd.ExecuteNonQuery(); + return _db; + } + + // ---- ISessionFsSqliteProvider ---- + + public Task QueryAsync( + SessionFsSqliteQueryType queryType, + string query, + IDictionary? bindParams, + CancellationToken cancellationToken) + { + sqliteCalls.Add(new SqliteCall(sessionId, queryType.Value, query)); + + var trimmed = query.Trim(); + if (trimmed.Length == 0) + { + return Task.FromResult(null); + } + + var db = GetOrCreateDb(); + + if (queryType == SessionFsSqliteQueryType.Exec) + { + using var cmd = db.CreateCommand(); + cmd.CommandText = trimmed; + cmd.ExecuteNonQuery(); + return Task.FromResult(null); + } + + if (queryType == SessionFsSqliteQueryType.Query) + { + using var cmd = db.CreateCommand(); + cmd.CommandText = trimmed; + AddParams(cmd, bindParams); + + using var reader = cmd.ExecuteReader(); + var columns = new List(); + for (var i = 0; i < reader.FieldCount; i++) + { + columns.Add(reader.GetName(i)); + } + + var rows = new List>(); + while (reader.Read()) + { + var row = new Dictionary(reader.FieldCount); + for (var i = 0; i < reader.FieldCount; i++) + { + row[columns[i]] = reader.IsDBNull(i) ? null! : reader.GetValue(i); + } + rows.Add(row); + } + + return Task.FromResult(new SessionFsSqliteResult + { + Columns = columns, + Rows = rows, + RowsAffected = 0, + }); + } + + if (queryType == SessionFsSqliteQueryType.Run) + { + using var cmd = db.CreateCommand(); + cmd.CommandText = trimmed; + AddParams(cmd, bindParams); + + var rowsAffected = cmd.ExecuteNonQuery(); + + using var rowidCmd = db.CreateCommand(); + rowidCmd.CommandText = "SELECT last_insert_rowid()"; + var lastRowid = rowidCmd.ExecuteScalar(); + + return Task.FromResult(new SessionFsSqliteResult + { + Columns = [], + Rows = [], + RowsAffected = rowsAffected, + LastInsertRowid = lastRowid is long l ? l : null, + }); + } + + throw new ArgumentException($"Unknown queryType: {queryType}"); + } + + public Task ExistsAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_db is not null); + } + + private static void AddParams(SqliteCommand cmd, IDictionary? bindParams) + { + if (bindParams is null) return; + foreach (var (key, value) in bindParams) + { + cmd.Parameters.AddWithValue(key.StartsWith(':') || key.StartsWith('$') || key.StartsWith('@') ? key : $":{key}", value ?? DBNull.Value); + } + } + + // ---- File operations (in-memory) ---- + + private string Resolve(string path) => $"/{sessionId}{(path.StartsWith('/') ? path : "/" + path)}"; + + protected override Task ReadFileAsync(string path, CancellationToken cancellationToken) + { + var key = Resolve(path); + if (!Files.TryGetValue(key, out var content)) + throw new FileNotFoundException($"File not found: {path}"); + return Task.FromResult(content); + } + + protected override Task WriteFileAsync(string path, string content, int? mode, CancellationToken cancellationToken) + { + Files[Resolve(path)] = content; + return Task.CompletedTask; + } + + protected override Task AppendFileAsync(string path, string content, int? mode, CancellationToken cancellationToken) + { + Files.AddOrUpdate(Resolve(path), content, (_, existing) => existing + content); + return Task.CompletedTask; + } + + protected override Task ExistsAsync(string path, CancellationToken cancellationToken) + { + var key = Resolve(path); + return Task.FromResult(Files.ContainsKey(key) || _directories.ContainsKey(key)); + } + + protected override Task StatAsync(string path, CancellationToken cancellationToken) + { + var key = Resolve(path); + if (Files.TryGetValue(key, out var fileContent)) + return Task.FromResult(new SessionFsStatResult { IsFile = true, IsDirectory = false, Size = fileContent.Length }); + if (_directories.ContainsKey(key)) + return Task.FromResult(new SessionFsStatResult { IsFile = false, IsDirectory = true, Size = 0 }); + throw new FileNotFoundException($"Path does not exist: {path}"); + } + + protected override Task MkdirAsync(string path, bool recursive, int? mode, CancellationToken cancellationToken) + { + _directories[Resolve(path)] = 0; + return Task.CompletedTask; + } + + protected override Task> ReaddirAsync(string path, CancellationToken cancellationToken) + => Task.FromResult>([]); + + protected override Task> ReaddirWithTypesAsync(string path, CancellationToken cancellationToken) + => Task.FromResult>([]); + + protected override Task RmAsync(string path, bool recursive, bool force, CancellationToken cancellationToken) + { + var key = Resolve(path); + Files.TryRemove(key, out _); + _directories.TryRemove(key, out _); + return Task.CompletedTask; + } + + protected override Task RenameAsync(string src, string dest, CancellationToken cancellationToken) + { + var srcKey = Resolve(src); + var destKey = Resolve(dest); + if (Files.TryRemove(srcKey, out var content)) + Files[destKey] = content; + return Task.CompletedTask; + } +} diff --git a/dotnet/test/E2E/SessionFsE2ETests.cs b/dotnet/test/E2E/SessionFsE2ETests.cs index c3c317b6f..7649c160c 100644 --- a/dotnet/test/E2E/SessionFsE2ETests.cs +++ b/dotnet/test/E2E/SessionFsE2ETests.cs @@ -576,7 +576,7 @@ private static string NormalizeRelativePathSegment(string segment, string paramN return normalized; } - private sealed class ThrowingSessionFsProvider(Exception exception) : SessionFsProvider + private sealed class ThrowingSessionFsProvider(Exception exception) : SessionFsProvider, ISessionFsSqliteProvider { protected override Task ReadFileAsync(string path, CancellationToken cancellationToken) => Task.FromException(exception); @@ -608,15 +608,10 @@ protected override Task RmAsync(string path, bool recursive, bool force, Cancell protected override Task RenameAsync(string src, string dest, CancellationToken cancellationToken) => Task.FromException(exception); - protected override Task SqliteQueryAsync( - string sessionId, - string query, - SessionFsSqliteQueryType queryType, - IDictionary? parameters, - CancellationToken cancellationToken) => - Task.FromException(exception); + Task ISessionFsSqliteProvider.QueryAsync(SessionFsSqliteQueryType queryType, string query, IDictionary? bindParams, CancellationToken cancellationToken) => + Task.FromException(exception); - protected override Task SqliteExistsAsync(string sessionId, CancellationToken cancellationToken) => + Task ISessionFsSqliteProvider.ExistsAsync(CancellationToken cancellationToken) => Task.FromException(exception); } @@ -751,18 +746,6 @@ protected override Task RenameAsync(string src, string dest, CancellationToken c return Task.CompletedTask; } - protected override Task SqliteQueryAsync( - string sessionId, - string query, - SessionFsSqliteQueryType queryType, - IDictionary? parameters, - CancellationToken cancellationToken) => - Task.FromException( - new NotSupportedException("SQLite session filesystem operations are not supported by this provider.")); - - protected override Task SqliteExistsAsync(string sessionId, CancellationToken cancellationToken) => - Task.FromResult(false); - private string ResolvePath(string sessionPath) { var normalizedSessionId = NormalizeRelativePathSegment(sessionId, nameof(sessionId)); diff --git a/dotnet/test/E2E/SessionFsSqliteE2ETests.cs b/dotnet/test/E2E/SessionFsSqliteE2ETests.cs new file mode 100644 index 000000000..f495fca3d --- /dev/null +++ b/dotnet/test/E2E/SessionFsSqliteE2ETests.cs @@ -0,0 +1,118 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Rpc; +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test.E2E; + +public class SessionFsSqliteE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "session_fs_sqlite", output) +{ + private static readonly SessionFsConfig SessionFsConfig = new() + { + InitialCwd = "/", + SessionStatePath = "/session-state", + Conventions = SessionFsSetProviderConventions.Posix, + Capabilities = new SessionFsSetProviderCapabilities { Sqlite = true }, + }; + + private readonly List _sqliteCalls = []; + + [Fact] + public async Task Should_Route_Sql_Queries_Through_The_Sessionfs_Sqlite_Handler() + { + await using var client = CreateSessionFsClient(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new InMemorySessionFsSqliteHandler(s.SessionId, _sqliteCalls), + }); + + var msg = await session.SendAndWaitAsync(new MessageOptions + { + Prompt = + "Use the sql tool to create a table called \"items\" with columns id (TEXT PRIMARY KEY) and name (TEXT). " + + "Then insert a row with id \"a1\" and name \"Widget\".", + }); + + var sessionCalls = _sqliteCalls.Where(c => c.SessionId == session.SessionId).ToList(); + Assert.NotEmpty(sessionCalls); + Assert.Contains(sessionCalls, c => c.Query.Contains("CREATE TABLE", StringComparison.OrdinalIgnoreCase)); + Assert.Contains(sessionCalls, c => c.Query.Contains("INSERT", StringComparison.OrdinalIgnoreCase)); + + Assert.Contains(sessionCalls, c => c.QueryType == "exec"); + Assert.Contains(sessionCalls, c => c.QueryType == "run"); + + await session.DisposeAsync(); + } + + [Fact] + public async Task Should_Allow_Subagents_To_Use_Sql_Tool_Via_Inherited_Sessionfs() + { + await using var client = CreateSessionFsClient(); + + var handler = (InMemorySessionFsSqliteHandler?)null; + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => + { + handler = new InMemorySessionFsSqliteHandler(s.SessionId, _sqliteCalls); + return handler; + }, + }); + + var events = new List(); + using var _ = session.On(evt => events.Add(evt)); + + await session.SendAndWaitAsync(new MessageOptions + { + Prompt = + "Use the task tool to ask a task agent to do the following: " + + "Use the sql tool to run this query: INSERT INTO todos (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')", + }); + + await session.DisposeAsync(); + + var sessionCalls = _sqliteCalls.Where(c => c.SessionId == session.SessionId).ToList(); + var insertCalls = sessionCalls.Where(c => c.Query.Contains("INSERT", StringComparison.OrdinalIgnoreCase)).ToList(); + Assert.NotEmpty(insertCalls); + + // Verify that the sql tool execution in events.jsonl came from the subagent (has agentId) + Assert.NotNull(handler); + var eventsKey = $"/{session.SessionId}/session-state/events.jsonl"; + await TestHelper.WaitForConditionAsync( + () => Task.FromResult(handler!.Files.ContainsKey(eventsKey)), + timeout: TimeSpan.FromSeconds(30), + timeoutMessage: "Timed out waiting for events.jsonl to be written."); + Assert.True(handler!.Files.TryGetValue(eventsKey, out var content)); + var lines = content.Split('\n', StringSplitOptions.RemoveEmptyEntries); + var sqlToolEvents = lines + .Select(line => System.Text.Json.JsonDocument.Parse(line)) + .Where(doc => + doc.RootElement.TryGetProperty("type", out var type) && type.GetString() == "tool.execution_start" && + doc.RootElement.TryGetProperty("data", out var data) && data.TryGetProperty("toolName", out var toolName) && toolName.GetString() == "sql") + .ToList(); + Assert.NotEmpty(sqlToolEvents); + Assert.All(sqlToolEvents, evt => + { + Assert.True(evt.RootElement.TryGetProperty("agentId", out var agentId)); + Assert.False(string.IsNullOrEmpty(agentId.GetString())); + }); + } + + private CopilotClient CreateSessionFsClient() + { + return Ctx.CreateClient( + useStdio: true, + options: new CopilotClientOptions + { + SessionFs = SessionFsConfig, + }); + } +} diff --git a/dotnet/test/GitHub.Copilot.SDK.Test.csproj b/dotnet/test/GitHub.Copilot.SDK.Test.csproj index 0eb5a626c..cdff9b014 100644 --- a/dotnet/test/GitHub.Copilot.SDK.Test.csproj +++ b/dotnet/test/GitHub.Copilot.SDK.Test.csproj @@ -16,6 +16,7 @@ + diff --git a/go/client.go b/go/client.go index 392ccd595..dcf793d5a 100644 --- a/go/client.go +++ b/go/client.go @@ -373,11 +373,18 @@ func (c *Client) Start(ctx context.Context) error { // If a session filesystem provider was configured, register it. if c.options.SessionFs != nil { - _, err := c.RPC.SessionFs.SetProvider(ctx, &rpc.SessionFsSetProviderRequest{ + req := &rpc.SessionFsSetProviderRequest{ InitialCwd: c.options.SessionFs.InitialCwd, SessionStatePath: c.options.SessionFs.SessionStatePath, Conventions: c.options.SessionFs.Conventions, - }) + } + if c.options.SessionFs.Capabilities != nil { + sqlite := c.options.SessionFs.Capabilities.Sqlite + req.Capabilities = &rpc.SessionFsSetProviderCapabilities{ + Sqlite: &sqlite, + } + } + _, err := c.RPC.SessionFs.SetProvider(ctx, req) if err != nil { killErr := c.killProcess() c.state = StateError @@ -737,7 +744,16 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses c.sessionsMux.Unlock() return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") } - session.clientSessionApis.SessionFs = newSessionFsAdapter(config.CreateSessionFsHandler(session)) + provider := config.CreateSessionFsHandler(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } + } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) } result, err := c.client.Request("session.create", req) @@ -913,7 +929,16 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessionsMux.Unlock() return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") } - session.clientSessionApis.SessionFs = newSessionFsAdapter(config.CreateSessionFsHandler(session)) + provider := config.CreateSessionFsHandler(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } + } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) } result, err := c.client.Request("session.resume", req) diff --git a/go/go.mod b/go/go.mod index ed06061a0..16114a0ab 100644 --- a/go/go.mod +++ b/go/go.mod @@ -10,6 +10,7 @@ require ( require ( github.com/google/uuid v1.6.0 go.opentelemetry.io/otel v1.35.0 + go.opentelemetry.io/otel/trace v1.35.0 ) require ( @@ -17,5 +18,4 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect - go.opentelemetry.io/otel/trace v1.35.0 // indirect ) diff --git a/go/internal/e2e/session_fs_e2e_test.go b/go/internal/e2e/session_fs_e2e_test.go index a23613c72..d56dc14a3 100644 --- a/go/internal/e2e/session_fs_e2e_test.go +++ b/go/internal/e2e/session_fs_e2e_test.go @@ -473,23 +473,6 @@ func (h *testSessionFsHandler) Rename(src string, dest string) error { return os.Rename(providerPath(h.root, h.sessionID, src), destPath) } -func (h *testSessionFsHandler) SqliteQuery(sessionID string, query string, queryType rpc.SessionFsSqliteQueryType, params map[string]any) (*rpc.SessionFsSqliteQueryResult, error) { - return &rpc.SessionFsSqliteQueryResult{ - Columns: []string{"sessionId", "query", "queryType", "answer"}, - Rows: []map[string]any{{ - "sessionId": sessionID, - "query": query, - "queryType": string(queryType), - "answer": params["answer"], - }}, - RowsAffected: 0, - }, nil -} - -func (h *testSessionFsHandler) SqliteExists(sessionID string) (bool, error) { - return sessionID == h.sessionID, nil -} - func providerPath(root string, sessionID string, path string) string { trimmed := strings.TrimPrefix(path, "/") if trimmed == "" { @@ -653,28 +636,6 @@ func TestSessionFsHandlerOperationsE2E(t *testing.T) { if _, err := handler.Stat("/workspace/nested/missing.txt"); err == nil || !os.IsNotExist(err) { t.Errorf("Expected os.ErrNotExist from Stat on missing file, got %v", err) } - - sqliteResult, err := handler.SqliteQuery(sessionID, "select :answer as answer", rpc.SessionFsSqliteQueryTypeQuery, map[string]any{"answer": 42}) - if err != nil { - t.Fatalf("SqliteQuery failed: %v", err) - } - if len(sqliteResult.Columns) != 4 || sqliteResult.Columns[3] != "answer" { - t.Errorf("Expected SQLite result columns to include answer, got %v", sqliteResult.Columns) - } - if len(sqliteResult.Rows) != 1 || sqliteResult.Rows[0]["answer"] != 42 { - t.Errorf("Expected SQLite result row to include answer=42, got %+v", sqliteResult.Rows) - } - if sqliteResult.RowsAffected != 0 { - t.Errorf("Expected RowsAffected=0, got %d", sqliteResult.RowsAffected) - } - - sqliteExists, err := handler.SqliteExists(sessionID) - if err != nil { - t.Fatalf("SqliteExists failed: %v", err) - } - if !sqliteExists { - t.Error("Expected SQLite database to exist for the handler session") - } } func sliceContains(slice []string, value string) bool { diff --git a/go/internal/e2e/session_fs_sqlite_e2e_test.go b/go/internal/e2e/session_fs_sqlite_e2e_test.go new file mode 100644 index 000000000..f73cf2e34 --- /dev/null +++ b/go/internal/e2e/session_fs_sqlite_e2e_test.go @@ -0,0 +1,414 @@ +package e2e + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + "sync" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" + "github.com/github/copilot-sdk/go/rpc" +) + +type sqliteCall struct { + SessionID string + QueryType string + Query string +} + +// inMemorySqliteProvider is a SessionFsProvider backed by in-memory maps with a stub SQLite handler. +// The stub returns plausible canned responses based on query type rather than executing real SQL. +// This avoids pulling in a real SQLite dependency (which would force a go directive bump across +// all scenario go.mod files). +type inMemorySqliteProvider struct { + mu sync.Mutex + sessionID string + files map[string]string + dirs map[string]bool + hadQuery bool + sqliteCalls *[]sqliteCall +} + +func newInMemorySqliteProvider(sessionID string, calls *[]sqliteCall) *inMemorySqliteProvider { + return &inMemorySqliteProvider{ + sessionID: sessionID, + files: make(map[string]string), + dirs: map[string]bool{"/": true}, + sqliteCalls: calls, + } +} + +func (p *inMemorySqliteProvider) ensureParent(path string) { + parts := strings.Split(strings.TrimRight(path, "/"), "/") + for i := 1; i < len(parts); i++ { + p.dirs[strings.Join(parts[:i], "/")] = true + } +} + +func (p *inMemorySqliteProvider) ReadFile(path string) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + content, ok := p.files[path] + if !ok { + return "", fmt.Errorf("file not found: %s", path) + } + return content, nil +} + +func (p *inMemorySqliteProvider) WriteFile(path string, content string, mode *int) error { + p.mu.Lock() + defer p.mu.Unlock() + p.ensureParent(path) + p.files[path] = content + return nil +} + +func (p *inMemorySqliteProvider) AppendFile(path string, content string, mode *int) error { + p.mu.Lock() + defer p.mu.Unlock() + p.ensureParent(path) + p.files[path] = p.files[path] + content + return nil +} + +func (p *inMemorySqliteProvider) Exists(path string) (bool, error) { + p.mu.Lock() + defer p.mu.Unlock() + _, isFile := p.files[path] + _, isDir := p.dirs[path] + return isFile || isDir, nil +} + +func (p *inMemorySqliteProvider) Stat(path string) (*copilot.SessionFsFileInfo, error) { + p.mu.Lock() + defer p.mu.Unlock() + now := time.Now().UTC() + if p.dirs[path] { + return &copilot.SessionFsFileInfo{ + IsFile: false, IsDirectory: true, Size: 0, Mtime: now, Birthtime: now, + }, nil + } + if content, ok := p.files[path]; ok { + return &copilot.SessionFsFileInfo{ + IsFile: true, IsDirectory: false, Size: int64(len(content)), Mtime: now, Birthtime: now, + }, nil + } + return nil, fmt.Errorf("not found: %s", path) +} + +func (p *inMemorySqliteProvider) Mkdir(path string, recursive bool, mode *int) error { + p.mu.Lock() + defer p.mu.Unlock() + if recursive { + parts := strings.Split(strings.TrimRight(path, "/"), "/") + for i := 1; i <= len(parts); i++ { + p.dirs[strings.Join(parts[:i], "/")] = true + } + } else { + p.dirs[path] = true + } + return nil +} + +func (p *inMemorySqliteProvider) Readdir(path string) ([]string, error) { + p.mu.Lock() + defer p.mu.Unlock() + prefix := strings.TrimRight(path, "/") + "/" + names := map[string]bool{} + for f := range p.files { + if strings.HasPrefix(f, prefix) { + rest := f[len(prefix):] + if rest != "" { + names[strings.SplitN(rest, "/", 2)[0]] = true + } + } + } + for d := range p.dirs { + if strings.HasPrefix(d, prefix) { + rest := d[len(prefix):] + if rest != "" { + names[strings.SplitN(rest, "/", 2)[0]] = true + } + } + } + result := make([]string, 0, len(names)) + for n := range names { + result = append(result, n) + } + sort.Strings(result) + return result, nil +} + +func (p *inMemorySqliteProvider) ReaddirWithTypes(path string) ([]rpc.SessionFsReaddirWithTypesEntry, error) { + p.mu.Lock() + defer p.mu.Unlock() + prefix := strings.TrimRight(path, "/") + "/" + entries := map[string]rpc.SessionFsReaddirWithTypesEntryType{} + for d := range p.dirs { + if strings.HasPrefix(d, prefix) { + rest := d[len(prefix):] + if rest != "" { + name := strings.SplitN(rest, "/", 2)[0] + entries[name] = rpc.SessionFsReaddirWithTypesEntryTypeDirectory + } + } + } + for f := range p.files { + if strings.HasPrefix(f, prefix) { + rest := f[len(prefix):] + if rest != "" { + name := strings.SplitN(rest, "/", 2)[0] + if _, exists := entries[name]; !exists { + entries[name] = rpc.SessionFsReaddirWithTypesEntryTypeFile + } + } + } + } + result := make([]rpc.SessionFsReaddirWithTypesEntry, 0, len(entries)) + for name, typ := range entries { + result = append(result, rpc.SessionFsReaddirWithTypesEntry{Name: name, Type: typ}) + } + sort.Slice(result, func(i, j int) bool { return result[i].Name < result[j].Name }) + return result, nil +} + +func (p *inMemorySqliteProvider) Rm(path string, recursive bool, force bool) error { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.files, path) + delete(p.dirs, path) + return nil +} + +func (p *inMemorySqliteProvider) Rename(src string, dest string) error { + p.mu.Lock() + defer p.mu.Unlock() + if content, ok := p.files[src]; ok { + p.ensureParent(dest) + p.files[dest] = content + delete(p.files, src) + } + return nil +} + +func (p *inMemorySqliteProvider) SqliteQuery(queryType rpc.SessionFsSqliteQueryType, query string, params map[string]any) (*copilot.SessionFsSqliteQueryResult, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.hadQuery = true + *p.sqliteCalls = append(*p.sqliteCalls, sqliteCall{ + SessionID: p.sessionID, + QueryType: string(queryType), + Query: query, + }) + + // Return canned results based on query type. The agent doesn't know or care + // whether a real SQLite database is behind this — it just receives SQL tool + // results. These stubs return plausible responses so the agent can proceed + // normally without pulling in a real SQLite dependency. + upper := strings.ToUpper(strings.TrimSpace(query)) + switch queryType { + case rpc.SessionFsSqliteQueryTypeExec: + return &copilot.SessionFsSqliteQueryResult{Columns: []string{}, Rows: []map[string]any{}}, nil + case rpc.SessionFsSqliteQueryTypeRun: + lastID := int64(1) + return &copilot.SessionFsSqliteQueryResult{ + Columns: []string{}, + Rows: []map[string]any{}, + RowsAffected: 1, + LastInsertRowid: &lastID, + }, nil + case rpc.SessionFsSqliteQueryTypeQuery: + if strings.Contains(upper, "SELECT") { + return &copilot.SessionFsSqliteQueryResult{ + Columns: []string{"id", "name"}, + Rows: []map[string]any{{"id": "a1", "name": "Widget"}}, + }, nil + } + return &copilot.SessionFsSqliteQueryResult{Columns: []string{}, Rows: []map[string]any{}}, nil + } + return &copilot.SessionFsSqliteQueryResult{Columns: []string{}, Rows: []map[string]any{}}, nil +} + +func (p *inMemorySqliteProvider) SqliteExists() (bool, error) { + p.mu.Lock() + defer p.mu.Unlock() + return p.hadQuery, nil +} + +func TestSessionFsSqliteE2E(t *testing.T) { + ctx := testharness.NewTestContext(t) + sessionStatePath := createSessionStatePath(t) + sessionFsConfig := &copilot.SessionFsConfig{ + InitialCwd: "/", + SessionStatePath: sessionStatePath, + Conventions: rpc.SessionFsSetProviderConventionsPosix, + Capabilities: &copilot.SessionFsCapabilities{Sqlite: true}, + } + + var sqliteCalls []sqliteCall + var providers sync.Map + + createSessionFsHandler := func(session *copilot.Session) copilot.SessionFsProvider { + p := newInMemorySqliteProvider(session.SessionID, &sqliteCalls) + providers.Store(session.SessionID, p) + return p + } + + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.SessionFs = sessionFsConfig + }) + t.Cleanup(func() { client.ForceStop() }) + + t.Run("should route sql queries through the sessionfs sqlite handler", func(t *testing.T) { + ctx.ConfigureForTest(t) + sqliteCalls = nil + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + msg, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + Prompt: `Use the sql tool to create a table called "items" with columns id (TEXT PRIMARY KEY) and name (TEXT). ` + + `Then insert a row with id "a1" and name "Widget".`, + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + _ = msg + + // Verify sqlite handler was called + sessionCalls := filterCalls(sqliteCalls, session.SessionID) + if len(sessionCalls) == 0 { + t.Fatal("Expected sqlite handler to be called") + } + assertCallContains(t, sessionCalls, "CREATE TABLE") + assertCallContains(t, sessionCalls, "INSERT") + + // Verify queryType is set correctly + assertQueryType(t, sessionCalls, "exec") + assertQueryType(t, sessionCalls, "run") + + if err := session.Disconnect(); err != nil { + t.Fatalf("Failed to disconnect: %v", err) + } + }) + + t.Run("should allow subagents to use sql tool via inherited sessionfs", func(t *testing.T) { + ctx.ConfigureForTest(t) + sqliteCalls = nil + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ + Prompt: "Use the task tool to ask a task agent to do the following: " + + "Use the sql tool to run this query: INSERT INTO todos " + + "(id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')", + }) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + if err := session.Disconnect(); err != nil { + t.Fatalf("Failed to disconnect: %v", err) + } + + // Verify INSERT calls were routed + sessionCalls := filterCalls(sqliteCalls, session.SessionID) + insertCalls := filterByQuery(sessionCalls, "INSERT") + if len(insertCalls) == 0 { + t.Fatal("Expected INSERT calls from subagent") + } + + // Read events.jsonl from in-memory FS + val, ok := providers.Load(session.SessionID) + if !ok { + t.Fatal("Provider not found for session") + } + provider := val.(*inMemorySqliteProvider) + eventsPath := sessionStatePath + "/events.jsonl" + content, err := provider.ReadFile(eventsPath) + if err != nil { + t.Fatalf("Failed to read events.jsonl: %v", err) + } + lines := strings.Split(strings.TrimSpace(content), "\n") + var sqlToolEvents []map[string]any + for _, line := range lines { + if line == "" { + continue + } + var event map[string]any + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue + } + if event["type"] == "tool.execution_start" { + if data, ok := event["data"].(map[string]any); ok { + if data["toolName"] == "sql" { + sqlToolEvents = append(sqlToolEvents, event) + } + } + } + } + if len(sqlToolEvents) == 0 { + t.Fatal("Expected sql tool events in events.jsonl") + } + for _, e := range sqlToolEvents { + if e["agentId"] == nil || e["agentId"] == "" { + t.Error("Expected agentId on sql tool event") + } + } + }) +} + +func filterCalls(calls []sqliteCall, sessionID string) []sqliteCall { + var result []sqliteCall + for _, c := range calls { + if c.SessionID == sessionID { + result = append(result, c) + } + } + return result +} + +func filterByQuery(calls []sqliteCall, keyword string) []sqliteCall { + var result []sqliteCall + for _, c := range calls { + if strings.Contains(strings.ToUpper(c.Query), keyword) { + result = append(result, c) + } + } + return result +} + +func assertCallContains(t *testing.T, calls []sqliteCall, keyword string) { + t.Helper() + for _, c := range calls { + if strings.Contains(strings.ToUpper(c.Query), keyword) { + return + } + } + t.Errorf("Expected a call with query containing %q", keyword) +} + +func assertQueryType(t *testing.T, calls []sqliteCall, queryType string) { + t.Helper() + for _, c := range calls { + if c.QueryType == queryType { + return + } + } + t.Errorf("Expected a call with queryType %q", queryType) +} diff --git a/go/session_fs_provider.go b/go/session_fs_provider.go index 6051a5e4a..3a6f297f8 100644 --- a/go/session_fs_provider.go +++ b/go/session_fs_provider.go @@ -15,6 +15,8 @@ import ( // SessionFsProvider is the interface that SDK users implement to provide // a session filesystem. Methods use idiomatic Go error handling: return an // error for failures (the adapter maps os.ErrNotExist → ENOENT automatically). +// +// To add SQLite support, also implement [SessionFsSqliteProvider] on the same type. type SessionFsProvider interface { // ReadFile reads the full content of a file. Return os.ErrNotExist (or wrap it) // if the file does not exist. @@ -44,10 +46,30 @@ type SessionFsProvider interface { Rm(path string, recursive bool, force bool) error // Rename moves/renames a file or directory. Rename(src string, dest string) error +} + +// SessionFsSqliteProvider is an optional interface that a [SessionFsProvider] +// may also implement to support per-session SQLite databases. The adapter +// checks for this interface at runtime using a type assertion. If the +// provider does not implement it, SQLite requests return an "unsupported" error. +// +// Providers are already session-scoped (created per session by the factory), +// so these methods do not take a session ID parameter. +type SessionFsSqliteProvider interface { // SqliteQuery executes a SQLite query against the provider's per-session database. - SqliteQuery(sessionID string, query string, queryType rpc.SessionFsSqliteQueryType, params map[string]any) (*rpc.SessionFsSqliteQueryResult, error) + SqliteQuery(queryType rpc.SessionFsSqliteQueryType, query string, params map[string]any) (*SessionFsSqliteQueryResult, error) // SqliteExists checks whether the provider has a SQLite database for the session. - SqliteExists(sessionID string) (bool, error) + SqliteExists() (bool, error) +} + +// SessionFsSqliteQueryResult holds the result of a SQLite query execution. +// Same shape as the generated RPC type but without the Error field, +// since providers signal errors by returning a Go error. +type SessionFsSqliteQueryResult struct { + Columns []string `json:"columns"` + Rows []map[string]any `json:"rows"` + RowsAffected int64 `json:"rowsAffected"` + LastInsertRowid *int64 `json:"lastInsertRowid,omitempty"` } // SessionFsFileInfo holds file metadata returned by SessionFsProvider.Stat. @@ -169,7 +191,17 @@ func (a *sessionFsAdapter) Rename(request *rpc.SessionFsRenameRequest) (*rpc.Ses } func (a *sessionFsAdapter) SqliteQuery(request *rpc.SessionFsSqliteQueryRequest) (*rpc.SessionFsSqliteQueryResult, error) { - result, err := a.provider.SqliteQuery(request.SessionID, request.Query, request.QueryType, request.Params) + sp, ok := a.provider.(SessionFsSqliteProvider) + if !ok { + msg := "SQLite is not supported by this session filesystem provider" + return &rpc.SessionFsSqliteQueryResult{ + Columns: []string{}, + Rows: []map[string]any{}, + RowsAffected: 0, + Error: &rpc.SessionFsError{Code: rpc.SessionFsErrorCodeUNKNOWN, Message: &msg}, + }, nil + } + result, err := sp.SqliteQuery(request.QueryType, request.Query, request.Params) if err != nil { return &rpc.SessionFsSqliteQueryResult{ Columns: []string{}, @@ -178,11 +210,32 @@ func (a *sessionFsAdapter) SqliteQuery(request *rpc.SessionFsSqliteQueryRequest) Error: toSessionFsError(err), }, nil } - return result, nil + if result == nil { + return &rpc.SessionFsSqliteQueryResult{ + Columns: []string{}, + Rows: []map[string]any{}, + RowsAffected: 0, + }, nil + } + var wireRowid *float64 + if result.LastInsertRowid != nil { + f := float64(*result.LastInsertRowid) + wireRowid = &f + } + return &rpc.SessionFsSqliteQueryResult{ + Columns: result.Columns, + Rows: result.Rows, + RowsAffected: result.RowsAffected, + LastInsertRowid: wireRowid, + }, nil } func (a *sessionFsAdapter) SqliteExists(request *rpc.SessionFsSqliteExistsRequest) (*rpc.SessionFsSqliteExistsResult, error) { - exists, err := a.provider.SqliteExists(request.SessionID) + sp, ok := a.provider.(SessionFsSqliteProvider) + if !ok { + return &rpc.SessionFsSqliteExistsResult{Exists: false}, nil + } + exists, err := sp.SqliteExists() if err != nil { return &rpc.SessionFsSqliteExistsResult{Exists: false}, nil } diff --git a/go/types.go b/go/types.go index cc30ff9d7..f568d1325 100644 --- a/go/types.go +++ b/go/types.go @@ -567,6 +567,12 @@ type InfiniteSessionConfig struct { BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } +// SessionFsCapabilities declares optional provider capabilities. +type SessionFsCapabilities struct { + // Sqlite indicates whether the provider supports SQLite query/exists operations. + Sqlite bool +} + // SessionFsConfig configures a custom session filesystem provider. type SessionFsConfig struct { // InitialCwd is the initial working directory for sessions. @@ -576,6 +582,8 @@ type SessionFsConfig struct { SessionStatePath string // Conventions identifies the path conventions used by this filesystem provider. Conventions rpc.SessionFsSetProviderConventions + // Capabilities declares optional provider capabilities such as SQLite support. + Capabilities *SessionFsCapabilities } // SessionConfig configures a new session diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 42d838ad2..b7f474d1d 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -33,7 +33,7 @@ import { } from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession, NO_RESULT_PERMISSION_V2_ERROR } from "./session.js"; -import { createSessionFsAdapter } from "./sessionFsProvider.js"; +import { createSessionFsAdapter, type SessionFsProvider } from "./sessionFsProvider.js"; import { getTraceContext } from "./telemetry.js"; import type { AutoModeSwitchRequest, @@ -450,6 +450,27 @@ export class CopilotClient { } } + private setupSessionFs( + session: CopilotSession, + config: { createSessionFsHandler?: (session: CopilotSession) => SessionFsProvider } + ): void { + if (!this.sessionFsConfig) { + return; + } + if (!config.createSessionFsHandler) { + throw new Error( + "createSessionFsHandler is required in session config when sessionFs is enabled in client options." + ); + } + const provider = config.createSessionFsHandler(session); + if (this.sessionFsConfig.capabilities?.sqlite && !provider.sqlite) { + throw new Error( + "SessionFsConfig declares capabilities.sqlite but the provider does not implement sqlite." + ); + } + session.clientSessionApis.sessionFs = createSessionFsAdapter(provider); + } + /** * Starts the CLI server and establishes a connection. * @@ -493,6 +514,7 @@ export class CopilotClient { initialCwd: this.sessionFsConfig.initialCwd, sessionStatePath: this.sessionFsConfig.sessionStatePath, conventions: this.sessionFsConfig.conventions, + capabilities: this.sessionFsConfig.capabilities, }); } @@ -772,17 +794,7 @@ export class CopilotClient { session.on(config.onEvent); } this.sessions.set(sessionId, session); - if (this.sessionFsConfig) { - if (config.createSessionFsHandler) { - session.clientSessionApis.sessionFs = createSessionFsAdapter( - config.createSessionFsHandler(session) - ); - } else { - throw new Error( - "createSessionFsHandler is required in session config when sessionFs is enabled in client options." - ); - } - } + this.setupSessionFs(session, config); try { const response = await this.connection!.sendRequest("session.create", { @@ -920,17 +932,7 @@ export class CopilotClient { session.on(config.onEvent); } this.sessions.set(sessionId, session); - if (this.sessionFsConfig) { - if (config.createSessionFsHandler) { - session.clientSessionApis.sessionFs = createSessionFsAdapter( - config.createSessionFsHandler(session) - ); - } else { - throw new Error( - "createSessionFsHandler is required in session config when sessionFs is enabled in client options." - ); - } - } + this.setupSessionFs(session, config); try { const response = await this.connection!.sendRequest("session.resume", { diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index b588aaf57..b4484987c 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -79,6 +79,9 @@ export type { SessionFsConfig, SessionFsProvider, SessionFsFileInfo, + SessionFsSqliteQueryResult, + SessionFsSqliteQueryType, + SessionFsSqliteProvider, SystemMessageAppendConfig, SystemMessageConfig, SystemMessageCustomizeConfig, diff --git a/nodejs/src/sessionFsProvider.ts b/nodejs/src/sessionFsProvider.ts index 589a30358..a2da12307 100644 --- a/nodejs/src/sessionFsProvider.ts +++ b/nodejs/src/sessionFsProvider.ts @@ -7,10 +7,12 @@ import type { SessionFsError, SessionFsStatResult, SessionFsReaddirWithTypesEntry, - SessionFsSqliteQueryResult, + SessionFsSqliteQueryResult as GeneratedSqliteQueryResult, SessionFsSqliteQueryType, } from "./generated/rpc.js"; +export type { SessionFsSqliteQueryType }; + /** * File metadata returned by {@link SessionFsProvider.stat}. * Same shape as the generated {@link SessionFsStatResult} but without the @@ -18,6 +20,37 @@ import type { */ export type SessionFsFileInfo = Omit; +/** + * Result of a SQLite query execution via {@link SessionFsSqliteProvider.query}. + * Same shape as the generated {@link GeneratedSqliteQueryResult} but without the + * `error` field, since providers signal errors by throwing. + */ +export type SessionFsSqliteQueryResult = Omit; + +/** + * SQLite operations for the per-session database. + * Implementers provide query execution and existence checking. + */ +export interface SessionFsSqliteProvider { + /** + * Execute a SQLite query against the per-session database. + * + * @param queryType - How to execute: `"exec"` for DDL/multi-statement, `"query"` for SELECT, `"run"` for INSERT/UPDATE/DELETE. + * @param query - SQL query to execute. + * @param params - Optional named bind parameters. + */ + query( + queryType: SessionFsSqliteQueryType, + query: string, + params?: Record + ): Promise; + + /** + * Check whether the per-session database already exists, without creating it. + */ + exists(): Promise; +} + /** * Interface for session filesystem providers. Implementers use idiomatic * TypeScript patterns: throw on error, return values directly. Use @@ -58,16 +91,8 @@ export interface SessionFsProvider { /** Renames/moves a file or directory. */ rename(src: string, dest: string): Promise; - /** Executes a SQLite query against the provider's per-session database. */ - sqliteQuery( - sessionId: string, - query: string, - queryType: SessionFsSqliteQueryType, - params?: Record - ): Promise; - - /** Checks whether the provider has a SQLite database for the session. */ - sqliteExists(sessionId: string): Promise; + /** Per-session SQLite database operations. Optional — omit if the provider does not support SQLite. */ + sqlite?: SessionFsSqliteProvider; } /** @@ -162,24 +187,23 @@ export function createSessionFsAdapter(provider: SessionFsProvider): SessionFsHa return toSessionFsError(err); } }, - sqliteQuery: async ({ sessionId, query, queryType, params }) => { - try { - return await provider.sqliteQuery(sessionId, query, queryType, params); - } catch (err) { - return { - columns: [], - rows: [], - rowsAffected: 0, - error: toSessionFsError(err), - }; + // Unlike the FS methods above, SQLite methods let errors propagate to the JSON-RPC layer + // rather than catching and mapping via toSessionFsError. The FS error mapping is specifically + // for translating Node.js errno codes (e.g., ENOENT) into SessionFsError, which isn't + // meaningful for SQL errors. Letting exceptions propagate preserves the original error + // message in the JSON-RPC error response. + sqliteQuery: async ({ queryType, query, params: bindParams }) => { + if (!provider.sqlite) { + throw new Error("SQLite is not supported by this provider"); } + const result = await provider.sqlite.query(queryType, query, bindParams); + return result ?? { rows: [], columns: [], rowsAffected: 0 }; }, - sqliteExists: async ({ sessionId }) => { - try { - return { exists: await provider.sqliteExists(sessionId) }; - } catch { - return { exists: false }; + sqliteExists: async () => { + if (!provider.sqlite) { + throw new Error("SQLite is not supported by this provider"); } + return { exists: await provider.sqlite.exists() }; }, }; } diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index f18e18ac1..a8e3bdfe5 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -16,6 +16,9 @@ export type SessionEvent = GeneratedSessionEvent; export type { SessionFsProvider } from "./sessionFsProvider.js"; export { createSessionFsAdapter } from "./sessionFsProvider.js"; export type { SessionFsFileInfo } from "./sessionFsProvider.js"; +export type { SessionFsSqliteQueryResult } from "./sessionFsProvider.js"; +export type { SessionFsSqliteQueryType } from "./sessionFsProvider.js"; +export type { SessionFsSqliteProvider } from "./sessionFsProvider.js"; /** * Options for creating a CopilotClient @@ -1793,6 +1796,20 @@ export interface SessionFsConfig { * Path conventions used by this filesystem provider. */ conventions: "windows" | "posix"; + + /** + * Optional capabilities declared by this provider. + * The runtime uses these to determine which features are available. + */ + capabilities?: { + /** + * Whether this provider supports SQLite query/exists operations. + * When false or omitted, the runtime will not offer SQL tools or + * todo tracking for sessions using this provider. + * @default false + */ + sqlite?: boolean; + }; } /** diff --git a/nodejs/test/e2e/harness/CapiProxy.ts b/nodejs/test/e2e/harness/CapiProxy.ts index a5fffc37a..a6232587e 100644 --- a/nodejs/test/e2e/harness/CapiProxy.ts +++ b/nodejs/test/e2e/harness/CapiProxy.ts @@ -6,6 +6,7 @@ import { CopilotUserResponse, ParsedHttpExchange, } from "../../../../test/harness/replayingCapiProxy"; +import { isCI } from "./sdkTestContext"; const HARNESS_SERVER_PATH = resolve(__dirname, "../../../../test/harness/server.ts"); const NO_PROXY = "127.0.0.1,localhost,::1"; @@ -92,9 +93,13 @@ export class CapiProxy { CURL_CA_BUNDLE: this.startupInfo.caFilePath, GIT_SSL_CAINFO: this.startupInfo.caFilePath, GH_TOKEN: "", - GITHUB_TOKEN: "", GH_ENTERPRISE_TOKEN: "", GITHUB_ENTERPRISE_TOKEN: "", + + // In CI we never want it to make real network requests, so there should be no need for auth + // But when running locally you have to be able to generate snapshots and that does require real auth, + // so you should set GH_TOKEN and we need to pass it through into the test app. + ...(isCI ? { GITHUB_TOKEN: "" } : undefined), }; } diff --git a/nodejs/test/e2e/harness/sdkTestContext.ts b/nodejs/test/e2e/harness/sdkTestContext.ts index af9642a50..970cfcbb9 100644 --- a/nodejs/test/e2e/harness/sdkTestContext.ts +++ b/nodejs/test/e2e/harness/sdkTestContext.ts @@ -11,7 +11,7 @@ import { fileURLToPath } from "url"; import { afterAll, afterEach, beforeEach, onTestFailed, TestContext } from "vitest"; import { CopilotClient, CopilotClientOptions } from "../../../src"; import { CapiProxy } from "./CapiProxy"; -import { retry, formatError } from "./sdkTestHelper"; +import { formatError, retry } from "./sdkTestHelper"; export const isCI = process.env.GITHUB_ACTIONS === "true"; export const DEFAULT_GITHUB_TOKEN = "fake-token-for-e2e-tests"; @@ -45,15 +45,19 @@ export async function createSdkTestContext({ }, analytics_tracking_id: "e2e-test-tracking-id", }); + const authTokenToUse = isCI + ? DEFAULT_GITHUB_TOKEN + : (process.env.GITHUB_TOKEN ?? DEFAULT_GITHUB_TOKEN); + const env = { ...process.env, ...openAiEndpoint.getProxyEnv(), COPILOT_API_URL: proxyUrl, COPILOT_HOME: copilotHomeDir, - COPILOT_SDK_AUTH_TOKEN: DEFAULT_GITHUB_TOKEN, + COPILOT_SDK_AUTH_TOKEN: "", GH_CONFIG_DIR: homeDir, - GH_TOKEN: DEFAULT_GITHUB_TOKEN, - GITHUB_TOKEN: DEFAULT_GITHUB_TOKEN, + GH_TOKEN: "", + GITHUB_TOKEN: "", // TODO: I'm not convinced the SDK should default to using whatever config you happen to have in your homedir. // The SDK config should be independent of the regular CLI app. Likewise it shouldn't mix sessions from the @@ -67,7 +71,7 @@ export async function createSdkTestContext({ env, logLevel: logLevel || "error", cliPath: process.env.COPILOT_CLI_PATH, - gitHubToken: DEFAULT_GITHUB_TOKEN, + gitHubToken: authTokenToUse, useStdio: useStdio, ...copilotClientOptions, }); diff --git a/nodejs/test/e2e/session_fs.e2e.test.ts b/nodejs/test/e2e/session_fs.e2e.test.ts index 4181152aa..3987012b1 100644 --- a/nodejs/test/e2e/session_fs.e2e.test.ts +++ b/nodejs/test/e2e/session_fs.e2e.test.ts @@ -45,22 +45,36 @@ describe("Session Fs", async () => { copilotClientOptions: { sessionFs: sessionFsConfig }, }); - it("should route file operations through the session fs provider", async () => { - const session = await client.createSession({ - onPermissionRequest: approveAll, - createSessionFsHandler, - }); + it( + "should route file operations through the session fs provider", + { timeout: 60000 }, + async () => { + const session = await client.createSession({ + onPermissionRequest: approveAll, + createSessionFsHandler, + }); + + const errors: SessionEvent[] = []; + session.on((event) => { + if (event.type === "session.error") { + errors.push(event); + } + }); - const msg = await session.sendAndWait({ prompt: "What is 100 + 200?" }); - expect(msg?.data.content).toContain("300"); - await session.disconnect(); + const msg = await session.sendAndWait({ prompt: "What is 100 + 200?" }); + expect(msg?.data.content).toContain("300"); + await session.disconnect(); - const buf = await provider.readFile( - p(session.sessionId, `${sessionStatePath}/events.jsonl`) - ); - const content = buf.toString("utf8"); - expect(content).toContain("300"); - }); + const buf = await provider.readFile( + p(session.sessionId, `${sessionStatePath}/events.jsonl`) + ); + const content = buf.toString("utf8"); + expect(content).toContain("300"); + + // No sqlite capabilities declared — verify no errors from missing sqlite + expect(errors).toHaveLength(0); + } + ); it("should load session data from fs provider on resume", async () => { const session1 = await client.createSession({ @@ -269,15 +283,24 @@ describe("Session Fs Adapter", () => { async rename(src: string, dest: string): Promise { await provider.rename(src, dest); }, - async sqliteQuery(sessionId, query, queryType, params) { - return { - columns: ["sessionId", "query", "queryType", "answer"], - rows: [{ sessionId, query, queryType, answer: params?.answer }], - rowsAffected: 0, - }; - }, - async sqliteExists(sessionId) { - return sessionId === "handler-session"; + sqlite: { + async query(queryType, query, params) { + return { + columns: ["sessionId", "query", "queryType", "answer"], + rows: [ + { + sessionId: "handler-session", + query, + queryType, + answer: params?.answer, + }, + ], + rowsAffected: 0, + }; + }, + async exists() { + return true; + }, }, }; const handler = createSessionFsAdapter(userProvider); @@ -405,11 +428,13 @@ describe("Session Fs Adapter", () => { rename: async () => { throw enoent; }, - sqliteQuery: async () => { - throw enoent; - }, - sqliteExists: async () => { - throw enoent; + sqlite: { + query: async () => { + throw enoent; + }, + exists: async () => { + throw enoent; + }, }, }; @@ -445,18 +470,18 @@ describe("Session Fs Adapter", () => { assertEnoent((await handler.readdirWithTypes({ path: "missing-dir" } as never)).error); assertEnoent(await handler.rm({ path: "missing.txt" } as never)); assertEnoent(await handler.rename({ src: "missing.txt", dest: "dest.txt" } as never)); - const sqliteQuery = await handler.sqliteQuery({ - sessionId: "throw-session", - query: "select 1", - queryType: "query", - }); - assertEnoent(sqliteQuery.error); - expect(sqliteQuery.columns).toEqual([]); - expect(sqliteQuery.rows).toEqual([]); - expect(sqliteQuery.rowsAffected).toBe(0); - const sqliteExistsResult = await handler.sqliteExists({ sessionId: "throw-session" }); - expect(sqliteExistsResult.exists).toBe(false); + // sqlite methods let errors propagate (no try/catch wrapping) + await expect( + handler.sqliteQuery({ + sessionId: "throw-session", + query: "select 1", + queryType: "query", + }) + ).rejects.toThrow("missing"); + await expect(handler.sqliteExists({ sessionId: "throw-session" })).rejects.toThrow( + "missing" + ); // Non-ENOENT errors map to UNKNOWN. const unknown: SessionFsProvider = { @@ -555,15 +580,17 @@ function createTestSessionFsHandler( async rename(src: string, dest: string): Promise { await provider.rename(sp(src), sp(dest)); }, - async sqliteQuery() { - return { - columns: [], - rows: [], - rowsAffected: 0, - }; - }, - async sqliteExists(sessionId) { - return sessionId === session.sessionId; + sqlite: { + async query() { + return { + columns: [], + rows: [], + rowsAffected: 0, + }; + }, + async exists() { + return true; + }, }, }; } diff --git a/nodejs/test/e2e/session_fs_sqlite.e2e.test.ts b/nodejs/test/e2e/session_fs_sqlite.e2e.test.ts new file mode 100644 index 000000000..cde6ee8cb --- /dev/null +++ b/nodejs/test/e2e/session_fs_sqlite.e2e.test.ts @@ -0,0 +1,252 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { DatabaseSync } from "node:sqlite"; +import { MemoryProvider, VirtualProvider } from "@platformatic/vfs"; +import { mkdtempSync, realpathSync } from "fs"; +import { tmpdir } from "os"; +import { join } from "path"; +import { describe, expect, it } from "vitest"; +import type { SessionFsReaddirWithTypesEntry } from "../../src/generated/rpc.js"; +import { + approveAll, + CopilotSession, + SessionEvent, + type SessionFsConfig, + type SessionFsProvider, + type SessionFsFileInfo, + type SessionFsSqliteQueryResult, + type SessionFsSqliteQueryType, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const sessionStatePath = + process.platform === "win32" + ? "/session-state" + : join( + realpathSync(mkdtempSync(join(tmpdir(), "copilot-sqlite-state-"))), + "session-state" + ).replace(/\\/g, "/"); + +const sessionFsConfig: SessionFsConfig = { + initialCwd: "/", + sessionStatePath, + conventions: "posix", + capabilities: { sqlite: true }, +}; + +describe("Session Fs SQLite", async () => { + const provider = new MemoryProvider(); + /** Track which queries were received, per session */ + const sqliteCalls: { sessionId: string; queryType: string; query: string }[] = []; + /** Per-session SQLite databases, keyed by session ID. + * Stored at describe scope so the database survives if the CLI + * re-creates the handler (e.g., on reconnect). */ + const sessionDbs = new Map(); + + const createSessionFsHandler = (session: CopilotSession) => + createTestSessionFsHandlerWithSqlite(session, provider, sqliteCalls, sessionDbs); + + // Helpers to build session-namespaced paths for direct provider assertions + const p = (sessionId: string, path: string) => + `/${sessionId}${path.startsWith("/") ? path : "/" + path}`; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { sessionFs: sessionFsConfig }, + }); + + it( + "should route SQL queries through the sessionFs sqlite handler", + { timeout: 60000 }, + async () => { + const session = await client.createSession({ + onPermissionRequest: approveAll, + createSessionFsHandler, + }); + + // Ask the agent to create a table and insert data using the SQL tool + await session.sendAndWait({ + prompt: + 'Use the sql tool to create a table called "items" with columns id (TEXT PRIMARY KEY) and name (TEXT). ' + + 'Then insert a row with id "a1" and name "Widget".', + }); + + // Verify the sqlite handler was called with the right operations + const sessionCalls = sqliteCalls.filter((c) => c.sessionId === session.sessionId); + expect(sessionCalls.length).toBeGreaterThan(0); + expect(sessionCalls.some((c) => c.query.toUpperCase().includes("CREATE TABLE"))).toBe( + true + ); + expect(sessionCalls.some((c) => c.query.toUpperCase().includes("INSERT"))).toBe(true); + + // Verify queryType is set correctly + expect(sessionCalls.some((c) => c.queryType === "exec")).toBe(true); + expect(sessionCalls.some((c) => c.queryType === "run")).toBe(true); + + await session.disconnect(); + } + ); + + it( + "should allow subagents to use SQL tool via inherited sessionFs", + { timeout: 60000 }, + async () => { + const session = await client.createSession({ + onPermissionRequest: approveAll, + createSessionFsHandler, + }); + + const events: SessionEvent[] = []; + session.on((event) => { + events.push(event); + }); + + // Ask the agent to use the task tool to spawn a subagent that uses SQL + await session.sendAndWait({ + prompt: + "Use the task tool to ask a task agent to do the following: " + + "Use the sql tool to run this query: INSERT INTO todos (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')", + }); + + await session.disconnect(); + + // Verify that the subagent's SQL queries were routed through the sessionFs sqlite handler + const sessionCalls = sqliteCalls.filter((c) => c.sessionId === session.sessionId); + const insertCalls = sessionCalls.filter((c) => + c.query.toUpperCase().includes("INSERT") + ); + expect(insertCalls.length).toBeGreaterThan(0); + + // Verify that the sql tool execution in events.jsonl came from the subagent (has agentId) + const buf = await provider.readFile( + p(session.sessionId, `${sessionStatePath}/events.jsonl`) + ); + const content = buf.toString("utf8"); + const lines = content.split("\n").filter(Boolean); + const parsed = lines.map((line) => JSON.parse(line)); + const sqlToolEvents = parsed.filter( + (e: { type?: string; data?: { toolName?: string } }) => + e.type === "tool.execution_start" && e.data?.toolName === "sql" + ); + expect(sqlToolEvents.length).toBeGreaterThan(0); + expect(sqlToolEvents.every((e: { agentId?: string }) => !!e.agentId)).toBe(true); + } + ); +}); + +function createTestSessionFsHandlerWithSqlite( + session: CopilotSession, + provider: VirtualProvider, + sqliteCalls: { sessionId: string; queryType: string; query: string }[], + sessionDbs: Map +): SessionFsProvider { + const sp = (path: string) => `/${session.sessionId}${path.startsWith("/") ? path : "/" + path}`; + + function getOrCreateDb(): DatabaseSync { + let db = sessionDbs.get(session.sessionId); + if (!db) { + db = new DatabaseSync(":memory:"); + db.exec("PRAGMA busy_timeout = 5000"); + sessionDbs.set(session.sessionId, db); + } + return db; + } + + return { + async readFile(path: string): Promise { + return (await provider.readFile(sp(path), "utf8")) as string; + }, + async writeFile(path: string, content: string): Promise { + await provider.writeFile(sp(path), content); + }, + async appendFile(path: string, content: string): Promise { + await provider.appendFile(sp(path), content); + }, + async exists(path: string): Promise { + return provider.exists(sp(path)); + }, + async stat(path: string): Promise { + const st = await provider.stat(sp(path)); + return { + isFile: st.isFile(), + isDirectory: st.isDirectory(), + size: st.size, + mtime: new Date(st.mtimeMs).toISOString(), + birthtime: new Date(st.birthtimeMs).toISOString(), + }; + }, + async mkdir(path: string, recursive: boolean, mode?: number): Promise { + await provider.mkdir(sp(path), { recursive, mode }); + }, + async readdir(path: string): Promise { + return (await provider.readdir(sp(path))) as string[]; + }, + async readdirWithTypes(path: string): Promise { + const names = (await provider.readdir(sp(path))) as string[]; + return Promise.all( + names.map(async (name) => { + const st = await provider.stat(sp(`${path}/${name}`)); + return { + name, + type: st.isDirectory() ? ("directory" as const) : ("file" as const), + }; + }) + ); + }, + async rm(path: string): Promise { + await provider.unlink(sp(path)); + }, + async rename(src: string, dest: string): Promise { + await provider.rename(sp(src), sp(dest)); + }, + sqlite: { + async query( + queryType: SessionFsSqliteQueryType, + query: string, + params?: Record + ): Promise { + sqliteCalls.push({ sessionId: session.sessionId, queryType, query }); + + const database = getOrCreateDb(); + const trimmed = query.trim(); + if (trimmed.length === 0) { + return undefined; + } + + switch (queryType) { + case "exec": + database.exec(trimmed); + return undefined; + + case "query": { + const stmt = database.prepare(trimmed); + const rows = (params ? stmt.all(params) : stmt.all()) as Record< + string, + unknown + >[]; + const columns = rows.length > 0 ? Object.keys(rows[0]) : []; + return { rows, columns, rowsAffected: 0 }; + } + + case "run": { + const stmt = database.prepare(trimmed); + const result = params ? stmt.run(params) : stmt.run(); + return { + rows: [], + columns: [], + rowsAffected: Number(result.changes), + lastInsertRowid: + result.lastInsertRowid !== undefined + ? Number(result.lastInsertRowid) + : undefined, + }; + } + } + }, + async exists(): Promise { + return sessionDbs.has(session.sessionId); + }, + }, + }; +} diff --git a/nodejs/test/session_fs_adapter.test.ts b/nodejs/test/session_fs_adapter.test.ts index 7bed1f8c1..fb62d9904 100644 --- a/nodejs/test/session_fs_adapter.test.ts +++ b/nodejs/test/session_fs_adapter.test.ts @@ -59,17 +59,17 @@ describe("SessionFsAdapter", () => { async rename(src, dest) { await memoryProvider.rename(sp(src), sp(dest)); }, - async sqliteQuery(actualSessionId, query, queryType, params) { - return { - columns: ["sessionId", "query", "queryType", "answer"], - rows: [ - { sessionId: actualSessionId, query, queryType, answer: params?.answer }, - ], - rowsAffected: 0, - }; - }, - async sqliteExists(actualSessionId) { - return actualSessionId === sessionId; + sqlite: { + async query(queryType, query, params) { + return { + columns: ["sessionId", "query", "queryType", "answer"], + rows: [{ sessionId, query, queryType, answer: params?.answer }], + rowsAffected: 0, + }; + }, + async exists() { + return true; + }, }, }; @@ -203,8 +203,10 @@ describe("SessionFsAdapter", () => { readdirWithTypes: () => Promise.reject(error), rm: () => Promise.reject(error), rename: () => Promise.reject(error), - sqliteQuery: () => Promise.reject(error), - sqliteExists: () => Promise.reject(error), + sqlite: { + query: () => Promise.reject(error), + exists: () => Promise.reject(error), + }, }; } @@ -235,16 +237,12 @@ describe("SessionFsAdapter", () => { assertEnoent((await handler.readdirWithTypes({ sessionId, path: "missing-dir" })).error); assertEnoent(await handler.rm({ sessionId, path: "missing.txt" })); assertEnoent(await handler.rename({ sessionId, src: "missing.txt", dest: "dest.txt" })); - const sqliteQuery = await handler.sqliteQuery({ - sessionId, - query: "select 1", - queryType: "query", - }); - assertEnoent(sqliteQuery.error); - expect(sqliteQuery.columns).toEqual([]); - expect(sqliteQuery.rows).toEqual([]); - expect(sqliteQuery.rowsAffected).toBe(0); - expect((await handler.sqliteExists({ sessionId })).exists).toBe(false); + + // sqlite methods let errors propagate (no try/catch wrapping) + await expect( + handler.sqliteQuery({ sessionId, query: "select 1", queryType: "query" }) + ).rejects.toThrow("missing file"); + await expect(handler.sqliteExists({ sessionId })).rejects.toThrow("missing file"); const unknownProvider = createSessionFsAdapter(makeThrowingProvider(makeError("bad path"))); const unknownError = await unknownProvider.writeFile({ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 58973ea83..c7a37ea0b 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -34,6 +34,7 @@ InputOptions, ProviderConfig, SessionCapabilities, + SessionFsCapabilities, SessionFsConfig, SessionUiApi, SessionUiCapabilities, @@ -41,6 +42,8 @@ from .session_fs_provider import ( SessionFsFileInfo, SessionFsProvider, + SessionFsSqliteProvider, + SessionFsSqliteQueryResult, create_session_fs_adapter, ) from .tools import ( @@ -81,9 +84,12 @@ "ProviderConfig", "RemoteSessionMode", "SessionCapabilities", + "SessionFsCapabilities", "SessionFsConfig", "SessionFsFileInfo", "SessionFsProvider", + "SessionFsSqliteProvider", + "SessionFsSqliteQueryResult", "create_session_fs_adapter", "SessionUiApi", "SessionUiCapabilities", diff --git a/python/copilot/client.py b/python/copilot/client.py index 16cef6dde..cb5c98c90 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -68,7 +68,7 @@ UserInputHandler, _PermissionHandlerFn, ) -from .session_fs_provider import create_session_fs_adapter +from .session_fs_provider import SessionFsProvider, create_session_fs_adapter from .tools import Tool, ToolInvocation, ToolResult logger = logging.getLogger(__name__) @@ -1621,9 +1621,17 @@ async def create_session( "create_session_fs_handler is required in session config when " "session_fs is enabled in client options." ) - session._client_session_apis.session_fs = create_session_fs_adapter( - create_session_fs_handler(session) - ) + fs_provider: SessionFsProvider = create_session_fs_handler(session) + caps = self._session_fs_config.get("capabilities") + if caps and caps.get("sqlite"): + from .session_fs_provider import SessionFsSqliteProvider + + if not isinstance(fs_provider, SessionFsSqliteProvider): + raise ValueError( + "SessionFs capabilities declare SQLite support but the provider " + "does not implement SessionFsSqliteProvider" + ) + session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -1966,9 +1974,17 @@ async def resume_session( "create_session_fs_handler is required in session config when " "session_fs is enabled in client options." ) - session._client_session_apis.session_fs = create_session_fs_adapter( - create_session_fs_handler(session) - ) + fs_provider: SessionFsProvider = create_session_fs_handler(session) + caps = self._session_fs_config.get("capabilities") + if caps and caps.get("sqlite"): + from .session_fs_provider import SessionFsSqliteProvider + + if not isinstance(fs_provider, SessionFsSqliteProvider): + raise ValueError( + "SessionFs capabilities declare SQLite support but the provider " + "does not implement SessionFsSqliteProvider" + ) + session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -2940,14 +2956,15 @@ async def _set_session_fs_provider(self) -> None: if not self._session_fs_config or not self._client: return - await self._client.request( - "sessionFs.setProvider", - { - "initialCwd": self._session_fs_config["initial_cwd"], - "sessionStatePath": self._session_fs_config["session_state_path"], - "conventions": self._session_fs_config["conventions"], - }, - ) + params: dict[str, Any] = { + "initialCwd": self._session_fs_config["initial_cwd"], + "sessionStatePath": self._session_fs_config["session_state_path"], + "conventions": self._session_fs_config["conventions"], + } + if "capabilities" in self._session_fs_config: + params["capabilities"] = self._session_fs_config["capabilities"] + + await self._client.request("sessionFs.setProvider", params) def _get_client_session_handlers(self, session_id: str) -> ClientSessionApiHandlers: with self._sessions_lock: diff --git a/python/copilot/session.py b/python/copilot/session.py index 380c47e12..4789724fb 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -79,10 +79,15 @@ SessionFsConventions = Literal["posix", "windows"] +class SessionFsCapabilities(TypedDict, total=False): + sqlite: bool + + class SessionFsConfig(TypedDict): initial_cwd: str session_state_path: str conventions: SessionFsConventions + capabilities: NotRequired[SessionFsCapabilities] # ============================================================================ diff --git a/python/copilot/session_fs_provider.py b/python/copilot/session_fs_provider.py index eb8882336..1421ffaf4 100644 --- a/python/copilot/session_fs_provider.py +++ b/python/copilot/session_fs_provider.py @@ -33,10 +33,12 @@ SessionFSReaddirWithTypesResult, SessionFSReadFileResult, SessionFSSqliteExistsResult, - SessionFSSqliteQueryResult, SessionFSSqliteQueryType, SessionFSStatResult, ) +from .generated.rpc import ( + SessionFSSqliteQueryResult as _GeneratedSqliteQueryResult, +) @dataclass @@ -99,19 +101,52 @@ async def rm(self, path: str, recursive: bool, force: bool) -> None: async def rename(self, src: str, dest: str) -> None: """Rename / move a file or directory.""" + +class SessionFsSqliteProvider(abc.ABC): + """Optional ABC for providers that support SQLite operations. + + To add SQLite support, subclass *both* :class:`SessionFsProvider` and + :class:`SessionFsSqliteProvider`:: + + class MyProvider(SessionFsProvider, SessionFsSqliteProvider): ... + + The adapter checks ``isinstance(provider, SessionFsSqliteProvider)`` at + runtime to decide whether SQLite calls should be dispatched. + + Providers are already session-scoped (created per session by the factory), + so these methods do not take a ``session_id`` parameter. + """ + @abc.abstractmethod async def sqlite_query( self, - session_id: str, - query: str, query_type: SessionFSSqliteQueryType, + query: str, params: dict[str, float | str | None] | None = None, - ) -> SessionFSSqliteQueryResult: - """Execute a SQLite query against the provider's per-session database.""" + ) -> SessionFsSqliteQueryResult | None: + """Execute a SQLite query against the provider's per-session database. + + Return ``None`` for exec-type queries (DDL / multi-statement) where + no result set is produced; the adapter will substitute an empty result. + """ @abc.abstractmethod - async def sqlite_exists(self, session_id: str) -> bool: - """Return whether the provider has a SQLite database for *session_id*.""" + async def sqlite_exists(self) -> bool: + """Return whether the provider has a SQLite database for this session.""" + + +@dataclass +class SessionFsSqliteQueryResult: + """Result of a SQLite query execution. + + Same shape as the generated RPC type but without the ``error`` field, + since providers signal errors by raising exceptions. + """ + + columns: list[str] + rows: list[dict[str, Any]] + rows_affected: int + last_insert_rowid: int | None = None def create_session_fs_adapter(provider: SessionFsProvider) -> SessionFsHandler: @@ -131,7 +166,7 @@ def __init__(self, provider: SessionFsProvider) -> None: async def read_file(self, params: Any) -> SessionFSReadFileResult: try: - content = await self._p.read_file(params.path) # type: ignore[attr-defined] + content = await self._p.read_file(params.path) return SessionFSReadFileResult.from_dict({"content": content}) except Exception as exc: err = _to_session_fs_error(exc) @@ -139,28 +174,28 @@ async def read_file(self, params: Any) -> SessionFSReadFileResult: async def write_file(self, params: Any) -> SessionFSError | None: try: - await self._p.write_file(params.path, params.content, getattr(params, "mode", None)) # type: ignore[attr-defined] + await self._p.write_file(params.path, params.content, getattr(params, "mode", None)) return None except Exception as exc: return _to_session_fs_error(exc) async def append_file(self, params: Any) -> SessionFSError | None: try: - await self._p.append_file(params.path, params.content, getattr(params, "mode", None)) # type: ignore[attr-defined] + await self._p.append_file(params.path, params.content, getattr(params, "mode", None)) return None except Exception as exc: return _to_session_fs_error(exc) async def exists(self, params: Any) -> SessionFSExistsResult: try: - result = await self._p.exists(params.path) # type: ignore[attr-defined] + result = await self._p.exists(params.path) return SessionFSExistsResult.from_dict({"exists": result}) except Exception: return SessionFSExistsResult.from_dict({"exists": False}) async def stat(self, params: Any) -> SessionFSStatResult: try: - info = await self._p.stat(params.path) # type: ignore[attr-defined] + info = await self._p.stat(params.path) return SessionFSStatResult( is_file=info.is_file, is_directory=info.is_directory, @@ -183,7 +218,7 @@ async def stat(self, params: Any) -> SessionFSStatResult: async def mkdir(self, params: Any) -> SessionFSError | None: try: await self._p.mkdir( - params.path, # type: ignore[attr-defined] + params.path, getattr(params, "recursive", False), getattr(params, "mode", None), ) @@ -193,7 +228,7 @@ async def mkdir(self, params: Any) -> SessionFSError | None: async def readdir(self, params: Any) -> SessionFSReaddirResult: try: - entries = await self._p.readdir(params.path) # type: ignore[attr-defined] + entries = await self._p.readdir(params.path) return SessionFSReaddirResult.from_dict({"entries": entries}) except Exception as exc: err = _to_session_fs_error(exc) @@ -201,7 +236,7 @@ async def readdir(self, params: Any) -> SessionFSReaddirResult: async def readdir_with_types(self, params: Any) -> SessionFSReaddirWithTypesResult: try: - entries = await self._p.readdir_with_types(params.path) # type: ignore[attr-defined] + entries = await self._p.readdir_with_types(params.path) return SessionFSReaddirWithTypesResult(entries=list(entries)) except Exception as exc: err = _to_session_fs_error(exc) @@ -212,7 +247,7 @@ async def readdir_with_types(self, params: Any) -> SessionFSReaddirWithTypesResu async def rm(self, params: Any) -> SessionFSError | None: try: await self._p.rm( - params.path, # type: ignore[attr-defined] + params.path, getattr(params, "recursive", False), getattr(params, "force", False), ) @@ -222,30 +257,50 @@ async def rm(self, params: Any) -> SessionFSError | None: async def rename(self, params: Any) -> SessionFSError | None: try: - await self._p.rename(params.src, params.dest) # type: ignore[attr-defined] + await self._p.rename(params.src, params.dest) return None except Exception as exc: return _to_session_fs_error(exc) - async def sqlite_query(self, params: Any) -> SessionFSSqliteQueryResult: - try: - return await self._p.sqlite_query( # type: ignore[attr-defined] - params.session_id, - params.query, - params.query_type, - getattr(params, "params", None), + async def sqlite_query(self, params: Any) -> _GeneratedSqliteQueryResult: + # SQLite methods intentionally skip toSessionFsError wrapping — FS errno + # mapping (ENOENT) isn't meaningful for SQL errors and the JSON-RPC layer + # already handles uncaught exceptions. + if not isinstance(self._p, SessionFsSqliteProvider): + return _GeneratedSqliteQueryResult( + columns=[], + rows=[], + rows_affected=0, + error=SessionFSError( + code=SessionFSErrorCode.UNKNOWN, + message="SQLite is not supported by this SessionFs provider", + ), ) - except Exception as exc: - return SessionFSSqliteQueryResult( + result = await self._p.sqlite_query( + params.query_type, + params.query, + getattr(params, "params", None), + ) + if result is None: + return _GeneratedSqliteQueryResult( columns=[], rows=[], rows_affected=0, - error=_to_session_fs_error(exc), ) + rowid = result.last_insert_rowid + wire_rowid = float(rowid) if rowid is not None else None + return _GeneratedSqliteQueryResult( + columns=result.columns, + rows=result.rows, + rows_affected=result.rows_affected, + last_insert_rowid=wire_rowid, + ) async def sqlite_exists(self, params: Any) -> SessionFSSqliteExistsResult: + if not isinstance(self._p, SessionFsSqliteProvider): + return SessionFSSqliteExistsResult.from_dict({"exists": False}) try: - result = await self._p.sqlite_exists(params.session_id) # type: ignore[attr-defined] + result = await self._p.sqlite_exists() return SessionFSSqliteExistsResult.from_dict({"exists": result}) except Exception: return SessionFSSqliteExistsResult.from_dict({"exists": False}) diff --git a/python/e2e/test_rpc_session_state_e2e.py b/python/e2e/test_rpc_session_state_e2e.py index 0c841465a..cba7e2164 100644 --- a/python/e2e/test_rpc_session_state_e2e.py +++ b/python/e2e/test_rpc_session_state_e2e.py @@ -332,7 +332,6 @@ async def test_should_set_and_get_each_session_mode_value(self, ctx: E2ETestCont await session.disconnect() async def test_should_reject_workspace_file_path_traversal(self, ctx: E2ETestContext): - for traversal_path in [ "../escaped.txt", "../../escaped.txt", diff --git a/python/e2e/test_session_fs_e2e.py b/python/e2e/test_session_fs_e2e.py index 328ad9e02..3b5487d00 100644 --- a/python/e2e/test_session_fs_e2e.py +++ b/python/e2e/test_session_fs_e2e.py @@ -17,8 +17,6 @@ from copilot.generated.rpc import ( SessionFSReaddirWithTypesEntry, SessionFSReaddirWithTypesEntryType, - SessionFSSqliteQueryResult, - SessionFSSqliteQueryType, ) from copilot.generated.session_events import SessionCompactionCompleteData, SessionEvent from copilot.session import PermissionHandler @@ -285,6 +283,7 @@ async def test_should_map_all_sessionfs_handler_operations(self, ctx: E2ETestCon SessionFSRmRequest, SessionFSSqliteExistsRequest, SessionFSSqliteQueryRequest, + SessionFSSqliteQueryType, SessionFSStatRequest, SessionFSWriteFileRequest, ) @@ -396,30 +395,22 @@ async def test_should_map_all_sessionfs_handler_operations(self, ctx: E2ETestCon assert missing.error.code == SessionFSErrorCode.ENOENT + # SQLite methods are not on the non-sqlite provider, so the adapter + # should return unsupported/not-found results. sqlite_query = await handler.sqlite_query( SessionFSSqliteQueryRequest( session_id=session_id, - query="select :answer as answer", + query="select 1", query_type=SessionFSSqliteQueryType.QUERY, - params={"answer": 42}, ) ) - assert "answer" in sqlite_query.columns - assert sqlite_query.rows == [ - { - "sessionId": session_id, - "query": "select :answer as answer", - "queryType": "query", - "answer": 42, - } - ] - assert sqlite_query.rows_affected == 0 - assert sqlite_query.error is None + assert sqlite_query.error is not None + assert sqlite_query.error.code == SessionFSErrorCode.UNKNOWN sqlite_exists = await handler.sqlite_exists( SessionFSSqliteExistsRequest(session_id=session_id) ) - assert sqlite_exists.exists is True + assert sqlite_exists.exists is False finally: try: import shutil @@ -441,6 +432,7 @@ async def test_sessionfsprovider_converts_exceptions_to_rpc_errors(self): SessionFSRmRequest, SessionFSSqliteExistsRequest, SessionFSSqliteQueryRequest, + SessionFSSqliteQueryType, SessionFSStatRequest, SessionFSWriteFileRequest, ) @@ -480,12 +472,6 @@ async def rm(self, path, recursive, force): async def rename(self, src, dest): raise self._exc - async def sqlite_query(self, session_id, query, query_type, params=None): - raise self._exc - - async def sqlite_exists(self, session_id): - raise self._exc - def assert_fs_error(error) -> None: assert error is not None assert error.code == SessionFSErrorCode.ENOENT @@ -542,12 +528,15 @@ def assert_fs_error(error) -> None: SessionFSRenameRequest(session_id=sid, src="missing.txt", dest="dest.txt") ) ) + # _ThrowingProvider does not implement SessionFsSqliteProvider, so the + # adapter returns "not supported" results rather than propagating throws. sqlite_query = await handler.sqlite_query( SessionFSSqliteQueryRequest( session_id=sid, query="select 1", query_type=SessionFSSqliteQueryType.QUERY ) ) - assert_fs_error(sqlite_query.error) + assert sqlite_query.error is not None + assert sqlite_query.error.code == SessionFSErrorCode.UNKNOWN assert sqlite_query.columns == [] assert sqlite_query.rows == [] assert sqlite_query.rows_affected == 0 @@ -630,29 +619,6 @@ async def rename(self, src: str, dest: str) -> None: d.parent.mkdir(parents=True, exist_ok=True) self._path(src).rename(d) - async def sqlite_query( - self, - session_id: str, - query: str, - query_type: SessionFSSqliteQueryType, - params: dict[str, float | str | None] | None = None, - ) -> SessionFSSqliteQueryResult: - return SessionFSSqliteQueryResult( - columns=["sessionId", "query", "queryType", "answer"], - rows=[ - { - "sessionId": session_id, - "query": query, - "queryType": query_type.value, - "answer": params["answer"] if params else None, - } - ], - rows_affected=0, - ) - - async def sqlite_exists(self, session_id: str) -> bool: - return session_id == self._session_id - def create_test_session_fs_handler(provider_root: Path): def create_handler(session): diff --git a/python/e2e/test_session_fs_sqlite_e2e.py b/python/e2e/test_session_fs_sqlite_e2e.py new file mode 100644 index 000000000..92d68e94b --- /dev/null +++ b/python/e2e/test_session_fs_sqlite_e2e.py @@ -0,0 +1,285 @@ +"""E2E SessionFs SQLite tests mirroring nodejs/test/e2e/session_fs_sqlite.e2e.test.ts.""" + +from __future__ import annotations + +import datetime as dt +import json +import os +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +import pytest_asyncio + +from copilot import CopilotClient, SessionFsConfig +from copilot.client import SubprocessConfig +from copilot.generated.rpc import ( + SessionFSReaddirWithTypesEntry, + SessionFSReaddirWithTypesEntryType, + SessionFSSqliteQueryType, +) +from copilot.session import PermissionHandler +from copilot.session_fs_provider import ( + SessionFsFileInfo, + SessionFsProvider, + SessionFsSqliteProvider, + SessionFsSqliteQueryResult, +) + +from .testharness import DEFAULT_GITHUB_TOKEN, E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +SESSION_STATE_PATH = ( + "/session-state" + if os.name == "nt" + else (Path(tempfile.mkdtemp(prefix="copilot-sessionfs-sqlite-")) / "session-state") + .resolve() + .as_posix() +) + +SESSION_FS_CONFIG: SessionFsConfig = { + "initial_cwd": "/", + "session_state_path": SESSION_STATE_PATH, + "conventions": "posix", + "capabilities": {"sqlite": True}, +} + + +class _InMemorySessionFsSqliteProvider(SessionFsProvider, SessionFsSqliteProvider): + """In-memory SessionFsProvider with real SQLite for E2E tests.""" + + def __init__(self, session_id: str, sqlite_calls: list[dict]): + self._session_id = session_id + self._sqlite_calls = sqlite_calls + self._files: dict[str, str] = {} + self._dirs: set[str] = {"/"} + self._db: sqlite3.Connection | None = None + + def _get_or_create_db(self) -> sqlite3.Connection: + if self._db is None: + self._db = sqlite3.connect(":memory:") + self._db.execute("PRAGMA busy_timeout = 5000") + return self._db + + def _ensure_parent(self, path: str) -> None: + parts = path.rstrip("/").split("/") + for i in range(1, len(parts)): + self._dirs.add("/".join(parts[:i]) or "/") + + async def read_file(self, path: str) -> str: + if path not in self._files: + raise FileNotFoundError(path) + return self._files[path] + + async def write_file(self, path: str, content: str, mode: int | None = None) -> None: + self._ensure_parent(path) + self._files[path] = content + + async def append_file(self, path: str, content: str, mode: int | None = None) -> None: + self._ensure_parent(path) + self._files[path] = self._files.get(path, "") + content + + async def exists(self, path: str) -> bool: + return path in self._files or path in self._dirs + + async def stat(self, path: str) -> SessionFsFileInfo: + now = dt.datetime.now(tz=dt.UTC) + if path in self._dirs: + return SessionFsFileInfo( + is_file=False, is_directory=True, size=0, mtime=now, birthtime=now + ) + if path in self._files: + return SessionFsFileInfo( + is_file=True, + is_directory=False, + size=len(self._files[path].encode()), + mtime=now, + birthtime=now, + ) + raise FileNotFoundError(path) + + async def mkdir(self, path: str, recursive: bool, mode: int | None = None) -> None: + if recursive: + parts = path.rstrip("/").split("/") + for i in range(1, len(parts) + 1): + self._dirs.add("/".join(parts[:i]) or "/") + else: + self._dirs.add(path) + + async def readdir(self, path: str) -> list[str]: + prefix = path.rstrip("/") + "/" + names: set[str] = set() + for p in list(self._files.keys()) + list(self._dirs): + if p.startswith(prefix): + rest = p[len(prefix) :] + if rest: + names.add(rest.split("/")[0]) + return sorted(names) + + async def readdir_with_types(self, path: str) -> list[SessionFSReaddirWithTypesEntry]: + prefix = path.rstrip("/") + "/" + entries: dict[str, SessionFSReaddirWithTypesEntryType] = {} + for p in self._dirs: + if p.startswith(prefix): + rest = p[len(prefix) :] + if rest: + name = rest.split("/")[0] + entries[name] = SessionFSReaddirWithTypesEntryType.DIRECTORY + for p in self._files: + if p.startswith(prefix): + rest = p[len(prefix) :] + if rest: + name = rest.split("/")[0] + if name not in entries: + entries[name] = SessionFSReaddirWithTypesEntryType.FILE + return [SessionFSReaddirWithTypesEntry(name=n, type=t) for n, t in sorted(entries.items())] + + async def rm(self, path: str, recursive: bool, force: bool) -> None: + self._files.pop(path, None) + self._dirs.discard(path) + + async def rename(self, src: str, dest: str) -> None: + if src in self._files: + self._ensure_parent(dest) + self._files[dest] = self._files.pop(src) + + async def sqlite_query( + self, + query_type: SessionFSSqliteQueryType, + query: str, + params: dict[str, float | str | None] | None = None, + ) -> SessionFsSqliteQueryResult | None: + self._sqlite_calls.append( + { + "sessionId": self._session_id, + "queryType": query_type.value, + "query": query, + } + ) + + db = self._get_or_create_db() + trimmed = query.strip() + if not trimmed: + return SessionFsSqliteQueryResult(columns=[], rows=[], rows_affected=0) + + if query_type == SessionFSSqliteQueryType.EXEC: + db.executescript(trimmed) + db.commit() + return SessionFsSqliteQueryResult(columns=[], rows=[], rows_affected=0) + + if query_type == SessionFSSqliteQueryType.QUERY: + cursor = db.execute(trimmed, params or {}) + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + rows = [dict(zip(columns, row)) for row in cursor.fetchall()] + return SessionFsSqliteQueryResult(columns=columns, rows=rows, rows_affected=0) + + # run (INSERT/UPDATE/DELETE) + cursor = db.execute(trimmed, params or {}) + db.commit() + return SessionFsSqliteQueryResult( + columns=[], + rows=[], + rows_affected=cursor.rowcount, + last_insert_rowid=cursor.lastrowid if cursor.lastrowid else None, + ) + + async def sqlite_exists(self) -> bool: + return self._db is not None + + +def _create_sqlite_handler(sqlite_calls: list[dict]): + def factory(session): + return _InMemorySessionFsSqliteProvider(session.session_id, sqlite_calls) + + return factory + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def sqlite_client(ctx: E2ETestContext): + client = CopilotClient( + SubprocessConfig( + cli_path=ctx.cli_path, + cwd=ctx.work_dir, + env=ctx.get_env(), + github_token=DEFAULT_GITHUB_TOKEN, + session_fs=SESSION_FS_CONFIG, + ) + ) + yield client + try: + await client.stop() + except Exception: + await client.force_stop() + + +class TestSessionFsSqlite: + async def test_should_route_sql_queries_through_the_sessionfs_sqlite_handler( + self, sqlite_client: CopilotClient + ): + sqlite_calls: list[dict] = [] + session = await sqlite_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=_create_sqlite_handler(sqlite_calls), + ) + + await session.send_and_wait( + 'Use the sql tool to create a table called "items" with columns ' + "id (TEXT PRIMARY KEY) and name (TEXT). " + 'Then insert a row with id "a1" and name "Widget".' + ) + + session_calls = [c for c in sqlite_calls if c["sessionId"] == session.session_id] + assert len(session_calls) > 0 + assert any("CREATE TABLE" in c["query"].upper() for c in session_calls) + assert any("INSERT" in c["query"].upper() for c in session_calls) + + assert any(c["queryType"] == "exec" for c in session_calls) + assert any(c["queryType"] == "run" for c in session_calls) + + await session.disconnect() + + async def test_should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs( + self, sqlite_client: CopilotClient + ): + sqlite_calls: list[dict] = [] + providers: dict[str, _InMemorySessionFsSqliteProvider] = {} + + def handler_factory(session): + provider = _InMemorySessionFsSqliteProvider(session.session_id, sqlite_calls) + providers[session.session_id] = provider + return provider + + session = await sqlite_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=handler_factory, + ) + + await session.send_and_wait( + "Use the task tool to ask a task agent to do the following: " + "Use the sql tool to run this query: INSERT INTO todos " + "(id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')" + ) + + await session.disconnect() + + session_calls = [c for c in sqlite_calls if c["sessionId"] == session.session_id] + insert_calls = [c for c in session_calls if "INSERT" in c["query"].upper()] + assert len(insert_calls) > 0 + + # Read events.jsonl from in-memory FS + provider = providers[session.session_id] + events_path = f"{SESSION_STATE_PATH}/events.jsonl" + content = await provider.read_file(events_path) + lines = [line for line in content.split("\n") if line.strip()] + parsed = [json.loads(line) for line in lines] + sql_tool_events = [ + e + for e in parsed + if e.get("type") == "tool.execution_start" + and e.get("data", {}).get("toolName") == "sql" + ] + assert len(sql_tool_events) > 0 + assert all(e.get("agentId") for e in sql_tool_events) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3065822e7..9c4790a6e 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -203,6 +203,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.4.1" @@ -348,6 +360,7 @@ dependencies = [ "getrandom 0.2.17", "parking_lot", "regex", + "rusqlite", "schemars", "serde", "serde_json", @@ -381,6 +394,15 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -548,6 +570,17 @@ dependencies = [ "redox_syscall 0.7.4", ] +[[package]] +name = "libsqlite3-sys" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947e6816f7825b2b45027c2c32e7085da9934defa535de4a6a46b10a4d5257fa" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -802,6 +835,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusqlite" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a22715a5d6deef63c637207afbe68d0c72c3f8d0022d7cf9714c442d6157606b" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustix" version = "1.1.4" @@ -1294,6 +1341,12 @@ dependencies = [ "getrandom 0.4.2", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 182707bf1..b2a2b4f54 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -55,6 +55,7 @@ uuid = { version = "1", default-features = false, features = ["v4"] } zstd = { version = "0.13", optional = true } [dev-dependencies] +rusqlite = { version = "0.35", features = ["bundled"] } schemars = "1" serial_test = "3" tempfile = "3" diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 6585676ec..abb1a72a4 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -884,6 +884,7 @@ struct ClientInner { on_list_models: Option>, models_cache: parking_lot::Mutex>>>, session_fs_configured: bool, + session_fs_sqlite_declared: bool, on_get_trace_context: Option>, /// Token sent in the `connect` handshake. Auto-generated when the /// SDK spawns its own CLI in TCP mode and no explicit token is set; @@ -962,6 +963,10 @@ impl Client { options.tcp_connection_token = effective_connection_token.clone(); } let session_fs_config = options.session_fs.clone(); + let session_fs_sqlite_declared = session_fs_config + .as_ref() + .and_then(|c| c.capabilities.as_ref()) + .is_some_and(|caps| caps.sqlite); let program = match &options.program { CliProgram::Path(path) => { info!(path = %path.display(), "using explicit copilot CLI path"); @@ -1006,6 +1011,7 @@ impl Client { options.cwd, options.on_list_models, session_fs_config.is_some(), + session_fs_sqlite_declared, options.on_get_trace_context, effective_connection_token.clone(), )? @@ -1028,6 +1034,7 @@ impl Client { options.cwd, options.on_list_models, session_fs_config.is_some(), + session_fs_sqlite_declared, options.on_get_trace_context, effective_connection_token.clone(), )? @@ -1044,6 +1051,7 @@ impl Client { options.cwd, options.on_list_models, session_fs_config.is_some(), + session_fs_sqlite_declared, options.on_get_trace_context, effective_connection_token.clone(), )? @@ -1061,8 +1069,13 @@ impl Client { ); if let Some(cfg) = session_fs_config { let session_fs_start = Instant::now(); + let capabilities = cfg.capabilities.as_ref().map(|c| { + crate::generated::api_types::SessionFsSetProviderCapabilities { + sqlite: Some(c.sqlite), + } + }); let request = crate::generated::api_types::SessionFsSetProviderRequest { - capabilities: None, + capabilities, conventions: cfg.conventions.into_wire(), initial_cwd: cfg.initial_cwd, session_state_path: cfg.session_state_path, @@ -1088,7 +1101,7 @@ impl Client { writer: impl AsyncWrite + Unpin + Send + 'static, cwd: PathBuf, ) -> Result { - Self::from_transport(reader, writer, None, cwd, None, false, None, None) + Self::from_transport(reader, writer, None, cwd, None, false, false, None, None) } /// Construct a [`Client`] from raw streams with a @@ -1105,7 +1118,17 @@ impl Client { cwd: PathBuf, provider: Arc, ) -> Result { - Self::from_transport(reader, writer, None, cwd, None, false, Some(provider), None) + Self::from_transport( + reader, + writer, + None, + cwd, + None, + false, + false, + Some(provider), + None, + ) } /// Construct a [`Client`] from raw streams with a preset @@ -1118,7 +1141,7 @@ impl Client { cwd: PathBuf, token: Option, ) -> Result { - Self::from_transport(reader, writer, None, cwd, None, false, None, token) + Self::from_transport(reader, writer, None, cwd, None, false, false, None, token) } /// Public test-only wrapper around the random connection-token @@ -1139,6 +1162,7 @@ impl Client { cwd: PathBuf, on_list_models: Option>, session_fs_configured: bool, + session_fs_sqlite_declared: bool, on_get_trace_context: Option>, effective_connection_token: Option, ) -> Result { @@ -1169,6 +1193,7 @@ impl Client { on_list_models, models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured, + session_fs_sqlite_declared, on_get_trace_context, effective_connection_token, }), @@ -2548,6 +2573,7 @@ mod tests { on_list_models: Some(handler), models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured: false, + session_fs_sqlite_declared: false, on_get_trace_context: None, effective_connection_token: None, }), diff --git a/rust/src/session.rs b/rust/src/session.rs index 124884866..970381724 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -769,6 +769,16 @@ impl Client { if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(Error::Session(SessionError::SessionFsProviderRequired)); } + if self.inner.session_fs_sqlite_declared + && let Some(ref provider) = session_fs_provider + && provider.sqlite().is_none() + { + return Err(Error::InvalidConfig( + "SessionFs capabilities declare SQLite support but the provider \ + does not implement SessionFsSqliteProvider" + .to_string(), + )); + } if hooks.is_some() && config.hooks.is_none() { config.hooks = Some(true); @@ -890,6 +900,16 @@ impl Client { if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(Error::Session(SessionError::SessionFsProviderRequired)); } + if self.inner.session_fs_sqlite_declared + && let Some(ref provider) = session_fs_provider + && provider.sqlite().is_none() + { + return Err(Error::InvalidConfig( + "SessionFs capabilities declare SQLite support but the provider \ + does not implement SessionFsSqliteProvider" + .to_string(), + )); + } if hooks.is_some() && config.hooks.is_none() { config.hooks = Some(true); diff --git a/rust/src/session_fs.rs b/rust/src/session_fs.rs index 8474235c7..0e13be7d7 100644 --- a/rust/src/session_fs.rs +++ b/rust/src/session_fs.rs @@ -44,11 +44,32 @@ use std::collections::HashMap; use async_trait::async_trait; +pub use crate::generated::api_types::SessionFsSqliteQueryType; use crate::generated::api_types::{ SessionFsError, SessionFsErrorCode, SessionFsReaddirWithTypesEntry, SessionFsReaddirWithTypesEntryType, SessionFsSetProviderConventions, SessionFsStatResult, }; -pub use crate::generated::api_types::{SessionFsSqliteQueryResult, SessionFsSqliteQueryType}; + +/// Optional capabilities declared by a session filesystem provider. +#[non_exhaustive] +#[derive(Debug, Clone, Default)] +pub struct SessionFsCapabilities { + /// Whether the provider supports SQLite query/exists operations. + pub sqlite: bool, +} + +impl SessionFsCapabilities { + /// Create a new capabilities struct with default values. + pub fn new() -> Self { + Self::default() + } + + /// Enable SQLite support. + pub fn with_sqlite(mut self, sqlite: bool) -> Self { + self.sqlite = sqlite; + self + } +} /// Configuration for a custom session filesystem provider. /// @@ -65,6 +86,8 @@ pub struct SessionFsConfig { pub session_state_path: String, /// Path conventions used by this filesystem provider. pub conventions: SessionFsConventions, + /// Optional capabilities such as SQLite support. + pub capabilities: Option, } impl SessionFsConfig { @@ -78,8 +101,15 @@ impl SessionFsConfig { initial_cwd: initial_cwd.into(), session_state_path: session_state_path.into(), conventions, + capabilities: None, } } + + /// Set the capabilities on this config and return it (builder pattern). + pub fn with_capabilities(mut self, capabilities: SessionFsCapabilities) -> Self { + self.capabilities = Some(capabilities); + self + } } /// Path conventions used by a session filesystem provider. @@ -347,23 +377,62 @@ pub trait SessionFsProvider: Send + Sync + 'static { Err(FsError::Other("rename not supported".to_string())) } + /// Return a reference to the SQLite provider, if this provider supports + /// SQLite operations. The default returns `None`. Providers that support + /// SQLite should also implement [`SessionFsSqliteProvider`] and override + /// this to return `Some(self)`. + fn sqlite(&self) -> Option<&dyn SessionFsSqliteProvider> { + None + } +} + +/// Optional trait for providers that support SQLite operations. +/// +/// Providers are already session-scoped (created per session by the factory), +/// so these methods do not take a `session_id` parameter. +/// +/// To opt in, implement this trait on your provider and override +/// [`SessionFsProvider::sqlite`] to return `Some(self)`: +/// +/// ```ignore +/// impl SessionFsSqliteProvider for MyProvider { /* ... */ } +/// +/// #[async_trait] +/// impl SessionFsProvider for MyProvider { +/// fn sqlite(&self) -> Option<&dyn SessionFsSqliteProvider> { +/// Some(self) +/// } +/// // ... other methods ... +/// } +/// ``` +#[async_trait] +pub trait SessionFsSqliteProvider: Send + Sync { /// Execute a SQLite query against the provider's per-session database. async fn sqlite_query( &self, - session_id: &str, - query: &str, query_type: SessionFsSqliteQueryType, + query: &str, params: Option<&HashMap>, - ) -> Result { - let _ = (session_id, query, query_type, params); - Err(FsError::Other("sqlite_query not supported".to_string())) - } + ) -> Result, FsError>; - /// Check whether the provider has a SQLite database for the session. - async fn sqlite_exists(&self, session_id: &str) -> Result { - let _ = session_id; - Err(FsError::Other("sqlite_exists not supported".to_string())) - } + /// Check whether the provider has a SQLite database for this session. + async fn sqlite_exists(&self) -> Result; +} + +/// Result of a SQLite query execution via [`SessionFsSqliteProvider::sqlite_query`]. +/// +/// Same shape as the generated RPC type but without the `error` field, +/// since providers signal errors by returning `Err`. +#[derive(Debug, Clone, Default)] +pub struct SessionFsSqliteQueryResult { + /// Column names from the result set. + pub columns: Vec, + /// For SELECT: array of row objects. For others: empty array. + pub rows: Vec>, + /// Number of rows affected (for INSERT/UPDATE/DELETE). + pub rows_affected: i64, + /// Last inserted row ID (for INSERT). + pub last_insert_rowid: Option, } #[cfg(test)] diff --git a/rust/src/session_fs_dispatch.rs b/rust/src/session_fs_dispatch.rs index 3810d978f..4a09666f4 100644 --- a/rust/src/session_fs_dispatch.rs +++ b/rust/src/session_fs_dispatch.rs @@ -12,13 +12,13 @@ use serde_json::Value; use tracing::warn; use crate::generated::api_types::{ - SessionFsAppendFileRequest, SessionFsExistsRequest, SessionFsExistsResult, - SessionFsMkdirRequest, SessionFsReadFileRequest, SessionFsReadFileResult, - SessionFsReaddirRequest, SessionFsReaddirResult, SessionFsReaddirWithTypesRequest, - SessionFsReaddirWithTypesResult, SessionFsRenameRequest, SessionFsRmRequest, - SessionFsSqliteExistsParams, SessionFsSqliteExistsResult, SessionFsSqliteQueryRequest, - SessionFsSqliteQueryResult, SessionFsStatRequest, SessionFsStatResult, - SessionFsWriteFileRequest, + SessionFsAppendFileRequest, SessionFsError, SessionFsErrorCode, SessionFsExistsRequest, + SessionFsExistsResult, SessionFsMkdirRequest, SessionFsReadFileRequest, + SessionFsReadFileResult, SessionFsReaddirRequest, SessionFsReaddirResult, + SessionFsReaddirWithTypesRequest, SessionFsReaddirWithTypesResult, SessionFsRenameRequest, + SessionFsRmRequest, SessionFsSqliteExistsParams, SessionFsSqliteExistsResult, + SessionFsSqliteQueryRequest, SessionFsSqliteQueryResult as GeneratedSqliteQueryResult, + SessionFsStatRequest, SessionFsStatResult, SessionFsWriteFileRequest, }; use crate::session_fs::SessionFsProvider; use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes}; @@ -316,18 +316,51 @@ pub(crate) async fn sqlite_query( } }; let id = request.id; + let sqlite = match provider.sqlite() { + Some(s) => s, + None => { + // SQLite not supported — return a result-level error, not a + // transport error, so the CLI can surface it gracefully. + respond( + client, + id, + GeneratedSqliteQueryResult { + columns: Vec::new(), + error: Some(SessionFsError { + code: SessionFsErrorCode::UNKNOWN, + message: Some( + "SQLite is not supported by this SessionFs provider".to_string(), + ), + }), + last_insert_rowid: None, + rows: Vec::new(), + rows_affected: 0, + }, + ) + .await; + return; + } + }; let sqlite_params = (!params.params.is_empty()).then_some(¶ms.params); - let result = match provider - .sqlite_query( - params.session_id.as_ref(), - ¶ms.query, - params.query_type, - sqlite_params, - ) + let result = match sqlite + .sqlite_query(params.query_type, ¶ms.query, sqlite_params) .await { - Ok(result) => result, - Err(e) => SessionFsSqliteQueryResult { + Ok(Some(result)) => GeneratedSqliteQueryResult { + columns: result.columns, + rows: result.rows, + rows_affected: result.rows_affected, + last_insert_rowid: result.last_insert_rowid.map(|v| v as f64), + error: None, + }, + Ok(None) => GeneratedSqliteQueryResult { + columns: Vec::new(), + rows: Vec::new(), + rows_affected: 0, + last_insert_rowid: None, + error: None, + }, + Err(e) => GeneratedSqliteQueryResult { columns: Vec::new(), error: Some(e.into_wire()), last_insert_rowid: None, @@ -343,7 +376,7 @@ pub(crate) async fn sqlite_exists( provider: &Arc, request: JsonRpcRequest, ) { - let params: SessionFsSqliteExistsParams = match parse_params(&request) { + let _params: SessionFsSqliteExistsParams = match parse_params(&request) { Some(p) => p, None => { send_error(client, request.id, "invalid sessionFs.sqliteExists params").await; @@ -351,9 +384,12 @@ pub(crate) async fn sqlite_exists( } }; let id = request.id; - let result = match provider.sqlite_exists(params.session_id.as_ref()).await { - Ok(exists) => SessionFsSqliteExistsResult { exists }, - Err(_) => SessionFsSqliteExistsResult { exists: false }, + let result = match provider.sqlite() { + Some(sqlite) => match sqlite.sqlite_exists().await { + Ok(exists) => SessionFsSqliteExistsResult { exists }, + Err(_) => SessionFsSqliteExistsResult { exists: false }, + }, + None => SessionFsSqliteExistsResult { exists: false }, }; respond(client, id, result).await; } diff --git a/rust/src/types.rs b/rust/src/types.rs index 0f242445e..cadf46271 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -15,8 +15,9 @@ use serde_json::Value; use crate::handler::SessionHandler; use crate::hooks::SessionHooks; pub use crate::session_fs::{ - DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, - SessionFsProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig, + SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, + SessionFsSqliteQueryType, }; pub use crate::trace_context::{TraceContext, TraceContextProvider}; use crate::transforms::SystemMessageTransform; diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index cb75dfec5..7a4bd4b04 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -71,6 +71,8 @@ mod session; mod session_config; #[path = "e2e/session_fs.rs"] mod session_fs; +#[path = "e2e/session_fs_sqlite.rs"] +mod session_fs_sqlite; #[path = "e2e/session_lifecycle.rs"] mod session_lifecycle; #[path = "e2e/skills.rs"] diff --git a/rust/tests/e2e/session_fs.rs b/rust/tests/e2e/session_fs.rs index 217e3e883..f069f6ffe 100644 --- a/rust/tests/e2e/session_fs.rs +++ b/rust/tests/e2e/session_fs.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use github_copilot_sdk::generated::api_types::PlanUpdateRequest; use github_copilot_sdk::{ Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsConfig, - SessionFsConventions, SessionFsProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, + SessionFsConventions, SessionFsProvider, }; use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; @@ -206,26 +206,6 @@ async fn should_map_all_sessionfs_handler_operations() { provider.stat("/workspace/nested/missing.txt").await, Err(FsError::NotFound(_)) )); - let sqlite_params = - std::collections::HashMap::from([("answer".to_string(), serde_json::Value::from(42))]); - let sqlite_result = provider - .sqlite_query( - "handler-session", - "select :answer as answer", - SessionFsSqliteQueryType::Query, - Some(&sqlite_params), - ) - .await - .expect("sqlite query"); - assert_eq!(sqlite_result.columns[3], "answer"); - assert_eq!(sqlite_result.rows[0]["answer"], 42); - assert_eq!(sqlite_result.rows_affected, 0); - assert!( - provider - .sqlite_exists("handler-session") - .await - .expect("sqlite exists") - ); let _ = std::fs::remove_dir_all(root); } @@ -622,52 +602,6 @@ impl SessionFsProvider for TestSessionFsProvider { } std::fs::rename(src, dest).map_err(FsError::from) } - - async fn sqlite_query( - &self, - session_id: &str, - query: &str, - query_type: SessionFsSqliteQueryType, - params: Option<&std::collections::HashMap>, - ) -> Result { - let mut row = std::collections::HashMap::new(); - row.insert("sessionId".to_string(), session_id.to_string().into()); - row.insert("query".to_string(), query.to_string().into()); - row.insert( - "queryType".to_string(), - match query_type { - SessionFsSqliteQueryType::Exec => "exec", - SessionFsSqliteQueryType::Query => "query", - SessionFsSqliteQueryType::Run => "run", - SessionFsSqliteQueryType::Unknown => "unknown", - } - .into(), - ); - row.insert( - "answer".to_string(), - params - .and_then(|params| params.get("answer")) - .cloned() - .unwrap_or(serde_json::Value::Null), - ); - - Ok(SessionFsSqliteQueryResult { - columns: vec![ - "sessionId".to_string(), - "query".to_string(), - "queryType".to_string(), - "answer".to_string(), - ], - rows: vec![row], - rows_affected: 0, - last_insert_rowid: None, - error: None, - }) - } - - async fn sqlite_exists(&self, session_id: &str) -> Result { - Ok(session_id == self.session_id) - } } #[derive(Clone)] diff --git a/rust/tests/e2e/session_fs_sqlite.rs b/rust/tests/e2e/session_fs_sqlite.rs new file mode 100644 index 000000000..cd8758c31 --- /dev/null +++ b/rust/tests/e2e/session_fs_sqlite.rs @@ -0,0 +1,509 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use github_copilot_sdk::{ + Client, DirEntry, DirEntryKind, FileInfo, FsError, SessionConfig, SessionFsCapabilities, + SessionFsConfig, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, + SessionFsSqliteQueryResult, SessionFsSqliteQueryType, +}; +use rusqlite::Connection; + +use super::support::with_e2e_context; + +#[derive(Debug)] +struct SqliteCall { + session_id: String, + query_type: String, + query: String, +} + +struct InMemorySqliteProvider { + session_id: String, + files: Mutex>, + dirs: Mutex>, + db: Mutex>, + sqlite_calls: Arc>>, +} + +impl InMemorySqliteProvider { + fn new(session_id: &str, calls: Arc>>) -> Self { + let mut dirs = std::collections::HashSet::new(); + dirs.insert("/".to_string()); + Self { + session_id: session_id.to_string(), + files: Mutex::new(HashMap::new()), + dirs: Mutex::new(dirs), + db: Mutex::new(None), + sqlite_calls: calls, + } + } + + fn ensure_parent(dirs: &mut std::collections::HashSet, path: &str) { + let parts: Vec<&str> = path.trim_end_matches('/').split('/').collect(); + for i in 1..parts.len() { + let parent = parts[..i].join("/"); + if parent.is_empty() { + dirs.insert("/".to_string()); + } else { + dirs.insert(parent); + } + } + } + + fn get_or_create_db(db: &mut Option) -> Result<&mut Connection, FsError> { + if db.is_none() { + let conn = Connection::open_in_memory().map_err(|e| FsError::Other(e.to_string()))?; + conn.execute_batch("PRAGMA busy_timeout = 5000;") + .map_err(|e| FsError::Other(e.to_string()))?; + *db = Some(conn); + } + Ok(db.as_mut().unwrap()) + } +} + +#[async_trait] +impl SessionFsProvider for InMemorySqliteProvider { + async fn read_file(&self, path: &str) -> Result { + let files = self.files.lock().unwrap(); + files + .get(path) + .cloned() + .ok_or_else(|| FsError::NotFound(path.to_string())) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + let mut files = self.files.lock().unwrap(); + let mut dirs = self.dirs.lock().unwrap(); + Self::ensure_parent(&mut dirs, path); + files.insert(path.to_string(), content.to_string()); + Ok(()) + } + + async fn append_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + let mut files = self.files.lock().unwrap(); + let mut dirs = self.dirs.lock().unwrap(); + Self::ensure_parent(&mut dirs, path); + let entry = files.entry(path.to_string()).or_default(); + entry.push_str(content); + Ok(()) + } + + async fn exists(&self, path: &str) -> Result { + let files = self.files.lock().unwrap(); + let dirs = self.dirs.lock().unwrap(); + Ok(files.contains_key(path) || dirs.contains(path)) + } + + async fn stat(&self, path: &str) -> Result { + let files = self.files.lock().unwrap(); + let dirs = self.dirs.lock().unwrap(); + let now = "1970-01-01T00:00:00Z"; + if dirs.contains(path) { + Ok(FileInfo::new(false, true, 0, now, now)) + } else if let Some(content) = files.get(path) { + Ok(FileInfo::new(true, false, content.len() as i64, now, now)) + } else { + Err(FsError::NotFound(path.to_string())) + } + } + + async fn mkdir(&self, path: &str, recursive: bool, _mode: Option) -> Result<(), FsError> { + let mut dirs = self.dirs.lock().unwrap(); + if recursive { + let parts: Vec<&str> = path.trim_end_matches('/').split('/').collect(); + for i in 1..=parts.len() { + let p = parts[..i].join("/"); + if p.is_empty() { + dirs.insert("/".to_string()); + } else { + dirs.insert(p); + } + } + } else { + dirs.insert(path.to_string()); + } + Ok(()) + } + + async fn readdir(&self, path: &str) -> Result, FsError> { + let files = self.files.lock().unwrap(); + let dirs = self.dirs.lock().unwrap(); + let prefix = format!("{}/", path.trim_end_matches('/')); + let mut names = std::collections::BTreeSet::new(); + for p in files.keys().chain(dirs.iter()) { + if let Some(name) = p + .strip_prefix(&prefix) + .and_then(|rest| rest.split('/').next()) + .filter(|n| !n.is_empty()) + { + names.insert(name.to_string()); + } + } + Ok(names.into_iter().collect()) + } + + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let files = self.files.lock().unwrap(); + let dirs = self.dirs.lock().unwrap(); + let prefix = format!("{}/", path.trim_end_matches('/')); + let mut entries: HashMap = HashMap::new(); + for d in dirs.iter() { + if let Some(name) = d + .strip_prefix(&prefix) + .and_then(|rest| rest.split('/').next()) + .filter(|n| !n.is_empty()) + { + entries.insert(name.to_string(), DirEntryKind::Directory); + } + } + for f in files.keys() { + if let Some(name) = f + .strip_prefix(&prefix) + .and_then(|rest| rest.split('/').next()) + .filter(|n| !n.is_empty()) + { + entries + .entry(name.to_string()) + .or_insert(DirEntryKind::File); + } + } + let mut result: Vec = entries + .into_iter() + .map(|(name, kind)| DirEntry::new(name, kind)) + .collect(); + result.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(result) + } + + async fn rm(&self, path: &str, _recursive: bool, _force: bool) -> Result<(), FsError> { + let mut files = self.files.lock().unwrap(); + let mut dirs = self.dirs.lock().unwrap(); + files.remove(path); + dirs.remove(path); + Ok(()) + } + + async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { + let mut files = self.files.lock().unwrap(); + let mut dirs = self.dirs.lock().unwrap(); + if let Some(content) = files.remove(src) { + Self::ensure_parent(&mut dirs, dest); + files.insert(dest.to_string(), content); + } + Ok(()) + } + + fn sqlite(&self) -> Option<&dyn SessionFsSqliteProvider> { + Some(self) + } +} + +#[async_trait] +impl SessionFsSqliteProvider for InMemorySqliteProvider { + async fn sqlite_query( + &self, + query_type: SessionFsSqliteQueryType, + query: &str, + _params: Option<&HashMap>, + ) -> Result, FsError> { + let qt_str = match query_type { + SessionFsSqliteQueryType::Exec => "exec", + SessionFsSqliteQueryType::Query => "query", + SessionFsSqliteQueryType::Run => "run", + SessionFsSqliteQueryType::Unknown => "unknown", + }; + self.sqlite_calls.lock().unwrap().push(SqliteCall { + session_id: self.session_id.clone(), + query_type: qt_str.to_string(), + query: query.to_string(), + }); + + let mut db_guard = self.db.lock().unwrap(); + let db = Self::get_or_create_db(&mut db_guard)?; + let trimmed = query.trim(); + if trimmed.is_empty() { + return Ok(Some(SessionFsSqliteQueryResult { + columns: vec![], + rows: vec![], + rows_affected: 0, + last_insert_rowid: None, + })); + } + + match query_type { + SessionFsSqliteQueryType::Exec => { + db.execute_batch(trimmed) + .map_err(|e| FsError::Other(e.to_string()))?; + Ok(Some(SessionFsSqliteQueryResult { + columns: vec![], + rows: vec![], + rows_affected: 0, + last_insert_rowid: None, + })) + } + SessionFsSqliteQueryType::Query => { + let mut stmt = db + .prepare(trimmed) + .map_err(|e| FsError::Other(e.to_string()))?; + let col_count = stmt.column_count(); + let columns: Vec = (0..col_count) + .map(|i| stmt.column_name(i).unwrap().to_string()) + .collect(); + let mut rows = vec![]; + let mut query_rows = stmt.query([]).map_err(|e| FsError::Other(e.to_string()))?; + while let Some(row) = query_rows + .next() + .map_err(|e| FsError::Other(e.to_string()))? + { + let mut map = HashMap::new(); + for (i, col) in columns.iter().enumerate() { + let val: rusqlite::types::Value = + row.get(i).map_err(|e| FsError::Other(e.to_string()))?; + let json_val = match val { + rusqlite::types::Value::Null => serde_json::Value::Null, + rusqlite::types::Value::Integer(n) => { + serde_json::Value::Number(n.into()) + } + rusqlite::types::Value::Real(f) => serde_json::Value::Number( + serde_json::Number::from_f64(f).unwrap_or(0.into()), + ), + rusqlite::types::Value::Text(s) => serde_json::Value::String(s), + rusqlite::types::Value::Blob(b) => { + serde_json::Value::String(String::from_utf8_lossy(&b).into_owned()) + } + }; + map.insert(col.clone(), json_val); + } + rows.push(map); + } + Ok(Some(SessionFsSqliteQueryResult { + columns, + rows, + rows_affected: 0, + last_insert_rowid: None, + })) + } + SessionFsSqliteQueryType::Run => { + let affected = db + .execute(trimmed, []) + .map_err(|e| FsError::Other(e.to_string()))?; + let last_id = db.last_insert_rowid(); + Ok(Some(SessionFsSqliteQueryResult { + columns: vec![], + rows: vec![], + rows_affected: affected as i64, + last_insert_rowid: Some(last_id), + })) + } + _ => Ok(Some(SessionFsSqliteQueryResult { + columns: vec![], + rows: vec![], + rows_affected: 0, + last_insert_rowid: None, + })), + } + } + + async fn sqlite_exists(&self) -> Result { + Ok(self.db.lock().unwrap().is_some()) + } +} + +fn session_state_path_sqlite() -> String { + if cfg!(windows) { + "/session-state".to_string() + } else { + std::env::temp_dir() + .join("copilot-rust-sessionfs-sqlite-state") + .join("session-state") + .to_string_lossy() + .replace('\\', "/") + } +} + +fn sqlite_session_fs_config() -> SessionFsConfig { + SessionFsConfig::new( + "/", + session_state_path_sqlite(), + SessionFsConventions::Posix, + ) + .with_capabilities(SessionFsCapabilities::new().with_sqlite(true)) +} + +async fn start_sqlite_client(ctx: &super::support::E2eContext) -> Client { + Client::start( + ctx.client_options() + .with_session_fs(sqlite_session_fs_config()), + ) + .await + .expect("start sqlite client") +} + +fn sqlite_session_config( + ctx: &super::support::E2eContext, + provider: Arc, +) -> SessionConfig { + ctx.approve_all_session_config() + .with_session_fs_provider(provider) +} + +#[tokio::test] +async fn should_route_sql_queries_through_the_sessionfs_sqlite_handler() { + with_e2e_context( + "session_fs_sqlite", + "should_route_sql_queries_through_the_sessionfs_sqlite_handler", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000201"; + let sqlite_calls = Arc::new(Mutex::new(Vec::new())); + let provider = Arc::new(InMemorySqliteProvider::new( + session_id, + sqlite_calls.clone(), + )); + let client = start_sqlite_client(ctx).await; + let session = client + .create_session( + sqlite_session_config(ctx, provider).with_session_id(session_id), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "Use the sql tool to create a table called \"items\" with columns \ + id (TEXT PRIMARY KEY) and name (TEXT). \ + Then insert a row with id \"a1\" and name \"Widget\".", + ) + .await + .expect("send") + .expect("assistant message"); + let _ = answer; + + { + let calls = sqlite_calls.lock().unwrap(); + let session_calls: Vec<&SqliteCall> = calls + .iter() + .filter(|c| c.session_id == session_id) + .collect(); + assert!(!session_calls.is_empty(), "expected sqlite calls"); + assert!( + session_calls + .iter() + .any(|c| c.query.to_uppercase().contains("CREATE TABLE")), + "expected CREATE TABLE" + ); + assert!( + session_calls + .iter() + .any(|c| c.query.to_uppercase().contains("INSERT")), + "expected INSERT" + ); + assert!( + session_calls.iter().any(|c| c.query_type == "exec"), + "expected exec queryType" + ); + assert!( + session_calls.iter().any(|c| c.query_type == "run"), + "expected run queryType" + ); + } + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +#[tokio::test] +async fn should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs() { + with_e2e_context( + "session_fs_sqlite", + "should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let session_id = "00000000-0000-4000-8000-000000000202"; + let sqlite_calls = Arc::new(Mutex::new(Vec::new())); + let provider = Arc::new(InMemorySqliteProvider::new(session_id, sqlite_calls.clone())); + let provider_ref = provider.clone(); + let client = start_sqlite_client(ctx).await; + let session = client + .create_session( + sqlite_session_config(ctx, provider).with_session_id(session_id), + ) + .await + .expect("create session"); + + session + .send_and_wait( + "Use the task tool to ask a task agent to do the following: \ + Use the sql tool to run this query: INSERT INTO todos \ + (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')", + ) + .await + .expect("send"); + + session.disconnect().await.expect("disconnect session"); + + { + let calls = sqlite_calls.lock().unwrap(); + let session_calls: Vec<&SqliteCall> = + calls.iter().filter(|c| c.session_id == session_id).collect(); + let insert_calls: Vec<&&SqliteCall> = session_calls + .iter() + .filter(|c| c.query.to_uppercase().contains("INSERT")) + .collect(); + assert!(!insert_calls.is_empty(), "expected INSERT calls from subagent"); + } + + // Read events.jsonl from in-memory FS + let events_path = format!("{}/events.jsonl", session_state_path_sqlite()); + let content = provider_ref + .read_file(&events_path) + .await + .expect("read events.jsonl"); + let lines: Vec<&str> = content.lines().filter(|l| !l.is_empty()).collect(); + let sql_tool_events: Vec = lines + .iter() + .filter_map(|line| serde_json::from_str::(line).ok()) + .filter(|e| { + e.get("type").and_then(|t| t.as_str()) == Some("tool.execution_start") + && e.get("data") + .and_then(|d| d.get("toolName")) + .and_then(|t| t.as_str()) + == Some("sql") + }) + .collect(); + assert!( + !sql_tool_events.is_empty(), + "expected sql tool events in events.jsonl" + ); + for e in &sql_tool_events { + assert!( + e.get("agentId").is_some() + && e.get("agentId") != Some(&serde_json::Value::Null) + && e.get("agentId").and_then(|v| v.as_str()) != Some(""), + "expected agentId on sql tool event" + ); + } + + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 3a60f4663..b9c28d30d 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -2912,7 +2912,7 @@ async fn command_execute_handler_error_propagates_to_ack() { use github_copilot_sdk::session_fs::{ DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConventions, SessionFsProvider, - SessionFsSqliteQueryResult, SessionFsSqliteQueryType, + SessionFsSqliteProvider, SessionFsSqliteQueryResult, SessionFsSqliteQueryType, }; struct RecordingFsProvider { @@ -2985,18 +2985,20 @@ impl SessionFsProvider for RecordingFsProvider { Ok(()) } + fn sqlite(&self) -> Option<&dyn SessionFsSqliteProvider> { + Some(self) + } +} + +#[async_trait] +impl SessionFsSqliteProvider for RecordingFsProvider { async fn sqlite_query( &self, - session_id: &str, - query: &str, query_type: SessionFsSqliteQueryType, + query: &str, params: Option<&std::collections::HashMap>, - ) -> Result { + ) -> Result, FsError> { let mut row = std::collections::HashMap::new(); - row.insert( - "sessionId".to_string(), - serde_json::Value::String(session_id.to_string()), - ); row.insert( "query".to_string(), serde_json::Value::String(query.to_string()), @@ -3020,9 +3022,8 @@ impl SessionFsProvider for RecordingFsProvider { .cloned() .unwrap_or(serde_json::Value::Null), ); - Ok(SessionFsSqliteQueryResult { + Ok(Some(SessionFsSqliteQueryResult { columns: vec![ - "sessionId".to_string(), "query".to_string(), "queryType".to_string(), "answer".to_string(), @@ -3030,12 +3031,11 @@ impl SessionFsProvider for RecordingFsProvider { rows: vec![row], rows_affected: 0, last_insert_rowid: None, - error: None, - }) + })) } - async fn sqlite_exists(&self, session_id: &str) -> Result { - Ok(!session_id.is_empty()) + async fn sqlite_exists(&self) -> Result { + Ok(true) } } @@ -3177,15 +3177,11 @@ async fn session_fs_dispatches_sqlite_query_to_provider() { let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); assert_eq!(response["id"], 9); - assert_eq!(response["result"]["columns"][3], "answer"); + assert_eq!(response["result"]["columns"][2], "answer"); assert_eq!( response["result"]["rows"][0]["query"], "select :answer as answer" ); - assert_eq!( - response["result"]["rows"][0]["sessionId"], - server.session_id.to_string() - ); assert_eq!(response["result"]["rows"][0]["queryType"], "query"); assert_eq!(response["result"]["rows"][0]["answer"], 42); assert_eq!(response["result"]["rowsAffected"], 0); @@ -3216,17 +3212,22 @@ async fn session_fs_maps_sqlite_errors_to_results() { struct AlwaysFails; #[async_trait] impl SessionFsProvider for AlwaysFails { + fn sqlite(&self) -> Option<&dyn SessionFsSqliteProvider> { + Some(self) + } + } + #[async_trait] + impl SessionFsSqliteProvider for AlwaysFails { async fn sqlite_query( &self, - _session_id: &str, - _query: &str, _query_type: SessionFsSqliteQueryType, + _query: &str, _params: Option<&std::collections::HashMap>, - ) -> Result { + ) -> Result, FsError> { Err(FsError::Other("sqlite unavailable".to_string())) } - async fn sqlite_exists(&self, _session_id: &str) -> Result { + async fn sqlite_exists(&self) -> Result { Err(FsError::Other("sqlite unavailable".to_string())) } } diff --git a/test/snapshots/session_fs_sqlite/should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs.yaml b/test/snapshots/session_fs_sqlite/should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs.yaml new file mode 100644 index 000000000..edeeecec7 --- /dev/null +++ b/test/snapshots/session_fs_sqlite/should_allow_subagents_to_use_sql_tool_via_inherited_sessionfs.yaml @@ -0,0 +1,98 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: "Use the task tool to ask a task agent to do the following: Use the sql tool to run this query: INSERT INTO + todos (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')" + - role: assistant + content: I'll delegate this SQL insert task to a task agent. + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Delegating SQL task"}' + - role: assistant + tool_calls: + - id: toolcall_1 + type: function + function: + name: task + arguments: "{\"name\":\"sql-insert-test\",\"agent_type\":\"task\",\"description\":\"Running SQL + insert\",\"prompt\":\"Use the sql tool to run this exact query:\\n\\nINSERT INTO todos (id, title, + status) VALUES ('subagent-test', 'Created by subagent', 'done')\\n\\nMake sure to provide a description + for the SQL operation (e.g., \\\"Insert test todo\\\").\",\"mode\":\"sync\"}" + - messages: + - role: system + content: ${system} + - role: user + content: |- + Use the sql tool to run this exact query: + + INSERT INTO todos (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done') + + Make sure to provide a description for the SQL operation (e.g., "Insert test todo"). + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: sql + arguments: "{\"description\":\"Insert test todo\",\"query\":\"INSERT INTO todos (id, title, status) VALUES + ('subagent-test', 'Created by subagent', 'done')\",\"database\":\"session\"}" + - role: tool + tool_call_id: toolcall_0 + content: "1 row(s) inserted. Last inserted row ID: 1." + - role: assistant + content: >- + ✅ Successfully inserted the test todo into the database. The query executed exactly as requested, inserting a + todo with: + + - ID: `subagent-test` + + - Title: `Created by subagent` + + - Status: `done` + - messages: + - role: system + content: ${system} + - role: user + content: "Use the task tool to ask a task agent to do the following: Use the sql tool to run this query: INSERT INTO + todos (id, title, status) VALUES ('subagent-test', 'Created by subagent', 'done')" + - role: assistant + content: I'll delegate this SQL insert task to a task agent. + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Delegating SQL task"}' + - id: toolcall_1 + type: function + function: + name: task + arguments: "{\"name\":\"sql-insert-test\",\"agent_type\":\"task\",\"description\":\"Running SQL + insert\",\"prompt\":\"Use the sql tool to run this exact query:\\n\\nINSERT INTO todos (id, title, + status) VALUES ('subagent-test', 'Created by subagent', 'done')\\n\\nMake sure to provide a description + for the SQL operation (e.g., \\\"Insert test todo\\\").\",\"mode\":\"sync\"}" + - role: tool + tool_call_id: toolcall_0 + content: Intent logged + - role: tool + tool_call_id: toolcall_1 + content: >- + ✅ Successfully inserted the test todo into the database. The query executed exactly as requested, inserting a + todo with: + + - ID: `subagent-test` + + - Title: `Created by subagent` + + - Status: `done` + - role: assistant + content: Perfect! The task agent successfully ran the SQL insert query. The todo record with ID `subagent-test`, title + `Created by subagent`, and status `done` has been inserted into the todos table. diff --git a/test/snapshots/session_fs_sqlite/should_route_sql_queries_through_the_sessionfs_sqlite_handler.yaml b/test/snapshots/session_fs_sqlite/should_route_sql_queries_through_the_sessionfs_sqlite_handler.yaml new file mode 100644 index 000000000..4d9f8f6c8 --- /dev/null +++ b/test/snapshots/session_fs_sqlite/should_route_sql_queries_through_the_sessionfs_sqlite_handler.yaml @@ -0,0 +1,68 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: Use the sql tool to create a table called "items" with columns id (TEXT PRIMARY KEY) and name (TEXT). Then + insert a row with id "a1" and name "Widget". + - role: assistant + content: I'll create the table and insert the row using the SQL tool. + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Creating database table"}' + - role: assistant + tool_calls: + - id: toolcall_1 + type: function + function: + name: sql + arguments: '{"description":"Create items table","query":"CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT)"}' + - role: assistant + tool_calls: + - id: toolcall_2 + type: function + function: + name: sql + arguments: "{\"description\":\"Insert Widget row\",\"query\":\"INSERT INTO items (id, name) VALUES ('a1', 'Widget')\"}" + - messages: + - role: system + content: ${system} + - role: user + content: Use the sql tool to create a table called "items" with columns id (TEXT PRIMARY KEY) and name (TEXT). Then + insert a row with id "a1" and name "Widget". + - role: assistant + content: I'll create the table and insert the row using the SQL tool. + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Creating database table"}' + - id: toolcall_1 + type: function + function: + name: sql + arguments: '{"description":"Create items table","query":"CREATE TABLE items (id TEXT PRIMARY KEY, name TEXT)"}' + - id: toolcall_2 + type: function + function: + name: sql + arguments: "{\"description\":\"Insert Widget row\",\"query\":\"INSERT INTO items (id, name) VALUES ('a1', 'Widget')\"}" + - role: tool + tool_call_id: toolcall_0 + content: Intent logged + - role: tool + tool_call_id: toolcall_1 + content: Schema operation completed successfully. + - role: tool + tool_call_id: toolcall_2 + content: "1 row(s) inserted. Last inserted row ID: 1." + - role: assistant + content: Done! I've created the `items` table with `id` and `name` columns, and inserted the row with id "a1" and name + "Widget".