diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 63909ee1a..493fa4b27 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.adk.agents.BaseAgentState; +import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; import java.util.Objects; @@ -98,10 +99,20 @@ public ConcurrentMap stateDelta() { return stateDelta; } + @Deprecated // Use stateDelta(), addState() and removeStateByKey() instead. public void setStateDelta(ConcurrentMap stateDelta) { this.stateDelta = stateDelta; } + /** + * Removes a key from the state delta. + * + * @param key The key to remove. + */ + public void removeStateByKey(String key) { + stateDelta.put(key, State.REMOVED); + } + @JsonProperty("artifactDelta") public ConcurrentMap artifactDelta() { return artifactDelta; diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 80c277fce..b658f6767 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -96,8 +96,8 @@ public Single createSession( .build(); sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(resolvedSessionId, newSession); // Create a mutable copy for the return value @@ -116,8 +116,8 @@ public Maybe getSession( Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -166,7 +166,7 @@ public Single listSessions(String appName, String userId) Objects.requireNonNull(userId, "userId cannot be null"); Map userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); + sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId); if (userSessionsMap == null || userSessionsMap.isEmpty()) { return Single.just(ListSessionsResponse.builder().build()); @@ -185,11 +185,12 @@ public Completable deleteSession(String appName, String userId, String sessionId Objects.requireNonNull(userId, "userId cannot be null"); Objects.requireNonNull(sessionId, "sessionId cannot be null"); - ConcurrentMap userSessionsMap = - sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId); - - if (userSessionsMap != null) { - userSessionsMap.remove(sessionId); + ConcurrentMap> appSessionsMap = sessions.get(appName); + if (appSessionsMap != null) { + ConcurrentMap userSessionsMap = appSessionsMap.get(userId); + if (userSessionsMap != null) { + userSessionsMap.remove(sessionId); + } } return Completable.complete(); } @@ -202,8 +203,8 @@ public Single listEvents(String appName, String userId, Stri Session storedSession = sessions - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .get(sessionId); if (storedSession == null) { @@ -236,17 +237,34 @@ public Single appendEvent(Session session, Event event) { (key, value) -> { if (key.startsWith(State.APP_PREFIX)) { String appStateKey = key.substring(State.APP_PREFIX.length()); - appState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .put(appStateKey, value); + if (value == State.REMOVED) { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .remove(appStateKey); + } else { + appState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .put(appStateKey, value); + } } else if (key.startsWith(State.USER_PREFIX)) { String userStateKey = key.substring(State.USER_PREFIX.length()); - userState - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) - .put(userStateKey, value); - } else { - session.state().put(key, value); + if (value == State.REMOVED) { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .remove(userStateKey); + } else { + userState + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) + .put(userStateKey, value); + } + } else if (!key.startsWith(State.TEMP_PREFIX)) { + if (value == State.REMOVED) { + session.state().remove(key); + } else { + session.state().put(key, value); + } } }); } @@ -257,8 +275,8 @@ public Single appendEvent(Session session, Event event) { // --- Update the session stored in this service --- sessions - .computeIfAbsent(appName, k -> new ConcurrentHashMap<>()) - .computeIfAbsent(userId, k -> new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .put(sessionId, session); mergeWithGlobalState(appName, userId, session); @@ -307,12 +325,12 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess // Merge App State directly into the session's state map appState - .getOrDefault(appName, new ConcurrentHashMap()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value)); userState - .getOrDefault(appName, new ConcurrentHashMap<>()) - .getOrDefault(userId, new ConcurrentHashMap<>()) + .computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()) + .computeIfAbsent(userId, unused -> new ConcurrentHashMap<>()) .forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value)); return session; diff --git a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java index 5dbbe76c7..d1a661a91 100644 --- a/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java +++ b/core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java @@ -91,7 +91,7 @@ static String convertEventToJson(Event event) { if (event.actions() != null) { Map actionsJson = new HashMap<>(); actionsJson.put("skipSummarization", event.actions().skipSummarization()); - actionsJson.put("stateDelta", event.actions().stateDelta()); + actionsJson.put("stateDelta", stateDeltaToJson(event.actions().stateDelta())); actionsJson.put("artifactDelta", event.actions().artifactDelta()); actionsJson.put("transferAgent", event.actions().transferToAgent()); actionsJson.put("escalate", event.actions().escalate()); @@ -126,8 +126,7 @@ static String convertEventToJson(Event event) { * @return parsed {@link Content}, or {@code null} if conversion fails. */ @Nullable - // Safe because we check instanceof Map before casting. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. private static Content convertMapToContent(Object rawContentValue) { if (rawContentValue == null) { return null; @@ -153,8 +152,7 @@ private static Content convertMapToContent(Object rawContentValue) { * * @return parsed {@link Event}. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. static Event fromApiEvent(Map apiEvent) { EventActions.Builder eventActionsBuilder = EventActions.builder(); if (apiEvent.get("actions") != null) { @@ -162,10 +160,7 @@ static Event fromApiEvent(Map apiEvent) { if (actionsMap.get("skipSummarization") != null) { eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization")); } - eventActionsBuilder.stateDelta( - actionsMap.get("stateDelta") != null - ? new ConcurrentHashMap<>((Map) actionsMap.get("stateDelta")) - : new ConcurrentHashMap<>()); + eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta"))); eventActionsBuilder.artifactDelta( actionsMap.get("artifactDelta") != null ? convertToArtifactDeltaMap(actionsMap.get("artifactDelta")) @@ -238,6 +233,32 @@ static Event fromApiEvent(Map apiEvent) { return event; } + @SuppressWarnings("unchecked") // stateDeltaFromMap is a Map from JSON. + private static ConcurrentMap stateDeltaFromJson(Object stateDeltaFromMap) { + if (stateDeltaFromMap == null) { + return new ConcurrentHashMap<>(); + } + return ((Map) stateDeltaFromMap) + .entrySet().stream() + .collect( + ConcurrentHashMap::new, + (map, entry) -> + map.put( + entry.getKey(), + entry.getValue() == null ? State.REMOVED : entry.getValue()), + ConcurrentHashMap::putAll); + } + + private static Map stateDeltaToJson(Map stateDelta) { + return stateDelta.entrySet().stream() + .collect( + HashMap::new, + (map, entry) -> + map.put( + entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()), + HashMap::putAll); + } + /** * Converts a timestamp from a Map or String into an {@link Instant}. * @@ -263,8 +284,7 @@ private static Instant convertToInstant(Object timestampObj) { * @param artifactDeltaObj The raw object from which to parse the artifact delta. * @return A {@link ConcurrentMap} representing the artifact delta. */ - // Safe because we check instanceof Map before casting. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting. private static ConcurrentMap convertToArtifactDeltaMap(Object artifactDeltaObj) { if (!(artifactDeltaObj instanceof Map)) { return new ConcurrentHashMap<>(); @@ -287,8 +307,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti * * @return thread-safe nested map. */ - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap> asConcurrentMapOfConcurrentMaps(Object value) { return ((Map>) value) @@ -299,8 +318,7 @@ private static ConcurrentMap convertToArtifactDeltaMap(Object arti ConcurrentHashMap::putAll); } - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfAgentState(Object value) { return ((Map) value) .entrySet().stream() @@ -313,8 +331,7 @@ private static ConcurrentMap asConcurrentMapOfAgentState ConcurrentHashMap::putAll); } - // Safe because we are parsing from a raw Map structure that follows a known schema. - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema. private static ConcurrentMap asConcurrentMapOfToolConfirmations( Object value) { return ((Map) value) diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index a531361f2..1a8dbc527 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -28,6 +28,7 @@ import com.google.adk.events.Event; import com.google.adk.runner.InMemoryRunner; import com.google.adk.runner.Runner; +import com.google.adk.sessions.State; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -154,7 +155,7 @@ public Single> runAsync(Map args, ToolContex if (lastEvent.actions() != null && lastEvent.actions().stateDelta() != null && !lastEvent.actions().stateDelta().isEmpty()) { - toolContext.state().putAll(lastEvent.actions().stateDelta()); + updateState(lastEvent.actions().stateDelta(), toolContext.state()); } if (outputText.isEmpty()) { @@ -174,4 +175,24 @@ public Single> runAsync(Map args, ToolContex } }); } + + /** + * Updates the given state map with the state delta. + * + *

If a value in the delta is {@link State#REMOVED}, the key is removed from the state map. + * Otherwise, the key-value pair is put into the state map. This method does not distinguish + * between session, app, and user state based on key prefixes. + * + * @param state The state map to update. + */ + private void updateState(Map stateDelta, Map state) { + stateDelta.forEach( + (key, value) -> { + if (value == State.REMOVED) { + state.remove(key); + } else { + state.put(key, value); + } + }); + } } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index a9e3693d5..18870ad44 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.Part; @@ -97,4 +98,13 @@ public void merge_mergesAllFields() { assertThat(merged.endInvocation()).hasValue(true); assertThat(merged.compaction()).hasValue(COMPACTION); } + + @Test + public void removeStateByKey_marksKeyAsRemoved() { + EventActions eventActions = new EventActions(); + eventActions.stateDelta().put("key1", "value1"); + eventActions.removeStateByKey("key1"); + + assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 4c35f5b90..97b182249 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -89,8 +89,9 @@ public void lifecycle_listSessions() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("_app_appKey", "appValue"); - stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("app:appKey", "appValue"); + stateDelta.put("user:userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -105,8 +106,9 @@ public void lifecycle_listSessions() { assertThat(listedSession.id()).isEqualTo(session.id()); assertThat(listedSession.events()).isEmpty(); assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue"); - assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(listedSession.state()).containsEntry("app:appKey", "appValue"); + assertThat(listedSession.state()).containsEntry("user:userKey", "userValue"); + assertThat(listedSession.state()).doesNotContainKey("temp:tempKey"); } @Test @@ -134,8 +136,9 @@ public void appendEvent_updatesSessionState() { ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); - stateDelta.put("_app_appKey", "appValue"); - stateDelta.put("_user_userKey", "userValue"); + stateDelta.put("app:appKey", "appValue"); + stateDelta.put("user:userKey", "userValue"); + stateDelta.put("temp:tempKey", "tempValue"); Event event = Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); @@ -145,8 +148,9 @@ public void appendEvent_updatesSessionState() { // After appendEvent, session state in memory should contain session-specific state from delta // and merged global state. assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).containsEntry("_app_appKey", "appValue"); - assertThat(session.state()).containsEntry("_user_userKey", "userValue"); + assertThat(session.state()).containsEntry("app:appKey", "appValue"); + assertThat(session.state()).containsEntry("user:userKey", "userValue"); + assertThat(session.state()).doesNotContainKey("temp:tempKey"); // getSession should return session with merged state. Session retrievedSession = @@ -154,7 +158,62 @@ public void appendEvent_updatesSessionState() { .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) .blockingGet(); assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue"); - assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue"); + assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue"); + assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue"); + assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey"); + } + + @Test + public void appendEvent_removesState() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService + .createSession("app", "user", new ConcurrentHashMap<>(), "session1") + .blockingGet(); + + ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); + stateDeltaAdd.put("sessionKey", "sessionValue"); + stateDeltaAdd.put("app:appKey", "appValue"); + stateDeltaAdd.put("user:userKey", "userValue"); + stateDeltaAdd.put("temp:tempKey", "tempValue"); + + Event eventAdd = + Event.builder().actions(EventActions.builder().stateDelta(stateDeltaAdd).build()).build(); + + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify state is added + Session retrievedSessionAdd = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue"); + assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue"); + assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey"); + + // Prepare and append event to remove state + ConcurrentMap stateDeltaRemove = new ConcurrentHashMap<>(); + stateDeltaRemove.put("sessionKey", State.REMOVED); + stateDeltaRemove.put("app:appKey", State.REMOVED); + stateDeltaRemove.put("user:userKey", State.REMOVED); + stateDeltaRemove.put("temp:tempKey", State.REMOVED); + + Event eventRemove = + Event.builder() + .actions(EventActions.builder().stateDelta(stateDeltaRemove).build()) + .build(); + + unused = sessionService.appendEvent(session, eventRemove).blockingGet(); + + // Verify state is removed + Session retrievedSessionRemove = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey"); + assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } } diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 5e8f3d992..111b1dce3 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -164,6 +165,18 @@ private ApiResponse handleAppendEvent(String path, InvocationOnMock invocation) eventsData.add(newEventData); eventMap.put(sessionId, mapper.writeValueAsString(eventsData)); + + // Apply stateDelta to session state + extractObjectMap(newEventData, "actions") + .flatMap(actions -> extractObjectMap(actions, "stateDelta")) + .ifPresent( + stateDelta -> { + try { + applyStateDelta(sessionId, stateDelta); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); } catch (Exception e) { throw new RuntimeException(e); } @@ -213,4 +226,30 @@ private ApiResponse handleDeleteSession(String path) { sessionMap.remove(sessionIdToDelete); return responseWithBody(""); } + + private void applyStateDelta(String sessionId, Map stateDelta) throws Exception { + String sessionDataString = sessionMap.get(sessionId); + if (sessionDataString == null) { + return; + } + Map sessionData = + mapper.readValue(sessionDataString, new TypeReference>() {}); + Map sessionState = + extractObjectMap(sessionData, "sessionState").map(HashMap::new).orElseGet(HashMap::new); + + for (Map.Entry entry : stateDelta.entrySet()) { + if (entry.getValue() == null) { + sessionState.remove(entry.getKey()); + } else { + sessionState.put(entry.getKey(), entry.getValue()); + } + } + sessionData.put("sessionState", sessionState); + sessionMap.put(sessionId, mapper.writeValueAsString(sessionData)); + } + + @SuppressWarnings("unchecked") // Safe because map values are Maps read from JSON. + private Optional> extractObjectMap(Map map, String key) { + return Optional.ofNullable((Map) map.get(key)); + } } diff --git a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java index b77d6f267..827e810aa 100644 --- a/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java +++ b/core/src/test/java/com/google/adk/sessions/SessionJsonConverterTest.java @@ -221,4 +221,51 @@ public void fromApiEvent_missingMetadataFields_success() { assertThat(event.turnComplete().get()).isFalse(); assertThat(event.interrupted().get()).isFalse(); } + + @Test + public void convertEventToJson_withStateRemoved_success() throws JsonProcessingException { + EventActions actions = + EventActions.builder() + .stateDelta( + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", State.REMOVED))) + .build(); + + Event event = + Event.builder() + .author("user") + .invocationId("inv-123") + .timestamp(Instant.parse("2023-01-01T00:00:00Z").toEpochMilli()) + .actions(actions) + .build(); + + String json = SessionJsonConverter.convertEventToJson(event); + JsonNode jsonNode = objectMapper.readTree(json); + + JsonNode actionsNode = jsonNode.get("actions"); + assertThat(actionsNode.get("stateDelta").get("key1").asText()).isEqualTo("value1"); + assertThat(actionsNode.get("stateDelta").get("key2").isNull()).isTrue(); + } + + @Test + public void fromApiEvent_withNullStateDeltaValue_success() { + Map apiEvent = new HashMap<>(); + apiEvent.put("name", "sessions/123/events/456"); + apiEvent.put("invocationId", "inv-123"); + apiEvent.put("author", "model"); + apiEvent.put("timestamp", "2023-01-01T00:00:00Z"); + + Map stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", null); + + Map actions = new HashMap<>(); + actions.put("stateDelta", stateDelta); + apiEvent.put("actions", actions); + + Event event = SessionJsonConverter.fromApiEvent(apiEvent); + + EventActions eventActions = event.actions(); + assertThat(eventActions.stateDelta()).containsEntry("key1", "value1"); + assertThat(eventActions.stateDelta()).containsEntry("key2", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 775b465ff..36eab1d16 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -337,4 +337,32 @@ public void listEmptySession_success() { .events()) .isEmpty(); } + + @Test + public void appendEvent_withStateRemoved_updatesSessionState() { + String userId = "userB"; + ConcurrentMap initialState = + new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + Session session = + vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); + + ConcurrentMap stateDelta = + new ConcurrentHashMap<>(ImmutableMap.of("key2", State.REMOVED)); + Event event = + Event.builder() + .invocationId("456") + .author(userId) + .timestamp(Instant.parse("2024-12-12T12:12:12.123456Z").toEpochMilli()) + .actions(EventActions.builder().stateDelta(stateDelta).build()) + .build(); + var unused = vertexAiSessionService.appendEvent(session, event).blockingGet(); + + Session updatedSession = + vertexAiSessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + + assertThat(updatedSession.state()).containsExactly("key1", "value1"); + assertThat(updatedSession.state()).doesNotContainKey("key2"); + } }