Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/src/main/java/com/google/adk/events/EventActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.adk.agents.BaseAgentState;
import com.google.adk.sessions.State;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Part;
import java.util.Objects;
Expand Down Expand Up @@ -98,10 +99,20 @@ public ConcurrentMap<String, Object> stateDelta() {
return stateDelta;
}

@Deprecated // Use stateDelta(), addState() and removeStateByKey() instead.
public void setStateDelta(ConcurrentMap<String, Object> stateDelta) {
this.stateDelta = stateDelta;
}

/**
* Removes a key from the state delta.
*
* @param key The key to remove.
*/
public void removeStateByKey(String key) {
stateDelta.put(key, State.REMOVED);
}

@JsonProperty("artifactDelta")
public ConcurrentMap<String, Part> artifactDelta() {
return artifactDelta;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public Single<Session> createSession(
.build();

sessions
.computeIfAbsent(appName, k -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, k -> new ConcurrentHashMap<>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.put(resolvedSessionId, newSession);

// Create a mutable copy for the return value
Expand All @@ -116,8 +116,8 @@ public Maybe<Session> getSession(

Session storedSession =
sessions
.getOrDefault(appName, new ConcurrentHashMap<>())
.getOrDefault(userId, new ConcurrentHashMap<>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.get(sessionId);

if (storedSession == null) {
Expand Down Expand Up @@ -166,7 +166,7 @@ public Single<ListSessionsResponse> listSessions(String appName, String userId)
Objects.requireNonNull(userId, "userId cannot be null");

Map<String, Session> userSessionsMap =
sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId);
sessions.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>()).get(userId);

if (userSessionsMap == null || userSessionsMap.isEmpty()) {
return Single.just(ListSessionsResponse.builder().build());
Expand All @@ -185,11 +185,12 @@ public Completable deleteSession(String appName, String userId, String sessionId
Objects.requireNonNull(userId, "userId cannot be null");
Objects.requireNonNull(sessionId, "sessionId cannot be null");

ConcurrentMap<String, Session> userSessionsMap =
sessions.getOrDefault(appName, new ConcurrentHashMap<>()).get(userId);

if (userSessionsMap != null) {
userSessionsMap.remove(sessionId);
ConcurrentMap<String, ConcurrentMap<String, Session>> appSessionsMap = sessions.get(appName);
if (appSessionsMap != null) {
ConcurrentMap<String, Session> userSessionsMap = appSessionsMap.get(userId);
if (userSessionsMap != null) {
userSessionsMap.remove(sessionId);
}
}
return Completable.complete();
}
Expand All @@ -202,8 +203,8 @@ public Single<ListEventsResponse> listEvents(String appName, String userId, Stri

Session storedSession =
sessions
.getOrDefault(appName, new ConcurrentHashMap<>())
.getOrDefault(userId, new ConcurrentHashMap<>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.get(sessionId);

if (storedSession == null) {
Expand Down Expand Up @@ -236,17 +237,34 @@ public Single<Event> appendEvent(Session session, Event event) {
(key, value) -> {
if (key.startsWith(State.APP_PREFIX)) {
String appStateKey = key.substring(State.APP_PREFIX.length());
appState
.computeIfAbsent(appName, k -> new ConcurrentHashMap<>())
.put(appStateKey, value);
if (value == State.REMOVED) {
appState
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.remove(appStateKey);
} else {
appState
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.put(appStateKey, value);
}
} else if (key.startsWith(State.USER_PREFIX)) {
String userStateKey = key.substring(State.USER_PREFIX.length());
userState
.computeIfAbsent(appName, k -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, k -> new ConcurrentHashMap<>())
.put(userStateKey, value);
} else {
session.state().put(key, value);
if (value == State.REMOVED) {
userState
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.remove(userStateKey);
} else {
userState
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.put(userStateKey, value);
}
} else if (!key.startsWith(State.TEMP_PREFIX)) {
if (value == State.REMOVED) {
session.state().remove(key);
} else {
session.state().put(key, value);
}
}
});
}
Expand All @@ -257,8 +275,8 @@ public Single<Event> appendEvent(Session session, Event event) {

// --- Update the session stored in this service ---
sessions
.computeIfAbsent(appName, k -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, k -> new ConcurrentHashMap<>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.put(sessionId, session);

mergeWithGlobalState(appName, userId, session);
Expand Down Expand Up @@ -307,12 +325,12 @@ private Session mergeWithGlobalState(String appName, String userId, Session sess

// Merge App State directly into the session's state map
appState
.getOrDefault(appName, new ConcurrentHashMap<String, Object>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.forEach((key, value) -> sessionState.put(State.APP_PREFIX + key, value));

userState
.getOrDefault(appName, new ConcurrentHashMap<>())
.getOrDefault(userId, new ConcurrentHashMap<>())
.computeIfAbsent(appName, unused -> new ConcurrentHashMap<>())
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
.forEach((key, value) -> sessionState.put(State.USER_PREFIX + key, value));

return session;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static String convertEventToJson(Event event) {
if (event.actions() != null) {
Map<String, Object> actionsJson = new HashMap<>();
actionsJson.put("skipSummarization", event.actions().skipSummarization());
actionsJson.put("stateDelta", event.actions().stateDelta());
actionsJson.put("stateDelta", stateDeltaToJson(event.actions().stateDelta()));
actionsJson.put("artifactDelta", event.actions().artifactDelta());
actionsJson.put("transferAgent", event.actions().transferToAgent());
actionsJson.put("escalate", event.actions().escalate());
Expand Down Expand Up @@ -126,8 +126,7 @@ static String convertEventToJson(Event event) {
* @return parsed {@link Content}, or {@code null} if conversion fails.
*/
@Nullable
// Safe because we check instanceof Map before casting.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting.
private static Content convertMapToContent(Object rawContentValue) {
if (rawContentValue == null) {
return null;
Expand All @@ -153,19 +152,15 @@ private static Content convertMapToContent(Object rawContentValue) {
*
* @return parsed {@link Event}.
*/
// Safe because we are parsing from a raw Map structure that follows a known schema.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
static Event fromApiEvent(Map<String, Object> apiEvent) {
EventActions.Builder eventActionsBuilder = EventActions.builder();
if (apiEvent.get("actions") != null) {
Map<String, Object> actionsMap = (Map<String, Object>) apiEvent.get("actions");
if (actionsMap.get("skipSummarization") != null) {
eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization"));
}
eventActionsBuilder.stateDelta(
actionsMap.get("stateDelta") != null
? new ConcurrentHashMap<>((Map<String, Object>) actionsMap.get("stateDelta"))
: new ConcurrentHashMap<>());
eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta")));
eventActionsBuilder.artifactDelta(
actionsMap.get("artifactDelta") != null
? convertToArtifactDeltaMap(actionsMap.get("artifactDelta"))
Expand Down Expand Up @@ -238,6 +233,32 @@ static Event fromApiEvent(Map<String, Object> apiEvent) {
return event;
}

@SuppressWarnings("unchecked") // stateDeltaFromMap is a Map<String, Object> from JSON.
private static ConcurrentMap<String, Object> stateDeltaFromJson(Object stateDeltaFromMap) {
if (stateDeltaFromMap == null) {
return new ConcurrentHashMap<>();
}
return ((Map<String, Object>) stateDeltaFromMap)
.entrySet().stream()
.collect(
ConcurrentHashMap::new,
(map, entry) ->
map.put(
entry.getKey(),
entry.getValue() == null ? State.REMOVED : entry.getValue()),
ConcurrentHashMap::putAll);
}

private static Map<String, Object> stateDeltaToJson(Map<String, Object> stateDelta) {
return stateDelta.entrySet().stream()
.collect(
HashMap::new,
(map, entry) ->
map.put(
entry.getKey(), entry.getValue() == State.REMOVED ? null : entry.getValue()),
HashMap::putAll);
}

/**
* Converts a timestamp from a Map or String into an {@link Instant}.
*
Expand All @@ -263,8 +284,7 @@ private static Instant convertToInstant(Object timestampObj) {
* @param artifactDeltaObj The raw object from which to parse the artifact delta.
* @return A {@link ConcurrentMap} representing the artifact delta.
*/
// Safe because we check instanceof Map before casting.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting.
private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object artifactDeltaObj) {
if (!(artifactDeltaObj instanceof Map)) {
return new ConcurrentHashMap<>();
Expand All @@ -287,8 +307,7 @@ private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object arti
*
* @return thread-safe nested map.
*/
// Safe because we are parsing from a raw Map structure that follows a known schema.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
private static ConcurrentMap<String, ConcurrentMap<String, Object>>
asConcurrentMapOfConcurrentMaps(Object value) {
return ((Map<String, Map<String, Object>>) value)
Expand All @@ -299,8 +318,7 @@ private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object arti
ConcurrentHashMap::putAll);
}

// Safe because we are parsing from a raw Map structure that follows a known schema.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
private static ConcurrentMap<String, BaseAgentState> asConcurrentMapOfAgentState(Object value) {
return ((Map<String, Object>) value)
.entrySet().stream()
Expand All @@ -313,8 +331,7 @@ private static ConcurrentMap<String, BaseAgentState> asConcurrentMapOfAgentState
ConcurrentHashMap::putAll);
}

// Safe because we are parsing from a raw Map structure that follows a known schema.
@SuppressWarnings("unchecked")
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
private static ConcurrentMap<String, ToolConfirmation> asConcurrentMapOfToolConfirmations(
Object value) {
return ((Map<String, Object>) value)
Expand Down
23 changes: 22 additions & 1 deletion core/src/main/java/com/google/adk/tools/AgentTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.adk.events.Event;
import com.google.adk.runner.InMemoryRunner;
import com.google.adk.runner.Runner;
import com.google.adk.sessions.State;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -154,7 +155,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
if (lastEvent.actions() != null
&& lastEvent.actions().stateDelta() != null
&& !lastEvent.actions().stateDelta().isEmpty()) {
toolContext.state().putAll(lastEvent.actions().stateDelta());
updateState(lastEvent.actions().stateDelta(), toolContext.state());
}

if (outputText.isEmpty()) {
Expand All @@ -174,4 +175,24 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
}
});
}

/**
* Updates the given state map with the state delta.
*
* <p>If a value in the delta is {@link State#REMOVED}, the key is removed from the state map.
* Otherwise, the key-value pair is put into the state map. This method does not distinguish
* between session, app, and user state based on key prefixes.
*
* @param state The state map to update.
*/
private void updateState(Map<String, Object> stateDelta, Map<String, Object> state) {
stateDelta.forEach(
(key, value) -> {
if (value == State.REMOVED) {
state.remove(key);
} else {
state.put(key, value);
}
});
}
}
10 changes: 10 additions & 0 deletions core/src/test/java/com/google/adk/events/EventActionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.adk.sessions.State;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
Expand Down Expand Up @@ -97,4 +98,13 @@ public void merge_mergesAllFields() {
assertThat(merged.endInvocation()).hasValue(true);
assertThat(merged.compaction()).hasValue(COMPACTION);
}

@Test
public void removeStateByKey_marksKeyAsRemoved() {
EventActions eventActions = new EventActions();
eventActions.stateDelta().put("key1", "value1");
eventActions.removeStateByKey("key1");

assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED);
}
}
Loading