From 52858ec1c791821d5b91685e819da2237a482931 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 29 Jan 2026 10:38:40 -0800 Subject: [PATCH] fix: Fixing a regression in InMemorySessionService PiperOrigin-RevId: 862801950 --- .../adk/sessions/InMemorySessionService.java | 2 +- .../com/google/adk/agents/LlmAgentTest.java | 55 +++++++++++++ .../sessions/InMemorySessionServiceTest.java | 80 ++++++++++++++----- .../com/google/adk/testing/TestUtils.java | 15 ++++ 4 files changed, 129 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b658f676..060fcaf6 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -259,7 +259,7 @@ public Single appendEvent(Session session, Event event) { .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(userStateKey, value); } - } else if (!key.startsWith(State.TEMP_PREFIX)) { + } else { if (value == State.REMOVED) { session.state().remove(key); } else { diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 519c9055..494145a9 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -34,8 +34,11 @@ import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; import com.google.adk.models.LlmRegistry; +import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.models.Model; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; @@ -49,6 +52,7 @@ import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -372,4 +376,55 @@ public void canonicalCallbacks_returnsListWhenPresent() { assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc); assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec); } + + @Test + public void run_sequentialAgents_shareTempStateViaSession() { + // 1. Setup Session Service and Session + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // 2. Agent 1: runs and produces output "value1" to state "temp:key1" + Content model1Content = Content.fromParts(Part.fromText("value1")); + TestLlm testLlm1 = createTestLlm(createLlmResponse(model1Content)); + LlmAgent agent1 = + createTestAgentBuilder(testLlm1).name("agent1").outputKey("temp:key1").build(); + InvocationContext invocationContext1 = createInvocationContext(agent1, sessionService, session); + + List events1 = agent1.runAsync(invocationContext1).toList().blockingGet(); + assertThat(events1).hasSize(1); + Event event1 = events1.get(0); + assertThat(event1.actions()).isNotNull(); + assertThat(event1.actions().stateDelta()).containsEntry("temp:key1", "value1"); + + // 3. Simulate orchestrator: append event1 to session, updating its state + var unused = sessionService.appendEvent(session, event1).blockingGet(); + assertThat(session.state()).containsEntry("temp:key1", "value1"); + + // 4. Agent 2: uses Instruction.Provider to read "temp:key1" from session state + // and generates an instruction based on it. + TestLlm testLlm2 = + createTestLlm(createLlmResponse(Content.fromParts(Part.fromText("response2")))); + LlmAgent agent2 = + createTestAgentBuilder(testLlm2) + .name("agent2") + .instruction( + new Instruction.Provider( + ctx -> + Single.just( + "Instruction for Agent2 based on Agent1 output: " + + ctx.state().get("temp:key1")))) + .build(); + InvocationContext invocationContext2 = createInvocationContext(agent2, sessionService, session); + List events2 = agent2.runAsync(invocationContext2).toList().blockingGet(); + assertThat(events2).hasSize(1); + + // 5. Verify that agent2's LLM received an instruction containing agent1's output + assertThat(testLlm2.getRequests()).hasSize(1); + LlmRequest request2 = testLlm2.getRequests().get(0); + assertThat(request2.getFirstSystemInstruction().get()) + .contains("Instruction for Agent2 based on Agent1 output: value1"); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 97b18224..6223dd2f 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -89,8 +89,8 @@ public void lifecycle_listSessions() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("app:appKey", "appValue"); - stateDelta.put("user:userKey", "userValue"); + stateDelta.put("_app_appKey", "appValue"); + stateDelta.put("_user_userKey", "userValue"); stateDelta.put("temp:tempKey", "tempValue"); Event event = @@ -106,9 +106,9 @@ public void lifecycle_listSessions() { assertThat(listedSession.id()).isEqualTo(session.id()); assertThat(listedSession.events()).isEmpty(); assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(listedSession.state()).containsEntry("app:appKey", "appValue"); - assertThat(listedSession.state()).containsEntry("user:userKey", "userValue"); - assertThat(listedSession.state()).doesNotContainKey("temp:tempKey"); + assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue"); + assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(listedSession.state()).containsEntry("temp:tempKey", "tempValue"); } @Test @@ -136,8 +136,8 @@ public void appendEvent_updatesSessionState() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("app:appKey", "appValue"); - stateDelta.put("user:userKey", "userValue"); + stateDelta.put("_app_appKey", "appValue"); + stateDelta.put("_user_userKey", "userValue"); stateDelta.put("temp:tempKey", "tempValue"); Event event = @@ -148,9 +148,9 @@ public void appendEvent_updatesSessionState() { // After appendEvent, session state in memory should contain session-specific state from delta // and merged global state. assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).containsEntry("app:appKey", "appValue"); - assertThat(session.state()).containsEntry("user:userKey", "userValue"); - assertThat(session.state()).doesNotContainKey("temp:tempKey"); + assertThat(session.state()).containsEntry("_app_appKey", "appValue"); + assertThat(session.state()).containsEntry("_user_userKey", "userValue"); + assertThat(session.state()).containsEntry("temp:tempKey", "tempValue"); // getSession should return session with merged state. Session retrievedSession = @@ -158,9 +158,9 @@ public void appendEvent_updatesSessionState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue"); - assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue"); - assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey"); + assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSession.state()).containsEntry("temp:tempKey", "tempValue"); } @Test @@ -173,8 +173,8 @@ public void appendEvent_removesState() { ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); - stateDeltaAdd.put("app:appKey", "appValue"); - stateDeltaAdd.put("user:userKey", "userValue"); + stateDeltaAdd.put("_app_appKey", "appValue"); + stateDeltaAdd.put("_user_userKey", "userValue"); stateDeltaAdd.put("temp:tempKey", "tempValue"); Event eventAdd = @@ -188,15 +188,15 @@ public void appendEvent_removesState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue"); - assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue"); - assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey"); + assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue"); // Prepare and append event to remove state ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); stateDeltaRemove.put("sessionKey", State.REMOVED); - stateDeltaRemove.put("app:appKey", State.REMOVED); - stateDeltaRemove.put("user:userKey", State.REMOVED); + stateDeltaRemove.put("_app_appKey", State.REMOVED); + stateDeltaRemove.put("_user_userKey", State.REMOVED); stateDeltaRemove.put("temp:tempKey", State.REMOVED); Event eventRemove = @@ -212,8 +212,44 @@ public void appendEvent_removesState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); - assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey"); - assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey"); assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + + @Test + public void sequentialAgents_shareTempState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + // Agent 1 writes to temp state + ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); + stateDelta1.put("temp:agent1_output", "data"); + Event event1 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build(); + var unused = sessionService.appendEvent(session, event1).blockingGet(); + + // Verify agent 1 output is in session state + assertThat(session.state()).containsEntry("temp:agent1_output", "data"); + + // Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes + // "agent1_output" + ConcurrentMap stateDelta2 = new ConcurrentHashMap<>(); + stateDelta2.put("temp:agent2_output", "processed_data"); + stateDelta2.put("temp:agent1_output", State.REMOVED); + Event event2 = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build(); + unused = sessionService.appendEvent(session, event2).blockingGet(); + + // Verify final state after agent 2 processing + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output"); + assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data"); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 2bdcf1fb..df94b76b 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -30,7 +30,9 @@ import com.google.adk.events.EventCompaction; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; @@ -68,6 +70,19 @@ public static InvocationContext createInvocationContext(BaseAgent agent) { return createInvocationContext(agent, RunConfig.builder().build()); } + public static InvocationContext createInvocationContext( + BaseAgent agent, BaseSessionService sessionService, Session session) { + return InvocationContext.builder() + .sessionService(sessionService) + .artifactService(new InMemoryArtifactService()) + .invocationId("invocationId") + .agent(agent) + .session(session) + .userContent(Content.fromParts(Part.fromText("user content"))) + .runConfig(RunConfig.builder().build()) + .build(); + } + public static Event createEvent(String id) { return Event.builder() .id(id)