diff --git a/core/src/main/java/com/google/adk/tools/AgentTool.java b/core/src/main/java/com/google/adk/tools/AgentTool.java index 1a8dbc52..2a50605a 100644 --- a/core/src/main/java/com/google/adk/tools/AgentTool.java +++ b/core/src/main/java/com/google/adk/tools/AgentTool.java @@ -37,6 +37,7 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Single; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -83,15 +84,42 @@ BaseAgent getAgent() { return agent; } + private Optional getInputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.inputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the first sub-agent. + currentAgent = subAgents.get(0); + } + } + + private Optional getOutputSchema(BaseAgent agent) { + BaseAgent currentAgent = agent; + while (true) { + if (currentAgent instanceof LlmAgent llmAgent) { + return llmAgent.outputSchema(); + } + List subAgents = currentAgent.subAgents(); + if (subAgents == null || subAgents.isEmpty()) { + return Optional.empty(); + } + // For composite agents, check the last sub-agent. + currentAgent = subAgents.get(subAgents.size() - 1); + } + } + @Override public Optional declaration() { FunctionDeclaration.Builder builder = FunctionDeclaration.builder().description(this.description()).name(this.name()); - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); if (agentInputSchema.isPresent()) { builder.parameters(agentInputSchema.get()); @@ -113,10 +141,7 @@ public Single> runAsync(Map args, ToolContex toolContext.setActions(toolContext.actions().toBuilder().skipSummarization(true).build()); } - Optional agentInputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentInputSchema = llmAgent.inputSchema(); - } + Optional agentInputSchema = getInputSchema(agent); final Content content; if (agentInputSchema.isPresent()) { @@ -163,10 +188,7 @@ public Single> runAsync(Map args, ToolContex } String output = outputText.get(); - Optional agentOutputSchema = Optional.empty(); - if (agent instanceof LlmAgent llmAgent) { - agentOutputSchema = llmAgent.outputSchema(); - } + Optional agentOutputSchema = getOutputSchema(agent); if (agentOutputSchema.isPresent()) { return SchemaUtils.validateOutputSchema(output, agentOutputSchema.get()); diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index d43d9d03..c961e654 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -21,10 +21,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.SequentialAgent; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; @@ -451,7 +453,176 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio assertThat(toolContext.state()).containsEntry("test_key", "test_value"); } - private static ToolContext createToolContext(LlmAgent agent) { + @Test + public void + declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "query", + Schema.builder().type("STRING").build(), + "language", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("query", "language")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .inputSchema(inputSchema) + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_sequentialAgentWithoutInputSchema_fallsBackToRequest() { + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("second_agent") + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("sequence"); + assertThat(declaration.description().get()) + .isEqualTo("Process the query through multiple steps"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + @Test + public void call_sequentialAgentWithLastSubAgentOutputSchema_successful() throws Exception { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "is_valid", + Schema.builder().type("BOOLEAN").build(), + "message", + Schema.builder().type("STRING").build())) + .required(ImmutableList.of("is_valid", "message")) + .build(); + LlmAgent firstAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("first_agent") + .build(); + LlmAgent secondAgent = + createTestAgentBuilder( + createTestLlm( + LlmResponse.builder() + .content( + Content.fromParts( + Part.fromText( + "{\"is_valid\": true, " + "\"message\": \"success\"}"))) + .build())) + .name("second_agent") + .outputSchema(outputSchema) + .build(); + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("sequence") + .description("Process the query through multiple steps") + .subAgents(ImmutableList.of(firstAgent, secondAgent)) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + ToolContext toolContext = createToolContext(sequentialAgent); + + Map result = + agentTool.runAsync(ImmutableMap.of("request", "test"), toolContext).blockingGet(); + + assertThat(result).containsExactly("is_valid", true, "message", "success"); + } + + @Test + public void declaration_nestedSequentialAgentInputSchema_returnsDeclarationWithSchema() { + Schema inputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("deep_query", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("deep_query")) + .build(); + LlmAgent innerAgent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .name("inner_agent") + .inputSchema(inputSchema) + .build(); + SequentialAgent innerSequence = + SequentialAgent.builder() + .name("inner_sequence") + .subAgents(ImmutableList.of(innerAgent)) + .build(); + SequentialAgent outerSequence = + SequentialAgent.builder() + .name("outer_sequence") + .description("Nested sequence") + .subAgents(ImmutableList.of(innerSequence)) + .build(); + AgentTool agentTool = AgentTool.create(outerSequence); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("outer_sequence"); + assertThat(declaration.parameters().get()).isEqualTo(inputSchema); + } + + @Test + public void declaration_emptySequentialAgent_fallsBackToRequest() { + SequentialAgent sequentialAgent = + SequentialAgent.builder() + .name("empty_sequence") + .description("An empty sequence") + .subAgents(ImmutableList.of()) + .build(); + AgentTool agentTool = AgentTool.create(sequentialAgent); + + FunctionDeclaration declaration = agentTool.declaration().get(); + + assertThat(declaration.name().get()).isEqualTo("empty_sequence"); + assertThat(declaration.parameters().get()) + .isEqualTo( + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("request", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("request")) + .build()); + } + + private static ToolContext createToolContext(BaseAgent agent) { return ToolContext.builder( InvocationContext.builder() .invocationId(InvocationContext.newInvocationContextId())