From bfebfbc42779cf371e306b64af08950949584956 Mon Sep 17 00:00:00 2001 From: Salman Muin Kayser Chishti <13schishti@gmail.com> Date: Fri, 23 Jan 2026 08:49:52 +0000 Subject: [PATCH 01/29] Upgrade GitHub Actions for Node 24 compatibility Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/workflows/pr-commit-check.yml | 2 +- .github/workflows/validation.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-commit-check.yml b/.github/workflows/pr-commit-check.yml index ec6644311..1e31e42f3 100644 --- a/.github/workflows/pr-commit-check.yml +++ b/.github/workflows/pr-commit-check.yml @@ -21,7 +21,7 @@ jobs: # Step 1: Check out the code # This action checks out your repository under $GITHUB_WORKSPACE, so your workflow can access it. - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # We need to fetch all commits to accurately count them. # '0' means fetch all history for all branches and tags. diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index eeb16e1ff..26a276f05 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -20,16 +20,16 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: temurin java-version: ${{ matrix.java-version }} - name: Cache Maven packages - uses: actions/cache@v3 + uses: actions/cache@v5 with: path: ~/.m2/repository key: ${{ runner.os }}-maven-${{ matrix.java-version }}-${{ hashFiles('**/pom.xml') }} From 8624d59de41205800e5538e51e2b1416b79bca48 Mon Sep 17 00:00:00 2001 From: Michael Vorburger Date: Tue, 3 Mar 2026 12:44:46 +0100 Subject: [PATCH 02/29] dev: Introduce initial AGENTS.md Intentionally named AGENTS.md instead of e.g. GEMINI.md to be fully model neutral; see https://agents.md for background. --- AGENTS.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..5d33d2172 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,3 @@ +# AGENTS.md + +Validate changes by running `./mvnw test`. From b6356d27c4dfbafdaa5803cb766b57ec5f09091a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 06:03:42 -0700 Subject: [PATCH 03/29] chore: update mcp dependency version to 0.17.2 PiperOrigin-RevId: 884394270 --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 11696db73..bd0caca0d 100644 --- a/pom.xml +++ b/pom.xml @@ -49,7 +49,7 @@ cloud libraries. Once they update their otel dependencies we can consider updating ours here as well --> 1.51.0 - 0.14.0 + 0.17.2 2.47.0 1.41.0 4.33.5 From 7ebeb07bf2ee72475484d8a31ccf7b4c601dda96 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Mar 2026 07:00:53 -0700 Subject: [PATCH 04/29] feat: init AGENTS.md file PiperOrigin-RevId: 884415542 --- AGENTS.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..e69de29bb From 567fdf048fee49afc86ca5d7d35f55424a6016ba Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 09:11:14 -0700 Subject: [PATCH 05/29] fix: fix null handling in runAsyncImpl PiperOrigin-RevId: 884472852 --- .../java/com/google/adk/runner/Runner.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 29b2b76d3..5859c4786 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -458,6 +458,9 @@ protected Flowable runAsyncImpl( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { + Preconditions.checkNotNull(session, "session cannot be null"); + Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); + Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,19 +479,14 @@ protected Flowable runAsyncImpl( .defaultIfEmpty(newMessage) .flatMap( content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - initialContext, - runConfig.saveInputBlobsAsArtifacts(), - stateDelta) - : Single.just(null)) + appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta)) .flatMapPublisher( event -> { - if (event == null) { - return Flowable.empty(); - } // Get the updated session after the message and state delta are // applied return this.sessionService From b8cb7e2db6d5ce20f4d7a1b237bdc155563cf4bd Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Mon, 16 Mar 2026 09:50:32 -0700 Subject: [PATCH 06/29] feat: add type-safe runAsync methods to BaseTool PiperOrigin-RevId: 884493553 --- .../java/com/google/adk/tools/BaseTool.java | 81 +++++++++++++ .../com/google/adk/tools/BaseToolTest.java | 108 ++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/core/src/main/java/com/google/adk/tools/BaseTool.java b/core/src/main/java/com/google/adk/tools/BaseTool.java index 1ea2808a1..01a399920 100644 --- a/core/src/main/java/com/google/adk/tools/BaseTool.java +++ b/core/src/main/java/com/google/adk/tools/BaseTool.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.models.LlmRequest; @@ -38,6 +39,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import javax.annotation.Nonnull; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -93,6 +95,85 @@ public Single> runAsync(Map args, ToolContex throw new UnsupportedOperationException("This method is not implemented."); } + /** + * Calls a tool with generic arguments and returns a map of results. The args type {@code T} need + * to be serializable with {@link JsonBaseModel#getMapper()} + */ + public final Single> runAsync(T args, ToolContext toolContext) { + return runAsync(args, toolContext, JsonBaseModel.getMapper()); + } + + /** + * Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of + * results. The args type {@code T} needs to be serializable with the provided {@link + * ObjectMapper}. + */ + public final Single> runAsync( + T args, ToolContext toolContext, ObjectMapper objectMapper) { + return runAsync(args, toolContext, objectMapper, output -> output); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified class. The input type {@code I} needs to be serializable and the + * output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, ToolContext toolContext, ObjectMapper objectMapper, Class oClass) { + return runAsync( + args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass)); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified type reference. The input type {@code I} needs to be serializable and + * the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + TypeReference typeReference) { + return runAsync( + args, + toolContext, + objectMapper, + output -> objectMapper.convertValue(output, typeReference)); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified class. The + * input type {@code I} needs to be serializable and the output type {@code O} needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, Class oClass) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified type + * reference. The input type needs to be serializable and the output type needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, TypeReference typeReference) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference); + } + + private Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + Function, ? extends O> deserializer) { + return Single.defer( + () -> + Single.just( + objectMapper.convertValue(args, new TypeReference>() {}))) + .flatMap(argsMap -> runAsync(argsMap, toolContext)) + .map(deserializer::apply); + } + /** * Processes the outgoing {@link LlmRequest.Builder}. * diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index 2a07e7a44..d3c8da5aa 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,12 +2,15 @@ import static com.google.common.truth.Truth.assertThat; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GoogleMaps; @@ -17,6 +20,7 @@ import com.google.genai.types.UrlContext; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; import java.util.Map; import java.util.Optional; import org.junit.Test; @@ -27,6 +31,20 @@ @RunWith(JUnit4.class) public final class BaseToolTest { + private final BaseTool doublingBaseTool = + new BaseTool("doubling-test-tool", "returns doubled args") { + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + String sArg = (String) args.get("s"); + Integer iArg = (Integer) args.get("i"); + return Single.just( + ImmutableMap.of( + "s", sArg + sArg, + "i", iArg + iArg)); + } + }; + @Test public void processLlmRequestNoDeclarationReturnsSameRequest() { BaseTool tool = @@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); } + + @Test + public void runAsync_withTypeReference_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, new TypeReference() {}); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClass_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + + Single out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, + /* toolContext= */ null, + objectMapper, + new TypeReference() {}); + + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + public record TestToolArgs(int i, String s) {} } From fca43fbb9684ec8d080e437761f6bb4e38adf255 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Mar 2026 12:38:53 -0700 Subject: [PATCH 07/29] fix: prevent ConcurrentModificationException when session events are modified by another thread during iteration PiperOrigin-RevId: 884587639 --- .../google/adk/flows/llmflows/Contents.java | 12 ++-- .../adk/flows/llmflows/ContentsTest.java | 60 +++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 6ebd39a9c..840a370c6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -64,6 +64,11 @@ public Single processRequest( modelName = ""; } + ImmutableList sessionEvents; + synchronized (context.session().events()) { + sessionEvents = ImmutableList.copyOf(context.session().events()); + } + if (llmAgent.includeContents() == LlmAgent.IncludeContents.NONE) { return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -71,7 +76,7 @@ public Single processRequest( .contents( getCurrentTurnContents( context.branch().orElse(null), - context.session().events(), + sessionEvents, context.agent().name(), modelName)) .build(), @@ -80,10 +85,7 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch().orElse(null), - context.session().events(), - context.agent().name(), - modelName); + context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85e78666d..7164991f3 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,10 +36,13 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -780,6 +783,63 @@ public void processRequest_notEmptyContent() { assertThat(contents).containsExactly(e.content().get()); } + @Test + public void processRequest_concurrentReadAndWrite_noException() throws Exception { + LlmAgent agent = + LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + Session session = + sessionService + .createSession("test-app", "test-user", new HashMap<>(), "test-session") + .blockingGet(); + + // Seed with dummy events to widen the race capability + for (int i = 0; i < 5000; i++) { + session.events().add(createUserEvent("dummy" + i, "dummy")); + } + + InvocationContext context = + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .sessionService(sessionService) + .build(); + + LlmRequest initialRequest = LlmRequest.builder().build(); + + AtomicReference writerError = new AtomicReference<>(); + CountDownLatch startLatch = new CountDownLatch(1); + + Thread writerThread = + new Thread( + () -> { + startLatch.countDown(); + try { + for (int i = 0; i < 2000; i++) { + session.events().add(createUserEvent("writer" + i, "new data")); + } + } catch (Throwable t) { + writerError.set(t); + } + }); + + writerThread.start(); + startLatch.await(); // wait for writer to be ready + + // Process (read) requests concurrently to trigger race conditions + for (int i = 0; i < 200; i++) { + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + + writerThread.join(); + if (writerError.get() != null) { + throw new RuntimeException("Writer failed", writerError.get()); + } + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) From 28a8cd04ca9348dbe51a15d2be3a2b5307394174 Mon Sep 17 00:00:00 2001 From: Mateusz Krawiec Date: Tue, 17 Mar 2026 01:48:10 -0700 Subject: [PATCH 08/29] chore!: remove deprecated Example processor PiperOrigin-RevId: 884881559 --- .../java/com/google/adk/agents/LlmAgent.java | 47 +++------ .../com/google/adk/examples/ExampleUtils.java | 3 + .../google/adk/flows/llmflows/Examples.java | 57 ----------- .../google/adk/flows/llmflows/SingleFlow.java | 1 - .../com/google/adk/tools/ExampleTool.java | 4 +- .../com/google/adk/agents/LlmAgentTest.java | 28 ++++++ .../google/adk/examples/ExampleUtilsTest.java | 12 +-- .../adk/flows/llmflows/ExamplesTest.java | 99 ------------------- .../com/google/adk/tools/ExampleToolTest.java | 26 +++++ 9 files changed, 73 insertions(+), 204 deletions(-) delete mode 100644 core/src/main/java/com/google/adk/flows/llmflows/Examples.java delete mode 100644 core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index d326d8154..89024a59b 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -45,8 +45,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.events.Event; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; import com.google.adk.flows.llmflows.AutoFlow; import com.google.adk.flows.llmflows.BaseLlmFlow; import com.google.adk.flows.llmflows.SingleFlow; @@ -97,8 +95,6 @@ public enum IncludeContents { private final List toolsUnion; private final ImmutableList toolsets; private final Optional generateContentConfig; - // TODO: Remove exampleProvider field - examples should only be provided via ExampleTool - private final Optional exampleProvider; private final IncludeContents includeContents; private final boolean planning; @@ -132,7 +128,6 @@ protected LlmAgent(Builder builder) { this.globalInstruction = requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); - this.exampleProvider = Optional.ofNullable(builder.exampleProvider); this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); @@ -180,7 +175,6 @@ public static class Builder extends BaseAgent.Builder { private Instruction globalInstruction; private ImmutableList toolsUnion; private GenerateContentConfig generateContentConfig; - private BaseExampleProvider exampleProvider; private IncludeContents includeContents; private Boolean planning; private Integer maxSteps; @@ -253,26 +247,6 @@ public Builder generateContentConfig(GenerateContentConfig generateContentConfig return this; } - // TODO: Remove these example provider methods and only use ExampleTool for providing examples. - // Direct example methods should be deprecated in favor of using ExampleTool consistently. - @CanIgnoreReturnValue - public Builder exampleProvider(BaseExampleProvider exampleProvider) { - this.exampleProvider = exampleProvider; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(List examples) { - this.exampleProvider = (unused) -> examples; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(Example... examples) { - this.exampleProvider = (unused) -> ImmutableList.copyOf(examples); - return this; - } - @CanIgnoreReturnValue public Builder includeContents(IncludeContents includeContents) { this.includeContents = includeContents; @@ -640,10 +614,18 @@ protected void validate() { + " transfer."); } if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); + boolean hasOtherTools = + this.toolsUnion.stream() + .anyMatch( + tool -> + !(tool instanceof BaseTool baseTool) + || !baseTool.name().equals("example_tool")); + if (hasOtherTools) { + throw new IllegalArgumentException( + "Invalid config for agent " + + this.name + + ": if outputSchema is set, tools must be empty."); + } } } } @@ -812,11 +794,6 @@ public Optional generateContentConfig() { return generateContentConfig; } - // TODO: Remove this getter - examples should only be provided via ExampleTool - public Optional exampleProvider() { - return exampleProvider; - } - public IncludeContents includeContents() { return includeContents; } diff --git a/core/src/main/java/com/google/adk/examples/ExampleUtils.java b/core/src/main/java/com/google/adk/examples/ExampleUtils.java index 9cce535dc..2f3927ece 100644 --- a/core/src/main/java/com/google/adk/examples/ExampleUtils.java +++ b/core/src/main/java/com/google/adk/examples/ExampleUtils.java @@ -64,6 +64,9 @@ public final class ExampleUtils { * @return string representation of the examples block. */ private static String convertExamplesToText(List examples) { + if (examples.isEmpty()) { + return ""; + } StringBuilder examplesStr = new StringBuilder(); // super header diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java deleted file mode 100644 index d9cee5fa0..000000000 --- a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.examples.ExampleUtils; -import com.google.adk.models.LlmRequest; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Single; - -/** {@link RequestProcessor} that populates examples in LLM request. */ -public final class Examples implements RequestProcessor { - - public Examples() {} - - @Override - public Single processRequest( - InvocationContext context, LlmRequest request) { - if (!(context.agent() instanceof LlmAgent)) { - throw new IllegalArgumentException("Agent in InvocationContext is not an instance of Agent."); - } - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder builder = request.toBuilder(); - - String query = - context - .userContent() - .flatMap(Content::parts) - .filter(parts -> !parts.isEmpty()) - .map(parts -> parts.get(0).text().orElse("")) - .orElse(""); - agent - .exampleProvider() - .ifPresent( - exampleProvider -> - builder.appendInstructions( - ImmutableList.of(ExampleUtils.buildExampleSi(exampleProvider, query)))); - return Single.just( - RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); - } -} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index de45ba702..f56cc61c3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -32,7 +32,6 @@ public class SingleFlow extends BaseLlmFlow { new Identity(), new Compaction(), new Contents(), - new Examples(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/main/java/com/google/adk/tools/ExampleTool.java b/core/src/main/java/com/google/adk/tools/ExampleTool.java index d08481532..d03c2e4f1 100644 --- a/core/src/main/java/com/google/adk/tools/ExampleTool.java +++ b/core/src/main/java/com/google/adk/tools/ExampleTool.java @@ -85,7 +85,9 @@ public Completable processLlmRequest( return Completable.complete(); } - llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + if (!examplesBlock.isEmpty()) { + llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + } // Delegate to BaseTool to keep any declaration bookkeeping (none for this tool) return super.processLlmRequest(llmRequestBuilder, toolContext); } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 594e47fd8..3524c7755 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -35,6 +35,7 @@ import com.google.adk.agents.Callbacks.OnModelErrorCallback; import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; +import com.google.adk.examples.Example; import com.google.adk.models.LlmRegistry; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; @@ -46,6 +47,7 @@ import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ExampleTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -649,4 +651,30 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void run_withExampleTool_doesNotAddFunctionDeclarations() { + ExampleTool tool = + ExampleTool.builder() + .addExample( + Example.builder() + .input(Content.fromParts(Part.fromText("qin"))) + .output(ImmutableList.of(Content.fromParts(Part.fromText("qout")))) + .build()) + .build(); + + Content modelContent = Content.fromParts(Part.fromText("Real LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(tool).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getRequests()).hasSize(1); + LlmRequest request = testLlm.getRequests().get(0); + + assertThat(request.config().isPresent()).isTrue(); + var config = request.config().get(); + assertThat(config.tools().isPresent()).isFalse(); + } } diff --git a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java index 4a1dcf8e3..2d22ed3f1 100644 --- a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java +++ b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java @@ -49,17 +49,7 @@ public List getExamples(String query) { @Test public void buildFewShotFewShot_noExamples() { TestExampleProvider exampleProvider = new TestExampleProvider(ImmutableList.of()); - String expected = - """ - - Begin few-shot - The following are examples of user queries and model responses using the available tools. - - End few-shot - Now, try to follow these examples and complete the following conversation - \ - """; - assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEqualTo(expected); + assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java deleted file mode 100644 index 7d1615dc2..000000000 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; -import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class ExamplesTest { - - private static final InMemorySessionService sessionService = new InMemorySessionService(); - - private static class TestExampleProvider implements BaseExampleProvider { - @Override - public List getExamples(String query) { - return ImmutableList.of( - Example.builder() - .input(Content.fromParts(Part.fromText("input1"))) - .output( - ImmutableList.of( - Content.builder().parts(Part.fromText("output1")).role("model").build())) - .build()); - } - } - - @Test - public void processRequest_withExampleProvider_addsExamplesToInstructions() { - LlmAgent agent = - LlmAgent.builder().name("test-agent").exampleProvider(new TestExampleProvider()).build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isNotEmpty(); - assertThat(result.updatedRequest().getSystemInstructions().get(0)) - .contains("[user]\ninput1\n\n[model]\noutput1\n"); - } - - @Test - public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructions() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isEmpty(); - } -} diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index 4e80ed0ff..55e5d8f93 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -305,4 +305,30 @@ static final class WrongTypeProviderHolder { private WrongTypeProviderHolder() {} } + + @Test + public void declaration_isEmpty() { + ExampleTool tool = ExampleTool.builder().build(); + assertThat(tool.declaration().isPresent()).isFalse(); + } + + @Test + public void processLlmRequest_doesNotAddFunctionDeclarations() { + ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); + InvocationContext ctx = buildInvocationContext(); + LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); + + tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); + LlmRequest updated = builder.build(); + + if (updated.config().isPresent()) { + var config = updated.config().get(); + if (config.tools().isPresent()) { + var tools = config.tools().get(); + boolean hasFunctionDeclarations = + tools.stream().anyMatch(t -> t.functionDeclarations().isPresent()); + assertThat(hasFunctionDeclarations).isFalse(); + } + } + } } From 8556d4af16ff04c6e3b678dcfc3d4bb232abc550 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 06:59:16 -0700 Subject: [PATCH 09/29] feat: Propagating the otel context This change ensures that the OpenTelemetry context is correctly propagated across asynchronous boundaries throughout the ADK, primarily within RxJava streams. ### Key Changes * **Context Propagation:** Replaces manual `Scope` management (which often fails in reactive code) with `.compose(Tracing.withContext(context))`. This ensures the OTel context is preserved when work moves between different threads or schedulers. * **`Runner` Refactoring:** * Adds a top-level `"invocation"` span to `runAsync` and `runLive` calls. * Captures the context at entry points and propagates it through the internal execution flow (`runAsyncImpl`, `runLiveImpl`, `runAgentWithFreshSession`). * **`BaseLlmFlow` & `Functions`:** Updates preprocessing, postprocessing, and tool execution logic to maintain context. This ensures that spans created within tools or processors are correctly parented. * **`PluginManager`:** Ensures that plugin callbacks (like `afterRunCallback` and `onEventCallback`) execute within the captured context. * **Testing:** Adds several unit tests across `BaseLlmFlowTest`, `FunctionsTest`, `PluginManagerTest`, and `RunnerTest` that specifically verify context propagation using `ContextKey` and `Schedulers.computation()`. ### Files Modified * **`BaseLlmFlow.java`**, **`Functions.java`**, **`PluginManager.java`**, **`Runner.java`**: Core logic updates for context propagation. * **`LlmAgentTest.java`**, **`BaseLlmFlowTest.java`**, **`FunctionsTest.java`**, **`PluginManagerTest.java`**, **`RunnerTest.java`**: New tests for OTel integration. * **`BUILD` files**: Updated dependencies for OpenTelemetry APIs and SDK testing. PiperOrigin-RevId: 884998997 --- .../adk/flows/llmflows/BaseLlmFlow.java | 208 +++++++++++------- .../google/adk/flows/llmflows/Functions.java | 138 ++++++------ ...equestConfirmationLlmRequestProcessor.java | 13 +- .../com/google/adk/plugins/PluginManager.java | 15 +- .../java/com/google/adk/runner/Runner.java | 163 ++++++++------ .../com/google/adk/agents/LlmAgentTest.java | 9 +- .../adk/flows/llmflows/BaseLlmFlowTest.java | 76 ++++++- .../google/adk/plugins/PluginManagerTest.java | 85 +++++++ .../com/google/adk/runner/RunnerTest.java | 128 +++++++++++ 9 files changed, 601 insertions(+), 234 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ab5f6567a..e00cf0cbf 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -91,8 +91,9 @@ public BaseLlmFlow( * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the * events generated by them. */ - protected Flowable preprocess( + private Flowable preprocess( InvocationContext context, AtomicReference llmRequestRef) { + Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -114,6 +115,7 @@ protected Flowable preprocess( .concatMap( processor -> Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .compose(Tracing.withContext(currentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -128,7 +130,8 @@ protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + Context parentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); @@ -144,15 +147,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); - return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + parentContext) + .compose(Tracing.withContext(parentContext))); } /** @@ -163,54 +167,80 @@ protected Flowable postprocess( * @param eventForCallbackUsage An Event object primarily for providing context (like actions) to * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ - private Flowable callLlm( + private Flowable callLlm( Context spanContext, InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) .toFlowable() + .concatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + llmRequestBuilder.build(), + llmResp, + spanContext)) .switchIfEmpty( Flowable.defer( () -> { + LlmAgent agent = (LlmAgent) context.agent(); BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose( - Tracing.trace("call_llm") - .setParent(spanContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + Span span = + Tracing.getTracer() + .spanBuilder("call_llm") + .setParent(spanContext) + .startSpan(); + Context callLlmContext = spanContext.with(span); + + Flowable flowable = + llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + s -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + + return Tracing.traceFlowable(callLlmContext, span, () -> flowable); })); } @@ -222,6 +252,7 @@ private Flowable callLlm( */ private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -240,7 +271,11 @@ private Maybe handleBeforeModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -257,6 +292,7 @@ private Maybe handleOnModelErrorCallback( LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, Throwable throwable) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -277,7 +313,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(currentContext))) .firstElement(); }); @@ -292,6 +332,7 @@ private Maybe handleOnModelErrorCallback( */ private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -310,7 +351,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -330,7 +375,6 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex return Flowable.defer( () -> { - Context currentContext = Context.current(); return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( @@ -362,23 +406,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - String newId = Event.generateEventId(); - logger.debug( - "Resetting event ID from {} to {}", oldId, newId); - mutableEventTemplate.setId(newId); - }); - } + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug("Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); }) .concatMap( event -> { @@ -545,6 +578,10 @@ public void onError(Throwable e) { .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)); + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + Flowable receiveFlow = connection .receive() @@ -556,7 +593,8 @@ public void onError(Throwable e) { invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, - llmResponse); + llmResponse, + callLlmContext); }) .flatMap( event -> { @@ -592,7 +630,12 @@ public void onError(Throwable e) { } }); - return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))); })); } @@ -608,7 +651,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context parentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -624,21 +668,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); + Flowable functionEvents; + try (Scope scope = parentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + } return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index c1a996064..84a8141ea 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -42,7 +42,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Observable; @@ -163,7 +162,9 @@ public static Maybe handleFunctionCalls( } return functionResponseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -226,7 +227,9 @@ public static Maybe handleFunctionCallsLive( return responseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -243,47 +246,45 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope innerScope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs, - parentContext) - : callTool(tool, functionArgs, toolContext, parentContext); - } - })); - - return postProcessFunctionResult( - maybeFunctionResult, - invocationContext, - tool, - functionArgs, - toolContext, - isLive, - parentContext); - } - }); + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool( + tool, functionArgs, toolContext, parentContext)) + .compose(Tracing.withContext(parentContext))); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + }) + .compose(Tracing.withContext(parentContext)); } /** @@ -410,34 +411,27 @@ private static Maybe postProcessFunctionResult( }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]") - .setParent(parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } - }); + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event event = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + Tracing.traceToolResponse(event.id(), event); + return Maybe.just(event); + }); + }) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index e00c0093d..a93eb3cb4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -29,6 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.ToolConfirmation; import com.google.adk.models.LlmRequest; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -37,6 +38,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; @@ -216,10 +218,13 @@ private Maybe assembleEvent( .build()) .build(); - return toolsMapSingle.flatMapMaybe( - toolsMap -> - Functions.handleFunctionCalls( - invocationContext, functionCallEvent, toolsMap, toolConfirmations)); + Context parentContext = Context.current(); + return toolsMapSingle + .flatMapMaybe( + toolsMap -> + Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsMap, toolConfirmations)) + .compose(Tracing.withContext(parentContext)); } private static Optional> maybeCreateToolConfirmationEntry( diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e534da787..8d0366e9a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -126,6 +128,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -136,11 +139,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e))) + .compose(Tracing.withContext(capturedContext)); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -149,7 +154,8 @@ public Completable close() { .doOnError( e -> logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + "[{}] Error during callback 'close'", plugin.getName(), e))) + .compose(Tracing.withContext(capturedContext)); } @Override @@ -227,7 +233,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -247,6 +253,7 @@ private Maybe runMaybeCallbacks( plugin.getName(), callbackName, e))) - .firstElement(); + .firstElement() + .compose(Tracing.withContext(capturedContext)); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 5859c4786..51e1b8f25 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -375,20 +376,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -441,7 +447,8 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); + return runAsyncImpl(session, newMessage, runConfig, stateDelta) + .compose(Tracing.trace("invocation")); } /** @@ -461,6 +468,7 @@ protected Flowable runAsyncImpl( Preconditions.checkNotNull(session, "session cannot be null"); Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); + Context capturedContext = Context.current(); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -476,6 +484,7 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -500,7 +509,8 @@ protected Flowable runAsyncImpl( event, invocationId, runConfig, - rootAgent)); + rootAgent)) + .compose(Tracing.withContext(capturedContext)); }); }) .doOnError( @@ -508,8 +518,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -562,12 +571,14 @@ private Flowable runAgentWithFreshSession( .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent + Context capturedContext = Context.current(); return beforeRunEvent .toFlowable() .switchIfEmpty(agentEvents) .concatWith( Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) - .concatWith(Completable.defer(() -> compactEvents(updatedSession))); + .concatWith(Completable.defer(() -> compactEvents(updatedSession))) + .compose(Tracing.withContext(capturedContext)); } private Completable compactEvents(Session session) { @@ -632,46 +643,9 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(this.findAgentToRun(session, rootAgent)); } - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); } /** @@ -682,19 +656,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** @@ -708,6 +688,49 @@ public Flowable runLive( return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); } + /** + * Runs the agent in live mode, appending generated events to the session. + * + * @return stream of events from the agent. + */ + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return Flowable.defer( + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.withContext(capturedContext)); + }); + } + /** * Runs the agent asynchronously with a default user ID. * diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 3524c7755..c193e4a65 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -574,8 +574,13 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + + // The tool calls and responses are children of the first LLM call that produced the function + // call. + String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); + toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach( + s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 4a0b345c6..6cae6c88a 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -43,9 +43,13 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; import java.util.Map; import java.util.Optional; @@ -572,6 +576,71 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + RequestProcessor requestProcessor = + (ctx, request) -> { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + ResponseProcessor responseProcessor = + (ctx, response) -> { + return Single.just(ResponseProcessingResult.create(response, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + return Maybe.empty().subscribeOn(Schedulers.computation()); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + return Maybe.just(resp).subscribeOn(Schedulers.computation()); + }; + + Callbacks.OnModelErrorCallback onErrorCallback = + (ctx, req, err) -> { + return Maybe.just( + LlmResponse.builder().content(Content.fromParts(Part.fromText("error"))).build()) + .subscribeOn(Schedulers.computation()); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .onModelErrorCallback(onErrorCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow(ImmutableList.of(requestProcessor), ImmutableList.of(responseProcessor)); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + baseLlmFlow + .run(invocationContext) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(content); + } + @Test public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { GenerateContentResponseUsageMetadata usageMetadata = @@ -588,7 +657,12 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { List events = baseLlmFlow - .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .postprocess( + invocationContext, + baseEvent, + LlmRequest.builder().build(), + llmResponse, + Context.current()) .toList() .blockingGet(); diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4ae856fc7..3771143cf 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -37,8 +37,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -144,6 +148,87 @@ public void onUserMessageCallback_pluginOrderRespected() { inOrder.verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } + @Test + public void contextPropagation_runMaybeCallbacks() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content expectedContent = Content.builder().build(); + when(plugin1.onUserMessageCallback(any(), any())) + .thenReturn(Maybe.just(expectedContent).subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Maybe resultMaybe; + try (Scope scope = testContext.makeCurrent()) { + resultMaybe = pluginManager.onUserMessageCallback(mockInvocationContext, content); + } + + // Assert downstream operators have the propagated context + resultMaybe + .doOnSuccess( + result -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(expectedContent); + + verify(plugin1).onUserMessageCallback(mockInvocationContext, content); + } + + @Test + public void contextPropagation_afterRunCallback() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.afterRunCallback(any())) + .thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.afterRunCallback(mockInvocationContext); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).afterRunCallback(mockInvocationContext); + } + + @Test + public void contextPropagation_close() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.close()).thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.close(); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).close(); + } + @Test public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 8a0a84b08..2eb515fa2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -57,6 +57,9 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Completable; @@ -977,6 +980,84 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void runAsync_createsToolSpansWithCorrectParent() { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerWithTool + .runAsync( + sessionWithTool.sessionKey(), + createContent("from user"), + RunConfig.builder().build()) + .toList() + .blockingGet(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + + @Test + public void runLive_createsToolSpansWithCorrectParent() throws Exception { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runnerWithTool + .runLive(sessionWithTool.sessionKey(), liveRequestQueue, RunConfig.builder().build()) + .test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + // In runLive, there is one call_llm span for the execution + assertThat(llmSpans).hasSize(1); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1188,6 +1269,53 @@ public void close_closesPluginsAndCodeExecutors() { verify(plugin).close(); } + @Test + public void runAsync_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + runner + .runAsync("user", session.id(), createContent("test message")) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + + @Test + public void runLive_contextPropagation() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber; + try (Scope scope = testContext.makeCurrent()) { + testSubscriber = + runner + .runLive(session, liveRequestQueue, RunConfig.builder().build()) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test(); + } + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void buildRunnerWithPlugins_success() { BasePlugin plugin1 = mockPlugin("test1"); From 2fcff3c30f5d0af4b4007e821af1f204e555fef9 Mon Sep 17 00:00:00 2001 From: Taylor Lanclos Date: Wed, 4 Mar 2026 19:12:16 +0000 Subject: [PATCH 10/29] Extract timestamp as double for InMemorySessionService events InMemorySessionService sets a Session's last modified time based on when the last appended event's timestamp. The timestamp in an event is recorded in millis while the Session's timestamp is an Instant. During the transformation, Events perform this converstion using division. Before this change, the timestamp was truncated to the second, yet the code was trying to extract nanos which were always 0. This fixes that bug with a simple type change. I've also added a test to prevent regressions. --- .../adk/sessions/InMemorySessionService.java | 13 ++----------- .../sessions/InMemorySessionServiceTest.java | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..d9bb047a3 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -154,19 +154,13 @@ public Maybe getSession( if (config.numRecentEvents().isEmpty() && config.afterTimestamp().isPresent()) { Instant threshold = config.afterTimestamp().get(); - eventsInCopy.removeIf( - event -> getEventTimestampEpochSeconds(event) < threshold.getEpochSecond()); + eventsInCopy.removeIf(event -> getInstantFromEvent(event).isBefore(threshold)); } // Merge state into the potentially filtered copy and return return Maybe.just(mergeWithGlobalState(appName, userId, sessionCopy)); } - // Helper to get event timestamp as epoch seconds - private long getEventTimestampEpochSeconds(Event event) { - return event.timestamp() / 1000L; - } - @Override public Single listSessions(String appName, String userId) { Objects.requireNonNull(appName, "appName cannot be null"); @@ -294,10 +288,7 @@ public Single appendEvent(Session session, Event event) { /** Converts an event's timestamp to an Instant. Adapt based on actual Event structure. */ // TODO: have Event.timestamp() return Instant directly private Instant getInstantFromEvent(Event event) { - double epochSeconds = getEventTimestampEpochSeconds(event); - long seconds = (long) epochSeconds; - long nanos = (long) ((epochSeconds % 1.0) * 1_000_000_000L); - return Instant.ofEpochSecond(seconds, nanos); + return Instant.ofEpochMilli(event.timestamp()); } /** diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..0d9235b1b 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.time.Instant; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -214,6 +215,24 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_updatesSessionTimestampWithFractionalSeconds() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Add an event with a timestamp that contains a fractional second + Event eventAdd = Event.builder().timestamp(5500).build(); + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify the last modified timestamp contains a fractional second + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.lastUpdateTime()).isEqualTo(Instant.ofEpochSecond(5, 500000000L)); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); From 4eb3613b65cb1334e9432960d0f864ef09829c23 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 10:27:33 -0700 Subject: [PATCH 11/29] fix: improve processRequest_concurrentReadAndWrite_noException test case PiperOrigin-RevId: 885091550 --- .../adk/flows/llmflows/ContentsTest.java | 89 ++++++++++--------- 1 file changed, 48 insertions(+), 41 deletions(-) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 7164991f3..1e6267dde 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,13 +36,15 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.HashMap; +import java.util.ArrayList; +import java.util.ConcurrentModificationException; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -787,15 +789,49 @@ public void processRequest_notEmptyContent() { public void processRequest_concurrentReadAndWrite_noException() throws Exception { LlmAgent agent = LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + List customEvents = + new ArrayList() { + private void checkLock() { + if (!Thread.holdsLock(this)) { + throw new ConcurrentModificationException("Unsynchronized iteration detected!"); + } + } + + @Override + public Iterator iterator() { + checkLock(); + return super.iterator(); + } + + @Override + public ListIterator listIterator() { + checkLock(); + return super.listIterator(); + } + + @Override + public ListIterator listIterator(int index) { + checkLock(); + return super.listIterator(index); + } + + @Override + public Stream stream() { + checkLock(); + return super.stream(); + } + }; + Session session = - sessionService - .createSession("test-app", "test-user", new HashMap<>(), "test-session") - .blockingGet(); + Session.builder("test-session") + .appName("test-app") + .userId("test-user") + .events(customEvents) + .build(); - // Seed with dummy events to widen the race capability - for (int i = 0; i < 5000; i++) { - session.events().add(createUserEvent("dummy" + i, "dummy")); - } + // The list must have at least one element so that operations interacting with events trigger + // iteration. + customEvents.add(createUserEvent("dummy", "dummy")); InvocationContext context = InvocationContext.builder() @@ -807,37 +843,8 @@ public void processRequest_concurrentReadAndWrite_noException() throws Exception LlmRequest initialRequest = LlmRequest.builder().build(); - AtomicReference writerError = new AtomicReference<>(); - CountDownLatch startLatch = new CountDownLatch(1); - - Thread writerThread = - new Thread( - () -> { - startLatch.countDown(); - try { - for (int i = 0; i < 2000; i++) { - session.events().add(createUserEvent("writer" + i, "new data")); - } - } catch (Throwable t) { - writerError.set(t); - } - }); - - writerThread.start(); - startLatch.await(); // wait for writer to be ready - - // Process (read) requests concurrently to trigger race conditions - for (int i = 0; i < 200; i++) { - var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } - } - - writerThread.join(); - if (writerError.get() != null) { - throw new RuntimeException("Writer failed", writerError.get()); - } + // This single call will throw the exception if the list is accessed insecurely. + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); } private static Event createUserEvent(String id, String text) { From c8ab0f96b09a6c9636728d634c62695fcd622246 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 17 Mar 2026 11:50:04 -0700 Subject: [PATCH 12/29] feat: Implement basic version of BigQuery Agent Analytics Plugin This change introduces a new plugin for the Agent Development Kit (ADK) that logs agent execution events to BigQuery. It includes: - `BigQueryAgentAnalyticsPlugin`: A plugin that captures various agent lifecycle events (user messages, tool calls, model invocations) and sends them to BigQuery. - `BigQueryLoggerConfig`: Configuration options for the plugin, including project/dataset/table IDs, batching, and retry settings. - `BigQuerySchema`: Defines the BigQuery and Arrow schemas used for the event table. - `BatchProcessor`: Handles batching of events and writing them to BigQuery using the Storage Write API with Arrow format. - `JsonFormatter`: Utility for safely formatting JSON content for BigQuery. PiperOrigin-RevId: 885133967 --- core/pom.xml | 20 + .../agentanalytics/BatchProcessor.java | 270 +++++++++++ .../BigQueryAgentAnalyticsPlugin.java | 436 +++++++++++++++++ .../agentanalytics/BigQueryLoggerConfig.java | 204 ++++++++ .../agentanalytics/BigQuerySchema.java | 304 ++++++++++++ .../plugins/agentanalytics/JsonFormatter.java | 111 +++++ .../agentanalytics/BatchProcessorTest.java | 367 ++++++++++++++ .../BigQueryAgentAnalyticsPluginTest.java | 457 ++++++++++++++++++ pom.xml | 11 + 9 files changed, 2180 insertions(+) create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java create mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java create mode 100644 core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java diff --git a/core/pom.xml b/core/pom.xml index 8c3c2069c..b3f2f5fd8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -197,6 +197,26 @@ opentelemetry-sdk-testing test + + com.google.cloud + google-cloud-bigquery + 2.40.0 + + + org.apache.arrow + arrow-vector + 17.0.0 + + + org.apache.arrow + arrow-memory-core + 17.0.0 + + + org.apache.arrow + arrow-memory-netty + 17.0.0 + diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java new file mode 100644 index 000000000..ef826fb56 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -0,0 +1,270 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializationError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Handles asynchronous batching and writing of events to BigQuery. */ +class BatchProcessor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + + private final StreamWriter writer; + private final int batchSize; + private final Duration flushInterval; + @VisibleForTesting final BlockingQueue> queue; + private final ScheduledExecutorService executor; + @VisibleForTesting final BufferAllocator allocator; + final AtomicBoolean flushLock = new AtomicBoolean(false); + private final Schema arrowSchema; + private final VectorSchemaRoot root; + + public BatchProcessor( + StreamWriter writer, + int batchSize, + Duration flushInterval, + int queueMaxSize, + ScheduledExecutorService executor) { + this.writer = writer; + this.batchSize = batchSize; + this.flushInterval = flushInterval; + this.queue = new LinkedBlockingQueue<>(queueMaxSize); + this.executor = executor; + // It's safe to use Long.MAX_VALUE here as this is a top-level RootAllocator, + // and memory is properly managed via try-with-resources in the flush() method. + // The actual memory usage is bounded by the batchSize and individual row sizes. + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.arrowSchema = BigQuerySchema.getArrowSchema(); + this.root = VectorSchemaRoot.create(arrowSchema, allocator); + } + + public void start() { + @SuppressWarnings("unused") + var unused = + executor.scheduleWithFixedDelay( + () -> { + try { + flush(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Error in background flush", e); + } + }, + flushInterval.toMillis(), + flushInterval.toMillis(), + MILLISECONDS); + } + + public void append(Map row) { + if (!queue.offer(row)) { + logger.warning("BigQuery event queue is full, dropping event."); + return; + } + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + + public void flush() { + // Acquire the flushLock. If another flush is already in progress, return immediately. + if (!flushLock.compareAndSet(false, true)) { + return; + } + try { + if (queue.isEmpty()) { + return; + } + List> batch = new ArrayList<>(); + queue.drainTo(batch, batchSize); + if (batch.isEmpty()) { + return; + } + try { + root.allocateNew(); + for (int i = 0; i < batch.size(); i++) { + Map row = batch.get(i); + for (Field field : arrowSchema.getFields()) { + populateVector(root.getVector(field.getName()), i, row.get(field.getName())); + } + } + root.setRowCount(batch.size()); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + AppendRowsResponse result = writer.append(recordBatch).get(); + if (result.hasError()) { + logger.severe("BigQuery append error: " + result.getError().getMessage()); + for (var error : result.getRowErrorsList()) { + logger.severe( + String.format("Row error at index %d: %s", error.getIndex(), error.getMessage())); + } + } else { + logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); + } + } catch (AppendSerializationError ase) { + logger.log( + Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); + Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); + if (rowIndexToErrorMessage != null && !rowIndexToErrorMessage.isEmpty()) { + logger.severe("Row-level errors found:"); + for (Map.Entry entry : rowIndexToErrorMessage.entrySet()) { + logger.severe( + String.format("Row error at index %d: %s", entry.getKey(), entry.getValue())); + } + } else { + logger.severe( + "AppendSerializationError occurred, but no row-specific errors were provided."); + } + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); + } finally { + // Clear the vectors to release the memory. + root.clear(); + } + } finally { + flushLock.set(false); + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + } + + private void populateVector(FieldVector vector, int index, Object value) { + if (value == null || (value instanceof JsonNode jsonNode && jsonNode.isNull())) { + vector.setNull(index); + return; + } + if (vector instanceof VarCharVector varCharVector) { + String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + varCharVector.setSafe(index, strValue.getBytes(UTF_8)); + } else if (vector instanceof BigIntVector bigIntVector) { + long longValue; + if (value instanceof JsonNode jsonNode) { + longValue = jsonNode.asLong(); + } else if (value instanceof Number number) { + longValue = number.longValue(); + } else { + longValue = Long.parseLong(value.toString()); + } + bigIntVector.setSafe(index, longValue); + } else if (vector instanceof BitVector bitVector) { + boolean boolValue = + (value instanceof JsonNode jsonNode) ? jsonNode.asBoolean() : (Boolean) value; + bitVector.setSafe(index, boolValue ? 1 : 0); + } else if (vector instanceof TimeStampVector timeStampVector) { + if (value instanceof Instant instant) { + long micros = + SECONDS.toMicros(instant.getEpochSecond()) + NANOSECONDS.toMicros(instant.getNano()); + timeStampVector.setSafe(index, micros); + } else if (value instanceof JsonNode jsonNode) { + timeStampVector.setSafe(index, jsonNode.asLong()); + } else if (value instanceof Long longValue) { + timeStampVector.setSafe(index, longValue); + } + } else if (vector instanceof ListVector listVector) { + int start = listVector.startNewValue(index); + if (value instanceof ArrayNode arrayNode) { + for (int i = 0; i < arrayNode.size(); i++) { + populateVector(listVector.getDataVector(), start + i, arrayNode.get(i)); + } + listVector.endValue(index, arrayNode.size()); + } else if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + populateVector(listVector.getDataVector(), start + i, list.get(i)); + } + listVector.endValue(index, list.size()); + } + } else if (vector instanceof StructVector structVector) { + structVector.setIndexDefined(index); + if (value instanceof ObjectNode objectNode) { + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, objectNode.get(child.getName())); + } + } else if (value instanceof Map) { + Map map = (Map) value; + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, map.get(child.getName())); + } + } + } + } + + @Override + public void close() { + if (this.queue != null && !this.queue.isEmpty()) { + this.flush(); + } + if (this.allocator != null) { + try { + this.allocator.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close Buffer allocator", e); + } + } + if (this.root != null) { + try { + this.root.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close VectorSchemaRoot", e); + } + } + if (this.writer != null) { + try { + this.writer.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close BigQuery writer", e); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java new file mode 100644 index 000000000..68b5fb5a1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -0,0 +1,436 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * BigQuery Agent Analytics Plugin for Java. + * + *

Logs agent execution events directly to a BigQuery table using the Storage Write API. + */ +public class BigQueryAgentAnalyticsPlugin extends BasePlugin { + private static final Logger logger = + Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + private static final ImmutableList DEFAULT_AUTH_SCOPES = + ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); + private static final AtomicLong threadCounter = new AtomicLong(0); + + private final BigQueryLoggerConfig config; + private final BigQuery bigQuery; + private final BigQueryWriteClient writeClient; + private final ScheduledExecutorService executor; + private final Object tableEnsuredLock = new Object(); + @VisibleForTesting final BatchProcessor batchProcessor; + private volatile boolean tableEnsured = false; + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { + this(config, createBigQuery(config)); + } + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) + throws IOException { + super("bigquery_agent_analytics"); + this.config = config; + this.bigQuery = bigQuery; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.writeClient = createWriteClient(config); + + if (config.enabled()) { + StreamWriter writer = createWriter(config); + this.batchProcessor = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + this.batchProcessor.start(); + } else { + this.batchProcessor = null; + } + } + + private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { + BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + if (config.credentials() != null) { + builder.setCredentials(config.credentials()); + } else { + builder.setCredentials( + GoogleCredentials.getApplicationDefault().createScoped(DEFAULT_AUTH_SCOPES)); + } + return builder.build().getService(); + } + + private void ensureTableExistsOnce() { + if (!tableEnsured) { + synchronized (tableEnsuredLock) { + if (!tableEnsured) { + // Table creation is expensive, so we only do it once per plugin instance. + tableEnsured = true; + ensureTableExists(bigQuery, config); + } + } + } + } + + private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { + TableId tableId = TableId.of(config.projectId(), config.datasetId(), config.tableName()); + Schema schema = BigQuerySchema.getEventsSchema(); + try { + Table table = bigQuery.getTable(tableId); + logger.info("BigQuery table: " + tableId); + if (table == null) { + logger.info("Creating BigQuery table: " + tableId); + StandardTableDefinition.Builder tableDefinitionBuilder = + StandardTableDefinition.newBuilder().setSchema(schema); + if (!config.clusteringFields().isEmpty()) { + tableDefinitionBuilder.setClustering( + Clustering.newBuilder().setFields(config.clusteringFields()).build()); + } + TableInfo tableInfo = TableInfo.newBuilder(tableId, tableDefinitionBuilder.build()).build(); + bigQuery.create(tableInfo); + } else if (config.autoSchemaUpgrade()) { + // TODO(b/491851868): Implement auto-schema upgrade. + logger.info("BigQuery table already exists and auto-schema upgrade is enabled: " + tableId); + logger.info("Auto-schema upgrade is not implemented yet."); + } + } catch (BigQueryException e) { + if (e.getMessage().contains("invalid_grant")) { + logger.log( + Level.SEVERE, + "Failed to authenticate with BigQuery. Please run 'gcloud auth application-default" + + " login' to refresh your credentials or provide valid credentials in" + + " BigQueryLoggerConfig.", + e); + } else { + logger.log( + Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Optional callbackContext, + Object content, + Map extraAttributes) { + if (batchProcessor == null) { + return; + } + + ensureTableExistsOnce(); + + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", eventType); + row.put( + "agent", + callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("session_id", invocationContext.session().id()); + row.put("invocation_id", invocationContext.invocationId()); + row.put("user_id", invocationContext.userId()); + + if (content instanceof Content contentParts) { + row.put( + "content_parts", + JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } else if (content != null) { + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } + + Map attributes = new HashMap<>(config.customTags()); + if (extraAttributes != null) { + attributes.putAll(extraAttributes); + } + row.put( + "attributes", + JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + + addTraceDetails(row); + batchProcessor.append(row); + } + + // TODO(b/491849911): Implement own trace management functionality. + private void addTraceDetails(Map row) { + SpanContext spanContext = Span.current().getSpanContext(); + if (spanContext.isValid()) { + row.put("trace_id", spanContext.getTraceId()); + row.put("span_id", spanContext.getSpanId()); + } + } + + @Override + public Completable close() { + if (batchProcessor != null) { + batchProcessor.close(); + } + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + return Completable.complete(); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.fromAction( + () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.fromAction( + () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("event_author", event.author()); + logEvent( + "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + }); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.fromAction( + () -> { + logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + batchProcessor.flush(); + }); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_START", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_END", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + LlmRequest req = llmRequest.build(); + attrs.put("model", req.model().orElse("unknown")); + logEvent( + "MODEL_REQUEST", + callbackContext.invocationContext(), + Optional.of(callbackContext), + req, + attrs); + }); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + logEvent( + "MODEL_RESPONSE", + callbackContext.invocationContext(), + Optional.of(callbackContext), + llmResponse, + attrs); + }); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("error_message", error.getMessage()); + logEvent( + "MODEL_ERROR", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + attrs); + }); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_START", + toolContext.invocationContext(), + Optional.of(toolContext), + toolArgs, + attrs); + }); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + }); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + attrs.put("error_message", error.getMessage()); + logEvent( + "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java new file mode 100644 index 000000000..aa5bf37de --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.auth.Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** Configuration for the BigQueryAgentAnalyticsPlugin. */ +@AutoValue +public abstract class BigQueryLoggerConfig { + // Whether the plugin is enabled. + public abstract boolean enabled(); + + // List of event types to log. If None, all are allowed + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventAllowlist(); + + // List of event types to ignore. + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventDenylist(); + + // Max length for text content before truncation. + public abstract int maxContentLength(); + + // Project ID for the BigQuery table. + public abstract String projectId(); + + // Dataset ID for the BigQuery table. + public abstract String datasetId(); + + // Table name for the BigQuery table. + public abstract String tableName(); + + // Fields to cluster the table by. + public abstract ImmutableList clusteringFields(); + + // Whether to log multi-modal content. + // TODO(b/491852782): Implement logging of multi-modal content. + public abstract boolean logMultiModalContent(); + + // Retry configuration for BigQuery writes. + public abstract RetryConfig retryConfig(); + + // Number of rows to batch before flushing. + public abstract int batchSize(); + + // Duration to wait before flushing the queue. + public abstract Duration batchFlushInterval(); + + // Max time to wait for shutdown. + public abstract Duration shutdownTimeout(); + + // Max size of the batch processor queue. + public abstract int queueMaxSize(); + + // Optional custom formatter for content. + // TODO(b/491852782): Implement content formatter. + @Nullable + public abstract BiFunction contentFormatter(); + + // TODO(b/491852782): Implement connection id. + public abstract Optional connectionId(); + + // Toggle for session metadata (e.g. gchat thread-id). + // TODO(b/491852782): Implement logging of session metadata. + public abstract boolean logSessionMetadata(); + + // Static custom tags (e.g. {"agent_role": "sales"}). + // TODO(b/491852782): Implement custom tags. + public abstract ImmutableMap customTags(); + + // Automatically add new columns to existing tables when the plugin + // schema evolves. Only additive changes are made (columns are never + // dropped or altered). + // TODO(b/491852782): Implement auto-schema upgrade. + public abstract boolean autoSchemaUpgrade(); + + @Nullable + public abstract Credentials credentials(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig.Builder() + .setEnabled(true) + .setMaxContentLength(500 * 1024) + .setDatasetId("agent_analytics") + .setTableName("events") + .setClusteringFields(ImmutableList.of("event_type", "agent", "user_id")) + .setLogMultiModalContent(true) + .setRetryConfig(RetryConfig.builder().build()) + .setBatchSize(1) + .setBatchFlushInterval(Duration.ofSeconds(1)) + .setShutdownTimeout(Duration.ofSeconds(10)) + .setQueueMaxSize(10000) + .setLogSessionMetadata(true) + .setCustomTags(ImmutableMap.of()) + // TODO(b/491851868): Enable auto-schema upgrade once implemented. + .setAutoSchemaUpgrade(false); + } + + /** Builder for {@link BigQueryLoggerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEnabled(boolean enabled); + + public abstract Builder setEventAllowlist(@Nullable List eventAllowlist); + + public abstract Builder setEventDenylist(@Nullable List eventDenylist); + + public abstract Builder setMaxContentLength(int maxContentLength); + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setTableName(String tableName); + + public abstract Builder setClusteringFields(List clusteringFields); + + public abstract Builder setLogMultiModalContent(boolean logMultiModalContent); + + public abstract Builder setRetryConfig(RetryConfig retryConfig); + + public abstract Builder setBatchSize(int batchSize); + + public abstract Builder setBatchFlushInterval(Duration batchFlushInterval); + + public abstract Builder setShutdownTimeout(Duration shutdownTimeout); + + public abstract Builder setQueueMaxSize(int queueMaxSize); + + public abstract Builder setContentFormatter( + @Nullable BiFunction contentFormatter); + + public abstract Builder setConnectionId(String connectionId); + + public abstract Builder setLogSessionMetadata(boolean logSessionMetadata); + + public abstract Builder setCustomTags(Map customTags); + + public abstract Builder setAutoSchemaUpgrade(boolean autoSchemaUpgrade); + + public abstract Builder setCredentials(Credentials credentials); + + public abstract BigQueryLoggerConfig build(); + } + + /** Retry configuration for BigQuery writes. */ + @AutoValue + public abstract static class RetryConfig { + public abstract int maxRetries(); + + public abstract Duration initialDelay(); + + public abstract double multiplier(); + + public abstract Duration maxDelay(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig_RetryConfig.Builder() + .setMaxRetries(3) + .setInitialDelay(Duration.ofSeconds(1)) + .setMultiplier(2.0) + .setMaxDelay(Duration.ofSeconds(10)); + } + + /** Builder for {@link RetryConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMaxRetries(int maxRetries); + + public abstract Builder setInitialDelay(Duration initialDelay); + + public abstract Builder setMultiplier(double multiplier); + + public abstract Builder setMaxDelay(Duration maxDelay); + + public abstract RetryConfig build(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java new file mode 100644 index 000000000..81181a1e0 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java @@ -0,0 +1,304 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema.Mode; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** Utility for defining the BigQuery events table schema. */ +public final class BigQuerySchema { + + private BigQuerySchema() {} + + private static final ImmutableMap> + FIELD_TYPE_TO_ARROW_FIELD_METADATA = + ImmutableMap.of( + StandardSQLTypeName.JSON, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:json"), + StandardSQLTypeName.DATETIME, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:datetime"), + StandardSQLTypeName.GEOGRAPHY, + ImmutableMap.of( + "ARROW:extension:name", + "google:sqlType:geography", + "ARROW:extension:metadata", + "{\"encoding\": \"WKT\"}")); + + /** Returns the BigQuery schema for the events table. */ + // TODO(b/491848381): Rely on the same schema defined for python plugin. + public static Schema getEventsSchema() { + return Schema.of( + Field.newBuilder("timestamp", StandardSQLTypeName.TIMESTAMP) + .setMode(Field.Mode.REQUIRED) + .setDescription("The UTC timestamp when the event occurred.") + .build(), + Field.newBuilder("event_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The category of the event.") + .build(), + Field.newBuilder("agent", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The name of the agent that generated this event.") + .build(), + Field.newBuilder("session_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for the entire conversation session.") + .build(), + Field.newBuilder("invocation_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for a single turn or execution.") + .build(), + Field.newBuilder("user_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The identifier of the end-user.") + .build(), + Field.newBuilder("trace_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry trace ID.") + .build(), + Field.newBuilder("span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry span ID.") + .build(), + Field.newBuilder("parent_span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry parent span ID.") + .build(), + Field.newBuilder("content", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("The primary payload of the event.") + .build(), + Field.newBuilder( + "content_parts", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("mime_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The MIME type of the content part.") + .build(), + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The URI of the content part if stored externally.") + .build(), + Field.newBuilder( + "object_ref", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("version", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("authorizer", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("details", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .build())) + .setMode(Field.Mode.NULLABLE) + .setDescription("The ObjectRef of the content part if stored externally.") + .build(), + Field.newBuilder("text", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The raw text content.") + .build(), + Field.newBuilder("part_index", StandardSQLTypeName.INT64) + .setMode(Field.Mode.NULLABLE) + .setDescription("The zero-based index of this part.") + .build(), + Field.newBuilder("part_attributes", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Additional metadata as a JSON object string.") + .build(), + Field.newBuilder("storage_mode", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates how the content part is stored.") + .build())) + .setMode(Field.Mode.REPEATED) + .setDescription("Multi-modal events content parts.") + .build(), + Field.newBuilder("attributes", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing arbitrary key-value pairs.") + .build(), + Field.newBuilder("latency_ms", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing latency measurements.") + .build(), + Field.newBuilder("status", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The outcome of the event.") + .build(), + Field.newBuilder("error_message", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Detailed error message if the status is 'ERROR'.") + .build(), + Field.newBuilder("is_truncated", StandardSQLTypeName.BOOL) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates if the 'content' field was truncated.") + .build()); + } + + /** Returns the Arrow schema for the events table. */ + public static org.apache.arrow.vector.types.pojo.Schema getArrowSchema() { + return new org.apache.arrow.vector.types.pojo.Schema( + getEventsSchema().getFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList())); + } + + /** Returns the serialized Arrow schema for the events table. */ + public static ByteString getSerializedArrowSchema() { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), getArrowSchema()); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new VerifyException("Failed to serialize arrow schema", e); + } + } + + private static org.apache.arrow.vector.types.pojo.Field convertToArrowField(Field field) { + ArrowType arrowType = convertTypeToArrow(field.getType().getStandardType()); + ImmutableList children = null; + if (field.getSubFields() != null) { + children = + field.getSubFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList()); + } + + ImmutableMap metadata = + FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(field.getType().getStandardType()); + + FieldType fieldType = + new FieldType(field.getMode() != Field.Mode.REQUIRED, arrowType, null, metadata); + org.apache.arrow.vector.types.pojo.Field arrowField = + new org.apache.arrow.vector.types.pojo.Field(field.getName(), fieldType, children); + + if (field.getMode() == Field.Mode.REPEATED) { + return new org.apache.arrow.vector.types.pojo.Field( + field.getName(), + new FieldType(false, new ArrowType.List(), null), + ImmutableList.of( + new org.apache.arrow.vector.types.pojo.Field( + "element", arrowField.getFieldType(), arrowField.getChildren()))); + } + return arrowField; + } + + private static ArrowType convertTypeToArrow(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> new ArrowType.Bool(); + case BYTES -> new ArrowType.Binary(); + case DATE -> new ArrowType.Date(DateUnit.DAY); + case DATETIME -> + // Arrow doesn't have a direct DATETIME, often mapped to Timestamp or Utf8 + new ArrowType.Utf8(); + case FLOAT64 -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case INT64 -> new ArrowType.Int(64, true); + case NUMERIC, BIGNUMERIC -> new ArrowType.Decimal(38, 9, 128); + case GEOGRAPHY, STRING, JSON -> new ArrowType.Utf8(); + case STRUCT -> new ArrowType.Struct(); + case TIME -> new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case TIMESTAMP -> new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + default -> new ArrowType.Null(); + }; + } + + /** Returns names of fields to cluster by default. */ + public static ImmutableList getDefaultClusteringFields() { + return ImmutableList.of("event_type", "agent", "user_id"); + } + + /** Returns the BigQuery TableSchema for the events table (Storage Write API). */ + public static TableSchema getEventsTableSchema() { + return convertTableSchema(getEventsSchema()); + } + + private static TableSchema convertTableSchema(Schema schema) { + TableSchema.Builder result = TableSchema.newBuilder(); + for (int i = 0; i < schema.getFields().size(); i++) { + result.addFields(i, convertFieldSchema(schema.getFields().get(i))); + } + return result.build(); + } + + private static TableFieldSchema convertFieldSchema(Field field) { + TableFieldSchema.Builder result = TableFieldSchema.newBuilder(); + Field.Mode mode = field.getMode() != null ? field.getMode() : Field.Mode.NULLABLE; + + Mode resultMode = Mode.valueOf(mode.name()); + result.setMode(resultMode).setName(field.getName()); + + StandardSQLTypeName standardType = field.getType().getStandardType(); + TableFieldSchema.Type resultType = convertType(standardType); + result.setType(resultType); + + if (field.getDescription() != null) { + result.setDescription(field.getDescription()); + } + if (field.getSubFields() != null) { + for (int i = 0; i < field.getSubFields().size(); i++) { + result.addFields(i, convertFieldSchema(field.getSubFields().get(i))); + } + } + return result.build(); + } + + private static TableFieldSchema.Type convertType(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> TableFieldSchema.Type.BOOL; + case BYTES -> TableFieldSchema.Type.BYTES; + case DATE -> TableFieldSchema.Type.DATE; + case DATETIME -> TableFieldSchema.Type.DATETIME; + case FLOAT64 -> TableFieldSchema.Type.DOUBLE; + case GEOGRAPHY -> TableFieldSchema.Type.GEOGRAPHY; + case INT64 -> TableFieldSchema.Type.INT64; + case NUMERIC -> TableFieldSchema.Type.NUMERIC; + case STRING -> TableFieldSchema.Type.STRING; + case STRUCT -> TableFieldSchema.Type.STRUCT; + case TIME -> TableFieldSchema.Type.TIME; + case TIMESTAMP -> TableFieldSchema.Type.TIMESTAMP; + case BIGNUMERIC -> TableFieldSchema.Type.BIGNUMERIC; + case JSON -> TableFieldSchema.Type.JSON; + case INTERVAL -> TableFieldSchema.Type.INTERVAL; + default -> TableFieldSchema.Type.TYPE_UNSPECIFIED; + }; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java new file mode 100644 index 000000000..b4b4a1049 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +/** Utility for formatting and truncating content for BigQuery logging. */ +final class JsonFormatter { + private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + + private JsonFormatter() {} + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + public static ArrayNode formatContentParts(Optional content, int maxLength) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty() || content.get().parts() == null) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncateString(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", "[BINARY DATA]"); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } + + /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ + public static JsonNode smartTruncate(Object obj, int maxLength) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); + } + } + + private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + if (node.isTextual()) { + return mapper.valueToTree(truncateString(node.asText(), maxLength)); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + node.properties() + .iterator() + .forEachRemaining( + entry -> { + newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); + }); + return newNode; + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + newNode.add(recursiveSmartTruncate(element, maxLength)); + } + return newNode; + } + return node; + } + + private static String truncateString(String s, int maxLength) { + if (s == null || s.length() <= maxLength) { + return s; + } + return s.substring(0, maxLength) + "...[truncated]"; + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java new file mode 100644 index 000000000..4f4350d1a --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java @@ -0,0 +1,367 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.RowError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.rpc.Status; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BatchProcessorTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private StreamWriter mockWriter; + private ScheduledExecutorService executor; + private BatchProcessor batchProcessor; + private Schema schema; + private Handler mockHandler; + + @Before + public void setUp() { + executor = Executors.newScheduledThreadPool(1); + batchProcessor = new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor); + schema = BigQuerySchema.getArrowSchema(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @After + public void tearDown() { + batchProcessor.close(); + executor.shutdown(); + } + + @Test + public void flush_populatesTimestampFieldCorrectly() throws Exception { + Instant now = Instant.parse("2026-03-02T19:11:49.631Z"); + Map row = new HashMap<>(); + row.put("timestamp", now); + row.put("event_type", "TEST_EVENT"); + + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + + var timestampVector = root.getVector("timestamp"); + if (!(timestampVector instanceof TimeStampMicroTZVector tzVector)) { + failureMessage[0] = "Vector should be an instance of TimeStampMicroTZVector"; + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + if (tzVector.isNull(0)) { + failureMessage[0] = "Timestamp should NOT be null"; + } else if (tzVector.get(0) != now.toEpochMilli() * 1000) { + failureMessage[0] = + "Expected " + (now.toEpochMilli() * 1000) + ", got " + tzVector.get(0); + } else { + checksPassed[0] = true; + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during check: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void flush_populatesAllBasicFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", "BASIC_EVENT"); + row.put("is_truncated", true); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals("BASIC_EVENT", root.getVector("event_type").getObject(0).toString()); + assertEquals(1, ((BitVector) root.getVector("is_truncated")).get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesJsonFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("content", "{\"key\": \"value\"}"); + row.put("attributes", "{\"attr\": 123}"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals( + "{\"key\": \"value\"}", root.getVector("content").getObject(0).toString()); + assertEquals( + "{\"attr\": 123}", root.getVector("attributes").getObject(0).toString()); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesNestedStructs() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + + List> contentParts = new ArrayList<>(); + Map part = new HashMap<>(); + part.put("mime_type", "text/plain"); + part.put("text", "hello world"); + part.put("part_index", 0L); + contentParts.add(part); + row.put("content_parts", contentParts); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + StructVector structVector = (StructVector) contentPartsVector.getDataVector(); + + assertEquals(1, ((List) contentPartsVector.getObject(0)).size()); + VarCharVector mimeTypeVector = (VarCharVector) structVector.getChild("mime_type"); + assertEquals("text/plain", mimeTypeVector.getObject(0).toString()); + + VarCharVector textVector = (VarCharVector) structVector.getChild("text"); + assertEquals("hello world", textVector.getObject(0).toString()); + + BigIntVector partIndexVector = (BigIntVector) structVector.getChild("part_index"); + assertEquals(0L, partIndexVector.get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesBigQueryErrorResponse() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ERROR_EVENT"); + + AppendRowsResponse responseWithError = + AppendRowsResponse.newBuilder() + .setError(Status.newBuilder().setMessage("Global error").build()) + .addRowErrors(RowError.newBuilder().setIndex(0).setMessage("Row error").build()) + .build(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(responseWithError)); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesGenericExceptionDuringAppend() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "EXCEPTION_EVENT"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenThrow(new RuntimeException("Simulated failure")); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void append_triggersFlushWhenBatchSizeReached() { + ScheduledExecutorService mockExecutor = mock(ScheduledExecutorService.class); + BatchProcessor bp = new BatchProcessor(mockWriter, 2, Duration.ofMinutes(1), 10, mockExecutor); + + Map row = new HashMap<>(); + bp.append(row); + verify(mockExecutor, never()).execute(any(Runnable.class)); + + bp.append(row); + verify(mockExecutor).execute(any(Runnable.class)); + } + + @Test + public void flush_doesNothingWhenQueueIsEmpty() throws Exception { + batchProcessor.flush(); + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesNullValues() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", null); + row.put("is_truncated", null); + + final boolean[] checksPassed = {false}; + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertTrue(root.getVector("event_type").isNull(0)); + assertTrue(root.getVector("is_truncated").isNull(0)); + checksPassed[0] = true; + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue("Null checks failed", checksPassed[0]); + } + + @Test + public void flush_handlesAllocationFailure() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ALLOC_FAIL_EVENT"); + batchProcessor.append(row); + batchProcessor.allocator.setLimit(1); + + batchProcessor.flush(); + + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + boolean foundError = false; + for (LogRecord record : captor.getAllValues()) { + if (record.getLevel().equals(Level.SEVERE) + && record.getMessage().contains("Failed to write batch to BigQuery")) { + foundError = true; + break; + } + } + assertTrue("Expected SEVERE error log not found", foundError); + } + + @Test + public void close_flushesAndClosesResources() throws Exception { + try (BatchProcessor bp = + new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor)) { + Map row = new HashMap<>(); + row.put("event_type", "CLOSE_EVENT"); + bp.append(row); + } + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + verify(mockWriter).close(); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java new file mode 100644 index 000000000..8147c5cc6 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -0,0 +1,457 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BigQueryAgentAnalyticsPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private BigQuery mockBigQuery; + @Mock private StreamWriter mockWriter; + @Mock private BigQueryWriteClient mockWriteClient; + @Mock private InvocationContext mockInvocationContext; + private BaseAgent fakeAgent; + + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Handler mockHandler; + + @Before + public void setUp() throws Exception { + fakeAgent = new FakeAgent("agent_name"); + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setAutoSchemaUpgrade(false) + .setCredentials(mock(Credentials.class)) + .setCustomTags(ImmutableMap.of("global_tag", "global_value")) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + Session session = Session.builder("session_id").build(); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); + when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.userId()).thenReturn("user_id"); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @Test + public void onUserMessageCallback_appendsToWriter() throws Exception { + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void beforeRunCallback_appendsToWriter() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void getStreamName_returnsCorrectFormat() { + BigQueryLoggerConfig config = + BigQueryLoggerConfig.builder() + .setProjectId("test-project") + .setDatasetId("test-dataset") + .setTableName("test-table") + .build(); + + String streamName = plugin.getStreamName(config); + + assertEquals( + "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", + streamName); + } + + @Test + public void formatContentParts_populatesCorrectFields() { + Content content = Content.fromParts(Part.fromText("hello")); + ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals(0, node.get("part_index").asInt()); + assertEquals("INLINE", node.get("storage_mode").asText()); + assertEquals("hello", node.get("text").asText()); + assertEquals("text/plain", node.get("mime_type").asText()); + } + + @Test + public void arrowSchema_hasJsonMetadata() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentField = schema.findField("content"); + assertNotNull(contentField); + assertEquals("google:sqlType:json", contentField.getMetadata().get("ARROW:extension:name")); + } + + @Test + public void onUserMessageCallback_handlesTableCreationFailure() throws Exception { + Logger logger = Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + Handler mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + try { + when(mockBigQuery.getTable(any(TableId.class))) + .thenThrow(new RuntimeException("Table check failed")); + Content content = Content.builder().build(); + + // Should not throw exception + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue( + captor + .getValue() + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table")); + assertEquals(Level.WARNING, captor.getValue().getLevel()); + } finally { + logger.removeHandler(mockHandler); + } + } + + @Test + public void onUserMessageCallback_handlesAppendFailure() throws Exception { + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFailedFuture(new RuntimeException("Append failed"))); + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + // Flush should handle the failed future from writer.append() + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue(captor.getValue().getMessage().contains("Failed to write batch to BigQuery")); + assertEquals(Level.SEVERE, captor.getValue().getLevel()); + } + + @Test + public void ensureTableExists_calledOnlyOnce() throws Exception { + Content content = Content.builder().build(); + + // Multiple calls to logEvent via different callbacks + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + // Verify getting table was only done once. Using fully qualified name to avoid ambiguity. + verify(mockBigQuery).getTable(any(TableId.class)); + } + + @Test + public void arrowSchema_handlesNestedFields() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentPartsField = schema.findField("content_parts"); + assertNotNull(contentPartsField); + // Repeated struct becomes a List of Structs + assertTrue(contentPartsField.getType() instanceof ArrowType.List); + + Field element = contentPartsField.getChildren().get(0); + assertEquals("element", element.getName()); + + // Check object_ref which is a nested STRUCT + Field objectRef = + element.getChildren().stream() + .filter(f -> f.getName().equals("object_ref")) + .findFirst() + .orElse(null); + assertNotNull(objectRef); + assertTrue(objectRef.getType() instanceof ArrowType.Struct); + assertFalse(objectRef.getChildren().isEmpty()); + } + + @Test + public void arrowSchema_handlesFieldNullability() { + Schema schema = BigQuerySchema.getArrowSchema(); + + // timestamp is REQUIRED in BigQuerySchema.getEventsSchema() + Field timestampField = schema.findField("timestamp"); + assertNotNull(timestampField); + assertFalse(timestampField.isNullable()); + + // event_type is NULLABLE in BigQuerySchema.getEventsSchema() + Field eventTypeField = schema.findField("event_type"); + assertNotNull(eventTypeField); + assertTrue(eventTypeField.isNullable()); + } + + @Test + public void logEvent_populatesCommonFields() throws Exception { + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + Schema schema = BigQuerySchema.getArrowSchema(); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + } else if (!Objects.equals( + root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + failureMessage[0] = + "Wrong event_type: " + root.getVector("event_type").getObject(0); + } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { + failureMessage[0] = "Wrong agent: " + root.getVector("agent").getObject(0); + } else if (!root.getVector("session_id") + .getObject(0) + .toString() + .equals("session_id")) { + failureMessage[0] = + "Wrong session_id: " + root.getVector("session_id").getObject(0); + } else if (!root.getVector("invocation_id") + .getObject(0) + .toString() + .equals("invocation_id")) { + failureMessage[0] = + "Wrong invocation_id: " + root.getVector("invocation_id").getObject(0); + } else if (!root.getVector("user_id").getObject(0).toString().equals("user_id")) { + failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); + } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { + failureMessage[0] = "Timestamp not populated"; + } else { + // Check content and content_parts + String contentJson = root.getVector("content").getObject(0).toString(); + if (!contentJson.contains("test message")) { + failureMessage[0] = "Wrong content: " + contentJson; + } else { + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + if (((List) contentPartsVector.getObject(0)).isEmpty()) { + failureMessage[0] = "content_parts is empty"; + } else { + // Check attributes + String attributesJson = root.getVector("attributes").getObject(0).toString(); + if (!attributesJson.contains("global_tag") + || !attributesJson.contains("global_value")) { + failureMessage[0] = "Wrong attributes: " + attributesJson; + } else { + checksPassed[0] = true; + } + } + } + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during inspection: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void logEvent_populatesTraceDetails() throws Exception { + String traceId = "4bf92f3577b34da6a3ce929d0e0e4736"; + String spanId = "00f067aa0ba902b7"; + + SpanContext mockSpanContext = mock(SpanContext.class); + when(mockSpanContext.isValid()).thenReturn(true); + when(mockSpanContext.getTraceId()).thenReturn(traceId); + when(mockSpanContext.getSpanId()).thenReturn(spanId); + + Span mockSpan = Span.wrap(mockSpanContext); + + try (Scope scope = mockSpan.makeCurrent()) { + Content content = Content.builder().build(); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals(traceId, row.get("trace_id")); + assertEquals(spanId, row.get("span_id")); + } + } + + @Test + public void complexType_appendsToWriter() throws Exception { + Part part = Part.fromText("test text"); + Content content = Content.fromParts(part); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void onEventCallback_populatesCorrectFields() throws Exception { + Event event = + Event.builder() + .author("agent_author") + .content(Content.fromParts(Part.fromText("event content"))) + .build(); + + plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("EVENT", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("agent_author")); + assertTrue(row.get("content").toString().contains("event content")); + } + + @Test + public void onModelErrorCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); + LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); + Throwable error = new RuntimeException("model error message"); + + plugin + .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("MODEL_ERROR", row.get("event_type")); + assertEquals("agent_in_context", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("model error message")); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/pom.xml b/pom.xml index bd0caca0d..62082cfc9 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,8 @@ 2.15.0 3.9.0 5.6 + 4.1.118.Final + @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -85,6 +87,13 @@ pom import + + io.netty + netty-bom + ${netty.version} + pom + import + com.google.cloud libraries-bom @@ -338,6 +347,8 @@ + + ${surefire.argLine} plain From 551c31f495aafde8568461cc0aa0973d7df7e5ac Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 18 Mar 2026 03:39:02 -0700 Subject: [PATCH 13/29] fix: include saveArtifact invocations in event chain PiperOrigin-RevId: 885495376 --- .../java/com/google/adk/runner/Runner.java | 12 ++-- .../com/google/adk/runner/RunnerTest.java | 55 +++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 51e1b8f25..1f7d924ab 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -313,6 +313,7 @@ private Single appendNewMessageToSession( throw new IllegalArgumentException("No parts in the new_message."); } + Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { // The runner directly saves the artifacts (if applicable) in the user message and replaces // the artifact data with a file name placeholder. @@ -322,9 +323,11 @@ private Single appendNewMessageToSession( continue; } String fileName = "artifact_" + invocationContext.invocationId() + "_" + i; - var unused = - this.artifactService.saveArtifact( - this.appName, session.userId(), session.id(), fileName, part); + saveArtifactsFlow = + saveArtifactsFlow.andThen( + this.artifactService + .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .ignoreElement()); newMessage .parts() @@ -349,7 +352,8 @@ private Single appendNewMessageToSession( EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); } - return this.sessionService.appendEvent(session, eventBuilder.build()); + return saveArtifactsFlow.andThen( + this.sessionService.appendEvent(session, eventBuilder.build())); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 2eb515fa2..a3e21cb73 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -24,6 +24,8 @@ import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -36,6 +38,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -65,12 +68,14 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -78,6 +83,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class RunnerTest { @@ -849,6 +855,19 @@ private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } + private static Content createInlineDataContent(byte[]... data) { + return Content.builder() + .parts( + stream(data) + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) + .toArray(Part[]::new)) + .build(); + } + + private static Content createInlineDataContent(String... data) { + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); + } + @Test public void runAsync_createsInvocationSpan() { var unused = @@ -1331,4 +1350,40 @@ public static ImmutableMap echoTool(String message) { return ImmutableMap.of("message", message); } } + + @Test + public void runner_executesSaveArtifactFlow() { + // arrange + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) + .thenReturn( + Single.defer( + () -> { + // we want to assert not only that the saveArtifact method was + // called, but also that the flow that it returned was run, so + // we need to record the call in a counter + artifactsSavedCounter.incrementAndGet(); + return Single.just(42); + })); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .artifactService(mockArtifactService) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + // each inline data will be saved using our mock artifact service + Content content = createInlineDataContent("test data", "test data 2"); + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); + + // act + var events = runner.runAsync("user", session.id(), content, runConfig).test(); + + // assert + events.assertComplete(); + // artifacts were saved + assertThat(artifactsSavedCounter.get()).isEqualTo(2); + // agent was run + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); + } } From e51f9112050955657da0dfc3aedc00f90ad739ec Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 05:37:53 -0700 Subject: [PATCH 14/29] feat: add handling the a2a metadata in the RemoteA2AAgent; Add the enum type for the metadata keys PiperOrigin-RevId: 885539894 --- .../adk/a2a/converters/A2AMetadataKey.java | 40 ++++ .../adk/a2a/converters/AdkMetadataKey.java | 35 ++++ .../adk/a2a/converters/PartConverter.java | 19 +- .../adk/a2a/converters/ResponseConverter.java | 131 +++++++++++-- .../adk/a2a/converters/PartConverterTest.java | 50 ++++- .../a2a/converters/ResponseConverterTest.java | 175 +++++++++++++++++- 6 files changed, 408 insertions(+), 42 deletions(-) create mode 100644 a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java create mode 100644 a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java new file mode 100644 index 000000000..d4f1fef58 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of A2A metadata. Adds a prefix used to differentiage ADK-related values stored + * in Metadata an A2A event. + */ +public enum A2AMetadataKey { + TYPE("type"), + IS_LONG_RUNNING("is_long_running"), + PARTIAL("partial"), + GROUNDING_METADATA("grounding_metadata"), + USAGE_METADATA("usage_metadata"), + CUSTOM_METADATA("custom_metadata"), + ERROR_CODE("error_code"); + + private final String type; + + private A2AMetadataKey(String type) { + this.type = "adk_" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java new file mode 100644 index 000000000..e38f28828 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of ADK metadata. Adds a prefix used to differentiate A2A-related values stored + * in custom metadata of an ADK session event. + */ +public enum AdkMetadataKey { + TASK_ID("task_id"), + CONTEXT_ID("context_id"); + + private final String type; + + private AdkMetadataKey(String type) { + this.type = "a2a:" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 61f24fa21..714a79736 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -52,11 +52,7 @@ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); private static final ObjectMapper objectMapper = new ObjectMapper(); - // Constants for metadata types. By convention metadata keys are prefixed with "adk_" to align - // with the Python and Golang libraries. - public static final String A2A_DATA_PART_METADATA_TYPE_KEY = "adk_type"; - public static final String A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = "adk_is_long_running"; - public static final String A2A_DATA_PART_METADATA_IS_PARTIAL_KEY = "adk_partial"; + // Constants for metadata types. public static final String LANGUAGE_KEY = "language"; public static final String OUTCOME_KEY = "outcome"; public static final String CODE_KEY = "code"; @@ -135,7 +131,7 @@ private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart d Map metadata = Optional.ofNullable(dataPart.getMetadata()).map(HashMap::new).orElseGet(HashMap::new); - String metadataType = metadata.getOrDefault(A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String metadataType = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY)) || metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { @@ -218,7 +214,7 @@ private static DataPart createDataPartFromFunctionCall( addValueIfPresent(data, WILL_CONTINUE_KEY, functionCall.willContinue()); addValueIfPresent(data, PARTIAL_ARGS_KEY, functionCall.partialArgs()); - metadata.put(A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_CALL.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -245,7 +241,7 @@ private static DataPart createDataPartFromFunctionResponse( addValueIfPresent(data, PARTS_KEY, functionResponse.parts()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -268,7 +264,7 @@ private static DataPart createDataPartFromCodeExecutionResult( addValueIfPresent(data, OUTPUT_KEY, codeExecutionResult.output()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -290,8 +286,7 @@ private static DataPart createDataPartFromExecutableCode( .orElse(Language.Known.LANGUAGE_UNSPECIFIED.toString())); addValueIfPresent(data, CODE_KEY, executableCode.code()); - metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -305,7 +300,7 @@ public static io.a2a.spec.Part fromGenaiPart(Part part, boolean isPartial) { } ImmutableMap.Builder metadata = ImmutableMap.builder(); if (isPartial) { - metadata.put(A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + metadata.put(A2AMetadataKey.PARTIAL.getType(), true); } if (part.text().isPresent()) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 503432a30..cffd76983 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -19,12 +19,20 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import io.a2a.client.ClientEvent; import io.a2a.client.MessageEvent; @@ -43,11 +51,13 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { + private static final ObjectMapper objectMapper = new ObjectMapper(); private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = ImmutableSet.of(TaskState.WORKING, TaskState.SUBMITTED); @@ -74,12 +84,11 @@ public static Optional clientEventToEvent( throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass()); } - private static boolean isPartial(Map metadata) { + private static boolean isPartial(@Nullable Map metadata) { if (metadata == null) { return false; } - return Objects.equals( - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true); + return Objects.equals(metadata.getOrDefault(A2AMetadataKey.PARTIAL.getType(), false), true); } /** @@ -110,7 +119,12 @@ private static Optional handleTaskUpdate( // append=false, lastChunk=false: emit as partial, reset aggregation // append=true, lastChunk=true: emit as partial, update aggregation and emit as non-partial // append=false, lastChunk=true: emit as non-partial, drop aggregation - return Optional.of(eventPart); + return Optional.of( + updateEventMetadata( + eventPart, + artifactEvent.getMetadata(), + artifactEvent.getTaskId(), + artifactEvent.getContextId())); } if (updateEvent instanceof TaskStatusUpdateEvent statusEvent) { @@ -128,14 +142,21 @@ private static Optional handleTaskUpdate( }); if (statusEvent.isFinal()) { - return messageEvent - .map(Event::toBuilder) - .or(() -> Optional.of(remoteAgentEventBuilder(context))) - .map(builder -> builder.turnComplete(true)) - .map(builder -> builder.partial(false)) - .map(Event.Builder::build); + messageEvent = + messageEvent + .map(Event::toBuilder) + .or(() -> Optional.of(remoteAgentEventBuilder(context))) + .map(builder -> builder.turnComplete(true)) + .map(builder -> builder.partial(false)) + .map(Event.Builder::build); } - return messageEvent; + return messageEvent.map( + finalMessageEvent -> + updateEventMetadata( + finalMessageEvent, + statusEvent.getMetadata(), + statusEvent.getTaskId(), + statusEvent.getContextId())); } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -163,9 +184,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo /** Converts an A2A message back to ADK events. */ public static Event messageToEvent(Message message, InvocationContext invocationContext) { - return remoteAgentEventBuilder(invocationContext) - .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) - .build(); + return updateEventMetadata( + remoteAgentEventBuilder(invocationContext) + .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) + .build(), + message.getMetadata(), + message.getTaskId(), + message.getContextId()); } /** @@ -228,7 +253,8 @@ public static Event taskToEvent(Task task, InvocationContext invocationContext) eventBuilder.longRunningToolIds(longRunningToolIds.build()); } eventBuilder.turnComplete(isFinal); - return eventBuilder.build(); + return updateEventMetadata( + eventBuilder.build(), task.getMetadata(), task.getId(), task.getContextId()); } private static ImmutableSet getLongRunningToolIds( @@ -241,9 +267,7 @@ private static ImmutableSet getLongRunningToolIds( return Optional.empty(); } Object isLongRunning = - dataPart - .getMetadata() - .get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY); + dataPart.getMetadata().get(A2AMetadataKey.IS_LONG_RUNNING.getType()); if (!Objects.equals(isLongRunning, true)) { return Optional.empty(); } @@ -256,6 +280,77 @@ private static ImmutableSet getLongRunningToolIds( .collect(toImmutableSet()); } + private static Event updateEventMetadata( + Event event, + @Nullable Map clientMetadata, + @Nullable String taskId, + @Nullable String contextId) { + if (taskId == null || contextId == null) { + logger.warn("Task ID or context ID is null, skipping metadata update."); + return event; + } + + if (clientMetadata == null) { + clientMetadata = ImmutableMap.of(); + } + Event.Builder eventBuilder = event.toBuilder(); + Object groundingMetadata = clientMetadata.get(A2AMetadataKey.GROUNDING_METADATA.getType()); + // if groundingMetadata is null, parseMetadata will return null as well. + eventBuilder.groundingMetadata(parseMetadata(groundingMetadata, GroundingMetadata.class)); + Object usageMetadata = clientMetadata.get(A2AMetadataKey.USAGE_METADATA.getType()); + // if usageMetadata is null, parseMetadata will return null as well. + eventBuilder.usageMetadata( + parseMetadata(usageMetadata, GenerateContentResponseUsageMetadata.class)); + + ImmutableList.Builder customMetadataList = ImmutableList.builder(); + customMetadataList + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.TASK_ID.getType()) + .stringValue(taskId) + .build()) + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.CONTEXT_ID.getType()) + .stringValue(contextId) + .build()); + Object customMetadata = clientMetadata.get(A2AMetadataKey.CUSTOM_METADATA.getType()); + if (customMetadata != null) { + customMetadataList.addAll( + parseMetadata(customMetadata, new TypeReference>() {})); + } + eventBuilder.customMetadata(customMetadataList.build()); + + Object errorCode = clientMetadata.get(A2AMetadataKey.ERROR_CODE.getType()); + eventBuilder.errorCode(parseMetadata(errorCode, FinishReason.class)); + + return eventBuilder.build(); + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, Class type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type, e); + } + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, TypeReference type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type.getType(), e); + } + } + private static Event emptyEvent(InvocationContext invocationContext) { Event.Builder builder = Event.builder() diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java index d93466dd2..4a0828c43 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java @@ -8,9 +8,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.CodeExecutionResult; +import com.google.genai.types.ExecutableCode; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Language; +import com.google.genai.types.Outcome; import com.google.genai.types.Part; import io.a2a.spec.DataPart; import io.a2a.spec.FilePart; @@ -86,8 +90,7 @@ public void toGenaiPart_withDataPartFunctionCall_returnsGenaiFunctionCallPart() new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType())); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -121,7 +124,7 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -188,7 +191,7 @@ public void fromGenaiPart_withTextPart_returnsTextPart() { assertThat(((TextPart) result).getText()).isEqualTo("text"); assertThat(((TextPart) result).getMetadata()).containsEntry("thought", true); assertThat(((TextPart) result).getMetadata()) - .containsEntry(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + .containsEntry(A2AMetadataKey.PARTIAL.getType(), true); } @Test @@ -226,6 +229,39 @@ public void fromGenaiPart_withInlineDataPart_returnsFilePartWithBytes() { assertThat(Base64.getDecoder().decode(fileWithBytes.bytes())).isEqualTo(bytes); } + @Test + public void fromGenaiPart_dataPart_executableCode_returnsDataPart() { + ExecutableCode executableCode = + ExecutableCode.builder().code("print('hello')").language(new Language("python")).build(); + Part part = Part.builder().executableCode(executableCode).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("code")).isEqualTo("print('hello')"); + assertThat(dataPart.getData().get("language")).isEqualTo("python"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("executable_code"); + } + + @Test + public void fromGenaiPart_dataPart_codeExecutionResult_returnsDataPart() { + CodeExecutionResult codeExecutionResult = + CodeExecutionResult.builder() + .outcome(new Outcome("OUTCOME_OK")) + .output("print('hello')") + .build(); + Part part = Part.builder().codeExecutionResult(codeExecutionResult).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("outcome")).isEqualTo("OUTCOME_OK"); + assertThat(dataPart.getData().get("output")).isEqualTo("print('hello')"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("code_execution_result"); + } + @Test public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { Part part = @@ -255,8 +291,7 @@ public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { true); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); } @Test @@ -275,8 +310,7 @@ public void fromGenaiPart_withFunctionResponsePart_returnsDataPart() { .containsExactly("name", "func", "id", "1", "response", ImmutableMap.of()); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index d84dc42cd..b61b00e1a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -1,6 +1,8 @@ package com.google.adk.a2a.converters; import static com.google.common.truth.Truth.assertThat; +import static java.util.stream.Collectors.joining; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -13,6 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import io.a2a.client.MessageEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; @@ -136,6 +142,74 @@ public void taskToEvent_withStatusMessage_returnsEvent() { assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); } + @Test + public void taskToEvent_withGroundingMetadata_returnsEvent() { + GroundingMetadata groundingMetadata = + GroundingMetadata.builder().webSearchQueries("test-query").build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of( + A2AMetadataKey.GROUNDING_METADATA.getType(), groundingMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.groundingMetadata()).hasValue(groundingMetadata); + } + + @Test + public void taskToEvent_withCustomMetadata_returnsEvent() { + ImmutableList customMetadataList = + ImmutableList.of( + CustomMetadata.builder().key("test-key").stringValue("test-value").build()); + String customMetadataJson = + customMetadataList.stream().map(CustomMetadata::toJson).collect(joining(",", "[", "]")); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.CUSTOM_METADATA.getType(), customMetadataJson)) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.customMetadata().get()) + .containsExactly( + CustomMetadata.builder().key("a2a:task_id").stringValue("task-1").build(), + CustomMetadata.builder().key("a2a:context_id").stringValue("context-1").build(), + CustomMetadata.builder().key("test-key").stringValue("test-value").build()) + .inOrder(); + } + + @Test + public void messageToEvent_withMissingTaskId_returnsEvent() { + Message a2aMessage = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .taskId("task-1") + .parts(ImmutableList.of(new TextPart("test-message"))) + .build(); + Event event = ResponseConverter.messageToEvent(a2aMessage, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.customMetadata()).isEmpty(); + } + @Test public void taskToEvent_withNoMessage_returnsEmptyEvent() { TaskStatus status = new TaskStatus(TaskState.WORKING, null, null); @@ -152,18 +226,18 @@ public void taskToEvent_withInputRequired_parsesLongRunningToolIds() { ImmutableMap.of("name", "myTool", "id", "call_123", "args", ImmutableMap.of()); ImmutableMap metadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart dataPart = new DataPart(data, metadata); ImmutableMap statusData = ImmutableMap.of("name", "messageTools", "id", "msg_123", "args", ImmutableMap.of()); ImmutableMap statusMetadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart statusDataPart = new DataPart(statusData, statusMetadata); Message statusMessage = @@ -361,6 +435,99 @@ public void clientEventToEvent_withFailedTaskStatusUpdateEvent_returnsErrorEvent assertThat(resultEvent.turnComplete()).hasValue(true); } + @Test + public void taskToEvent_withInvalidMetadata_throwsException() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.GROUNDING_METADATA.getType(), "{ invalid json ]")) + .build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ResponseConverter.taskToEvent(task, invocationContext)); + assertThat(exception).hasMessageThat().contains("Failed to parse metadata"); + assertThat(exception).hasMessageThat().contains("GroundingMetadata"); + } + + @Test + public void taskToEvent_withErrorCode_returnsEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.ERROR_CODE.getType(), "\"STOP\"")) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.errorCode()).hasValue(new FinishReason(FinishReason.Known.STOP)); + } + + @Test + public void taskToEvent_withUsageMetadata_returnsEvent() { + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.USAGE_METADATA.getType(), usageMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.usageMetadata()).hasValue(usageMetadata); + } + + @Test + public void clientEventToEvent_withTaskArtifactUpdateEventAndPartialTrue_returnsEmpty() { + io.a2a.spec.Part a2aPart = new TextPart("Artifact content"); + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(a2aPart)).build(); + Task task = + testTask() + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(ImmutableList.of(artifact)) + .build(); + TaskArtifactUpdateEvent updateEvent = + new TaskArtifactUpdateEvent.Builder() + .lastChunk(true) + .metadata(ImmutableMap.of(A2AMetadataKey.PARTIAL.getType(), true)) + .contextId("context-1") + .artifact(artifact) + .taskId("task-id-1") + .build(); + TaskUpdateEvent event = new TaskUpdateEvent(task, updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isEmpty(); + } + private static final class TestAgent extends BaseAgent { TestAgent() { super("test_agent", "test", ImmutableList.of(), null, null); From 0d1e5c7b0c42cea66b178cf8fedf08a8c20f7fd0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 07:02:14 -0700 Subject: [PATCH 15/29] feat: update stateDelta builder input to Map from ConcurrentMap PiperOrigin-RevId: 885570460 --- .../main/java/com/google/adk/events/EventActions.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 1ca856b45..3565c3e99 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -287,8 +287,14 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { - this.stateDelta = value; + public Builder stateDelta(@Nullable Map value) { + if (value == null) { + this.stateDelta = new ConcurrentHashMap<>(); + } else if (value instanceof ConcurrentMap) { + this.stateDelta = (ConcurrentMap) value; + } else { + this.stateDelta = new ConcurrentHashMap<>(value); + } return this; } From de3b2767748436b07f55e7d00034d77d7d940579 Mon Sep 17 00:00:00 2001 From: Greg Brail Date: Mon, 12 Jan 2026 15:26:20 -0800 Subject: [PATCH 16/29] Remove ADK dependency for langchain4j module --- contrib/langchain4j/pom.xml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index c2326fa0a..3dd2d1132 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -58,11 +58,6 @@ google-adk ${project.version} - - com.google.adk - google-adk-dev - ${project.version} - com.google.genai google-genai From 3ba04d33dc8f2ef8b151abe1be4d1c8b7afcc25a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 07:59:31 -0700 Subject: [PATCH 17/29] fix: workaround for the client config streaming settings are not respected (#983) PiperOrigin-RevId: 885595843 --- .../com/google/adk/a2a/agent/RemoteA2AAgent.java | 13 ++++++++++++- .../google/adk/a2a/agent/RemoteA2AAgentTest.java | 16 +++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index ccb662b7c..4a375980c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -117,7 +117,7 @@ private RemoteA2AAgent(Builder builder) { if (this.description.isEmpty() && this.agentCard.description() != null) { this.description = this.agentCard.description(); } - this.streaming = this.agentCard.capabilities().streaming(); + this.streaming = builder.streaming && this.agentCard.capabilities().streaming(); } public static Builder builder() { @@ -133,6 +133,13 @@ public static class Builder { private List subAgents; private List beforeAgentCallback; private List afterAgentCallback; + private boolean streaming; + + @CanIgnoreReturnValue + public Builder streaming(boolean streaming) { + this.streaming = streaming; + return this; + } @CanIgnoreReturnValue public Builder name(String name) { @@ -181,6 +188,10 @@ public RemoteA2AAgent build() { } } + public boolean isStreaming() { + return streaming; + } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); } diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index b1ffa248a..0609c3b04 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -113,6 +113,20 @@ public void setUp() { .build(); } + @Test + public void createAgent_streaming_false_returnsNonStreamingAgent() { + // With streaming false, the agent should not stream even if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(false).build(); + assertThat(agent.isStreaming()).isFalse(); + } + + @Test + public void createAgent_streaming_true_returnsStreamingAgent() { + // With streaming true, the agent should support streaming if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(true).build(); + assertThat(agent.isStreaming()).isTrue(); + } + @Test public void runAsync_aggregatesPartialEvents() { RemoteA2AAgent agent = createAgent(); @@ -763,7 +777,7 @@ private RemoteA2AAgent.Builder getAgentBuilder() { } private RemoteA2AAgent createAgent() { - return getAgentBuilder().build(); + return getAgentBuilder().streaming(true).build(); } @SuppressWarnings("unchecked") // cast for Mockito From 94de7f199f86b39bdb7cce6e9800eb05008a8953 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 18 Mar 2026 09:37:12 -0700 Subject: [PATCH 18/29] fix: Use ConcurrentHashMap in InvocationReplayState fixes #1009 PiperOrigin-RevId: 885641755 --- .../java/com/google/adk/plugins/InvocationReplayState.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java index 7d70a0efb..eab293de0 100644 --- a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java +++ b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java @@ -16,8 +16,8 @@ package com.google.adk.plugins; import com.google.adk.plugins.recordings.Recordings; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** Per-invocation replay state to isolate concurrent runs. */ class InvocationReplayState { @@ -33,7 +33,7 @@ public InvocationReplayState(String testCasePath, int userMessageIndex, Recordin this.testCasePath = testCasePath; this.userMessageIndex = userMessageIndex; this.recordings = recordings; - this.agentReplayIndices = new HashMap<>(); + this.agentReplayIndices = new ConcurrentHashMap<>(); } public String getTestCasePath() { @@ -57,7 +57,6 @@ public void setAgentReplayIndex(String agentName, int index) { } public void incrementAgentReplayIndex(String agentName) { - int currentIndex = getAgentReplayIndex(agentName); - setAgentReplayIndex(agentName, currentIndex + 1); + agentReplayIndices.merge(agentName, 1, Integer::sum); } } From 2c71ba1332e052189115cd4644b7a473c31ed414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fatih=20C=C3=BCre?= Date: Wed, 18 Mar 2026 10:06:35 +0300 Subject: [PATCH 19/29] feat: Enhance LangChain4j to support MCP tools with parametersJsonSchema --- .../adk/models/langchain4j/LangChain4j.java | 23 +++- .../models/langchain4j/LangChain4jTest.java | 124 ++++++++++++++++++ .../com/google/adk/models/LlmRequest.java | 2 +- 3 files changed, 144 insertions(+), 5 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 80c25610d..3ccb1e029 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; @@ -428,8 +429,24 @@ private List toToolSpecifications(LlmRequest llmRequest) { baseTool -> { if (baseTool.declaration().isPresent()) { FunctionDeclaration functionDeclaration = baseTool.declaration().get(); - if (functionDeclaration.parameters().isPresent()) { - Schema schema = functionDeclaration.parameters().get(); + Schema schema = null; + if (functionDeclaration.parametersJsonSchema().isPresent()) { + Object jsonSchemaObj = functionDeclaration.parametersJsonSchema().get(); + try { + if (jsonSchemaObj instanceof Schema) { + schema = (Schema) jsonSchemaObj; + } else { + schema = JsonBaseModel.getMapper().convertValue(jsonSchemaObj, Schema.class); + } + } catch (Exception e) { + throw new IllegalStateException( + "Failed to convert parametersJsonSchema to Schema: " + e.getMessage(), e); + } + } else if (functionDeclaration.parameters().isPresent()) { + schema = functionDeclaration.parameters().get(); + } + + if (schema != null) { ToolSpecification toolSpecification = ToolSpecification.builder() .name(baseTool.name()) @@ -438,11 +455,9 @@ private List toToolSpecifications(LlmRequest llmRequest) { .build(); toolSpecifications.add(toolSpecification); } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 428a5660c..076bb79a3 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -688,4 +688,128 @@ void testGenerateContentWithStructuredResponseJsonSchema() { final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema") + void testGenerateContentWithMcpToolParametersJsonSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // MCP tools use parametersJsonSchema() instead of parameters() + // Create a JSON schema object (Map representation) + final Map jsonSchemaMap = + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "City name")), + "required", + List.of("city")); + + // Mock parametersJsonSchema() to return the JSON schema object + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema") + void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a Schema object directly (when parametersJsonSchema returns Schema) + final Schema cityPropertySchema = + Schema.builder().type("STRING").description("City name").build(); + + final Schema objectSchema = + Schema.builder() + .type("OBJECT") + .properties(Map.of("city", cityPropertySchema)) + .required(List.of("city")) + .build(); + + // Mock parametersJsonSchema() to return Schema directly + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } } diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index 1a45c3a95..760a7c1c6 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -150,7 +150,7 @@ private static Builder create() { abstract LiveConnectConfig liveConnectConfig(); @CanIgnoreReturnValue - abstract Builder tools(Map tools); + public abstract Builder tools(Map tools); abstract Map tools(); From fa67101fe0555e8bbed5cf304d00550a56308222 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 18 Mar 2026 14:05:17 -0700 Subject: [PATCH 20/29] ADK changes PiperOrigin-RevId: 885777704 --- .../google/adk/agents/InvocationContext.java | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 91ce13a87..365f4f8c1 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -90,16 +90,6 @@ public Builder toBuilder() { return new Builder(this); } - /** - * Creates a shallow copy of the given {@link InvocationContext}. - * - * @deprecated Use {@code other.toBuilder().build()} instead. - */ - @Deprecated(forRemoval = true) - public static InvocationContext copyOf(InvocationContext other) { - return other.toBuilder().build(); - } - /** Returns the session service for managing session state. */ public BaseSessionService sessionService() { return sessionService; @@ -156,16 +146,6 @@ public BaseAgent agent() { return agent; } - /** - * Sets the [agent] being invoked. This is useful when delegating to a sub-agent. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead. - */ - @Deprecated(forRemoval = true) - public void agent(BaseAgent agent) { - this.agent = agent; - } - /** Returns the session associated with this invocation. */ public Session session() { return session; From d7e03eeb067b83abd2afa3ea9bb5fc1c16143245 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 01:55:07 -0700 Subject: [PATCH 21/29] fix: Relaxing constraints for output schema These changes are now in sync with Python ADK PiperOrigin-RevId: 886040294 --- .../java/com/google/adk/agents/LlmAgent.java | 34 ------- .../com/google/adk/agents/LlmAgentTest.java | 98 +++++-------------- 2 files changed, 26 insertions(+), 106 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 89024a59b..b387aee34 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -594,40 +594,6 @@ protected void validate() { this.disallowTransferToParent != null && this.disallowTransferToParent; this.disallowTransferToPeers = this.disallowTransferToPeers != null && this.disallowTransferToPeers; - - if (this.outputSchema != null) { - if (!this.disallowTransferToParent || !this.disallowTransferToPeers) { - logger.warn( - "Invalid config for agent {}: outputSchema cannot co-exist with agent transfer" - + " configurations. Setting disallowTransferToParent=true and" - + " disallowTransferToPeers=true.", - this.name); - this.disallowTransferToParent = true; - this.disallowTransferToPeers = true; - } - - if (this.subAgents != null && !this.subAgents.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, subAgents must be empty to disable agent" - + " transfer."); - } - if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - boolean hasOtherTools = - this.toolsUnion.stream() - .anyMatch( - tool -> - !(tool instanceof BaseTool baseTool) - || !baseTool.name().equals("example_tool")); - if (hasOtherTools) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); - } - } - } } @Override diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index c193e4a65..a9e7a6f8d 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -26,7 +26,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; import com.google.adk.agents.Callbacks.AfterToolCallback; @@ -52,9 +51,9 @@ import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; -import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import com.google.genai.types.Type; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -63,7 +62,6 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; @@ -213,75 +211,6 @@ public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { assertEqualIgnoringFunctionIds(events.get(3).content().get(), expectedFunctionResponseContent); } - @Test - public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() { - BaseTool tool = - new BaseTool("test_tool", "test_description") { - @Override - public Optional declaration() { - return Optional.empty(); - } - }; - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .tools(ImmutableList.of(tool)) // Set tools (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set, tools" - + " must be empty"); - } - - @Test - public void build_withOutputSchemaAndSubAgents_throwsIllegalArgumentException() { - ImmutableList subAgents = - ImmutableList.of( - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("test_sub_agent") - .description("test_sub_agent_description") - .build()); - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .subAgents(subAgents) // Set subAgents (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set," - + " subAgents must be empty to disable agent transfer."); - } - @Test public void testBuild_withNullInstruction_setsInstructionToEmptyString() { LlmAgent agent = @@ -645,6 +574,31 @@ public void runAsync_withSubAgents_createsSpans() throws InterruptedException { assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent } + @Test + public void run_outputSchemaWithTools_allowed() { + Schema personShema = + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + "name", Schema.builder().type(Type.Known.STRING).build(), + "age", Schema.builder().type(Type.Known.INTEGER).build(), + "city", Schema.builder().type(Type.Known.STRING).build())) + .build(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .outputSchema(personShema) + .tools(new EchoTool()) + .build(); + assertThat(agent.outputSchema()).hasValue(personShema); + assertThat( + agent + .canonicalTools(new ReadonlyContext(createInvocationContext(agent))) + .count() + .blockingGet()) + .isEqualTo(1); + } + private List findSpansByName(List spans, String name) { return spans.stream().filter(s -> s.getName().equals(name)).toList(); } From e534f12bd5c7cadb8a6100b00ac2ae771a868ab0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 05:59:19 -0700 Subject: [PATCH 22/29] refactor: Update map handling in EventActions to always use defensive copy and add null handling for `artifactDelta` in the Builder PiperOrigin-RevId: 886130618 --- .../sessions/FirestoreSessionServiceTest.java | 99 ------------------- .../com/google/adk/events/EventActions.java | 17 ++-- .../google/adk/events/EventActionsTest.java | 11 --- 3 files changed, 6 insertions(+), 121 deletions(-) diff --git a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java index ffcbcf8f5..43ca6889f 100644 --- a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java +++ b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java @@ -47,15 +47,11 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.observers.TestObserver; import java.time.Instant; -import java.util.AbstractMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -530,44 +526,6 @@ void appendAndGet_withAllPartTypes_serializesAndDeserializesCorrectly() { }); } - /** - * A wrapper class that implements ConcurrentMap but delegates to a HashMap. This is a workaround - * to allow putting null values, which ConcurrentHashMap forbids, for testing state removal logic. - */ - private static class HashMapAsConcurrentMap extends AbstractMap - implements ConcurrentMap { - private final HashMap map; - - public HashMapAsConcurrentMap(Map map) { - this.map = new HashMap<>(map); - } - - @Override - public Set> entrySet() { - return map.entrySet(); - } - - @Override - public V putIfAbsent(K key, V value) { - return map.putIfAbsent(key, value); - } - - @Override - public boolean remove(Object key, Object value) { - return map.remove(key, value); - } - - @Override - public boolean replace(K key, V oldValue, V newValue) { - return map.replace(key, oldValue, newValue); - } - - @Override - public V replace(K key, V value) { - return map.replace(key, value); - } - } - /** Tests that appendEvent with only app state deltas updates the correct stores. */ @Test void appendEvent_withAppOnlyStateDeltas_updatesCorrectStores() { @@ -662,63 +620,6 @@ void appendEvent_withUserOnlyStateDeltas_updatesCorrectStores() { verify(mockSessionDocRef, never()).update(eq(Constants.KEY_STATE), any()); } - /** - * Tests that appendEvent with all types of state deltas updates the correct stores and session - * state. - */ - @Test - void appendEvent_withAllStateDeltas_updatesCorrectStores() { - // Arrange - Session session = - Session.builder(SESSION_ID) - .appName(APP_NAME) - .userId(USER_ID) - .state(new ConcurrentHashMap<>()) // The session state itself must be concurrent - .build(); - session.state().put("keyToRemove", "someValue"); - - Map stateDeltaMap = new HashMap<>(); - stateDeltaMap.put("sessionKey", "sessionValue"); - stateDeltaMap.put("_app_appKey", "appValue"); - stateDeltaMap.put("_user_userKey", "userValue"); - stateDeltaMap.put("keyToRemove", null); - - // Use the wrapper to satisfy the ConcurrentMap interface for the builder - EventActions actions = - EventActions.builder().stateDelta(new HashMapAsConcurrentMap<>(stateDeltaMap)).build(); - - Event event = - Event.builder() - .author("model") - .content(Content.builder().parts(List.of(Part.fromText("..."))).build()) - .actions(actions) - .build(); - - when(mockSessionsCollection.document(SESSION_ID)).thenReturn(mockSessionDocRef); - when(mockEventsCollection.document()).thenReturn(mockEventDocRef); - when(mockEventDocRef.getId()).thenReturn(EVENT_ID); - // THIS IS THE MISSING MOCK: Stub the call to get the document by its specific ID. - when(mockEventsCollection.document(EVENT_ID)).thenReturn(mockEventDocRef); - // Add the missing mock for the final session update call - when(mockSessionDocRef.update(anyMap())) - .thenReturn(ApiFutures.immediateFuture(mockWriteResult)); - - // Act - sessionService.appendEvent(session, event).test().assertComplete(); - - // Assert - assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).doesNotContainKey("keyToRemove"); - - ArgumentCaptor> appStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockAppStateDocRef).set(appStateCaptor.capture(), any(SetOptions.class)); - assertThat(appStateCaptor.getValue()).containsEntry("appKey", "appValue"); - - ArgumentCaptor> userStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockUserStateUserDocRef).set(userStateCaptor.capture(), any(SetOptions.class)); - assertThat(userStateCaptor.getValue()).containsEntry("userKey", "userValue"); - } - /** Tests that getSession skips malformed events and returns only the well-formed ones. */ @Test @SuppressWarnings("unchecked") diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 3565c3e99..83fd60e54 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -157,9 +157,6 @@ public void setRequestedToolConfirmations( Map requestedToolConfirmations) { if (requestedToolConfirmations == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - } else if (requestedToolConfirmations instanceof ConcurrentMap) { - this.requestedToolConfirmations = - (ConcurrentMap) requestedToolConfirmations; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); } @@ -290,8 +287,6 @@ public Builder skipSummarization(boolean skipSummarization) { public Builder stateDelta(@Nullable Map value) { if (value == null) { this.stateDelta = new ConcurrentHashMap<>(); - } else if (value instanceof ConcurrentMap) { - this.stateDelta = (ConcurrentMap) value; } else { this.stateDelta = new ConcurrentHashMap<>(value); } @@ -300,8 +295,12 @@ public Builder stateDelta(@Nullable Map value) { @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(Map value) { - this.artifactDelta = new ConcurrentHashMap<>(value); + public Builder artifactDelta(@Nullable Map value) { + if (value == null) { + this.artifactDelta = new ConcurrentHashMap<>(); + } else { + this.artifactDelta = new ConcurrentHashMap<>(value); + } return this; } @@ -339,10 +338,6 @@ public Builder requestedAuthConfigs( public Builder requestedToolConfirmations(@Nullable Map value) { if (value == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - return this; - } - if (value instanceof ConcurrentMap) { - this.requestedToolConfirmations = (ConcurrentMap) value; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(value); } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 22bb94e64..b1e645e1a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -177,17 +177,6 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } - @Test - public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { - ConcurrentHashMap map = new ConcurrentHashMap<>(); - map.put("tool", TOOL_CONFIRMATION); - - EventActions actions = new EventActions(); - actions.setRequestedToolConfirmations(map); - - assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); - } - @Test public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); From cd56902b803d4f7a1f3c718529842823d9e4370a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 06:35:06 -0700 Subject: [PATCH 23/29] feat: Update return type of toolsets() from ImmutableList to List PiperOrigin-RevId: 886145022 --- core/src/main/java/com/google/adk/agents/LlmAgent.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index b387aee34..077068283 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -772,7 +772,7 @@ public List toolsUnion() { return toolsUnion; } - public ImmutableList toolsets() { + public List toolsets() { return toolsets; } From 9a080763d83c319f539d1bacac4595d13b299e7e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 07:11:45 -0700 Subject: [PATCH 24/29] feat: fixing context propagation for agent transfers PiperOrigin-RevId: 886159283 --- .../adk/flows/llmflows/BaseLlmFlow.java | 16 +- .../java/com/google/adk/telemetry/README.md | 156 ++++++++++ .../adk/telemetry/ContextPropagationTest.java | 269 +++++++++++++----- 3 files changed, 368 insertions(+), 73 deletions(-) create mode 100644 core/src/main/java/com/google/adk/telemetry/README.md diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index e00cf0cbf..8fabc978d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -430,7 +430,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.get().runAsync(context))); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runAsync(context); + } + })); } return postProcessedEvents; }); @@ -488,6 +493,8 @@ private Flowable run( public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + // Capture agent context at assembly time to use as parent for agent transfer at subscription + // time. See Flowable.defer() usages below. Context spanContext = Context.current(); return preprocessEvents.concatWith( @@ -608,7 +615,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.get().runLive(invocationContext); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/telemetry/README.md b/core/src/main/java/com/google/adk/telemetry/README.md new file mode 100644 index 000000000..8665b3352 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/README.md @@ -0,0 +1,156 @@ +# ADK Telemetry and Tracing + +This package contains classes for capturing and reporting telemetry data within +the ADK, primarily for tracing agent execution leveraging OpenTelemetry. + +## Overview + +The `Tracing` utility class provides methods to trace various aspects of an +agent's execution, including: + +* Agent invocations +* LLM requests and responses +* Tool calls and responses + +These traces can be exported and visualized in telemetry backends like Google +Cloud Trace or Zipkin, or viewed through the ADK Dev Server UI, providing +observability into agent behavior. + +## How Tracing is Used + +Tracing is deeply integrated into the ADK's RxJava-based asynchronous workflows. + +### Agent Invocations + +Every agent's `runAsync` or `runLive` execution is wrapped in a span named +`invoke_agent `. The top-level agent invocation initiated by +`Runner.runAsync` or `Runner.runLive` is captured in a span named `invocation`. +Agent-specific metadata like name and description are added as span attributes, +following OpenTelemetry semantic conventions (e.g., `gen_ai.agent.name`). + +### LLM Calls + +Calls to Large Language Models (LLMs) are traced within a `call_llm` span. The +`traceCallLlm` method attaches detailed attributes to this span, including: + +* The LLM request (excluding large data like images) and response. +* Model name (`gen_ai.request.model`). +* Token usage (`gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`). +* Configuration parameters (`gen_ai.request.top_p`, + `gen_ai.request.max_tokens`). +* Response finish reason (`gen_ai.response.finish_reasons`). + +### Tool Calls and Responses + +Tool executions triggered by the LLM are traced using `tool_call []` +and `tool_response []` spans. + +* `traceToolCall` records tool arguments in the + `gcp.vertex.agent.tool_call_args` attribute. +* `traceToolResponse` records tool output in the + `gcp.vertex.agent.tool_response` attribute. +* If multiple tools are called in parallel, a single `tool_response` span may + be created for the merged result. + +### Context Propagation + +ADK is built on RxJava and heavily uses asynchronous processing, which means +that work is often handed off between different threads. For tracing to work +correctly in such an environment, it's crucial that the active span's context +is propagated across these thread boundaries. If context is not propagated, +new spans may be orphaned or attached to the wrong parent, making traces +difficult to interpret. + +OpenTelemetry stores the currently active span in a thread-local variable. +When an asynchronous operation switches threads, this thread-local context is +lost. To solve this, ADK's `Tracing` class provides functionality to capture +the context on one thread and restore it on another when an asynchronous +operation resumes. This ensures that spans created on different threads are +correctly parented under the same trace. + +The primary mechanism for this is the `Tracing.withContext(context)` method, +which returns an RxJava transformer. When applied to an RxJava stream via +`.compose()`, this transformer ensures that the provided `Context` (containing +the parent span) is re-activated before any `onNext`, `onError`, `onComplete`, +or `onSuccess` signals are propagated downstream. It achieves this by wrapping +the downstream observer with a `TracingObserver`, which uses +`context.makeCurrent()` in a try-with-resources block around each callback, +guaranteeing that the correct span is active when downstream operators execute, +regardless of the thread. + +### RxJava Integration + +ADK integrates OpenTelemetry with RxJava streams to simplify span creation and +ensure context propagation: + +* **Span Creation**: The `Tracing.trace(spanName)` method returns an RxJava + transformer that can be applied to a `Flowable`, `Single`, `Maybe`, or + `Completable` using `.compose()`. This transformer wraps the stream's + execution in a new OpenTelemetry span. +* **Context Propagation**: The `Tracing.withContext(context)` transformer is + used with `.compose()` to ensure that the correct OpenTelemetry `Context` + (and thus the correct parent span) is active when stream operators or + subscriptions are executed, even across thread boundaries. + +## Trace Hierarchy Example + +A typical agent interaction might produce a trace hierarchy like the following: + +``` +invocation +└── invoke_agent my_agent + ├── call_llm + │ ├── tool_call [search_flights] + │ └── tool_response [search_flights] + └── call_llm +``` + +This shows: + +1. The overall `invocation` started by the `Runner`. +2. The invocation of `my_agent`. +3. The first `call_llm` made by `my_agent`. +4. A `tool_call` to `search_flights` and its corresponding `tool_response`. +5. A second `call_llm` made by `my_agent` to generate the final user response. + +### Nested Agents + +ADK supports nested agents, where one agent invokes another. If an agent has +sub-agents, it can transfer control to one of them using the built-in +`transfer_to_agent` tool. When `AgentA` calls `transfer_to_agent` to transfer +control to `AgentB`, the `invoke_agent AgentB` span will appear as a child of +the `invoke_agent AgentA` span, like so: + +``` +invocation +└── invoke_agent AgentA + ├── call_llm + │ ├── tool_call [transfer_to_agent] + │ └── tool_response [transfer_to_agent] + └── invoke_agent AgentB + ├── call_llm + └── ... +``` + +This structure allows you to see how `AgentA` delegated work to `AgentB`. + +## Span Creation References + +The following classes are the primary places where spans are created: + +* **`com.google.adk.runner.Runner`**: Initiates the top-level `invocation` + span for `runAsync` and `runLive`. +* **`com.google.adk.agents.BaseAgent`**: Creates the `invoke_agent + ` span for each agent execution. +* **`com.google.adk.flows.llmflows.BaseLlmFlow`**: Creates the `call_llm` span + when the LLM is invoked. +* **`com.google.adk.flows.llmflows.Functions`**: Creates `tool_call [...]` and + `tool_response [...]` spans when handling tool calls and responses. + +## Configuration + +**ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS**: This environment variable controls +whether LLM request/response content and tool arguments/responses are captured +in span attributes. It defaults to `true`. Set to `false` to exclude potentially +large or sensitive data from traces, in which case a `{}` JSON object will be +recorded instead. diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index e5795d61f..1ee018848 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,6 +16,7 @@ package com.google.adk.telemetry; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -32,10 +33,15 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; +import com.google.adk.testing.TestLlm; +import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponseUsageMetadata; @@ -54,6 +60,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -123,10 +130,7 @@ public void testToolCallSpanLinksToParent() { parentSpanData.getSpanContext().getTraceId(), toolCallSpanData.getSpanContext().getTraceId()); - assertEquals( - "Tool call's parent should be the parent span", - parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, toolCallSpanData); } @Test @@ -146,7 +150,7 @@ public void testToolCallWithoutParentCreatesRootSpan() { // Then: Should create root span (backward compatible) List spans = openTelemetryRule.getSpans(); - assertEquals("Should have exactly 1 span", 1, spans.size()); + assertThat(spans).hasSize(1); SpanData toolCallSpanData = spans.get(0); assertFalse( @@ -193,7 +197,7 @@ public void testNestedSpanHierarchy() { List spans = openTelemetryRule.getSpans(); // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response // [testTool]". - assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); + assertThat(spans).hasSize(4); SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -210,22 +214,13 @@ public void testNestedSpanHierarchy() { SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent - assertEquals( - "Invocation should be child of parent", - parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, invocationSpanData); // tool_call should be child of invocation - assertEquals( - "Tool call should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, toolCallSpanData); // tool_response should be child of tool_call - assertEquals( - "Tool response should be child of tool call", - toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId()); + assertParent(toolCallSpanData, toolResponseSpanData); } @Test @@ -253,7 +248,6 @@ public void testMultipleSpansInParallel() { // Verify all tool calls link to same parent SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = @@ -261,7 +255,7 @@ public void testMultipleSpansInParallel() { .filter(s -> s.getName().startsWith("tool_call")) .toList(); - assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); + assertThat(toolCallSpans).hasSize(3); toolCallSpans.forEach( span -> { @@ -269,10 +263,7 @@ public void testMultipleSpansInParallel() { "Tool call should have same trace ID as parent", parentTraceId, span.getSpanContext().getTraceId()); - assertEquals( - "Tool call should have parent as parent span", - parentSpanId, - span.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, span); }); } @@ -298,10 +289,7 @@ public void testInvokeAgentSpanLinksToInvocation() { SpanData invocationSpanData = findSpanByName("invocation"); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - assertEquals( - "Agent run should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - invokeAgentSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, invokeAgentSpanData); } @Test @@ -323,15 +311,12 @@ public void testCallLlmSpanLinksToAgentRun() { } List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans", 2, spans.size()); + assertThat(spans).hasSize(2); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName("call_llm"); - assertEquals( - "Call LLM should be child of agent run", - invokeAgentSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId()); + assertParent(invokeAgentSpanData, callLlmSpanData); } @Test @@ -349,10 +334,7 @@ public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { SpanData parentSpanData = findSpanByName("invocation"); SpanData agentSpanData = findSpanByName("invoke_agent"); - assertEquals( - "Agent span should be a child of the invocation span", - parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, agentSpanData); } @Test @@ -380,9 +362,7 @@ public void testTraceFlowable() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData flowableSpanData = findSpanByName("flowable"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - flowableSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, flowableSpanData); assertTrue(flowableSpanData.hasEnded()); } @@ -469,9 +449,7 @@ public void testTraceTransformer() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData transformerSpanData = findSpanByName("transformer"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - transformerSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, transformerSpanData); assertTrue(transformerSpanData.hasEnded()); } @@ -485,7 +463,7 @@ public void testTraceAgentInvocation() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -504,7 +482,7 @@ public void testTraceToolCall() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -541,7 +519,7 @@ public void testTraceToolResponse() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -578,7 +556,7 @@ public void testTraceCallLlm() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); @@ -606,12 +584,12 @@ public void testTraceSendData() { Tracing.traceSendData( buildInvocationContext(), "event-1", - ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + ImmutableList.of(Content.fromParts(Part.fromText("hello")))); } finally { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals( @@ -653,37 +631,23 @@ public void baseAgentRunAsync_propagatesContext() throws InterruptedException { } SpanData parent = findSpanByName("parent"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, agentSpan); } @Test public void runnerRunAsync_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - Session session = - runner - .sessionService() - .createSession(new SessionKey("test-app", "test-user", "test-session")) - .blockingGet(); - Content newMessage = Content.fromParts(Part.fromText("hi")); - RunConfig runConfig = RunConfig.builder().build(); - runner - .runAsync(session.userId(), session.id(), newMessage, runConfig, null) - .test() - .await() - .assertComplete(); + runAgent(agent); } finally { parentSpan.end(); } SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); } @Test @@ -713,10 +677,173 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); + } + + @Test + public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when an agent calls an LLM, + // which then invokes a tool. The expected hierarchy is: + // invocation + // └── invoke_agent test_agent + // ├── call_llm + // │ ├── tool_call [search_flights] + // │ └── tool_response [search_flights] + // └── call_llm + + SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); + + TestLlm testLlm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + searchFlightsTool.name(), ImmutableMap.of("destination", "SFO"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("done")))); + + LlmAgent agentWithTool = + LlmAgent.builder() + .name("test_agent") + .description("description") + .model(testLlm) + .tools(ImmutableList.of(searchFlightsTool)) + .build(); + + runAgent(agentWithTool); + + SpanData invocation = findSpanByName("invocation"); + SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); + SpanData toolCall = findSpanByName("tool_call [search_flights]"); + SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + SpanData callLlm1 = callLlmSpans.get(0); + SpanData callLlm2 = callLlmSpans.get(1); + + // Assert hierarchy: + // invocation + // └── invoke_agent test_agent + assertParent(invocation, invokeAgent); + // ├── call_llm 1 + assertParent(invokeAgent, callLlm1); + // │ ├── tool_call [search_flights] + assertParent(callLlm1, toolCall); + // │ └── tool_response [search_flights] + assertParent(callLlm1, toolResponse); + // └── call_llm 2 + assertParent(invokeAgent, callLlm2); + } + + @Test + public void testNestedAgentTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when AgentA transfers to AgentB. + // The expected hierarchy is: + // invocation + // └── invoke_agent AgentA + // ├── call_llm + // │ ├── tool_call [transfer_to_agent] + // │ └── tool_response [transfer_to_agent] + // └── invoke_agent AgentB + // └── call_llm + TestLlm llm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "AgentB"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("agent b response")))); + LlmAgent agentB = LlmAgent.builder().name("AgentB").description("Agent B").model(llm).build(); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .description("Agent A") + .model(llm) + .subAgents(ImmutableList.of(agentB)) + .build(); + + runAgent(agentA); + + SpanData invocation = findSpanByName("invocation"); + SpanData agentASpan = findSpanByName("invoke_agent AgentA"); + SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); + SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); + + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + + SpanData agentACallLlm1 = callLlmSpans.get(0); + SpanData agentBCallLlm = callLlmSpans.get(1); + + assertParent(invocation, agentASpan); + assertParent(agentASpan, agentACallLlm1); + assertParent(agentACallLlm1, toolCall); + assertParent(agentACallLlm1, toolResponse); + assertParent(agentASpan, agentBSpan); + assertParent(agentBSpan, agentBCallLlm); + } + + private void runAgent(BaseAgent agent) throws InterruptedException { + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Session session = + runner + .sessionService() + .createSession(new SessionKey("test-app", "test-user", "test-session")) + .blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.sessionKey(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } + + /** Tool for testing. */ + public static class SearchFlightsTool extends BaseTool { + public SearchFlightsTool() { + super("search_flights", "Search for flights tool"); + } + + @Override + public Single> runAsync(Map args, ToolContext context) { + return Single.just(ImmutableMap.of("result", args)); + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name("search_flights") + .description("Search for flights tool") + .build()); + } + } + + /** + * Asserts that the parent span is the parent of the child span. + * + * @param parent The parent span. + * @param child The child span. + */ + private void assertParent(SpanData parent, SpanData child) { + assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId()); } /** From 0af82e61a3c0dbbd95166a10b450cb507115ab60 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 19 Mar 2026 07:29:15 -0700 Subject: [PATCH 25/29] fix: Removing deprecated methods in Runner PiperOrigin-RevId: 886166671 --- .../java/com/google/adk/runner/Runner.java | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -425,36 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -735,18 +705,6 @@ protected Flowable runLiveImpl( }); } - /** - * Runs the agent asynchronously with a default user ID. - * - * @return stream of generated events. - */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); - } - /** * Checks if the agent and its parent chain allow transfer up the tree. * From dc5d794c066571c7d87f006767bd32298e2a3ba8 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 08:39:07 -0700 Subject: [PATCH 26/29] chore: set version to 1.0.0-rc.1 Release-As: 1.0.0-rc.1 PiperOrigin-RevId: 886198912 --- .release-please-manifest.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b0f3ba770..6db3039d0 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,4 +1,3 @@ { ".": "0.9.0" } - From dfbab955314a428ccb17855a69a77386e924c92b Mon Sep 17 00:00:00 2001 From: ddobrin Date: Tue, 17 Mar 2026 09:30:34 -0400 Subject: [PATCH 27/29] Updated tests in Spring AI, Langchain4j, dependency for Spering AI and GenAI SDK --- .../LangChain4jIntegrationTest.java | 16 ++--- contrib/spring-ai/pom.xml | 2 +- .../AnthropicApiIntegrationTest.java | 62 ++++++++++++------- pom.xml | 2 +- 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 3fafb046d..191e48017 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -44,7 +44,7 @@ class LangChain4jIntegrationTest { - public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String CLAUDE_4_6_SONNET = "claude-sonnet-4-6"; public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; public static final String GPT_4_O_MINI = "gpt-4o-mini"; @@ -55,14 +55,14 @@ void testSimpleAgent() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); LlmAgent agent = LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a helpful science teacher that explains science concepts @@ -91,14 +91,14 @@ void testSingleAgentWithTools() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a friendly assistant. @@ -155,7 +155,7 @@ void testSingleAgentWithTools() { List partsThree = contentThree.parts().get(); assertEquals(1, partsThree.size()); assertTrue(partsThree.get(0).text().isPresent()); - assertTrue(partsThree.get(0).text().get().contains("beautiful")); + assertTrue(partsThree.get(0).text().get().contains("sunny")); } @Test @@ -352,10 +352,10 @@ void testSimpleStreamingResponse() { AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); // when Flowable responses = diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index b24fa4b63..08d237ab5 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -29,7 +29,7 @@ Spring AI integration for the Agent Development Kit. - 2.0.0-M2 + 2.0.0-M3 1.21.3 diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index f21b07ae9..c59a94f82 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -34,7 +34,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; -import org.springframework.ai.anthropic.api.AnthropicApi; /** * Integration tests with real Anthropic API. @@ -53,10 +52,14 @@ void testSimpleAgentWithRealAnthropicApi() throws InterruptedException { Thread.sleep(2000); // Create Anthropic model using Spring AI's builder pattern - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); // Wrap with SpringAI SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -92,10 +95,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { // Add delay to avoid rapid requests Thread.sleep(2000); - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -134,10 +141,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { @Test void testAgentWithToolsAndRealApi() { - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); LlmAgent agent = LlmAgent.builder() @@ -175,10 +186,13 @@ void testAgentWithToolsAndRealApi() { @Test void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { // Test both non-streaming and streaming with the same model to compare behavior - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -271,13 +285,13 @@ void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { @Test void testConfigurationOptions() { // Test with custom configuration - AnthropicChatOptions options = - AnthropicChatOptions.builder().model(CLAUDE_MODEL).temperature(0.7).maxTokens(100).build(); - - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).defaultOptions(options).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); diff --git a/pom.xml b/pom.xml index 62082cfc9..0be05a629 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,7 @@ 1.51.0 0.17.2 2.47.0 - 1.41.0 + 1.43.0 4.33.5 5.11.4 5.20.0 From dbb139439d38157b4b9af38c52824b1e8405a495 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 10:14:14 -0700 Subject: [PATCH 28/29] feat!: remove McpToolset constructors taking Optional parameters PiperOrigin-RevId: 886244600 --- .../com/google/adk/tools/mcp/McpToolset.java | 174 ++++++++++++------ .../google/adk/tools/mcp/McpToolsetTest.java | 9 +- 2 files changed, 120 insertions(+), 63 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 207243ceb..4cafb9681 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -24,6 +24,8 @@ import com.google.adk.agents.ReadonlyContext; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolPredicate; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Booleans; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -32,6 +34,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +54,7 @@ public class McpToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private McpSyncClient mcpSession; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private static final int MAX_RETRIES = 3; private static final long RETRY_DELAY_MILLIS = 100; @@ -62,17 +65,29 @@ public class McpToolset implements BaseToolset { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with SSE server parameters. + * + * @param connectionParams The SSE connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + SseServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** @@ -82,7 +97,9 @@ public McpToolset( * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -90,36 +107,39 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** - * Initializes the McpToolset with local server parameters and no tool filter. + * Initializes the McpToolset with local server parameters. * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names */ - public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + public McpToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** - * Initializes the McpToolset with SSE server parameters, using the ObjectMapper used across the - * ADK. + * Initializes the McpToolset with local server parameters and no tool filter. * - * @param connectionParams The SSE connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param connectionParams The local server connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. */ - public McpToolset(SseServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -129,28 +149,31 @@ public McpToolset(SseServerParameters connectionParams, Optional toolFil * @param connectionParams The SSE connection parameters to the MCP server. */ public McpToolset(SseServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } /** * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK. + * ADK and no tool filter. * * @param connectionParams The local server connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. */ - public McpToolset(ServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams) { + this(connectionParams, JsonBaseModel.getMapper()); } /** - * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK and no tool filter. + * Initializes the McpToolset with an McpSessionManager. * - * @param connectionParams The local server connection parameters to the MCP server. + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolPredicate A {@link ToolPredicate} */ - public McpToolset(ServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + public McpToolset( + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** @@ -158,33 +181,69 @@ public McpToolset(ServerParameters connectionParams) { * * @param mcpSessionManager A McpSessionManager instance for testing. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolNames A list of tool names */ public McpToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.mcpSessionManager = mcpSessionManager; - this.objectMapper = objectMapper; - this.toolFilter = toolFilter; + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, List toolNames) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with an McpSessionManager and no tool filter. + * + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(McpSessionManager mcpSessionManager, ObjectMapper objectMapper) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = null; } /** - * Initializes the McpToolset with Steamable HTTP server parameters. + * Initializes the McpToolset with Streamable HTTP server parameters. * * @param connectionParams The Streamable HTTP connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + StreamableHttpServerParameters connectionParams, + ObjectMapper objectMapper, + List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters and no tool filter. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -194,7 +253,7 @@ public McpToolset( * @param connectionParams The Streamable HTTP connection parameters to the MCP server. */ public McpToolset(StreamableHttpServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } @Override @@ -215,8 +274,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { tool -> new McpTool( tool, this.mcpSession, this.mcpSessionManager, this.objectMapper)) - .filter( - tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))); + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext))); }) .retryWhen( errorObservable -> @@ -357,16 +415,18 @@ public static McpToolset fromConfig(BaseTool.ToolConfig config, String configAbs + " for McpToolset"); } - // Convert tool filter to Optional - Optional toolFilter = Optional.ofNullable(mcpToolsetConfig.toolFilter()); - + List toolNames = mcpToolsetConfig.toolFilter(); Object connectionParameters = Optional.ofNullable(mcpToolsetConfig.stdioConnectionParams()) .or(() -> Optional.ofNullable(mcpToolsetConfig.sseServerParams())) .orElse(mcpToolsetConfig.stdioConnectionParams()); // Create McpToolset with McpSessionManager having appropriate connection parameters - return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolFilter); + if (toolNames != null) { + return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolNames); + } else { + return new McpToolset(new McpSessionManager(connectionParameters), mapper); + } } catch (IllegalArgumentException e) { throw new ConfigurationException("Failed to parse McpToolsetConfig from ToolArgsConfig", e); } diff --git a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java index 3d322b73f..0db218347 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java @@ -34,7 +34,6 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.List; -import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -324,7 +323,7 @@ public void getTools_withToolFilter_returnsFilteredTools() { when(mockMcpSyncClient.listTools()).thenReturn(mockResult); McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter)); + new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), toolFilter); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); @@ -340,8 +339,7 @@ public void getTools_retriesAndFailsAfterMaxRetries() { when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient); when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception")); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); toolset .getTools(mockReadonlyContext) @@ -362,8 +360,7 @@ public void getTools_succeedsOnLastRetryAttempt() { .thenThrow(new RuntimeException("Attempt 2 failed")) .thenReturn(mockResult); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); From 897f9d9776b75d66b6e7e01c98427d4d36d4dd5a Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 19 Mar 2026 10:16:25 -0700 Subject: [PATCH 29/29] chore: add test-jar goal in core sub-project PiperOrigin-RevId: 886245655 --- core/pom.xml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index b3f2f5fd8..02c75f88b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -229,6 +229,16 @@ maven-compiler-plugin + + maven-jar-plugin + + + + test-jar + + + + maven-surefire-plugin