Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public Single<Event> appendEvent(Session session, Event event) {
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.put(userStateKey, value);
}
} else if (!key.startsWith(State.TEMP_PREFIX)) {
} else {
if (value == State.REMOVED) {
session.state().remove(key);
} else {
Expand Down
55 changes: 55 additions & 0 deletions core/src/test/java/com/google/adk/agents/LlmAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
import com.google.adk.agents.Callbacks.OnToolErrorCallback;
import com.google.adk.events.Event;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.adk.models.Model;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
import com.google.adk.testing.TestUtils.EchoTool;
import com.google.adk.tools.BaseTool;
Expand All @@ -49,6 +52,7 @@
import io.reactivex.rxjava3.core.Single;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -372,4 +376,55 @@ public void canonicalCallbacks_returnsListWhenPresent() {
assertThat(agent.canonicalAfterToolCallbacks()).containsExactly(atc);
assertThat(agent.canonicalOnToolErrorCallbacks()).containsExactly(otec);
}

@Test
public void run_sequentialAgents_shareTempStateViaSession() {
// 1. Setup Session Service and Session
InMemorySessionService sessionService = new InMemorySessionService();
Session session =
sessionService
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
.blockingGet();

// 2. Agent 1: runs and produces output "value1" to state "temp:key1"
Content model1Content = Content.fromParts(Part.fromText("value1"));
TestLlm testLlm1 = createTestLlm(createLlmResponse(model1Content));
LlmAgent agent1 =
createTestAgentBuilder(testLlm1).name("agent1").outputKey("temp:key1").build();
InvocationContext invocationContext1 = createInvocationContext(agent1, sessionService, session);

List<Event> events1 = agent1.runAsync(invocationContext1).toList().blockingGet();
assertThat(events1).hasSize(1);
Event event1 = events1.get(0);
assertThat(event1.actions()).isNotNull();
assertThat(event1.actions().stateDelta()).containsEntry("temp:key1", "value1");

// 3. Simulate orchestrator: append event1 to session, updating its state
var unused = sessionService.appendEvent(session, event1).blockingGet();
assertThat(session.state()).containsEntry("temp:key1", "value1");

// 4. Agent 2: uses Instruction.Provider to read "temp:key1" from session state
// and generates an instruction based on it.
TestLlm testLlm2 =
createTestLlm(createLlmResponse(Content.fromParts(Part.fromText("response2"))));
LlmAgent agent2 =
createTestAgentBuilder(testLlm2)
.name("agent2")
.instruction(
new Instruction.Provider(
ctx ->
Single.just(
"Instruction for Agent2 based on Agent1 output: "
+ ctx.state().get("temp:key1"))))
.build();
InvocationContext invocationContext2 = createInvocationContext(agent2, sessionService, session);
List<Event> events2 = agent2.runAsync(invocationContext2).toList().blockingGet();
assertThat(events2).hasSize(1);

// 5. Verify that agent2's LLM received an instruction containing agent1's output
assertThat(testLlm2.getRequests()).hasSize(1);
LlmRequest request2 = testLlm2.getRequests().get(0);
assertThat(request2.getFirstSystemInstruction().get())
.contains("Instruction for Agent2 based on Agent1 output: value1");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ public void lifecycle_listSessions() {

ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
stateDelta.put("sessionKey", "sessionValue");
stateDelta.put("app:appKey", "appValue");
stateDelta.put("user:userKey", "userValue");
stateDelta.put("_app_appKey", "appValue");
stateDelta.put("_user_userKey", "userValue");
stateDelta.put("temp:tempKey", "tempValue");

Event event =
Expand All @@ -106,9 +106,9 @@ public void lifecycle_listSessions() {
assertThat(listedSession.id()).isEqualTo(session.id());
assertThat(listedSession.events()).isEmpty();
assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue");
assertThat(listedSession.state()).containsEntry("app:appKey", "appValue");
assertThat(listedSession.state()).containsEntry("user:userKey", "userValue");
assertThat(listedSession.state()).doesNotContainKey("temp:tempKey");
assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue");
assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue");
assertThat(listedSession.state()).containsEntry("temp:tempKey", "tempValue");
}

@Test
Expand Down Expand Up @@ -136,8 +136,8 @@ public void appendEvent_updatesSessionState() {

ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
stateDelta.put("sessionKey", "sessionValue");
stateDelta.put("app:appKey", "appValue");
stateDelta.put("user:userKey", "userValue");
stateDelta.put("_app_appKey", "appValue");
stateDelta.put("_user_userKey", "userValue");
stateDelta.put("temp:tempKey", "tempValue");

Event event =
Expand All @@ -148,19 +148,19 @@ public void appendEvent_updatesSessionState() {
// After appendEvent, session state in memory should contain session-specific state from delta
// and merged global state.
assertThat(session.state()).containsEntry("sessionKey", "sessionValue");
assertThat(session.state()).containsEntry("app:appKey", "appValue");
assertThat(session.state()).containsEntry("user:userKey", "userValue");
assertThat(session.state()).doesNotContainKey("temp:tempKey");
assertThat(session.state()).containsEntry("_app_appKey", "appValue");
assertThat(session.state()).containsEntry("_user_userKey", "userValue");
assertThat(session.state()).containsEntry("temp:tempKey", "tempValue");

// getSession should return session with merged state.
Session retrievedSession =
sessionService
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.blockingGet();
assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue");
assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue");
assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue");
assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey");
assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue");
assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue");
assertThat(retrievedSession.state()).containsEntry("temp:tempKey", "tempValue");
}

@Test
Expand All @@ -173,8 +173,8 @@ public void appendEvent_removesState() {

ConcurrentMap<String, Object> stateDeltaAdd = new ConcurrentHashMap<>();
stateDeltaAdd.put("sessionKey", "sessionValue");
stateDeltaAdd.put("app:appKey", "appValue");
stateDeltaAdd.put("user:userKey", "userValue");
stateDeltaAdd.put("_app_appKey", "appValue");
stateDeltaAdd.put("_user_userKey", "userValue");
stateDeltaAdd.put("temp:tempKey", "tempValue");

Event eventAdd =
Expand All @@ -188,15 +188,15 @@ public void appendEvent_removesState() {
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.blockingGet();
assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue");
assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue");
assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue");
assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey");
assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue");
assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue");
assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue");

// Prepare and append event to remove state
ConcurrentMap<String, Object> stateDeltaRemove = new ConcurrentHashMap<>();
stateDeltaRemove.put("sessionKey", State.REMOVED);
stateDeltaRemove.put("app:appKey", State.REMOVED);
stateDeltaRemove.put("user:userKey", State.REMOVED);
stateDeltaRemove.put("_app_appKey", State.REMOVED);
stateDeltaRemove.put("_user_userKey", State.REMOVED);
stateDeltaRemove.put("temp:tempKey", State.REMOVED);

Event eventRemove =
Expand All @@ -212,8 +212,44 @@ public void appendEvent_removesState() {
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.blockingGet();
assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey");
assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey");
assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey");
assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey");
assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey");
assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey");
}

@Test
public void sequentialAgents_shareTempState() {
InMemorySessionService sessionService = new InMemorySessionService();
Session session =
sessionService
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
.blockingGet();

// Agent 1 writes to temp state
ConcurrentMap<String, Object> stateDelta1 = new ConcurrentHashMap<>();
stateDelta1.put("temp:agent1_output", "data");
Event event1 =
Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build();
var unused = sessionService.appendEvent(session, event1).blockingGet();

// Verify agent 1 output is in session state
assertThat(session.state()).containsEntry("temp:agent1_output", "data");

// Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes
// "agent1_output"
ConcurrentMap<String, Object> stateDelta2 = new ConcurrentHashMap<>();
stateDelta2.put("temp:agent2_output", "processed_data");
stateDelta2.put("temp:agent1_output", State.REMOVED);
Event event2 =
Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build();
unused = sessionService.appendEvent(session, event2).blockingGet();

// Verify final state after agent 2 processing
Session retrievedSession =
sessionService
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.blockingGet();
assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output");
assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data");
}
}
15 changes: 15 additions & 0 deletions core/src/test/java/com/google/adk/testing/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import com.google.adk.events.EventCompaction;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.LlmResponse;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -68,6 +70,19 @@ public static InvocationContext createInvocationContext(BaseAgent agent) {
return createInvocationContext(agent, RunConfig.builder().build());
}

public static InvocationContext createInvocationContext(
BaseAgent agent, BaseSessionService sessionService, Session session) {
return InvocationContext.builder()
.sessionService(sessionService)
.artifactService(new InMemoryArtifactService())
.invocationId("invocationId")
.agent(agent)
.session(session)
.userContent(Content.fromParts(Part.fromText("user content")))
.runConfig(RunConfig.builder().build())
.build();
}

public static Event createEvent(String id) {
return Event.builder()
.id(id)
Expand Down