From 490c51c5b246b60a8bc4cb557ccbcb10590786a1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 09:43:23 -0800 Subject: [PATCH] feat: Adding SafeFilesAsArtifactsPlugin PiperOrigin-RevId: 862275776 --- .../plugins/SaveFilesAsArtifactsPlugin.java | 202 +++++++++++++ .../SaveFilesAsArtifactsPluginTest.java | 265 ++++++++++++++++++ 2 files changed, 467 insertions(+) create mode 100644 core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java create mode 100644 core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java diff --git a/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java b/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java new file mode 100644 index 00000000..1063bff0 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/SaveFilesAsArtifactsPlugin.java @@ -0,0 +1,202 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A plugin that saves files embedded in user messages as artifacts. + * + *

This is useful to allow users to upload files in the chat experience and have those files + * available to the agent within the current session. + * + *

Artifacts with the same name will be overwritten. A placeholder with the artifact name will be + * put in place of the embedded file in the user message so the model knows where to find the file. + * You may want to add load_artifacts tool to the agent, or load the artifacts in your own tool to + * use the files. + */ +public class SaveFilesAsArtifactsPlugin extends BasePlugin { + private static final Logger logger = LoggerFactory.getLogger(SaveFilesAsArtifactsPlugin.class); + + private static final ImmutableSet MODEL_ACCESSIBLE_URI_SCHEMES = + ImmutableSet.of("gs", "https", "http"); + + public SaveFilesAsArtifactsPlugin(String name) { + super(name); + } + + public SaveFilesAsArtifactsPlugin() { + this("save_files_as_artifacts_plugin"); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + if (invocationContext.artifactService() == null) { + logger.warn("Artifact service is not set. SaveFilesAsArtifactsPlugin will not be enabled."); + return Maybe.just(userMessage); + } + + if (userMessage.parts().isEmpty() + || userMessage.parts().stream() + .flatMap(List::stream) + .noneMatch(part -> part.inlineData().isPresent())) { + return Maybe.empty(); + } + + AtomicBoolean modified = new AtomicBoolean(false); + AtomicInteger index = new AtomicInteger(0); + + return Flowable.fromIterable(userMessage.parts().get()) + .concatMapSingle( + part -> { + if (part.inlineData().isEmpty()) { + return Single.just(ImmutableList.of(part)); + } + modified.set(true); + return saveArtifactAndBuildParts(invocationContext, part, index.getAndIncrement()); + }) + .toList() // Collects Single> into a Single>> + .map( + listOfLists -> + listOfLists.stream() + .flatMap(List::stream) + .collect(toImmutableList())) // Flatten the list of lists + .filter(unused -> modified.get()) + .map( + parts -> + Content.builder().parts(parts).role(userMessage.role().orElse("user")).build()); + } + + private Single> saveArtifactAndBuildParts( + InvocationContext invocationContext, Part part, int index) { + Blob inlineData = part.inlineData().get(); + String fileName = + inlineData + .displayName() + .filter(s -> !s.isEmpty()) + .orElseGet( + () -> { + String generatedName = + String.format("artifact_%s_%d", invocationContext.invocationId(), index); + logger.info("No display_name found, using generated filename: {}", generatedName); + return generatedName; + }); + String displayName = fileName; + + return invocationContext + .artifactService() + .saveArtifact( + invocationContext.appName(), + invocationContext.userId(), + invocationContext.session().id(), + fileName, + part) + .flatMap( + version -> { + logger.info("Successfully saved artifact: {}", fileName); + Part placeholderPart = + Part.fromText(String.format("[Uploaded Artifact: \"%s\"]", displayName)); + + return buildFileReferencePart( + invocationContext, fileName, version, inlineData.mimeType(), displayName) + .map(filePart -> ImmutableList.of(placeholderPart, filePart)) + .defaultIfEmpty(ImmutableList.of(placeholderPart)); + }) + .onErrorReturn( + e -> { + logger.error("Failed to save artifact for part {}: {}", index, e.getMessage()); + return ImmutableList.of(part); // Keep original part if saving fails + }); + } + + private Maybe buildFileReferencePart( + InvocationContext invocationContext, + String filename, + int version, + Optional mimeType, + String displayName) { + BaseArtifactService artifactService = invocationContext.artifactService(); + if (artifactService == null) { + return Maybe.empty(); + } + + return artifactService + .loadArtifact( + invocationContext.appName(), + invocationContext.userId(), + invocationContext.session().id(), + filename, + Optional.of(version)) + .flatMap( + artifact -> { + Optional optionalPart = + artifact + .fileData() + .filter(fd -> fd.fileUri().map(this::isModelAccessibleUri).orElse(false)) + .map( + fd -> + Part.builder() + .fileData( + FileData.builder() + .fileUri(fd.fileUri().get()) + .mimeType( + mimeType + .or(fd::mimeType) + .orElse("application/octet-stream")) + .displayName(displayName) + .build()) + .build()); + if (optionalPart.isPresent()) { + return Maybe.just(optionalPart.get()); + } + return Maybe.empty(); + }) + .doOnError(e -> logger.warn("Failed to resolve artifact version for {}: {}", filename, e)) + .onErrorComplete(); + } + + private boolean isModelAccessibleUri(String uri) { + try { + URI parsed = new URI(uri); + return parsed.getScheme() != null + && MODEL_ACCESSIBLE_URI_SCHEMES.contains(parsed.getScheme().toLowerCase(Locale.ROOT)); + } catch (URISyntaxException e) { + return false; + } + } +} diff --git a/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java b/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java new file mode 100644 index 00000000..663294b9 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/SaveFilesAsArtifactsPluginTest.java @@ -0,0 +1,265 @@ +package com.google.adk.plugins; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import com.google.protobuf.ByteString; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class SaveFilesAsArtifactsPluginTest { + + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + private static final String APP_NAME = "test_app"; + private static final String USER_ID = "test_user"; + private static final String SESSION_ID = "test_session"; + private static final String INVOCATION_ID = "test_invocation"; + + @Mock private BaseArtifactService mockArtifactService; + private Session session; + + private SaveFilesAsArtifactsPlugin plugin; + private InvocationContext invocationContext; + private InvocationContext invocationContextWithNoArtifactService; + + @Before + public void setUp() { + session = Session.builder(SESSION_ID).appName(APP_NAME).userId(USER_ID).build(); + + invocationContext = + InvocationContext.builder() + .invocationId(INVOCATION_ID) + .session(session) + .artifactService(mockArtifactService) + .build(); + invocationContextWithNoArtifactService = + InvocationContext.builder() + .invocationId(INVOCATION_ID) + .session(session) + .artifactService(null) + .build(); + plugin = new SaveFilesAsArtifactsPlugin(); + } + + private Part createInlineDataPart(String mimeType, String data) { + return createInlineDataPart(mimeType, data, Optional.empty()); + } + + private Part createInlineDataPart(String mimeType, String data, Optional displayName) { + Blob.Builder blobBuilder = + Blob.builder().mimeType(mimeType).data(ByteString.copyFromUtf8(data).toByteArray()); + displayName.ifPresent(blobBuilder::displayName); + return Part.builder().inlineData(blobBuilder.build()).build(); + } + + @Test + public void onUserMessageCallback_noArtifactService_returnsMessage() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + + plugin + .onUserMessageCallback(invocationContextWithNoArtifactService, userMessage) + .test() + .assertValue(userMessage); + } + + @Test + public void onUserMessageCallback_noInlineData_returnsEmpty() { + Content userMessage = Content.builder().parts(Part.fromText("hello")).role("user").build(); + plugin.onUserMessageCallback(invocationContext, userMessage).test().assertNoValues(); + } + + @Test + public void onUserMessageCallback_withInlineDataAndSuccessfulSaveAndNoUri_returnsTextPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Single.just(1)); + // Load artifact returns part without FileData + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(Optional.of(1)))) + .thenReturn(Maybe.just(Part.fromText("a part without file data"))); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).hasSize(1); + assertThat(result.parts().get().get(0).text()) + .hasValue("[Uploaded Artifact: \"" + fileName + "\"]"); + } + + @Test + public void + onUserMessageCallback_withInlineDataAndSuccessfulSaveAndAccessibleUri_returnsTextAndUriParts() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + String fileUri = "gs://my-bucket/artifact_test_invocation_0"; + String mimeType = "text/plain"; + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Single.just(1)); + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(Optional.of(1)))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri) + .mimeType(mimeType) + .displayName(fileName) + .build()) + .build())); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).hasSize(2); + assertThat(result.parts().get().get(0).text().get()) + .isEqualTo("[Uploaded Artifact: \"" + fileName + "\"]"); + assertThat(result.parts().get().get(1).fileData().get().fileUri().get()).isEqualTo(fileUri); + assertThat(result.parts().get().get(1).fileData().get().mimeType().get()).isEqualTo(mimeType); + assertThat(result.parts().get().get(1).fileData().get().displayName()).hasValue(fileName); + } + + @Test + public void + onUserMessageCallback_withInlineDataAndSuccessfulSaveAndInaccessibleUri_returnsTextPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + String fileUri = "file://my-bucket/artifact_test_invocation_0"; // Inaccessible scheme + String mimeType = "text/plain"; + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Single.just(1)); + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(Optional.of(1)))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri) + .mimeType(mimeType) + .displayName(fileName) + .build()) + .build())); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).hasSize(1); + assertThat(result.parts().get().get(0).text()) + .hasValue("[Uploaded Artifact: \"" + fileName + "\"]"); + } + + @Test + public void onUserMessageCallback_withInlineDataAndFailedSave_returnsOriginalPart() { + Part partWithInlineData = createInlineDataPart("text/plain", "hello"); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + String fileName = "artifact_" + INVOCATION_ID + "_0"; + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName), eq(partWithInlineData))) + .thenReturn(Single.error(new RuntimeException("Failed to save"))); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).containsExactly(partWithInlineData); + } + + @Test + public void onUserMessageCallback_withInlineDataAndMultipleParts_returnsMixedParts() { + Part textPart = Part.fromText("this is text"); + Part partWithInlineData1 = createInlineDataPart("text/plain", "inline1"); + Part partWithInlineData2 = createInlineDataPart("image/png", "inline2"); + Content userMessage = + Content.builder() + .parts(ImmutableList.of(textPart, partWithInlineData1, partWithInlineData2)) + .role("user") + .build(); + + String fileName1 = "artifact_" + INVOCATION_ID + "_0"; + String fileName2 = "artifact_" + INVOCATION_ID + "_1"; + String fileUri1 = "gs://my-bucket/artifact_test_invocation_0"; + String mimeType1 = "text/plain"; + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName1), eq(partWithInlineData1))) + .thenReturn(Single.just(1)); + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName2), eq(partWithInlineData2))) + .thenReturn(Single.just(2)); + + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName1), eq(Optional.of(1)))) + .thenReturn( + Maybe.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri(fileUri1) + .mimeType(mimeType1) + .displayName(fileName1) + .build()) + .build())); + // For 2nd artifact, do not return a file URI. + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(fileName2), eq(Optional.of(2)))) + .thenReturn(Maybe.empty()); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).hasSize(4); + assertThat(result.parts().get().get(0)).isEqualTo(textPart); + assertThat(result.parts().get().get(1).text()) + .hasValue("[Uploaded Artifact: \"" + fileName1 + "\"]"); + assertThat(result.parts().get().get(2).fileData().get().fileUri()).hasValue(fileUri1); + assertThat(result.parts().get().get(3).text()) + .hasValue("[Uploaded Artifact: \"" + fileName2 + "\"]"); + } + + @Test + public void onUserMessageCallback_withDisplayName_usesDisplayNameAsFileName() { + String displayName = "mydocument.txt"; + Part partWithInlineData = createInlineDataPart("text/plain", "hello", Optional.of(displayName)); + Content userMessage = Content.builder().parts(partWithInlineData).role("user").build(); + + when(mockArtifactService.saveArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(displayName), eq(partWithInlineData))) + .thenReturn(Single.just(1)); + when(mockArtifactService.loadArtifact( + eq(APP_NAME), eq(USER_ID), eq(SESSION_ID), eq(displayName), eq(Optional.of(1)))) + .thenReturn(Maybe.empty()); + + Content result = plugin.onUserMessageCallback(invocationContext, userMessage).blockingGet(); + + assertThat(result.parts().get()).hasSize(1); + assertThat(result.parts().get().get(0).text()) + .hasValue("[Uploaded Artifact: \"" + displayName + "\"]"); + } +}