Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions core/src/main/java/com/google/adk/tools/AgentTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import io.reactivex.rxjava3.core.Single;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -83,15 +84,42 @@ BaseAgent getAgent() {
return agent;
}

private Optional<Schema> getInputSchema(BaseAgent agent) {
BaseAgent currentAgent = agent;
while (true) {
if (currentAgent instanceof LlmAgent llmAgent) {
return llmAgent.inputSchema();
}
List<? extends BaseAgent> subAgents = currentAgent.subAgents();
if (subAgents == null || subAgents.isEmpty()) {
return Optional.empty();
}
// For composite agents, check the first sub-agent.
currentAgent = subAgents.get(0);
}
}

private Optional<Schema> getOutputSchema(BaseAgent agent) {
BaseAgent currentAgent = agent;
while (true) {
if (currentAgent instanceof LlmAgent llmAgent) {
return llmAgent.outputSchema();
}
List<? extends BaseAgent> subAgents = currentAgent.subAgents();
if (subAgents == null || subAgents.isEmpty()) {
return Optional.empty();
}
// For composite agents, check the last sub-agent.
currentAgent = subAgents.get(subAgents.size() - 1);
}
}

@Override
public Optional<FunctionDeclaration> declaration() {
FunctionDeclaration.Builder builder =
FunctionDeclaration.builder().description(this.description()).name(this.name());

Optional<Schema> agentInputSchema = Optional.empty();
if (agent instanceof LlmAgent llmAgent) {
agentInputSchema = llmAgent.inputSchema();
}
Optional<Schema> agentInputSchema = getInputSchema(agent);

if (agentInputSchema.isPresent()) {
builder.parameters(agentInputSchema.get());
Expand All @@ -113,10 +141,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
toolContext.setActions(toolContext.actions().toBuilder().skipSummarization(true).build());
}

Optional<Schema> agentInputSchema = Optional.empty();
if (agent instanceof LlmAgent llmAgent) {
agentInputSchema = llmAgent.inputSchema();
}
Optional<Schema> agentInputSchema = getInputSchema(agent);

final Content content;
if (agentInputSchema.isPresent()) {
Expand Down Expand Up @@ -163,10 +188,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
}
String output = outputText.get();

Optional<Schema> agentOutputSchema = Optional.empty();
if (agent instanceof LlmAgent llmAgent) {
agentOutputSchema = llmAgent.outputSchema();
}
Optional<Schema> agentOutputSchema = getOutputSchema(agent);

if (agentOutputSchema.isPresent()) {
return SchemaUtils.validateOutputSchema(output, agentOutputSchema.get());
Expand Down
173 changes: 172 additions & 1 deletion core/src/test/java/com/google/adk/tools/AgentToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.SequentialAgent;
import com.google.adk.models.LlmResponse;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
Expand Down Expand Up @@ -451,7 +453,176 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio
assertThat(toolContext.state()).containsEntry("test_key", "test_value");
}

private static ToolContext createToolContext(LlmAgent agent) {
@Test
public void
declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() {
Schema inputSchema =
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"query",
Schema.builder().type("STRING").build(),
"language",
Schema.builder().type("STRING").build()))
.required(ImmutableList.of("query", "language"))
.build();
LlmAgent firstAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("first_agent")
.inputSchema(inputSchema)
.build();
LlmAgent secondAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("second_agent")
.build();
SequentialAgent sequentialAgent =
SequentialAgent.builder()
.name("sequence")
.description("Process the query through multiple steps")
.subAgents(ImmutableList.of(firstAgent, secondAgent))
.build();
AgentTool agentTool = AgentTool.create(sequentialAgent);

FunctionDeclaration declaration = agentTool.declaration().get();

assertThat(declaration.name().get()).isEqualTo("sequence");
assertThat(declaration.description().get())
.isEqualTo("Process the query through multiple steps");
assertThat(declaration.parameters().get()).isEqualTo(inputSchema);
}

@Test
public void declaration_sequentialAgentWithoutInputSchema_fallsBackToRequest() {
LlmAgent firstAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("first_agent")
.build();
LlmAgent secondAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("second_agent")
.build();
SequentialAgent sequentialAgent =
SequentialAgent.builder()
.name("sequence")
.description("Process the query through multiple steps")
.subAgents(ImmutableList.of(firstAgent, secondAgent))
.build();
AgentTool agentTool = AgentTool.create(sequentialAgent);

FunctionDeclaration declaration = agentTool.declaration().get();

assertThat(declaration.name().get()).isEqualTo("sequence");
assertThat(declaration.description().get())
.isEqualTo("Process the query through multiple steps");
assertThat(declaration.parameters().get())
.isEqualTo(
Schema.builder()
.type("OBJECT")
.properties(ImmutableMap.of("request", Schema.builder().type("STRING").build()))
.required(ImmutableList.of("request"))
.build());
}

@Test
public void call_sequentialAgentWithLastSubAgentOutputSchema_successful() throws Exception {
Schema outputSchema =
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"is_valid",
Schema.builder().type("BOOLEAN").build(),
"message",
Schema.builder().type("STRING").build()))
.required(ImmutableList.of("is_valid", "message"))
.build();
LlmAgent firstAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("first_agent")
.build();
LlmAgent secondAgent =
createTestAgentBuilder(
createTestLlm(
LlmResponse.builder()
.content(
Content.fromParts(
Part.fromText(
"{\"is_valid\": true, " + "\"message\": \"success\"}")))
.build()))
.name("second_agent")
.outputSchema(outputSchema)
.build();
SequentialAgent sequentialAgent =
SequentialAgent.builder()
.name("sequence")
.description("Process the query through multiple steps")
.subAgents(ImmutableList.of(firstAgent, secondAgent))
.build();
AgentTool agentTool = AgentTool.create(sequentialAgent);
ToolContext toolContext = createToolContext(sequentialAgent);

Map<String, Object> result =
agentTool.runAsync(ImmutableMap.of("request", "test"), toolContext).blockingGet();

assertThat(result).containsExactly("is_valid", true, "message", "success");
}

@Test
public void declaration_nestedSequentialAgentInputSchema_returnsDeclarationWithSchema() {
Schema inputSchema =
Schema.builder()
.type("OBJECT")
.properties(ImmutableMap.of("deep_query", Schema.builder().type("STRING").build()))
.required(ImmutableList.of("deep_query"))
.build();
LlmAgent innerAgent =
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
.name("inner_agent")
.inputSchema(inputSchema)
.build();
SequentialAgent innerSequence =
SequentialAgent.builder()
.name("inner_sequence")
.subAgents(ImmutableList.of(innerAgent))
.build();
SequentialAgent outerSequence =
SequentialAgent.builder()
.name("outer_sequence")
.description("Nested sequence")
.subAgents(ImmutableList.of(innerSequence))
.build();
AgentTool agentTool = AgentTool.create(outerSequence);

FunctionDeclaration declaration = agentTool.declaration().get();

assertThat(declaration.name().get()).isEqualTo("outer_sequence");
assertThat(declaration.parameters().get()).isEqualTo(inputSchema);
}

@Test
public void declaration_emptySequentialAgent_fallsBackToRequest() {
SequentialAgent sequentialAgent =
SequentialAgent.builder()
.name("empty_sequence")
.description("An empty sequence")
.subAgents(ImmutableList.of())
.build();
AgentTool agentTool = AgentTool.create(sequentialAgent);

FunctionDeclaration declaration = agentTool.declaration().get();

assertThat(declaration.name().get()).isEqualTo("empty_sequence");
assertThat(declaration.parameters().get())
.isEqualTo(
Schema.builder()
.type("OBJECT")
.properties(ImmutableMap.of("request", Schema.builder().type("STRING").build()))
.required(ImmutableList.of("request"))
.build());
}

private static ToolContext createToolContext(BaseAgent agent) {
return ToolContext.builder(
InvocationContext.builder()
.invocationId(InvocationContext.newInvocationContextId())
Expand Down