diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java new file mode 100644 index 000000000..6df01694a --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java @@ -0,0 +1,185 @@ +package com.google.adk.a2a; + +import com.google.adk.a2a.converters.EventConverter; +import com.google.adk.a2a.converters.PartConverter; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.events.EventQueue; +import io.a2a.server.tasks.TaskUpdater; +import io.a2a.spec.InvalidAgentResponseError; +import io.a2a.spec.Message; +import io.a2a.spec.Part; +import io.a2a.spec.TextPart; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.disposables.CompositeDisposable; +import io.reactivex.rxjava3.disposables.Disposable; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of the A2A AgentExecutor interface that uses ADK to execute agent tasks. + * + *

**EXPERIMENTAL:** Subject to change, rename, or removal in any future patch release. Do not + * use in production code. + */ +public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor { + + private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); + private static final String USER_ID_PREFIX = "A2A_USER_"; + private static final RunConfig DEFAULT_RUN_CONFIG = + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); + + private final Runner runner; + private final Map activeTasks = new ConcurrentHashMap<>(); + + private AgentExecutor(Runner runner) { + this.runner = runner; + } + + /** Builder for {@link AgentExecutor}. */ + public static class Builder { + private Runner runner; + + @CanIgnoreReturnValue + public Builder runner(Runner runner) { + this.runner = runner; + return this; + } + + @CanIgnoreReturnValue + public AgentExecutor build() { + if (runner == null) { + throw new IllegalStateException("Runner must be provided."); + } + return new AgentExecutor(runner); + } + } + + @Override + public void cancel(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + updater.cancel(); + cleanupTask(ctx.getTaskId()); + } + + @Override + public void execute(RequestContext ctx, EventQueue eventQueue) { + TaskUpdater updater = new TaskUpdater(ctx, eventQueue); + Message message = ctx.getMessage(); + if (message == null) { + throw new IllegalArgumentException("Message cannot be null"); + } + + // Submits a new task if there is no active task. + if (ctx.getTask() == null) { + updater.submit(); + } + + // Group all reactive work for this task into one container + CompositeDisposable taskDisposables = new CompositeDisposable(); + // Check if the task with the task id is already running, put if absent. + if (activeTasks.putIfAbsent(ctx.getTaskId(), taskDisposables) != null) { + throw new IllegalStateException(String.format("Task %s already running", ctx.getTaskId())); + } + + EventProcessor p = new EventProcessor(); + Content content = PartConverter.messageToContent(message); + + taskDisposables.add( + prepareSession(ctx, runner.sessionService()) + .flatMapPublisher( + session -> { + updater.startWork(); + return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + }) + .subscribe( + event -> { + p.process(event, updater); + }, + error -> { + logger.error("Runner failed with {}", error); + updater.fail(failedMessage(ctx, error)); + cleanupTask(ctx.getTaskId()); + }, + () -> { + updater.complete(); + cleanupTask(ctx.getTaskId()); + })); + } + + private void cleanupTask(String taskId) { + Disposable d = activeTasks.remove(taskId); + if (d != null) { + d.dispose(); // Stops all streams in the CompositeDisposable + } + } + + private String getUserId(RequestContext ctx) { + return USER_ID_PREFIX + ctx.getContextId(); + } + + private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + return service + .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .switchIfEmpty( + Maybe.defer( + () -> { + return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + })); + } + + private static Message failedMessage(RequestContext context, Throwable e) { + return new Message.Builder() + .messageId(UUID.randomUUID().toString()) + .contextId(context.getContextId()) + .taskId(context.getTaskId()) + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart(e.getMessage()))) + .build(); + } + + // Processor that will process all events related to the one runner invocation. + private static class EventProcessor { + + // All artifacts related to the invocation should have the same artifact id. + private EventProcessor() { + artifactId = UUID.randomUUID().toString(); + } + + private final String artifactId; + + private void process(Event event, TaskUpdater updater) { + if (event.errorCode().isPresent()) { + throw new InvalidAgentResponseError( + null, // Uses default code -32006 + "Agent returned an error: " + event.errorCode().get(), + null); + } + + ImmutableList> parts = EventConverter.contentToParts(event.content()); + + // Mark all parts as partial if the event is partial. + if (event.partial().orElse(false)) { + parts.forEach( + part -> { + Map metadata = part.getMetadata(); + metadata.put("adk_partial", true); + }); + } + + updater.addArtifact(parts, artifactId, null, ImmutableMap.of()); + } + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java index cd8bcefb0..f5b1178c0 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java @@ -1,7 +1,10 @@ package com.google.adk.a2a.converters; +import static com.google.common.collect.ImmutableList.toImmutableList; + import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; import io.a2a.spec.Message; @@ -37,6 +40,16 @@ public enum AggregationMode { EXTERNAL_HANDOFF } + public static ImmutableList> contentToParts(Optional content) { + if (content.isPresent() && content.get().parts().isPresent()) { + return content.get().parts().get().stream() + .map(PartConverter::fromGenaiPart) + .flatMap(Optional::stream) + .collect(toImmutableList()); + } + return ImmutableList.of(); + } + public static Optional convertEventsToA2AMessage(InvocationContext context) { return convertEventsToA2AMessage(context, AggregationMode.AS_IS); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 0b5ea5503..c6ef06400 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -16,6 +17,7 @@ import io.a2a.spec.FilePart; import io.a2a.spec.FileWithBytes; import io.a2a.spec.FileWithUri; +import io.a2a.spec.Message; import io.a2a.spec.TextPart; import java.util.Base64; import java.util.HashMap; @@ -181,6 +183,17 @@ private static Optional convertDataPartToGenAiPart( } } + /** + * Converts an A2A Message to a Google GenAI Content object. + * + * @param message The A2A Message to convert. + * @return The converted Google GenAI Content object. + */ + public static Content messageToContent(Message message) { + ImmutableList parts = toGenaiParts(message.getParts()); + return Content.builder().role("user").parts(parts).build(); + } + /** * Creates an A2A DataPart from a Google GenAI FunctionResponse. * @@ -227,7 +240,7 @@ public static Optional> fromGenaiPart(Part part) { } if (part.text().isPresent()) { - return Optional.of(new TextPart(part.text().get())); + return Optional.of(new TextPart(part.text().get(), new HashMap<>())); } if (part.fileData().isPresent()) { @@ -235,7 +248,7 @@ public static Optional> fromGenaiPart(Part part) { String uri = fileData.fileUri().orElse(null); String mime = fileData.mimeType().orElse(null); String name = fileData.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithUri(mime, name, uri))); + return Optional.of(new FilePart(new FileWithUri(mime, name, uri), new HashMap<>())); } if (part.inlineData().isPresent()) { @@ -244,7 +257,7 @@ public static Optional> fromGenaiPart(Part part) { String encoded = bytes != null ? Base64.getEncoder().encodeToString(bytes) : null; String mime = blob.mimeType().orElse(null); String name = blob.displayName().orElse(null); - return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded))); + return Optional.of(new FilePart(new FileWithBytes(mime, name, encoded), new HashMap<>())); } if (part.functionCall().isPresent() || part.functionResponse().isPresent()) {