diff --git a/.github/workflows/validation.yml b/.github/workflows/validation.yml index 65e66f8fd..241dc1ccb 100644 --- a/.github/workflows/validation.yml +++ b/.github/workflows/validation.yml @@ -23,13 +23,13 @@ jobs: uses: actions/checkout@v6 - name: Set up Java ${{ matrix.java-version }} - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: temurin java-version: ${{ matrix.java-version }} - name: Cache Maven packages - uses: actions/cache@v3 + uses: actions/cache@v5 with: path: ~/.m2/repository key: ${{ runner.os }}-maven-${{ matrix.java-version }}-${{ hashFiles('**/pom.xml') }} diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b0f3ba770..6db3039d0 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,4 +1,3 @@ { ".": "0.9.0" } - diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..5d33d2172 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,3 @@ +# AGENTS.md + +Validate changes by running `./mvnw test`. diff --git a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java index ccb662b7c..4a375980c 100644 --- a/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java +++ b/a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java @@ -117,7 +117,7 @@ private RemoteA2AAgent(Builder builder) { if (this.description.isEmpty() && this.agentCard.description() != null) { this.description = this.agentCard.description(); } - this.streaming = this.agentCard.capabilities().streaming(); + this.streaming = builder.streaming && this.agentCard.capabilities().streaming(); } public static Builder builder() { @@ -133,6 +133,13 @@ public static class Builder { private List subAgents; private List beforeAgentCallback; private List afterAgentCallback; + private boolean streaming; + + @CanIgnoreReturnValue + public Builder streaming(boolean streaming) { + this.streaming = streaming; + return this; + } @CanIgnoreReturnValue public Builder name(String name) { @@ -181,6 +188,10 @@ public RemoteA2AAgent build() { } } + public boolean isStreaming() { + return streaming; + } + private Message.Builder newA2AMessage(Message.Role role, List> parts) { return new Message.Builder().messageId(UUID.randomUUID().toString()).role(role).parts(parts); } diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java new file mode 100644 index 000000000..d4f1fef58 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/A2AMetadataKey.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of A2A metadata. Adds a prefix used to differentiage ADK-related values stored + * in Metadata an A2A event. + */ +public enum A2AMetadataKey { + TYPE("type"), + IS_LONG_RUNNING("is_long_running"), + PARTIAL("partial"), + GROUNDING_METADATA("grounding_metadata"), + USAGE_METADATA("usage_metadata"), + CUSTOM_METADATA("custom_metadata"), + ERROR_CODE("error_code"); + + private final String type; + + private A2AMetadataKey(String type) { + this.type = "adk_" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java new file mode 100644 index 000000000..e38f28828 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/converters/AdkMetadataKey.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.a2a.converters; + +/** + * Enum for the type of ADK metadata. Adds a prefix used to differentiate A2A-related values stored + * in custom metadata of an ADK session event. + */ +public enum AdkMetadataKey { + TASK_ID("task_id"), + CONTEXT_ID("context_id"); + + private final String type; + + private AdkMetadataKey(String type) { + this.type = "a2a:" + type; + } + + public String getType() { + return type; + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index 61f24fa21..714a79736 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -52,11 +52,7 @@ public final class PartConverter { private static final Logger logger = LoggerFactory.getLogger(PartConverter.class); private static final ObjectMapper objectMapper = new ObjectMapper(); - // Constants for metadata types. By convention metadata keys are prefixed with "adk_" to align - // with the Python and Golang libraries. - public static final String A2A_DATA_PART_METADATA_TYPE_KEY = "adk_type"; - public static final String A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = "adk_is_long_running"; - public static final String A2A_DATA_PART_METADATA_IS_PARTIAL_KEY = "adk_partial"; + // Constants for metadata types. public static final String LANGUAGE_KEY = "language"; public static final String OUTCOME_KEY = "outcome"; public static final String CODE_KEY = "code"; @@ -135,7 +131,7 @@ private static com.google.genai.types.Part convertDataPartToGenAiPart(DataPart d Map metadata = Optional.ofNullable(dataPart.getMetadata()).map(HashMap::new).orElseGet(HashMap::new); - String metadataType = metadata.getOrDefault(A2A_DATA_PART_METADATA_TYPE_KEY, "").toString(); + String metadataType = metadata.getOrDefault(A2AMetadataKey.TYPE.getType(), "").toString(); if ((data.containsKey(NAME_KEY) && data.containsKey(ARGS_KEY)) || metadataType.equals(A2ADataPartMetadataType.FUNCTION_CALL.getType())) { @@ -218,7 +214,7 @@ private static DataPart createDataPartFromFunctionCall( addValueIfPresent(data, WILL_CONTINUE_KEY, functionCall.willContinue()); addValueIfPresent(data, PARTIAL_ARGS_KEY, functionCall.partialArgs()); - metadata.put(A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_CALL.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -245,7 +241,7 @@ private static DataPart createDataPartFromFunctionResponse( addValueIfPresent(data, PARTS_KEY, functionResponse.parts()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -268,7 +264,7 @@ private static DataPart createDataPartFromCodeExecutionResult( addValueIfPresent(data, OUTPUT_KEY, codeExecutionResult.output()); metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.CODE_EXECUTION_RESULT.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -290,8 +286,7 @@ private static DataPart createDataPartFromExecutableCode( .orElse(Language.Known.LANGUAGE_UNSPECIFIED.toString())); addValueIfPresent(data, CODE_KEY, executableCode.code()); - metadata.put( - A2A_DATA_PART_METADATA_TYPE_KEY, A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); + metadata.put(A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.EXECUTABLE_CODE.getType()); return new DataPart(data.buildOrThrow(), metadata.buildOrThrow()); } @@ -305,7 +300,7 @@ public static io.a2a.spec.Part fromGenaiPart(Part part, boolean isPartial) { } ImmutableMap.Builder metadata = ImmutableMap.builder(); if (isPartial) { - metadata.put(A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + metadata.put(A2AMetadataKey.PARTIAL.getType(), true); } if (part.text().isPresent()) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 503432a30..cffd76983 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -19,12 +19,20 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import com.google.genai.types.Part; import io.a2a.client.ClientEvent; import io.a2a.client.MessageEvent; @@ -43,11 +51,13 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Utility for converting ADK events to A2A spec messages (and back). */ public final class ResponseConverter { + private static final ObjectMapper objectMapper = new ObjectMapper(); private static final Logger logger = LoggerFactory.getLogger(ResponseConverter.class); private static final ImmutableSet PENDING_STATES = ImmutableSet.of(TaskState.WORKING, TaskState.SUBMITTED); @@ -74,12 +84,11 @@ public static Optional clientEventToEvent( throw new IllegalArgumentException("Unsupported ClientEvent type: " + event.getClass()); } - private static boolean isPartial(Map metadata) { + private static boolean isPartial(@Nullable Map metadata) { if (metadata == null) { return false; } - return Objects.equals( - metadata.getOrDefault(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, false), true); + return Objects.equals(metadata.getOrDefault(A2AMetadataKey.PARTIAL.getType(), false), true); } /** @@ -110,7 +119,12 @@ private static Optional handleTaskUpdate( // append=false, lastChunk=false: emit as partial, reset aggregation // append=true, lastChunk=true: emit as partial, update aggregation and emit as non-partial // append=false, lastChunk=true: emit as non-partial, drop aggregation - return Optional.of(eventPart); + return Optional.of( + updateEventMetadata( + eventPart, + artifactEvent.getMetadata(), + artifactEvent.getTaskId(), + artifactEvent.getContextId())); } if (updateEvent instanceof TaskStatusUpdateEvent statusEvent) { @@ -128,14 +142,21 @@ private static Optional handleTaskUpdate( }); if (statusEvent.isFinal()) { - return messageEvent - .map(Event::toBuilder) - .or(() -> Optional.of(remoteAgentEventBuilder(context))) - .map(builder -> builder.turnComplete(true)) - .map(builder -> builder.partial(false)) - .map(Event.Builder::build); + messageEvent = + messageEvent + .map(Event::toBuilder) + .or(() -> Optional.of(remoteAgentEventBuilder(context))) + .map(builder -> builder.turnComplete(true)) + .map(builder -> builder.partial(false)) + .map(Event.Builder::build); } - return messageEvent; + return messageEvent.map( + finalMessageEvent -> + updateEventMetadata( + finalMessageEvent, + statusEvent.getMetadata(), + statusEvent.getTaskId(), + statusEvent.getContextId())); } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -163,9 +184,13 @@ public static Event messageToFailedEvent(Message message, InvocationContext invo /** Converts an A2A message back to ADK events. */ public static Event messageToEvent(Message message, InvocationContext invocationContext) { - return remoteAgentEventBuilder(invocationContext) - .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) - .build(); + return updateEventMetadata( + remoteAgentEventBuilder(invocationContext) + .content(fromModelParts(PartConverter.toGenaiParts(message.getParts()))) + .build(), + message.getMetadata(), + message.getTaskId(), + message.getContextId()); } /** @@ -228,7 +253,8 @@ public static Event taskToEvent(Task task, InvocationContext invocationContext) eventBuilder.longRunningToolIds(longRunningToolIds.build()); } eventBuilder.turnComplete(isFinal); - return eventBuilder.build(); + return updateEventMetadata( + eventBuilder.build(), task.getMetadata(), task.getId(), task.getContextId()); } private static ImmutableSet getLongRunningToolIds( @@ -241,9 +267,7 @@ private static ImmutableSet getLongRunningToolIds( return Optional.empty(); } Object isLongRunning = - dataPart - .getMetadata() - .get(PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY); + dataPart.getMetadata().get(A2AMetadataKey.IS_LONG_RUNNING.getType()); if (!Objects.equals(isLongRunning, true)) { return Optional.empty(); } @@ -256,6 +280,77 @@ private static ImmutableSet getLongRunningToolIds( .collect(toImmutableSet()); } + private static Event updateEventMetadata( + Event event, + @Nullable Map clientMetadata, + @Nullable String taskId, + @Nullable String contextId) { + if (taskId == null || contextId == null) { + logger.warn("Task ID or context ID is null, skipping metadata update."); + return event; + } + + if (clientMetadata == null) { + clientMetadata = ImmutableMap.of(); + } + Event.Builder eventBuilder = event.toBuilder(); + Object groundingMetadata = clientMetadata.get(A2AMetadataKey.GROUNDING_METADATA.getType()); + // if groundingMetadata is null, parseMetadata will return null as well. + eventBuilder.groundingMetadata(parseMetadata(groundingMetadata, GroundingMetadata.class)); + Object usageMetadata = clientMetadata.get(A2AMetadataKey.USAGE_METADATA.getType()); + // if usageMetadata is null, parseMetadata will return null as well. + eventBuilder.usageMetadata( + parseMetadata(usageMetadata, GenerateContentResponseUsageMetadata.class)); + + ImmutableList.Builder customMetadataList = ImmutableList.builder(); + customMetadataList + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.TASK_ID.getType()) + .stringValue(taskId) + .build()) + .add( + CustomMetadata.builder() + .key(AdkMetadataKey.CONTEXT_ID.getType()) + .stringValue(contextId) + .build()); + Object customMetadata = clientMetadata.get(A2AMetadataKey.CUSTOM_METADATA.getType()); + if (customMetadata != null) { + customMetadataList.addAll( + parseMetadata(customMetadata, new TypeReference>() {})); + } + eventBuilder.customMetadata(customMetadataList.build()); + + Object errorCode = clientMetadata.get(A2AMetadataKey.ERROR_CODE.getType()); + eventBuilder.errorCode(parseMetadata(errorCode, FinishReason.class)); + + return eventBuilder.build(); + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, Class type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type, e); + } + } + + private static @Nullable T parseMetadata(@Nullable Object metadata, TypeReference type) { + try { + if (metadata instanceof String jsonString) { + return objectMapper.readValue(jsonString, type); + } else { + return objectMapper.convertValue(metadata, type); + } + } catch (IllegalArgumentException | JsonProcessingException e) { + throw new IllegalArgumentException("Failed to parse metadata of type " + type.getType(), e); + } + } + private static Event emptyEvent(InvocationContext invocationContext) { Event.Builder builder = Event.builder() diff --git a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java index b1ffa248a..0609c3b04 100644 --- a/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/agent/RemoteA2AAgentTest.java @@ -113,6 +113,20 @@ public void setUp() { .build(); } + @Test + public void createAgent_streaming_false_returnsNonStreamingAgent() { + // With streaming false, the agent should not stream even if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(false).build(); + assertThat(agent.isStreaming()).isFalse(); + } + + @Test + public void createAgent_streaming_true_returnsStreamingAgent() { + // With streaming true, the agent should support streaming if the AgentCard supports streaming. + RemoteA2AAgent agent = getAgentBuilder().streaming(true).build(); + assertThat(agent.isStreaming()).isTrue(); + } + @Test public void runAsync_aggregatesPartialEvents() { RemoteA2AAgent agent = createAgent(); @@ -763,7 +777,7 @@ private RemoteA2AAgent.Builder getAgentBuilder() { } private RemoteA2AAgent createAgent() { - return getAgentBuilder().build(); + return getAgentBuilder().streaming(true).build(); } @SuppressWarnings("unchecked") // cast for Mockito diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java index d93466dd2..4a0828c43 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/PartConverterTest.java @@ -8,9 +8,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Blob; +import com.google.genai.types.CodeExecutionResult; +import com.google.genai.types.ExecutableCode; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Language; +import com.google.genai.types.Outcome; import com.google.genai.types.Part; import io.a2a.spec.DataPart; import io.a2a.spec.FilePart; @@ -86,8 +90,7 @@ public void toGenaiPart_withDataPartFunctionCall_returnsGenaiFunctionCallPart() new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType())); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -121,7 +124,7 @@ public void toGenaiPart_withDataPartFunctionResponse_returnsGenaiFunctionRespons new DataPart( data, ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType())); Part result = PartConverter.toGenaiPart(dataPart); @@ -188,7 +191,7 @@ public void fromGenaiPart_withTextPart_returnsTextPart() { assertThat(((TextPart) result).getText()).isEqualTo("text"); assertThat(((TextPart) result).getMetadata()).containsEntry("thought", true); assertThat(((TextPart) result).getMetadata()) - .containsEntry(PartConverter.A2A_DATA_PART_METADATA_IS_PARTIAL_KEY, true); + .containsEntry(A2AMetadataKey.PARTIAL.getType(), true); } @Test @@ -226,6 +229,39 @@ public void fromGenaiPart_withInlineDataPart_returnsFilePartWithBytes() { assertThat(Base64.getDecoder().decode(fileWithBytes.bytes())).isEqualTo(bytes); } + @Test + public void fromGenaiPart_dataPart_executableCode_returnsDataPart() { + ExecutableCode executableCode = + ExecutableCode.builder().code("print('hello')").language(new Language("python")).build(); + Part part = Part.builder().executableCode(executableCode).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("code")).isEqualTo("print('hello')"); + assertThat(dataPart.getData().get("language")).isEqualTo("python"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("executable_code"); + } + + @Test + public void fromGenaiPart_dataPart_codeExecutionResult_returnsDataPart() { + CodeExecutionResult codeExecutionResult = + CodeExecutionResult.builder() + .outcome(new Outcome("OUTCOME_OK")) + .output("print('hello')") + .build(); + Part part = Part.builder().codeExecutionResult(codeExecutionResult).build(); + io.a2a.spec.Part result = PartConverter.fromGenaiPart(part, false); + + assertThat(result).isInstanceOf(DataPart.class); + DataPart dataPart = (DataPart) result; + assertThat(dataPart.getData().get("outcome")).isEqualTo("OUTCOME_OK"); + assertThat(dataPart.getData().get("output")).isEqualTo("print('hello')"); + assertThat(dataPart.getMetadata().get(A2AMetadataKey.TYPE.getType())) + .isEqualTo("code_execution_result"); + } + @Test public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { Part part = @@ -255,8 +291,7 @@ public void fromGenaiPart_withFunctionCallPart_returnsDataPart() { true); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_CALL.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_CALL.getType()); } @Test @@ -275,8 +310,7 @@ public void fromGenaiPart_withFunctionResponsePart_returnsDataPart() { .containsExactly("name", "func", "id", "1", "response", ImmutableMap.of()); assertThat(dataPart.getMetadata()) .containsEntry( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, - A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); + A2AMetadataKey.TYPE.getType(), A2ADataPartMetadataType.FUNCTION_RESPONSE.getType()); } @Test diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index d84dc42cd..b61b00e1a 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -1,6 +1,8 @@ package com.google.adk.a2a.converters; import static com.google.common.truth.Truth.assertThat; +import static java.util.stream.Collectors.joining; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; @@ -13,6 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.GroundingMetadata; import io.a2a.client.MessageEvent; import io.a2a.client.TaskUpdateEvent; import io.a2a.spec.Artifact; @@ -136,6 +142,74 @@ public void taskToEvent_withStatusMessage_returnsEvent() { assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); } + @Test + public void taskToEvent_withGroundingMetadata_returnsEvent() { + GroundingMetadata groundingMetadata = + GroundingMetadata.builder().webSearchQueries("test-query").build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of( + A2AMetadataKey.GROUNDING_METADATA.getType(), groundingMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.groundingMetadata()).hasValue(groundingMetadata); + } + + @Test + public void taskToEvent_withCustomMetadata_returnsEvent() { + ImmutableList customMetadataList = + ImmutableList.of( + CustomMetadata.builder().key("test-key").stringValue("test-value").build()); + String customMetadataJson = + customMetadataList.stream().map(CustomMetadata::toJson).collect(joining(",", "[", "]")); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.CUSTOM_METADATA.getType(), customMetadataJson)) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.content().get().parts().get().get(0).text()).hasValue("Status message"); + assertThat(event.customMetadata().get()) + .containsExactly( + CustomMetadata.builder().key("a2a:task_id").stringValue("task-1").build(), + CustomMetadata.builder().key("a2a:context_id").stringValue("context-1").build(), + CustomMetadata.builder().key("test-key").stringValue("test-value").build()) + .inOrder(); + } + + @Test + public void messageToEvent_withMissingTaskId_returnsEvent() { + Message a2aMessage = + new Message.Builder() + .messageId("msg-1") + .role(Message.Role.USER) + .taskId("task-1") + .parts(ImmutableList.of(new TextPart("test-message"))) + .build(); + Event event = ResponseConverter.messageToEvent(a2aMessage, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.customMetadata()).isEmpty(); + } + @Test public void taskToEvent_withNoMessage_returnsEmptyEvent() { TaskStatus status = new TaskStatus(TaskState.WORKING, null, null); @@ -152,18 +226,18 @@ public void taskToEvent_withInputRequired_parsesLongRunningToolIds() { ImmutableMap.of("name", "myTool", "id", "call_123", "args", ImmutableMap.of()); ImmutableMap metadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart dataPart = new DataPart(data, metadata); ImmutableMap statusData = ImmutableMap.of("name", "messageTools", "id", "msg_123", "args", ImmutableMap.of()); ImmutableMap statusMetadata = ImmutableMap.of( - PartConverter.A2A_DATA_PART_METADATA_TYPE_KEY, + A2AMetadataKey.TYPE.getType(), "function_call", - PartConverter.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, + A2AMetadataKey.IS_LONG_RUNNING.getType(), true); DataPart statusDataPart = new DataPart(statusData, statusMetadata); Message statusMessage = @@ -361,6 +435,99 @@ public void clientEventToEvent_withFailedTaskStatusUpdateEvent_returnsErrorEvent assertThat(resultEvent.turnComplete()).hasValue(true); } + @Test + public void taskToEvent_withInvalidMetadata_throwsException() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.GROUNDING_METADATA.getType(), "{ invalid json ]")) + .build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ResponseConverter.taskToEvent(task, invocationContext)); + assertThat(exception).hasMessageThat().contains("Failed to parse metadata"); + assertThat(exception).hasMessageThat().contains("GroundingMetadata"); + } + + @Test + public void taskToEvent_withErrorCode_returnsEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata(ImmutableMap.of(A2AMetadataKey.ERROR_CODE.getType(), "\"STOP\"")) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.errorCode()).hasValue(new FinishReason(FinishReason.Known.STOP)); + } + + @Test + public void taskToEvent_withUsageMetadata_returnsEvent() { + GenerateContentResponseUsageMetadata usageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.WORKING, statusMessage, null); + Task task = + testTask() + .status(status) + .artifacts(null) + .metadata( + ImmutableMap.of(A2AMetadataKey.USAGE_METADATA.getType(), usageMetadata.toJson())) + .build(); + Event event = ResponseConverter.taskToEvent(task, invocationContext); + assertThat(event).isNotNull(); + assertThat(event.usageMetadata()).hasValue(usageMetadata); + } + + @Test + public void clientEventToEvent_withTaskArtifactUpdateEventAndPartialTrue_returnsEmpty() { + io.a2a.spec.Part a2aPart = new TextPart("Artifact content"); + Artifact artifact = + new Artifact.Builder().artifactId("artifact-1").parts(ImmutableList.of(a2aPart)).build(); + Task task = + testTask() + .status(new TaskStatus(TaskState.COMPLETED)) + .artifacts(ImmutableList.of(artifact)) + .build(); + TaskArtifactUpdateEvent updateEvent = + new TaskArtifactUpdateEvent.Builder() + .lastChunk(true) + .metadata(ImmutableMap.of(A2AMetadataKey.PARTIAL.getType(), true)) + .contextId("context-1") + .artifact(artifact) + .taskId("task-id-1") + .build(); + TaskUpdateEvent event = new TaskUpdateEvent(task, updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isEmpty(); + } + private static final class TestAgent extends BaseAgent { TestAgent() { super("test_agent", "test", ImmutableList.of(), null, null); diff --git a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java index ffcbcf8f5..43ca6889f 100644 --- a/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java +++ b/contrib/firestore-session-service/src/test/java/com/google/adk/sessions/FirestoreSessionServiceTest.java @@ -47,15 +47,11 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.observers.TestObserver; import java.time.Instant; -import java.util.AbstractMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -530,44 +526,6 @@ void appendAndGet_withAllPartTypes_serializesAndDeserializesCorrectly() { }); } - /** - * A wrapper class that implements ConcurrentMap but delegates to a HashMap. This is a workaround - * to allow putting null values, which ConcurrentHashMap forbids, for testing state removal logic. - */ - private static class HashMapAsConcurrentMap extends AbstractMap - implements ConcurrentMap { - private final HashMap map; - - public HashMapAsConcurrentMap(Map map) { - this.map = new HashMap<>(map); - } - - @Override - public Set> entrySet() { - return map.entrySet(); - } - - @Override - public V putIfAbsent(K key, V value) { - return map.putIfAbsent(key, value); - } - - @Override - public boolean remove(Object key, Object value) { - return map.remove(key, value); - } - - @Override - public boolean replace(K key, V oldValue, V newValue) { - return map.replace(key, oldValue, newValue); - } - - @Override - public V replace(K key, V value) { - return map.replace(key, value); - } - } - /** Tests that appendEvent with only app state deltas updates the correct stores. */ @Test void appendEvent_withAppOnlyStateDeltas_updatesCorrectStores() { @@ -662,63 +620,6 @@ void appendEvent_withUserOnlyStateDeltas_updatesCorrectStores() { verify(mockSessionDocRef, never()).update(eq(Constants.KEY_STATE), any()); } - /** - * Tests that appendEvent with all types of state deltas updates the correct stores and session - * state. - */ - @Test - void appendEvent_withAllStateDeltas_updatesCorrectStores() { - // Arrange - Session session = - Session.builder(SESSION_ID) - .appName(APP_NAME) - .userId(USER_ID) - .state(new ConcurrentHashMap<>()) // The session state itself must be concurrent - .build(); - session.state().put("keyToRemove", "someValue"); - - Map stateDeltaMap = new HashMap<>(); - stateDeltaMap.put("sessionKey", "sessionValue"); - stateDeltaMap.put("_app_appKey", "appValue"); - stateDeltaMap.put("_user_userKey", "userValue"); - stateDeltaMap.put("keyToRemove", null); - - // Use the wrapper to satisfy the ConcurrentMap interface for the builder - EventActions actions = - EventActions.builder().stateDelta(new HashMapAsConcurrentMap<>(stateDeltaMap)).build(); - - Event event = - Event.builder() - .author("model") - .content(Content.builder().parts(List.of(Part.fromText("..."))).build()) - .actions(actions) - .build(); - - when(mockSessionsCollection.document(SESSION_ID)).thenReturn(mockSessionDocRef); - when(mockEventsCollection.document()).thenReturn(mockEventDocRef); - when(mockEventDocRef.getId()).thenReturn(EVENT_ID); - // THIS IS THE MISSING MOCK: Stub the call to get the document by its specific ID. - when(mockEventsCollection.document(EVENT_ID)).thenReturn(mockEventDocRef); - // Add the missing mock for the final session update call - when(mockSessionDocRef.update(anyMap())) - .thenReturn(ApiFutures.immediateFuture(mockWriteResult)); - - // Act - sessionService.appendEvent(session, event).test().assertComplete(); - - // Assert - assertThat(session.state()).containsEntry("sessionKey", "sessionValue"); - assertThat(session.state()).doesNotContainKey("keyToRemove"); - - ArgumentCaptor> appStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockAppStateDocRef).set(appStateCaptor.capture(), any(SetOptions.class)); - assertThat(appStateCaptor.getValue()).containsEntry("appKey", "appValue"); - - ArgumentCaptor> userStateCaptor = ArgumentCaptor.forClass(Map.class); - verify(mockUserStateUserDocRef).set(userStateCaptor.capture(), any(SetOptions.class)); - assertThat(userStateCaptor.getValue()).containsEntry("userKey", "userValue"); - } - /** Tests that getSession skips malformed events and returns only the well-formed ones. */ @Test @SuppressWarnings("unchecked") diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index c2326fa0a..3dd2d1132 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -58,11 +58,6 @@ google-adk ${project.version} - - com.google.adk - google-adk-dev - ${project.version} - com.google.genai google-genai diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 80c25610d..3ccb1e029 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; import com.google.adk.models.LlmRequest; @@ -428,8 +429,24 @@ private List toToolSpecifications(LlmRequest llmRequest) { baseTool -> { if (baseTool.declaration().isPresent()) { FunctionDeclaration functionDeclaration = baseTool.declaration().get(); - if (functionDeclaration.parameters().isPresent()) { - Schema schema = functionDeclaration.parameters().get(); + Schema schema = null; + if (functionDeclaration.parametersJsonSchema().isPresent()) { + Object jsonSchemaObj = functionDeclaration.parametersJsonSchema().get(); + try { + if (jsonSchemaObj instanceof Schema) { + schema = (Schema) jsonSchemaObj; + } else { + schema = JsonBaseModel.getMapper().convertValue(jsonSchemaObj, Schema.class); + } + } catch (Exception e) { + throw new IllegalStateException( + "Failed to convert parametersJsonSchema to Schema: " + e.getMessage(), e); + } + } else if (functionDeclaration.parameters().isPresent()) { + schema = functionDeclaration.parameters().get(); + } + + if (schema != null) { ToolSpecification toolSpecification = ToolSpecification.builder() .name(baseTool.name()) @@ -438,11 +455,9 @@ private List toToolSpecifications(LlmRequest llmRequest) { .build(); toolSpecifications.add(toolSpecification); } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking parameters: " + baseTool); } } else { - // TODO exception or something else? throw new IllegalStateException("Tool lacking declaration: " + baseTool); } }); diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java index 3fafb046d..191e48017 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -44,7 +44,7 @@ class LangChain4jIntegrationTest { - public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String CLAUDE_4_6_SONNET = "claude-sonnet-4-6"; public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; public static final String GPT_4_O_MINI = "gpt-4o-mini"; @@ -55,14 +55,14 @@ void testSimpleAgent() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); LlmAgent agent = LlmAgent.builder() .name("science-app") .description("Science teacher agent") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a helpful science teacher that explains science concepts @@ -91,14 +91,14 @@ void testSingleAgentWithTools() { AnthropicChatModel claudeModel = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); BaseAgent agent = LlmAgent.builder() .name("friendly-weather-app") .description("Friend agent that knows about the weather") - .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .model(new LangChain4j(claudeModel, CLAUDE_4_6_SONNET)) .instruction( """ You are a friendly assistant. @@ -155,7 +155,7 @@ void testSingleAgentWithTools() { List partsThree = contentThree.parts().get(); assertEquals(1, partsThree.size()); assertTrue(partsThree.get(0).text().isPresent()); - assertTrue(partsThree.get(0).text().get().contains("beautiful")); + assertTrue(partsThree.get(0).text().get().contains("sunny")); } @Test @@ -352,10 +352,10 @@ void testSimpleStreamingResponse() { AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) - .modelName(CLAUDE_3_7_SONNET_20250219) + .modelName(CLAUDE_4_6_SONNET) .build(); - LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_4_6_SONNET); // when Flowable responses = diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 428a5660c..076bb79a3 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -688,4 +688,128 @@ void testGenerateContentWithStructuredResponseJsonSchema() { final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema") + void testGenerateContentWithMcpToolParametersJsonSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // MCP tools use parametersJsonSchema() instead of parameters() + // Create a JSON schema object (Map representation) + final Map jsonSchemaMap = + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "City name")), + "required", + List.of("city")); + + // Mock parametersJsonSchema() to return the JSON schema object + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } + + @Test + @DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema") + void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() { + // Given + // Create a mock BaseTool for MCP tool + final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class); + when(mcpTool.name()).thenReturn("mcpTool"); + when(mcpTool.description()).thenReturn("An MCP tool"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a Schema object directly (when parametersJsonSchema returns Schema) + final Schema cityPropertySchema = + Schema.builder().type("STRING").description("City name").build(); + + final Schema objectSchema = + Schema.builder() + .type("OBJECT") + .properties(Map.of("city", cityPropertySchema)) + .required(List.of("city")) + .build(); + + // Mock parametersJsonSchema() to return Schema directly + when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema)); + when(functionDeclaration.parameters()).thenReturn(Optional.empty()); + + // Create a LlmRequest with the MCP tool + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool")))) + .tools(Map.of("mcpTool", mcpTool)) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("Tool executed successfully"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Tool executed successfully"); + + // Verify the request was built correctly with the tool specification + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications were created from parametersJsonSchema + assertThat(capturedRequest.toolSpecifications()).isNotEmpty(); + assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool"); + assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool"); + } } diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index b24fa4b63..08d237ab5 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -29,7 +29,7 @@ Spring AI integration for the Agent Development Kit. - 2.0.0-M2 + 2.0.0-M3 1.21.3 diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java index f21b07ae9..c59a94f82 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/integrations/AnthropicApiIntegrationTest.java @@ -34,7 +34,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.anthropic.AnthropicChatModel; import org.springframework.ai.anthropic.AnthropicChatOptions; -import org.springframework.ai.anthropic.api.AnthropicApi; /** * Integration tests with real Anthropic API. @@ -53,10 +52,14 @@ void testSimpleAgentWithRealAnthropicApi() throws InterruptedException { Thread.sleep(2000); // Create Anthropic model using Spring AI's builder pattern - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); // Wrap with SpringAI SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -92,10 +95,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { // Add delay to avoid rapid requests Thread.sleep(2000); - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -134,10 +141,14 @@ void testStreamingWithRealAnthropicApi() throws InterruptedException { @Test void testAgentWithToolsAndRealApi() { - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); LlmAgent agent = LlmAgent.builder() @@ -175,10 +186,13 @@ void testAgentWithToolsAndRealApi() { @Test void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { // Test both non-streaming and streaming with the same model to compare behavior - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); @@ -271,13 +285,13 @@ void testDirectComparisonNonStreamingVsStreaming() throws InterruptedException { @Test void testConfigurationOptions() { // Test with custom configuration - AnthropicChatOptions options = - AnthropicChatOptions.builder().model(CLAUDE_MODEL).temperature(0.7).maxTokens(100).build(); - - AnthropicApi anthropicApi = - AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - AnthropicChatModel anthropicModel = - AnthropicChatModel.builder().anthropicApi(anthropicApi).defaultOptions(options).build(); + var options = + AnthropicChatOptions.builder() + .model(CLAUDE_MODEL) + .maxTokens(1024) + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .build(); + AnthropicChatModel anthropicModel = AnthropicChatModel.builder().options(options).build(); SpringAI springAI = new SpringAI(anthropicModel, CLAUDE_MODEL); diff --git a/core/pom.xml b/core/pom.xml index 8c3c2069c..02c75f88b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -197,6 +197,26 @@ opentelemetry-sdk-testing test + + com.google.cloud + google-cloud-bigquery + 2.40.0 + + + org.apache.arrow + arrow-vector + 17.0.0 + + + org.apache.arrow + arrow-memory-core + 17.0.0 + + + org.apache.arrow + arrow-memory-netty + 17.0.0 + @@ -209,6 +229,16 @@ maven-compiler-plugin + + maven-jar-plugin + + + + test-jar + + + + maven-surefire-plugin diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 91ce13a87..365f4f8c1 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -90,16 +90,6 @@ public Builder toBuilder() { return new Builder(this); } - /** - * Creates a shallow copy of the given {@link InvocationContext}. - * - * @deprecated Use {@code other.toBuilder().build()} instead. - */ - @Deprecated(forRemoval = true) - public static InvocationContext copyOf(InvocationContext other) { - return other.toBuilder().build(); - } - /** Returns the session service for managing session state. */ public BaseSessionService sessionService() { return sessionService; @@ -156,16 +146,6 @@ public BaseAgent agent() { return agent; } - /** - * Sets the [agent] being invoked. This is useful when delegating to a sub-agent. - * - * @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead. - */ - @Deprecated(forRemoval = true) - public void agent(BaseAgent agent) { - this.agent = agent; - } - /** Returns the session associated with this invocation. */ public Session session() { return session; diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index d326d8154..077068283 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -45,8 +45,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.codeexecutors.BaseCodeExecutor; import com.google.adk.events.Event; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; import com.google.adk.flows.llmflows.AutoFlow; import com.google.adk.flows.llmflows.BaseLlmFlow; import com.google.adk.flows.llmflows.SingleFlow; @@ -97,8 +95,6 @@ public enum IncludeContents { private final List toolsUnion; private final ImmutableList toolsets; private final Optional generateContentConfig; - // TODO: Remove exampleProvider field - examples should only be provided via ExampleTool - private final Optional exampleProvider; private final IncludeContents includeContents; private final boolean planning; @@ -132,7 +128,6 @@ protected LlmAgent(Builder builder) { this.globalInstruction = requireNonNullElse(builder.globalInstruction, new Instruction.Static("")); this.generateContentConfig = Optional.ofNullable(builder.generateContentConfig); - this.exampleProvider = Optional.ofNullable(builder.exampleProvider); this.includeContents = requireNonNullElse(builder.includeContents, IncludeContents.DEFAULT); this.planning = builder.planning != null && builder.planning; this.maxSteps = Optional.ofNullable(builder.maxSteps); @@ -180,7 +175,6 @@ public static class Builder extends BaseAgent.Builder { private Instruction globalInstruction; private ImmutableList toolsUnion; private GenerateContentConfig generateContentConfig; - private BaseExampleProvider exampleProvider; private IncludeContents includeContents; private Boolean planning; private Integer maxSteps; @@ -253,26 +247,6 @@ public Builder generateContentConfig(GenerateContentConfig generateContentConfig return this; } - // TODO: Remove these example provider methods and only use ExampleTool for providing examples. - // Direct example methods should be deprecated in favor of using ExampleTool consistently. - @CanIgnoreReturnValue - public Builder exampleProvider(BaseExampleProvider exampleProvider) { - this.exampleProvider = exampleProvider; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(List examples) { - this.exampleProvider = (unused) -> examples; - return this; - } - - @CanIgnoreReturnValue - public Builder exampleProvider(Example... examples) { - this.exampleProvider = (unused) -> ImmutableList.copyOf(examples); - return this; - } - @CanIgnoreReturnValue public Builder includeContents(IncludeContents includeContents) { this.includeContents = includeContents; @@ -620,32 +594,6 @@ protected void validate() { this.disallowTransferToParent != null && this.disallowTransferToParent; this.disallowTransferToPeers = this.disallowTransferToPeers != null && this.disallowTransferToPeers; - - if (this.outputSchema != null) { - if (!this.disallowTransferToParent || !this.disallowTransferToPeers) { - logger.warn( - "Invalid config for agent {}: outputSchema cannot co-exist with agent transfer" - + " configurations. Setting disallowTransferToParent=true and" - + " disallowTransferToPeers=true.", - this.name); - this.disallowTransferToParent = true; - this.disallowTransferToPeers = true; - } - - if (this.subAgents != null && !this.subAgents.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, subAgents must be empty to disable agent" - + " transfer."); - } - if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); - } - } } @Override @@ -812,11 +760,6 @@ public Optional generateContentConfig() { return generateContentConfig; } - // TODO: Remove this getter - examples should only be provided via ExampleTool - public Optional exampleProvider() { - return exampleProvider; - } - public IncludeContents includeContents() { return includeContents; } @@ -829,7 +772,7 @@ public List toolsUnion() { return toolsUnion; } - public ImmutableList toolsets() { + public List toolsets() { return toolsets; } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 1ca856b45..83fd60e54 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -157,9 +157,6 @@ public void setRequestedToolConfirmations( Map requestedToolConfirmations) { if (requestedToolConfirmations == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - } else if (requestedToolConfirmations instanceof ConcurrentMap) { - this.requestedToolConfirmations = - (ConcurrentMap) requestedToolConfirmations; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(requestedToolConfirmations); } @@ -287,15 +284,23 @@ public Builder skipSummarization(boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") - public Builder stateDelta(ConcurrentMap value) { - this.stateDelta = value; + public Builder stateDelta(@Nullable Map value) { + if (value == null) { + this.stateDelta = new ConcurrentHashMap<>(); + } else { + this.stateDelta = new ConcurrentHashMap<>(value); + } return this; } @CanIgnoreReturnValue @JsonProperty("artifactDelta") - public Builder artifactDelta(Map value) { - this.artifactDelta = new ConcurrentHashMap<>(value); + public Builder artifactDelta(@Nullable Map value) { + if (value == null) { + this.artifactDelta = new ConcurrentHashMap<>(); + } else { + this.artifactDelta = new ConcurrentHashMap<>(value); + } return this; } @@ -333,10 +338,6 @@ public Builder requestedAuthConfigs( public Builder requestedToolConfirmations(@Nullable Map value) { if (value == null) { this.requestedToolConfirmations = new ConcurrentHashMap<>(); - return this; - } - if (value instanceof ConcurrentMap) { - this.requestedToolConfirmations = (ConcurrentMap) value; } else { this.requestedToolConfirmations = new ConcurrentHashMap<>(value); } diff --git a/core/src/main/java/com/google/adk/examples/ExampleUtils.java b/core/src/main/java/com/google/adk/examples/ExampleUtils.java index 9cce535dc..2f3927ece 100644 --- a/core/src/main/java/com/google/adk/examples/ExampleUtils.java +++ b/core/src/main/java/com/google/adk/examples/ExampleUtils.java @@ -64,6 +64,9 @@ public final class ExampleUtils { * @return string representation of the examples block. */ private static String convertExamplesToText(List examples) { + if (examples.isEmpty()) { + return ""; + } StringBuilder examplesStr = new StringBuilder(); // super header diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ab5f6567a..8fabc978d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -91,8 +91,9 @@ public BaseLlmFlow( * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the * events generated by them. */ - protected Flowable preprocess( + private Flowable preprocess( InvocationContext context, AtomicReference llmRequestRef) { + Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -114,6 +115,7 @@ protected Flowable preprocess( .concatMap( processor -> Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .compose(Tracing.withContext(currentContext)) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -128,7 +130,8 @@ protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + Context parentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); @@ -144,15 +147,16 @@ protected Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } - Context parentContext = Context.current(); - return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - try (Scope scope = parentContext.makeCurrent()) { - return buildPostprocessingEvents( - updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); - } - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, + eventIterables, + context, + baseEventForLlmResponse, + llmRequest, + parentContext) + .compose(Tracing.withContext(parentContext))); } /** @@ -163,54 +167,80 @@ protected Flowable postprocess( * @param eventForCallbackUsage An Event object primarily for providing context (like actions) to * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ - private Flowable callLlm( + private Flowable callLlm( Context spanContext, InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) .toFlowable() + .concatMap( + llmResp -> + postprocess( + context, + eventForCallbackUsage, + llmRequestBuilder.build(), + llmResp, + spanContext)) .switchIfEmpty( Flowable.defer( () -> { + LlmAgent agent = (LlmAgent) context.agent(); BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, llmRequestBuilder, eventForCallbackUsage, exception) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .compose( - Tracing.trace("call_llm") - .setParent(spanContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, + LlmRequest finalLlmRequest = llmRequestBuilder.build(); + + Span span = + Tracing.getTracer() + .spanBuilder("call_llm") + .setParent(spanContext) + .startSpan(); + Context callLlmContext = spanContext.with(span); + + Flowable flowable = + llm.generateContent( + finalLlmRequest, + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, + llmRequestBuilder, + eventForCallbackUsage, + exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnError( + error -> { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()) + .flatMap( + llmResp -> + postprocess( context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .concatMap( - llmResp -> - handleAfterModelCallback(context, llmResp, eventForCallbackUsage) - .toFlowable()); + eventForCallbackUsage, + finalLlmRequest, + llmResp, + callLlmContext) + .doOnSubscribe( + s -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + finalLlmRequest, + llmResp))); + + return Tracing.traceFlowable(callLlmContext, span, () -> flowable); })); } @@ -222,6 +252,7 @@ private Flowable callLlm( */ private Maybe handleBeforeModelCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -240,7 +271,11 @@ private Maybe handleBeforeModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequestBuilder) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); @@ -257,6 +292,7 @@ private Maybe handleOnModelErrorCallback( LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, Throwable throwable) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -277,7 +313,11 @@ private Maybe handleOnModelErrorCallback( () -> { LlmRequest llmRequest = llmRequestBuilder.build(); return Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmRequest, ex) + .compose(Tracing.withContext(currentContext))) .firstElement(); }); @@ -292,6 +332,7 @@ private Maybe handleOnModelErrorCallback( */ private Single handleAfterModelCallback( InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Context currentContext = Context.current(); Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); @@ -310,7 +351,11 @@ private Single handleAfterModelCallback( Maybe.defer( () -> Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .concatMapMaybe( + callback -> + callback + .call(callbackContext, llmResponse) + .compose(Tracing.withContext(currentContext))) .firstElement()); return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); @@ -330,7 +375,6 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex return Flowable.defer( () -> { - Context currentContext = Context.current(); return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( @@ -362,23 +406,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> { - try (Scope postScope = currentContext.makeCurrent()) { - return postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - String newId = Event.generateEventId(); - logger.debug( - "Resetting event ID from {} to {}", oldId, newId); - mutableEventTemplate.setId(newId); - }); - } + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + String newId = Event.generateEventId(); + logger.debug("Resetting event ID from {} to {}", oldId, newId); + mutableEventTemplate.setId(newId); }) .concatMap( event -> { @@ -397,7 +430,12 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.get().runAsync(context))); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runAsync(context); + } + })); } return postProcessedEvents; }); @@ -455,6 +493,8 @@ private Flowable run( public Flowable runLive(InvocationContext invocationContext) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); + // Capture agent context at assembly time to use as parent for agent transfer at subscription + // time. See Flowable.defer() usages below. Context spanContext = Context.current(); return preprocessEvents.concatWith( @@ -545,6 +585,10 @@ public void onError(Throwable e) { .author(invocationContext.agent().name()) .branch(invocationContext.branch().orElse(null)); + Span span = + Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan(); + Context callLlmContext = spanContext.with(span); + Flowable receiveFlow = connection .receive() @@ -556,7 +600,8 @@ public void onError(Throwable e) { invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, - llmResponse); + llmResponse, + callLlmContext); }) .flatMap( event -> { @@ -570,7 +615,12 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - nextAgent.get().runLive(invocationContext); + Flowable.defer( + () -> { + try (Scope s = spanContext.makeCurrent()) { + return nextAgent.get().runLive(invocationContext); + } + }); events = Flowable.concat(events, nextAgentEvents); } return events; @@ -592,7 +642,12 @@ public void onError(Throwable e) { } }); - return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + return Tracing.traceFlowable( + callLlmContext, + span, + () -> + receiveFlow.takeWhile( + event -> !event.actions().endInvocation().orElse(false))); })); } @@ -608,7 +663,8 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest) { + LlmRequest llmRequest, + Context parentContext) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -624,21 +680,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); + Flowable functionEvents; + try (Scope scope = parentContext.makeCurrent()) { + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + } return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 6ebd39a9c..840a370c6 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -64,6 +64,11 @@ public Single processRequest( modelName = ""; } + ImmutableList sessionEvents; + synchronized (context.session().events()) { + sessionEvents = ImmutableList.copyOf(context.session().events()); + } + if (llmAgent.includeContents() == LlmAgent.IncludeContents.NONE) { return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -71,7 +76,7 @@ public Single processRequest( .contents( getCurrentTurnContents( context.branch().orElse(null), - context.session().events(), + sessionEvents, context.agent().name(), modelName)) .build(), @@ -80,10 +85,7 @@ public Single processRequest( ImmutableList contents = getContents( - context.branch().orElse(null), - context.session().events(), - context.agent().name(), - modelName); + context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); return Single.just( RequestProcessor.RequestProcessingResult.create( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java deleted file mode 100644 index d9cee5fa0..000000000 --- a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.examples.ExampleUtils; -import com.google.adk.models.LlmRequest; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Single; - -/** {@link RequestProcessor} that populates examples in LLM request. */ -public final class Examples implements RequestProcessor { - - public Examples() {} - - @Override - public Single processRequest( - InvocationContext context, LlmRequest request) { - if (!(context.agent() instanceof LlmAgent)) { - throw new IllegalArgumentException("Agent in InvocationContext is not an instance of Agent."); - } - LlmAgent agent = (LlmAgent) context.agent(); - LlmRequest.Builder builder = request.toBuilder(); - - String query = - context - .userContent() - .flatMap(Content::parts) - .filter(parts -> !parts.isEmpty()) - .map(parts -> parts.get(0).text().orElse("")) - .orElse(""); - agent - .exampleProvider() - .ifPresent( - exampleProvider -> - builder.appendInstructions( - ImmutableList.of(ExampleUtils.buildExampleSi(exampleProvider, query)))); - return Single.just( - RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); - } -} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index c1a996064..84a8141ea 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -42,7 +42,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Observable; @@ -163,7 +162,9 @@ public static Maybe handleFunctionCalls( } return functionResponseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -226,7 +227,9 @@ public static Maybe handleFunctionCallsLive( return responseEventsObservable .toList() - .flatMapMaybe( + .toMaybe() + .compose(Tracing.withContext(parentContext)) + .flatMap( events -> { if (events.isEmpty()) { return Maybe.empty(); @@ -243,47 +246,45 @@ private static Function> getFunctionCallMapper( Context parentContext) { return functionCall -> Maybe.defer( - () -> { - try (Scope scope = parentContext.makeCurrent()) { - BaseTool tool = tools.get(functionCall.name().get()); - ToolContext toolContext = - ToolContext.builder(invocationContext) - .functionCallId(functionCall.id().orElse("")) - .toolConfirmation( - functionCall.id().map(toolConfirmations::get).orElse(null)) - .build(); - - Map functionArgs = - functionCall.args().map(HashMap::new).orElse(new HashMap<>()); - - Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) - .switchIfEmpty( - Maybe.defer( - () -> { - try (Scope innerScope = parentContext.makeCurrent()) { - return isLive - ? processFunctionLive( - invocationContext, - tool, - toolContext, - functionCall, - functionArgs, - parentContext) - : callTool(tool, functionArgs, toolContext, parentContext); - } - })); - - return postProcessFunctionResult( - maybeFunctionResult, - invocationContext, - tool, - functionArgs, - toolContext, - isLive, - parentContext); - } - }); + () -> { + BaseTool tool = tools.get(functionCall.name().get()); + ToolContext toolContext = + ToolContext.builder(invocationContext) + .functionCallId(functionCall.id().orElse("")) + .toolConfirmation( + functionCall.id().map(toolConfirmations::get).orElse(null)) + .build(); + + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); + + Maybe> maybeFunctionResult = + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) + .switchIfEmpty( + Maybe.defer( + () -> + isLive + ? processFunctionLive( + invocationContext, + tool, + toolContext, + functionCall, + functionArgs, + parentContext) + : callTool( + tool, functionArgs, toolContext, parentContext)) + .compose(Tracing.withContext(parentContext))); + + return postProcessFunctionResult( + maybeFunctionResult, + invocationContext, + tool, + functionArgs, + toolContext, + isLive, + parentContext); + }) + .compose(Tracing.withContext(parentContext)); } /** @@ -410,34 +411,27 @@ private static Maybe postProcessFunctionResult( }) .flatMapMaybe( optionalInitialResult -> { - try (Scope scope = parentContext.makeCurrent()) { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, tool, functionArgs, toolContext, initialFunctionResult) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = - finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, - finalFunctionResult, - toolContext, - invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]") - .setParent(parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); - } - }); + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + Event event = + buildResponseEvent( + tool, finalFunctionResult, toolContext, invocationContext); + Tracing.traceToolResponse(event.id(), event); + return Maybe.just(event); + }); + }) + .compose( + Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index e00c0093d..a93eb3cb4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -29,6 +29,7 @@ import com.google.adk.events.Event; import com.google.adk.events.ToolConfirmation; import com.google.adk.models.LlmRequest; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -37,6 +38,7 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collection; @@ -216,10 +218,13 @@ private Maybe assembleEvent( .build()) .build(); - return toolsMapSingle.flatMapMaybe( - toolsMap -> - Functions.handleFunctionCalls( - invocationContext, functionCallEvent, toolsMap, toolConfirmations)); + Context parentContext = Context.current(); + return toolsMapSingle + .flatMapMaybe( + toolsMap -> + Functions.handleFunctionCalls( + invocationContext, functionCallEvent, toolsMap, toolConfirmations)) + .compose(Tracing.withContext(parentContext)); } private static Optional> maybeCreateToolConfirmationEntry( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index de45ba702..f56cc61c3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -32,7 +32,6 @@ public class SingleFlow extends BaseLlmFlow { new Identity(), new Compaction(), new Contents(), - new Examples(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/main/java/com/google/adk/models/LlmRequest.java b/core/src/main/java/com/google/adk/models/LlmRequest.java index 1a45c3a95..760a7c1c6 100644 --- a/core/src/main/java/com/google/adk/models/LlmRequest.java +++ b/core/src/main/java/com/google/adk/models/LlmRequest.java @@ -150,7 +150,7 @@ private static Builder create() { abstract LiveConnectConfig liveConnectConfig(); @CanIgnoreReturnValue - abstract Builder tools(Map tools); + public abstract Builder tools(Map tools); abstract Map tools(); diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e534da787..8d0366e9a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,11 +21,13 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; +import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -126,6 +128,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -136,11 +139,13 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e))); + e))) + .compose(Tracing.withContext(capturedContext)); } @Override public Completable close() { + Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -149,7 +154,8 @@ public Completable close() { .doOnError( e -> logger.error( - "[{}] Error during callback 'close'", plugin.getName(), e))); + "[{}] Error during callback 'close'", plugin.getName(), e))) + .compose(Tracing.withContext(capturedContext)); } @Override @@ -227,7 +233,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - + Context capturedContext = Context.current(); return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -247,6 +253,7 @@ private Maybe runMaybeCallbacks( plugin.getName(), callbackName, e))) - .firstElement(); + .firstElement() + .compose(Tracing.withContext(capturedContext)); } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java new file mode 100644 index 000000000..ef826fb56 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -0,0 +1,270 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializationError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Handles asynchronous batching and writing of events to BigQuery. */ +class BatchProcessor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + + private final StreamWriter writer; + private final int batchSize; + private final Duration flushInterval; + @VisibleForTesting final BlockingQueue> queue; + private final ScheduledExecutorService executor; + @VisibleForTesting final BufferAllocator allocator; + final AtomicBoolean flushLock = new AtomicBoolean(false); + private final Schema arrowSchema; + private final VectorSchemaRoot root; + + public BatchProcessor( + StreamWriter writer, + int batchSize, + Duration flushInterval, + int queueMaxSize, + ScheduledExecutorService executor) { + this.writer = writer; + this.batchSize = batchSize; + this.flushInterval = flushInterval; + this.queue = new LinkedBlockingQueue<>(queueMaxSize); + this.executor = executor; + // It's safe to use Long.MAX_VALUE here as this is a top-level RootAllocator, + // and memory is properly managed via try-with-resources in the flush() method. + // The actual memory usage is bounded by the batchSize and individual row sizes. + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.arrowSchema = BigQuerySchema.getArrowSchema(); + this.root = VectorSchemaRoot.create(arrowSchema, allocator); + } + + public void start() { + @SuppressWarnings("unused") + var unused = + executor.scheduleWithFixedDelay( + () -> { + try { + flush(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Error in background flush", e); + } + }, + flushInterval.toMillis(), + flushInterval.toMillis(), + MILLISECONDS); + } + + public void append(Map row) { + if (!queue.offer(row)) { + logger.warning("BigQuery event queue is full, dropping event."); + return; + } + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + + public void flush() { + // Acquire the flushLock. If another flush is already in progress, return immediately. + if (!flushLock.compareAndSet(false, true)) { + return; + } + try { + if (queue.isEmpty()) { + return; + } + List> batch = new ArrayList<>(); + queue.drainTo(batch, batchSize); + if (batch.isEmpty()) { + return; + } + try { + root.allocateNew(); + for (int i = 0; i < batch.size(); i++) { + Map row = batch.get(i); + for (Field field : arrowSchema.getFields()) { + populateVector(root.getVector(field.getName()), i, row.get(field.getName())); + } + } + root.setRowCount(batch.size()); + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + AppendRowsResponse result = writer.append(recordBatch).get(); + if (result.hasError()) { + logger.severe("BigQuery append error: " + result.getError().getMessage()); + for (var error : result.getRowErrorsList()) { + logger.severe( + String.format("Row error at index %d: %s", error.getIndex(), error.getMessage())); + } + } else { + logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); + } + } catch (AppendSerializationError ase) { + logger.log( + Level.SEVERE, "Failed to write batch to BigQuery due to serialization error", ase); + Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); + if (rowIndexToErrorMessage != null && !rowIndexToErrorMessage.isEmpty()) { + logger.severe("Row-level errors found:"); + for (Map.Entry entry : rowIndexToErrorMessage.entrySet()) { + logger.severe( + String.format("Row error at index %d: %s", entry.getKey(), entry.getValue())); + } + } else { + logger.severe( + "AppendSerializationError occurred, but no row-specific errors were provided."); + } + } + } catch (Exception e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", e); + } finally { + // Clear the vectors to release the memory. + root.clear(); + } + } finally { + flushLock.set(false); + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + } + + private void populateVector(FieldVector vector, int index, Object value) { + if (value == null || (value instanceof JsonNode jsonNode && jsonNode.isNull())) { + vector.setNull(index); + return; + } + if (vector instanceof VarCharVector varCharVector) { + String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + varCharVector.setSafe(index, strValue.getBytes(UTF_8)); + } else if (vector instanceof BigIntVector bigIntVector) { + long longValue; + if (value instanceof JsonNode jsonNode) { + longValue = jsonNode.asLong(); + } else if (value instanceof Number number) { + longValue = number.longValue(); + } else { + longValue = Long.parseLong(value.toString()); + } + bigIntVector.setSafe(index, longValue); + } else if (vector instanceof BitVector bitVector) { + boolean boolValue = + (value instanceof JsonNode jsonNode) ? jsonNode.asBoolean() : (Boolean) value; + bitVector.setSafe(index, boolValue ? 1 : 0); + } else if (vector instanceof TimeStampVector timeStampVector) { + if (value instanceof Instant instant) { + long micros = + SECONDS.toMicros(instant.getEpochSecond()) + NANOSECONDS.toMicros(instant.getNano()); + timeStampVector.setSafe(index, micros); + } else if (value instanceof JsonNode jsonNode) { + timeStampVector.setSafe(index, jsonNode.asLong()); + } else if (value instanceof Long longValue) { + timeStampVector.setSafe(index, longValue); + } + } else if (vector instanceof ListVector listVector) { + int start = listVector.startNewValue(index); + if (value instanceof ArrayNode arrayNode) { + for (int i = 0; i < arrayNode.size(); i++) { + populateVector(listVector.getDataVector(), start + i, arrayNode.get(i)); + } + listVector.endValue(index, arrayNode.size()); + } else if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + populateVector(listVector.getDataVector(), start + i, list.get(i)); + } + listVector.endValue(index, list.size()); + } + } else if (vector instanceof StructVector structVector) { + structVector.setIndexDefined(index); + if (value instanceof ObjectNode objectNode) { + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, objectNode.get(child.getName())); + } + } else if (value instanceof Map) { + Map map = (Map) value; + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, map.get(child.getName())); + } + } + } + } + + @Override + public void close() { + if (this.queue != null && !this.queue.isEmpty()) { + this.flush(); + } + if (this.allocator != null) { + try { + this.allocator.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close Buffer allocator", e); + } + } + if (this.root != null) { + try { + this.root.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close VectorSchemaRoot", e); + } + } + if (this.writer != null) { + try { + this.writer.close(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to close BigQuery writer", e); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java new file mode 100644 index 000000000..68b5fb5a1 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -0,0 +1,436 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * BigQuery Agent Analytics Plugin for Java. + * + *

Logs agent execution events directly to a BigQuery table using the Storage Write API. + */ +public class BigQueryAgentAnalyticsPlugin extends BasePlugin { + private static final Logger logger = + Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + private static final ImmutableList DEFAULT_AUTH_SCOPES = + ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); + private static final AtomicLong threadCounter = new AtomicLong(0); + + private final BigQueryLoggerConfig config; + private final BigQuery bigQuery; + private final BigQueryWriteClient writeClient; + private final ScheduledExecutorService executor; + private final Object tableEnsuredLock = new Object(); + @VisibleForTesting final BatchProcessor batchProcessor; + private volatile boolean tableEnsured = false; + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { + this(config, createBigQuery(config)); + } + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) + throws IOException { + super("bigquery_agent_analytics"); + this.config = config; + this.bigQuery = bigQuery; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.writeClient = createWriteClient(config); + + if (config.enabled()) { + StreamWriter writer = createWriter(config); + this.batchProcessor = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + this.batchProcessor.start(); + } else { + this.batchProcessor = null; + } + } + + private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { + BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + if (config.credentials() != null) { + builder.setCredentials(config.credentials()); + } else { + builder.setCredentials( + GoogleCredentials.getApplicationDefault().createScoped(DEFAULT_AUTH_SCOPES)); + } + return builder.build().getService(); + } + + private void ensureTableExistsOnce() { + if (!tableEnsured) { + synchronized (tableEnsuredLock) { + if (!tableEnsured) { + // Table creation is expensive, so we only do it once per plugin instance. + tableEnsured = true; + ensureTableExists(bigQuery, config); + } + } + } + } + + private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { + TableId tableId = TableId.of(config.projectId(), config.datasetId(), config.tableName()); + Schema schema = BigQuerySchema.getEventsSchema(); + try { + Table table = bigQuery.getTable(tableId); + logger.info("BigQuery table: " + tableId); + if (table == null) { + logger.info("Creating BigQuery table: " + tableId); + StandardTableDefinition.Builder tableDefinitionBuilder = + StandardTableDefinition.newBuilder().setSchema(schema); + if (!config.clusteringFields().isEmpty()) { + tableDefinitionBuilder.setClustering( + Clustering.newBuilder().setFields(config.clusteringFields()).build()); + } + TableInfo tableInfo = TableInfo.newBuilder(tableId, tableDefinitionBuilder.build()).build(); + bigQuery.create(tableInfo); + } else if (config.autoSchemaUpgrade()) { + // TODO(b/491851868): Implement auto-schema upgrade. + logger.info("BigQuery table already exists and auto-schema upgrade is enabled: " + tableId); + logger.info("Auto-schema upgrade is not implemented yet."); + } + } catch (BigQueryException e) { + if (e.getMessage().contains("invalid_grant")) { + logger.log( + Level.SEVERE, + "Failed to authenticate with BigQuery. Please run 'gcloud auth application-default" + + " login' to refresh your credentials or provide valid credentials in" + + " BigQueryLoggerConfig.", + e); + } else { + logger.log( + Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(retryConfig.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(retryConfig.initialDelay().toMillis())) + .setRetryDelayMultiplier(retryConfig.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(retryConfig.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Optional callbackContext, + Object content, + Map extraAttributes) { + if (batchProcessor == null) { + return; + } + + ensureTableExistsOnce(); + + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", eventType); + row.put( + "agent", + callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("session_id", invocationContext.session().id()); + row.put("invocation_id", invocationContext.invocationId()); + row.put("user_id", invocationContext.userId()); + + if (content instanceof Content contentParts) { + row.put( + "content_parts", + JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } else if (content != null) { + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } + + Map attributes = new HashMap<>(config.customTags()); + if (extraAttributes != null) { + attributes.putAll(extraAttributes); + } + row.put( + "attributes", + JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + + addTraceDetails(row); + batchProcessor.append(row); + } + + // TODO(b/491849911): Implement own trace management functionality. + private void addTraceDetails(Map row) { + SpanContext spanContext = Span.current().getSpanContext(); + if (spanContext.isValid()) { + row.put("trace_id", spanContext.getTraceId()); + row.put("span_id", spanContext.getSpanId()); + } + } + + @Override + public Completable close() { + if (batchProcessor != null) { + batchProcessor.close(); + } + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + return Completable.complete(); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.fromAction( + () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.fromAction( + () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("event_author", event.author()); + logEvent( + "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + }); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.fromAction( + () -> { + logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + batchProcessor.flush(); + }); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_START", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_END", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + LlmRequest req = llmRequest.build(); + attrs.put("model", req.model().orElse("unknown")); + logEvent( + "MODEL_REQUEST", + callbackContext.invocationContext(), + Optional.of(callbackContext), + req, + attrs); + }); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + logEvent( + "MODEL_RESPONSE", + callbackContext.invocationContext(), + Optional.of(callbackContext), + llmResponse, + attrs); + }); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("error_message", error.getMessage()); + logEvent( + "MODEL_ERROR", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + attrs); + }); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_START", + toolContext.invocationContext(), + Optional.of(toolContext), + toolArgs, + attrs); + }); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + }); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + attrs.put("error_message", error.getMessage()); + logEvent( + "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java new file mode 100644 index 000000000..aa5bf37de --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.auth.Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** Configuration for the BigQueryAgentAnalyticsPlugin. */ +@AutoValue +public abstract class BigQueryLoggerConfig { + // Whether the plugin is enabled. + public abstract boolean enabled(); + + // List of event types to log. If None, all are allowed + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventAllowlist(); + + // List of event types to ignore. + // TODO(b/491852782): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventDenylist(); + + // Max length for text content before truncation. + public abstract int maxContentLength(); + + // Project ID for the BigQuery table. + public abstract String projectId(); + + // Dataset ID for the BigQuery table. + public abstract String datasetId(); + + // Table name for the BigQuery table. + public abstract String tableName(); + + // Fields to cluster the table by. + public abstract ImmutableList clusteringFields(); + + // Whether to log multi-modal content. + // TODO(b/491852782): Implement logging of multi-modal content. + public abstract boolean logMultiModalContent(); + + // Retry configuration for BigQuery writes. + public abstract RetryConfig retryConfig(); + + // Number of rows to batch before flushing. + public abstract int batchSize(); + + // Duration to wait before flushing the queue. + public abstract Duration batchFlushInterval(); + + // Max time to wait for shutdown. + public abstract Duration shutdownTimeout(); + + // Max size of the batch processor queue. + public abstract int queueMaxSize(); + + // Optional custom formatter for content. + // TODO(b/491852782): Implement content formatter. + @Nullable + public abstract BiFunction contentFormatter(); + + // TODO(b/491852782): Implement connection id. + public abstract Optional connectionId(); + + // Toggle for session metadata (e.g. gchat thread-id). + // TODO(b/491852782): Implement logging of session metadata. + public abstract boolean logSessionMetadata(); + + // Static custom tags (e.g. {"agent_role": "sales"}). + // TODO(b/491852782): Implement custom tags. + public abstract ImmutableMap customTags(); + + // Automatically add new columns to existing tables when the plugin + // schema evolves. Only additive changes are made (columns are never + // dropped or altered). + // TODO(b/491852782): Implement auto-schema upgrade. + public abstract boolean autoSchemaUpgrade(); + + @Nullable + public abstract Credentials credentials(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig.Builder() + .setEnabled(true) + .setMaxContentLength(500 * 1024) + .setDatasetId("agent_analytics") + .setTableName("events") + .setClusteringFields(ImmutableList.of("event_type", "agent", "user_id")) + .setLogMultiModalContent(true) + .setRetryConfig(RetryConfig.builder().build()) + .setBatchSize(1) + .setBatchFlushInterval(Duration.ofSeconds(1)) + .setShutdownTimeout(Duration.ofSeconds(10)) + .setQueueMaxSize(10000) + .setLogSessionMetadata(true) + .setCustomTags(ImmutableMap.of()) + // TODO(b/491851868): Enable auto-schema upgrade once implemented. + .setAutoSchemaUpgrade(false); + } + + /** Builder for {@link BigQueryLoggerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEnabled(boolean enabled); + + public abstract Builder setEventAllowlist(@Nullable List eventAllowlist); + + public abstract Builder setEventDenylist(@Nullable List eventDenylist); + + public abstract Builder setMaxContentLength(int maxContentLength); + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setTableName(String tableName); + + public abstract Builder setClusteringFields(List clusteringFields); + + public abstract Builder setLogMultiModalContent(boolean logMultiModalContent); + + public abstract Builder setRetryConfig(RetryConfig retryConfig); + + public abstract Builder setBatchSize(int batchSize); + + public abstract Builder setBatchFlushInterval(Duration batchFlushInterval); + + public abstract Builder setShutdownTimeout(Duration shutdownTimeout); + + public abstract Builder setQueueMaxSize(int queueMaxSize); + + public abstract Builder setContentFormatter( + @Nullable BiFunction contentFormatter); + + public abstract Builder setConnectionId(String connectionId); + + public abstract Builder setLogSessionMetadata(boolean logSessionMetadata); + + public abstract Builder setCustomTags(Map customTags); + + public abstract Builder setAutoSchemaUpgrade(boolean autoSchemaUpgrade); + + public abstract Builder setCredentials(Credentials credentials); + + public abstract BigQueryLoggerConfig build(); + } + + /** Retry configuration for BigQuery writes. */ + @AutoValue + public abstract static class RetryConfig { + public abstract int maxRetries(); + + public abstract Duration initialDelay(); + + public abstract double multiplier(); + + public abstract Duration maxDelay(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig_RetryConfig.Builder() + .setMaxRetries(3) + .setInitialDelay(Duration.ofSeconds(1)) + .setMultiplier(2.0) + .setMaxDelay(Duration.ofSeconds(10)); + } + + /** Builder for {@link RetryConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMaxRetries(int maxRetries); + + public abstract Builder setInitialDelay(Duration initialDelay); + + public abstract Builder setMultiplier(double multiplier); + + public abstract Builder setMaxDelay(Duration maxDelay); + + public abstract RetryConfig build(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java new file mode 100644 index 000000000..81181a1e0 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java @@ -0,0 +1,304 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema.Mode; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** Utility for defining the BigQuery events table schema. */ +public final class BigQuerySchema { + + private BigQuerySchema() {} + + private static final ImmutableMap> + FIELD_TYPE_TO_ARROW_FIELD_METADATA = + ImmutableMap.of( + StandardSQLTypeName.JSON, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:json"), + StandardSQLTypeName.DATETIME, + ImmutableMap.of("ARROW:extension:name", "google:sqlType:datetime"), + StandardSQLTypeName.GEOGRAPHY, + ImmutableMap.of( + "ARROW:extension:name", + "google:sqlType:geography", + "ARROW:extension:metadata", + "{\"encoding\": \"WKT\"}")); + + /** Returns the BigQuery schema for the events table. */ + // TODO(b/491848381): Rely on the same schema defined for python plugin. + public static Schema getEventsSchema() { + return Schema.of( + Field.newBuilder("timestamp", StandardSQLTypeName.TIMESTAMP) + .setMode(Field.Mode.REQUIRED) + .setDescription("The UTC timestamp when the event occurred.") + .build(), + Field.newBuilder("event_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The category of the event.") + .build(), + Field.newBuilder("agent", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The name of the agent that generated this event.") + .build(), + Field.newBuilder("session_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for the entire conversation session.") + .build(), + Field.newBuilder("invocation_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for a single turn or execution.") + .build(), + Field.newBuilder("user_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The identifier of the end-user.") + .build(), + Field.newBuilder("trace_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry trace ID.") + .build(), + Field.newBuilder("span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry span ID.") + .build(), + Field.newBuilder("parent_span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry parent span ID.") + .build(), + Field.newBuilder("content", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("The primary payload of the event.") + .build(), + Field.newBuilder( + "content_parts", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("mime_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The MIME type of the content part.") + .build(), + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The URI of the content part if stored externally.") + .build(), + Field.newBuilder( + "object_ref", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("version", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("authorizer", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("details", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .build())) + .setMode(Field.Mode.NULLABLE) + .setDescription("The ObjectRef of the content part if stored externally.") + .build(), + Field.newBuilder("text", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The raw text content.") + .build(), + Field.newBuilder("part_index", StandardSQLTypeName.INT64) + .setMode(Field.Mode.NULLABLE) + .setDescription("The zero-based index of this part.") + .build(), + Field.newBuilder("part_attributes", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Additional metadata as a JSON object string.") + .build(), + Field.newBuilder("storage_mode", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates how the content part is stored.") + .build())) + .setMode(Field.Mode.REPEATED) + .setDescription("Multi-modal events content parts.") + .build(), + Field.newBuilder("attributes", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing arbitrary key-value pairs.") + .build(), + Field.newBuilder("latency_ms", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing latency measurements.") + .build(), + Field.newBuilder("status", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The outcome of the event.") + .build(), + Field.newBuilder("error_message", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Detailed error message if the status is 'ERROR'.") + .build(), + Field.newBuilder("is_truncated", StandardSQLTypeName.BOOL) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates if the 'content' field was truncated.") + .build()); + } + + /** Returns the Arrow schema for the events table. */ + public static org.apache.arrow.vector.types.pojo.Schema getArrowSchema() { + return new org.apache.arrow.vector.types.pojo.Schema( + getEventsSchema().getFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList())); + } + + /** Returns the serialized Arrow schema for the events table. */ + public static ByteString getSerializedArrowSchema() { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), getArrowSchema()); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new VerifyException("Failed to serialize arrow schema", e); + } + } + + private static org.apache.arrow.vector.types.pojo.Field convertToArrowField(Field field) { + ArrowType arrowType = convertTypeToArrow(field.getType().getStandardType()); + ImmutableList children = null; + if (field.getSubFields() != null) { + children = + field.getSubFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList()); + } + + ImmutableMap metadata = + FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(field.getType().getStandardType()); + + FieldType fieldType = + new FieldType(field.getMode() != Field.Mode.REQUIRED, arrowType, null, metadata); + org.apache.arrow.vector.types.pojo.Field arrowField = + new org.apache.arrow.vector.types.pojo.Field(field.getName(), fieldType, children); + + if (field.getMode() == Field.Mode.REPEATED) { + return new org.apache.arrow.vector.types.pojo.Field( + field.getName(), + new FieldType(false, new ArrowType.List(), null), + ImmutableList.of( + new org.apache.arrow.vector.types.pojo.Field( + "element", arrowField.getFieldType(), arrowField.getChildren()))); + } + return arrowField; + } + + private static ArrowType convertTypeToArrow(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> new ArrowType.Bool(); + case BYTES -> new ArrowType.Binary(); + case DATE -> new ArrowType.Date(DateUnit.DAY); + case DATETIME -> + // Arrow doesn't have a direct DATETIME, often mapped to Timestamp or Utf8 + new ArrowType.Utf8(); + case FLOAT64 -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case INT64 -> new ArrowType.Int(64, true); + case NUMERIC, BIGNUMERIC -> new ArrowType.Decimal(38, 9, 128); + case GEOGRAPHY, STRING, JSON -> new ArrowType.Utf8(); + case STRUCT -> new ArrowType.Struct(); + case TIME -> new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case TIMESTAMP -> new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + default -> new ArrowType.Null(); + }; + } + + /** Returns names of fields to cluster by default. */ + public static ImmutableList getDefaultClusteringFields() { + return ImmutableList.of("event_type", "agent", "user_id"); + } + + /** Returns the BigQuery TableSchema for the events table (Storage Write API). */ + public static TableSchema getEventsTableSchema() { + return convertTableSchema(getEventsSchema()); + } + + private static TableSchema convertTableSchema(Schema schema) { + TableSchema.Builder result = TableSchema.newBuilder(); + for (int i = 0; i < schema.getFields().size(); i++) { + result.addFields(i, convertFieldSchema(schema.getFields().get(i))); + } + return result.build(); + } + + private static TableFieldSchema convertFieldSchema(Field field) { + TableFieldSchema.Builder result = TableFieldSchema.newBuilder(); + Field.Mode mode = field.getMode() != null ? field.getMode() : Field.Mode.NULLABLE; + + Mode resultMode = Mode.valueOf(mode.name()); + result.setMode(resultMode).setName(field.getName()); + + StandardSQLTypeName standardType = field.getType().getStandardType(); + TableFieldSchema.Type resultType = convertType(standardType); + result.setType(resultType); + + if (field.getDescription() != null) { + result.setDescription(field.getDescription()); + } + if (field.getSubFields() != null) { + for (int i = 0; i < field.getSubFields().size(); i++) { + result.addFields(i, convertFieldSchema(field.getSubFields().get(i))); + } + } + return result.build(); + } + + private static TableFieldSchema.Type convertType(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> TableFieldSchema.Type.BOOL; + case BYTES -> TableFieldSchema.Type.BYTES; + case DATE -> TableFieldSchema.Type.DATE; + case DATETIME -> TableFieldSchema.Type.DATETIME; + case FLOAT64 -> TableFieldSchema.Type.DOUBLE; + case GEOGRAPHY -> TableFieldSchema.Type.GEOGRAPHY; + case INT64 -> TableFieldSchema.Type.INT64; + case NUMERIC -> TableFieldSchema.Type.NUMERIC; + case STRING -> TableFieldSchema.Type.STRING; + case STRUCT -> TableFieldSchema.Type.STRUCT; + case TIME -> TableFieldSchema.Type.TIME; + case TIMESTAMP -> TableFieldSchema.Type.TIMESTAMP; + case BIGNUMERIC -> TableFieldSchema.Type.BIGNUMERIC; + case JSON -> TableFieldSchema.Type.JSON; + case INTERVAL -> TableFieldSchema.Type.INTERVAL; + default -> TableFieldSchema.Type.TYPE_UNSPECIFIED; + }; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java new file mode 100644 index 000000000..b4b4a1049 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +/** Utility for formatting and truncating content for BigQuery logging. */ +final class JsonFormatter { + private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + + private JsonFormatter() {} + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + public static ArrayNode formatContentParts(Optional content, int maxLength) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty() || content.get().parts() == null) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncateString(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", "[BINARY DATA]"); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } + + /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ + public static JsonNode smartTruncate(Object obj, int maxLength) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); + } + } + + private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + if (node.isTextual()) { + return mapper.valueToTree(truncateString(node.asText(), maxLength)); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + node.properties() + .iterator() + .forEachRemaining( + entry -> { + newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); + }); + return newNode; + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + newNode.add(recursiveSmartTruncate(element, maxLength)); + } + return newNode; + } + return node; + } + + private static String truncateString(String s, int maxLength) { + if (s == null || s.length() <= maxLength) { + return s; + } + return s.substring(0, maxLength) + "...[truncated]"; + } +} diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 29b2b76d3..849a3cd04 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,6 +52,7 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -312,6 +313,7 @@ private Single appendNewMessageToSession( throw new IllegalArgumentException("No parts in the new_message."); } + Completable saveArtifactsFlow = Completable.complete(); if (this.artifactService != null && saveInputBlobsAsArtifacts) { // The runner directly saves the artifacts (if applicable) in the user message and replaces // the artifact data with a file name placeholder. @@ -321,9 +323,11 @@ private Single appendNewMessageToSession( continue; } String fileName = "artifact_" + invocationContext.invocationId() + "_" + i; - var unused = - this.artifactService.saveArtifact( - this.appName, session.userId(), session.id(), fileName, part); + saveArtifactsFlow = + saveArtifactsFlow.andThen( + this.artifactService + .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .ignoreElement()); newMessage .parts() @@ -348,7 +352,8 @@ private Single appendNewMessageToSession( EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); } - return this.sessionService.appendEvent(session, eventBuilder.build()); + return saveArtifactsFlow.andThen( + this.sessionService.appendEvent(session, eventBuilder.build())); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -375,20 +380,25 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - Maybe maybeSession = - this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); - return maybeSession - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) + .compose(Tracing.trace("invocation")); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -415,35 +425,6 @@ public Flowable runAsync(String userId, String sessionId, Content newMess return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } - /** - * See {@link #runAsync(Session, Content, RunConfig, Map)}. - * - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { - return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); - } - - /** - * Runs the agent asynchronously using a provided Session object. - * - * @param session The session to run the agent in. - * @param newMessage The new message from the user to process. - * @param runConfig Configuration for the agent run. - * @param stateDelta Optional map of state updates to merge into the session for this run. - * @return A Flowable stream of {@link Event} objects generated by the agent during execution. - * @deprecated Use runAsync with sessionId. - */ - @Deprecated(since = "0.4.0", forRemoval = true) - public Flowable runAsync( - Session session, - Content newMessage, - RunConfig runConfig, - @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta); - } - /** * Runs the agent asynchronously using a provided Session object. * @@ -458,6 +439,10 @@ protected Flowable runAsyncImpl( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { + Preconditions.checkNotNull(session, "session cannot be null"); + Preconditions.checkNotNull(newMessage, "newMessage cannot be null"); + Preconditions.checkNotNull(runConfig, "runConfig cannot be null"); + Context capturedContext = Context.current(); return Flowable.defer( () -> { BaseAgent rootAgent = this.agent; @@ -473,22 +458,18 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) + .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - initialContext, - runConfig.saveInputBlobsAsArtifacts(), - stateDelta) - : Single.just(null)) + appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta)) .flatMapPublisher( event -> { - if (event == null) { - return Flowable.empty(); - } // Get the updated session after the message and state delta are // applied return this.sessionService @@ -502,7 +483,8 @@ protected Flowable runAsyncImpl( event, invocationId, runConfig, - rootAgent)); + rootAgent)) + .compose(Tracing.withContext(capturedContext)); }); }) .doOnError( @@ -510,8 +492,7 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }) - .compose(Tracing.trace("invocation")); + }); } private Flowable runAgentWithFreshSession( @@ -564,12 +545,14 @@ private Flowable runAgentWithFreshSession( .toFlowable()); // If beforeRunCallback returns content, emit it and skip agent + Context capturedContext = Context.current(); return beforeRunEvent .toFlowable() .switchIfEmpty(agentEvents) .concatWith( Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) - .concatWith(Completable.defer(() -> compactEvents(updatedSession))); + .concatWith(Completable.defer(() -> compactEvents(updatedSession))) + .compose(Tracing.withContext(capturedContext)); } private Completable compactEvents(Session session) { @@ -634,46 +617,9 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(this.findAgentToRun(session, rootAgent)); } - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> { - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }) - .compose(Tracing.trace("invocation")); + return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); } /** @@ -684,19 +630,25 @@ public Flowable runLive( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession(appName, userId, null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format("Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); + return Flowable.defer( + () -> + this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession( + appName, userId, (Map) null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format( + "Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher( + session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) + .compose(Tracing.trace("invocation")); } /** @@ -711,15 +663,46 @@ public Flowable runLive( } /** - * Runs the agent asynchronously with a default user ID. + * Runs the agent in live mode, appending generated events to the session. * - * @return stream of generated events. + * @return stream of events from the agent. */ - @Deprecated(since = "0.5.0", forRemoval = true) - public Flowable runWithSessionId( - String sessionId, Content newMessage, RunConfig runConfig) { - // TODO(b/410859954): Add user_id to getter or method signature. Assuming "tmp-user" for now. - return this.runAsync("tmp-user", sessionId, newMessage, runConfig); + protected Flowable runLiveImpl( + Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return Flowable.defer( + () -> { + Context capturedContext = Context.current(); + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }) + .compose(Tracing.withContext(capturedContext)); + }); } /** diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index b2a584b11..d9bb047a3 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -154,19 +154,13 @@ public Maybe getSession( if (config.numRecentEvents().isEmpty() && config.afterTimestamp().isPresent()) { Instant threshold = config.afterTimestamp().get(); - eventsInCopy.removeIf( - event -> getEventTimestampEpochSeconds(event) < threshold.getEpochSecond()); + eventsInCopy.removeIf(event -> getInstantFromEvent(event).isBefore(threshold)); } // Merge state into the potentially filtered copy and return return Maybe.just(mergeWithGlobalState(appName, userId, sessionCopy)); } - // Helper to get event timestamp as epoch seconds - private long getEventTimestampEpochSeconds(Event event) { - return event.timestamp() / 1000L; - } - @Override public Single listSessions(String appName, String userId) { Objects.requireNonNull(appName, "appName cannot be null"); @@ -294,10 +288,7 @@ public Single appendEvent(Session session, Event event) { /** Converts an event's timestamp to an Instant. Adapt based on actual Event structure. */ // TODO: have Event.timestamp() return Instant directly private Instant getInstantFromEvent(Event event) { - double epochSeconds = getEventTimestampEpochSeconds(event); - long seconds = (long) epochSeconds; - long nanos = (long) ((epochSeconds % 1.0) * 1_000_000_000L); - return Instant.ofEpochSecond(seconds, nanos); + return Instant.ofEpochMilli(event.timestamp()); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/README.md b/core/src/main/java/com/google/adk/telemetry/README.md new file mode 100644 index 000000000..8665b3352 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/README.md @@ -0,0 +1,156 @@ +# ADK Telemetry and Tracing + +This package contains classes for capturing and reporting telemetry data within +the ADK, primarily for tracing agent execution leveraging OpenTelemetry. + +## Overview + +The `Tracing` utility class provides methods to trace various aspects of an +agent's execution, including: + +* Agent invocations +* LLM requests and responses +* Tool calls and responses + +These traces can be exported and visualized in telemetry backends like Google +Cloud Trace or Zipkin, or viewed through the ADK Dev Server UI, providing +observability into agent behavior. + +## How Tracing is Used + +Tracing is deeply integrated into the ADK's RxJava-based asynchronous workflows. + +### Agent Invocations + +Every agent's `runAsync` or `runLive` execution is wrapped in a span named +`invoke_agent `. The top-level agent invocation initiated by +`Runner.runAsync` or `Runner.runLive` is captured in a span named `invocation`. +Agent-specific metadata like name and description are added as span attributes, +following OpenTelemetry semantic conventions (e.g., `gen_ai.agent.name`). + +### LLM Calls + +Calls to Large Language Models (LLMs) are traced within a `call_llm` span. The +`traceCallLlm` method attaches detailed attributes to this span, including: + +* The LLM request (excluding large data like images) and response. +* Model name (`gen_ai.request.model`). +* Token usage (`gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens`). +* Configuration parameters (`gen_ai.request.top_p`, + `gen_ai.request.max_tokens`). +* Response finish reason (`gen_ai.response.finish_reasons`). + +### Tool Calls and Responses + +Tool executions triggered by the LLM are traced using `tool_call []` +and `tool_response []` spans. + +* `traceToolCall` records tool arguments in the + `gcp.vertex.agent.tool_call_args` attribute. +* `traceToolResponse` records tool output in the + `gcp.vertex.agent.tool_response` attribute. +* If multiple tools are called in parallel, a single `tool_response` span may + be created for the merged result. + +### Context Propagation + +ADK is built on RxJava and heavily uses asynchronous processing, which means +that work is often handed off between different threads. For tracing to work +correctly in such an environment, it's crucial that the active span's context +is propagated across these thread boundaries. If context is not propagated, +new spans may be orphaned or attached to the wrong parent, making traces +difficult to interpret. + +OpenTelemetry stores the currently active span in a thread-local variable. +When an asynchronous operation switches threads, this thread-local context is +lost. To solve this, ADK's `Tracing` class provides functionality to capture +the context on one thread and restore it on another when an asynchronous +operation resumes. This ensures that spans created on different threads are +correctly parented under the same trace. + +The primary mechanism for this is the `Tracing.withContext(context)` method, +which returns an RxJava transformer. When applied to an RxJava stream via +`.compose()`, this transformer ensures that the provided `Context` (containing +the parent span) is re-activated before any `onNext`, `onError`, `onComplete`, +or `onSuccess` signals are propagated downstream. It achieves this by wrapping +the downstream observer with a `TracingObserver`, which uses +`context.makeCurrent()` in a try-with-resources block around each callback, +guaranteeing that the correct span is active when downstream operators execute, +regardless of the thread. + +### RxJava Integration + +ADK integrates OpenTelemetry with RxJava streams to simplify span creation and +ensure context propagation: + +* **Span Creation**: The `Tracing.trace(spanName)` method returns an RxJava + transformer that can be applied to a `Flowable`, `Single`, `Maybe`, or + `Completable` using `.compose()`. This transformer wraps the stream's + execution in a new OpenTelemetry span. +* **Context Propagation**: The `Tracing.withContext(context)` transformer is + used with `.compose()` to ensure that the correct OpenTelemetry `Context` + (and thus the correct parent span) is active when stream operators or + subscriptions are executed, even across thread boundaries. + +## Trace Hierarchy Example + +A typical agent interaction might produce a trace hierarchy like the following: + +``` +invocation +└── invoke_agent my_agent + ├── call_llm + │ ├── tool_call [search_flights] + │ └── tool_response [search_flights] + └── call_llm +``` + +This shows: + +1. The overall `invocation` started by the `Runner`. +2. The invocation of `my_agent`. +3. The first `call_llm` made by `my_agent`. +4. A `tool_call` to `search_flights` and its corresponding `tool_response`. +5. A second `call_llm` made by `my_agent` to generate the final user response. + +### Nested Agents + +ADK supports nested agents, where one agent invokes another. If an agent has +sub-agents, it can transfer control to one of them using the built-in +`transfer_to_agent` tool. When `AgentA` calls `transfer_to_agent` to transfer +control to `AgentB`, the `invoke_agent AgentB` span will appear as a child of +the `invoke_agent AgentA` span, like so: + +``` +invocation +└── invoke_agent AgentA + ├── call_llm + │ ├── tool_call [transfer_to_agent] + │ └── tool_response [transfer_to_agent] + └── invoke_agent AgentB + ├── call_llm + └── ... +``` + +This structure allows you to see how `AgentA` delegated work to `AgentB`. + +## Span Creation References + +The following classes are the primary places where spans are created: + +* **`com.google.adk.runner.Runner`**: Initiates the top-level `invocation` + span for `runAsync` and `runLive`. +* **`com.google.adk.agents.BaseAgent`**: Creates the `invoke_agent + ` span for each agent execution. +* **`com.google.adk.flows.llmflows.BaseLlmFlow`**: Creates the `call_llm` span + when the LLM is invoked. +* **`com.google.adk.flows.llmflows.Functions`**: Creates `tool_call [...]` and + `tool_response [...]` spans when handling tool calls and responses. + +## Configuration + +**ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS**: This environment variable controls +whether LLM request/response content and tool arguments/responses are captured +in span attributes. It defaults to `true`. Set to `false` to exclude potentially +large or sensitive data from traces, in which case a `{}` JSON object will be +recorded instead. diff --git a/core/src/main/java/com/google/adk/tools/BaseTool.java b/core/src/main/java/com/google/adk/tools/BaseTool.java index 1ea2808a1..01a399920 100644 --- a/core/src/main/java/com/google/adk/tools/BaseTool.java +++ b/core/src/main/java/com/google/adk/tools/BaseTool.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.JsonBaseModel; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.models.LlmRequest; @@ -38,6 +39,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import javax.annotation.Nonnull; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -93,6 +95,85 @@ public Single> runAsync(Map args, ToolContex throw new UnsupportedOperationException("This method is not implemented."); } + /** + * Calls a tool with generic arguments and returns a map of results. The args type {@code T} need + * to be serializable with {@link JsonBaseModel#getMapper()} + */ + public final Single> runAsync(T args, ToolContext toolContext) { + return runAsync(args, toolContext, JsonBaseModel.getMapper()); + } + + /** + * Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of + * results. The args type {@code T} needs to be serializable with the provided {@link + * ObjectMapper}. + */ + public final Single> runAsync( + T args, ToolContext toolContext, ObjectMapper objectMapper) { + return runAsync(args, toolContext, objectMapper, output -> output); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified class. The input type {@code I} needs to be serializable and the + * output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, ToolContext toolContext, ObjectMapper objectMapper, Class oClass) { + return runAsync( + args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass)); + } + + /** + * Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results + * converted to a specified type reference. The input type {@code I} needs to be serializable and + * the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}. + */ + public final Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + TypeReference typeReference) { + return runAsync( + args, + toolContext, + objectMapper, + output -> objectMapper.convertValue(output, typeReference)); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified class. The + * input type {@code I} needs to be serializable and the output type {@code O} needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, Class oClass) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass); + } + + /** + * Calls a tool with generic arguments, returning the results converted to a specified type + * reference. The input type needs to be serializable and the output type needs to be + * deserializable with {@link JsonBaseModel#getMapper()} + */ + public final Single runAsync( + I args, ToolContext toolContext, TypeReference typeReference) { + return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference); + } + + private Single runAsync( + I args, + ToolContext toolContext, + ObjectMapper objectMapper, + Function, ? extends O> deserializer) { + return Single.defer( + () -> + Single.just( + objectMapper.convertValue(args, new TypeReference>() {}))) + .flatMap(argsMap -> runAsync(argsMap, toolContext)) + .map(deserializer::apply); + } + /** * Processes the outgoing {@link LlmRequest.Builder}. * diff --git a/core/src/main/java/com/google/adk/tools/ExampleTool.java b/core/src/main/java/com/google/adk/tools/ExampleTool.java index d08481532..d03c2e4f1 100644 --- a/core/src/main/java/com/google/adk/tools/ExampleTool.java +++ b/core/src/main/java/com/google/adk/tools/ExampleTool.java @@ -85,7 +85,9 @@ public Completable processLlmRequest( return Completable.complete(); } - llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + if (!examplesBlock.isEmpty()) { + llmRequestBuilder.appendInstructions(ImmutableList.of(examplesBlock)); + } // Delegate to BaseTool to keep any declaration bookkeeping (none for this tool) return super.processLlmRequest(llmRequestBuilder, toolContext); } diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 207243ceb..4cafb9681 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -24,6 +24,8 @@ import com.google.adk.agents.ReadonlyContext; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolPredicate; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Booleans; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -32,6 +34,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +54,7 @@ public class McpToolset implements BaseToolset { private final McpSessionManager mcpSessionManager; private McpSyncClient mcpSession; private final ObjectMapper objectMapper; - private final Optional toolFilter; + private final @Nullable Object toolFilter; private static final int MAX_RETRIES = 3; private static final long RETRY_DELAY_MILLIS = 100; @@ -62,17 +65,29 @@ public class McpToolset implements BaseToolset { * * @param connectionParams The SSE connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( SseServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with SSE server parameters. + * + * @param connectionParams The SSE connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + SseServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** @@ -82,7 +97,9 @@ public McpToolset( * @param objectMapper An ObjectMapper instance for parsing schemas. */ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -90,36 +107,39 @@ public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMappe * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( - ServerParameters connectionParams, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** - * Initializes the McpToolset with local server parameters and no tool filter. + * Initializes the McpToolset with local server parameters. * * @param connectionParams The local server connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names */ - public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { - this(connectionParams, objectMapper, Optional.empty()); + public McpToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); } /** - * Initializes the McpToolset with SSE server parameters, using the ObjectMapper used across the - * ADK. + * Initializes the McpToolset with local server parameters and no tool filter. * - * @param connectionParams The SSE connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param connectionParams The local server connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. */ - public McpToolset(SseServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -129,28 +149,31 @@ public McpToolset(SseServerParameters connectionParams, Optional toolFil * @param connectionParams The SSE connection parameters to the MCP server. */ public McpToolset(SseServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } /** * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK. + * ADK and no tool filter. * * @param connectionParams The local server connection parameters to the MCP server. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. */ - public McpToolset(ServerParameters connectionParams, Optional toolFilter) { - this(connectionParams, JsonBaseModel.getMapper(), toolFilter); + public McpToolset(ServerParameters connectionParams) { + this(connectionParams, JsonBaseModel.getMapper()); } /** - * Initializes the McpToolset with local server parameters, using the ObjectMapper used across the - * ADK and no tool filter. + * Initializes the McpToolset with an McpSessionManager. * - * @param connectionParams The local server connection parameters to the MCP server. + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolPredicate A {@link ToolPredicate} */ - public McpToolset(ServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + public McpToolset( + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolPredicate) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = Objects.requireNonNull(toolPredicate); } /** @@ -158,33 +181,69 @@ public McpToolset(ServerParameters connectionParams) { * * @param mcpSessionManager A McpSessionManager instance for testing. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolNames A list of tool names */ public McpToolset( - McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional toolFilter) { - Objects.requireNonNull(mcpSessionManager); - Objects.requireNonNull(objectMapper); - this.mcpSessionManager = mcpSessionManager; - this.objectMapper = objectMapper; - this.toolFilter = toolFilter; + McpSessionManager mcpSessionManager, ObjectMapper objectMapper, List toolNames) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with an McpSessionManager and no tool filter. + * + * @param mcpSessionManager A McpSessionManager instance for testing. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(McpSessionManager mcpSessionManager, ObjectMapper objectMapper) { + this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager); + this.objectMapper = Objects.requireNonNull(objectMapper); + this.toolFilter = null; } /** - * Initializes the McpToolset with Steamable HTTP server parameters. + * Initializes the McpToolset with Streamable HTTP server parameters. * * @param connectionParams The Streamable HTTP connection parameters to the MCP server. * @param objectMapper An ObjectMapper instance for parsing schemas. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. + * @param toolPredicate A {@link ToolPredicate} */ public McpToolset( StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper, - Optional toolFilter) { - Objects.requireNonNull(connectionParams); - Objects.requireNonNull(objectMapper); - this.objectMapper = objectMapper; - this.mcpSessionManager = new McpSessionManager(connectionParams); - this.toolFilter = toolFilter; + ToolPredicate toolPredicate) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = Objects.requireNonNull(toolPredicate); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolNames A list of tool names + */ + public McpToolset( + StreamableHttpServerParameters connectionParams, + ObjectMapper objectMapper, + List toolNames) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = ImmutableList.copyOf(toolNames); + } + + /** + * Initializes the McpToolset with Streamable HTTP server parameters and no tool filter. + * + * @param connectionParams The Streamable HTTP connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + */ + public McpToolset(StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper) { + this.objectMapper = Objects.requireNonNull(objectMapper); + this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams)); + this.toolFilter = null; } /** @@ -194,7 +253,7 @@ public McpToolset( * @param connectionParams The Streamable HTTP connection parameters to the MCP server. */ public McpToolset(StreamableHttpServerParameters connectionParams) { - this(connectionParams, JsonBaseModel.getMapper(), Optional.empty()); + this(connectionParams, JsonBaseModel.getMapper()); } @Override @@ -215,8 +274,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { tool -> new McpTool( tool, this.mcpSession, this.mcpSessionManager, this.objectMapper)) - .filter( - tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))); + .filter(tool -> isToolSelected(tool, toolFilter, readonlyContext))); }) .retryWhen( errorObservable -> @@ -357,16 +415,18 @@ public static McpToolset fromConfig(BaseTool.ToolConfig config, String configAbs + " for McpToolset"); } - // Convert tool filter to Optional - Optional toolFilter = Optional.ofNullable(mcpToolsetConfig.toolFilter()); - + List toolNames = mcpToolsetConfig.toolFilter(); Object connectionParameters = Optional.ofNullable(mcpToolsetConfig.stdioConnectionParams()) .or(() -> Optional.ofNullable(mcpToolsetConfig.sseServerParams())) .orElse(mcpToolsetConfig.stdioConnectionParams()); // Create McpToolset with McpSessionManager having appropriate connection parameters - return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolFilter); + if (toolNames != null) { + return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolNames); + } else { + return new McpToolset(new McpSessionManager(connectionParameters), mapper); + } } catch (IllegalArgumentException e) { throw new ConfigurationException("Failed to parse McpToolsetConfig from ToolArgsConfig", e); } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 594e47fd8..a9e7a6f8d 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -26,7 +26,6 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; import com.google.adk.agents.Callbacks.AfterModelCallback; import com.google.adk.agents.Callbacks.AfterToolCallback; @@ -35,6 +34,7 @@ import com.google.adk.agents.Callbacks.OnModelErrorCallback; import com.google.adk.agents.Callbacks.OnToolErrorCallback; import com.google.adk.events.Event; +import com.google.adk.examples.Example; import com.google.adk.models.LlmRegistry; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; @@ -46,13 +46,14 @@ import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ExampleTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Content; -import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Part; import com.google.genai.types.Schema; +import com.google.genai.types.Type; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -61,7 +62,6 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.After; @@ -211,75 +211,6 @@ public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { assertEqualIgnoringFunctionIds(events.get(3).content().get(), expectedFunctionResponseContent); } - @Test - public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() { - BaseTool tool = - new BaseTool("test_tool", "test_description") { - @Override - public Optional declaration() { - return Optional.empty(); - } - }; - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .tools(ImmutableList.of(tool)) // Set tools (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set, tools" - + " must be empty"); - } - - @Test - public void build_withOutputSchemaAndSubAgents_throwsIllegalArgumentException() { - ImmutableList subAgents = - ImmutableList.of( - createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) - .name("test_sub_agent") - .description("test_sub_agent_description") - .build()); - - Schema outputSchema = - Schema.builder() - .type("OBJECT") - .properties(ImmutableMap.of("status", Schema.builder().type("STRING").build())) - .required(ImmutableList.of("status")) - .build(); - - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .subAgents(subAgents) // Set subAgents (this should cause the error) - .build()); // Attempt to build the agent - - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set," - + " subAgents must be empty to disable agent transfer."); - } - @Test public void testBuild_withNullInstruction_setsInstructionToEmptyString() { LlmAgent agent = @@ -572,8 +503,13 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolCallSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); + + // The tool calls and responses are children of the first LLM call that produced the function + // call. + String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); + toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolResponseSpans.forEach( + s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test @@ -638,6 +574,31 @@ public void runAsync_withSubAgents_createsSpans() throws InterruptedException { assertThat(llmSpans).hasSize(2); // One for main agent, one for sub agent } + @Test + public void run_outputSchemaWithTools_allowed() { + Schema personShema = + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + "name", Schema.builder().type(Type.Known.STRING).build(), + "age", Schema.builder().type(Type.Known.INTEGER).build(), + "city", Schema.builder().type(Type.Known.STRING).build())) + .build(); + LlmAgent agent = + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .outputSchema(personShema) + .tools(new EchoTool()) + .build(); + assertThat(agent.outputSchema()).hasValue(personShema); + assertThat( + agent + .canonicalTools(new ReadonlyContext(createInvocationContext(agent))) + .count() + .blockingGet()) + .isEqualTo(1); + } + private List findSpansByName(List spans, String name) { return spans.stream().filter(s -> s.getName().equals(name)).toList(); } @@ -649,4 +610,30 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void run_withExampleTool_doesNotAddFunctionDeclarations() { + ExampleTool tool = + ExampleTool.builder() + .addExample( + Example.builder() + .input(Content.fromParts(Part.fromText("qin"))) + .output(ImmutableList.of(Content.fromParts(Part.fromText("qout")))) + .build()) + .build(); + + Content modelContent = Content.fromParts(Part.fromText("Real LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(modelContent)); + LlmAgent agent = createTestAgentBuilder(testLlm).tools(tool).build(); + InvocationContext invocationContext = createInvocationContext(agent); + + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getRequests()).hasSize(1); + LlmRequest request = testLlm.getRequests().get(0); + + assertThat(request.config().isPresent()).isTrue(); + var config = request.config().get(); + assertThat(config.tools().isPresent()).isFalse(); + } } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 22bb94e64..b1e645e1a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -177,17 +177,6 @@ public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); } - @Test - public void setRequestedToolConfirmations_withConcurrentMap_usesSameInstance() { - ConcurrentHashMap map = new ConcurrentHashMap<>(); - map.put("tool", TOOL_CONFIRMATION); - - EventActions actions = new EventActions(); - actions.setRequestedToolConfirmations(map); - - assertThat(actions.requestedToolConfirmations()).isSameInstanceAs(map); - } - @Test public void setRequestedToolConfirmations_withRegularMap_createsConcurrentMap() { ImmutableMap map = ImmutableMap.of("tool", TOOL_CONFIRMATION); diff --git a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java index 4a1dcf8e3..2d22ed3f1 100644 --- a/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java +++ b/core/src/test/java/com/google/adk/examples/ExampleUtilsTest.java @@ -49,17 +49,7 @@ public List getExamples(String query) { @Test public void buildFewShotFewShot_noExamples() { TestExampleProvider exampleProvider = new TestExampleProvider(ImmutableList.of()); - String expected = - """ - - Begin few-shot - The following are examples of user queries and model responses using the available tools. - - End few-shot - Now, try to follow these examples and complete the following conversation - \ - """; - assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEqualTo(expected); + assertThat(ExampleUtils.buildExampleSi(exampleProvider, "test query")).isEmpty(); } @Test diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 4a0b345c6..6cae6c88a 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -43,9 +43,13 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; import java.util.Map; import java.util.Optional; @@ -572,6 +576,71 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + RequestProcessor requestProcessor = + (ctx, request) -> { + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + ResponseProcessor responseProcessor = + (ctx, response) -> { + return Single.just(ResponseProcessingResult.create(response, ImmutableList.of())) + .subscribeOn(Schedulers.computation()); + }; + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + return Maybe.empty().subscribeOn(Schedulers.computation()); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + return Maybe.just(resp).subscribeOn(Schedulers.computation()); + }; + + Callbacks.OnModelErrorCallback onErrorCallback = + (ctx, req, err) -> { + return Maybe.just( + LlmResponse.builder().content(Content.fromParts(Part.fromText("error"))).build()) + .subscribeOn(Schedulers.computation()); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .onModelErrorCallback(onErrorCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow(ImmutableList.of(requestProcessor), ImmutableList.of(responseProcessor)); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + baseLlmFlow + .run(invocationContext) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(events).hasSize(1); + assertThat(events.get(0).content()).hasValue(content); + } + @Test public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { GenerateContentResponseUsageMetadata usageMetadata = @@ -588,7 +657,12 @@ public void postprocess_noResponseProcessors_onlyUsageMetadata_returnsEvent() { List events = baseLlmFlow - .postprocess(invocationContext, baseEvent, LlmRequest.builder().build(), llmResponse) + .postprocess( + invocationContext, + baseEvent, + LlmRequest.builder().build(), + llmResponse, + Context.current()) .toList() .blockingGet(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 85e78666d..1e6267dde 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -36,10 +36,15 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.ConcurrentModificationException; +import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.stream.Stream; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -780,6 +785,68 @@ public void processRequest_notEmptyContent() { assertThat(contents).containsExactly(e.content().get()); } + @Test + public void processRequest_concurrentReadAndWrite_noException() throws Exception { + LlmAgent agent = + LlmAgent.builder().name(AGENT).includeContents(LlmAgent.IncludeContents.DEFAULT).build(); + List customEvents = + new ArrayList() { + private void checkLock() { + if (!Thread.holdsLock(this)) { + throw new ConcurrentModificationException("Unsynchronized iteration detected!"); + } + } + + @Override + public Iterator iterator() { + checkLock(); + return super.iterator(); + } + + @Override + public ListIterator listIterator() { + checkLock(); + return super.listIterator(); + } + + @Override + public ListIterator listIterator(int index) { + checkLock(); + return super.listIterator(index); + } + + @Override + public Stream stream() { + checkLock(); + return super.stream(); + } + }; + + Session session = + Session.builder("test-session") + .appName("test-app") + .userId("test-user") + .events(customEvents) + .build(); + + // The list must have at least one element so that operations interacting with events trigger + // iteration. + customEvents.add(createUserEvent("dummy", "dummy")); + + InvocationContext context = + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .sessionService(sessionService) + .build(); + + LlmRequest initialRequest = LlmRequest.builder().build(); + + // This single call will throw the exception if the list is accessed insecurely. + var unused = contentsProcessor.processRequest(context, initialRequest).blockingGet(); + } + private static Event createUserEvent(String id, String text) { return Event.builder() .id(id) diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java deleted file mode 100644 index 7d1615dc2..000000000 --- a/core/src/test/java/com/google/adk/flows/llmflows/ExamplesTest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2026 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.adk.flows.llmflows; - -import static com.google.common.truth.Truth.assertThat; - -import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.examples.BaseExampleProvider; -import com.google.adk.examples.Example; -import com.google.adk.models.LlmRequest; -import com.google.adk.sessions.InMemorySessionService; -import com.google.adk.sessions.Session; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Content; -import com.google.genai.types.Part; -import java.util.List; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public final class ExamplesTest { - - private static final InMemorySessionService sessionService = new InMemorySessionService(); - - private static class TestExampleProvider implements BaseExampleProvider { - @Override - public List getExamples(String query) { - return ImmutableList.of( - Example.builder() - .input(Content.fromParts(Part.fromText("input1"))) - .output( - ImmutableList.of( - Content.builder().parts(Part.fromText("output1")).role("model").build())) - .build()); - } - } - - @Test - public void processRequest_withExampleProvider_addsExamplesToInstructions() { - LlmAgent agent = - LlmAgent.builder().name("test-agent").exampleProvider(new TestExampleProvider()).build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isNotEmpty(); - assertThat(result.updatedRequest().getSystemInstructions().get(0)) - .contains("[user]\ninput1\n\n[model]\noutput1\n"); - } - - @Test - public void processRequest_withoutExampleProvider_doesNotAddExamplesToInstructions() { - LlmAgent agent = LlmAgent.builder().name("test-agent").build(); - InvocationContext context = - InvocationContext.builder() - .invocationId("invocation1") - .session(Session.builder("session1").build()) - .sessionService(sessionService) - .agent(agent) - .userContent(Content.fromParts(Part.fromText("what is up?"))) - .runConfig(RunConfig.builder().build()) - .build(); - LlmRequest request = LlmRequest.builder().build(); - Examples examplesProcessor = new Examples(); - - RequestProcessor.RequestProcessingResult result = - examplesProcessor.processRequest(context, request).blockingGet(); - - assertThat(result.updatedRequest().getSystemInstructions()).isEmpty(); - } -} diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4ae856fc7..3771143cf 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -37,8 +37,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.schedulers.Schedulers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -144,6 +148,87 @@ public void onUserMessageCallback_pluginOrderRespected() { inOrder.verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } + @Test + public void contextPropagation_runMaybeCallbacks() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + Content expectedContent = Content.builder().build(); + when(plugin1.onUserMessageCallback(any(), any())) + .thenReturn(Maybe.just(expectedContent).subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Maybe resultMaybe; + try (Scope scope = testContext.makeCurrent()) { + resultMaybe = pluginManager.onUserMessageCallback(mockInvocationContext, content); + } + + // Assert downstream operators have the propagated context + resultMaybe + .doOnSuccess( + result -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(expectedContent); + + verify(plugin1).onUserMessageCallback(mockInvocationContext, content); + } + + @Test + public void contextPropagation_afterRunCallback() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.afterRunCallback(any())) + .thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.afterRunCallback(mockInvocationContext); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).afterRunCallback(mockInvocationContext); + } + + @Test + public void contextPropagation_close() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + when(plugin1.close()).thenReturn(Completable.complete().subscribeOn(Schedulers.computation())); + pluginManager.registerPlugin(plugin1); + + Completable resultCompletable; + try (Scope scope = testContext.makeCurrent()) { + resultCompletable = pluginManager.close(); + } + + // Assert downstream operators have the propagated context + resultCompletable + .doOnComplete( + () -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test() + .await() + .assertResult(); + + verify(plugin1).close(); + } + @Test public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java new file mode 100644 index 000000000..4f4350d1a --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java @@ -0,0 +1,367 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.RowError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.rpc.Status; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BatchProcessorTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private StreamWriter mockWriter; + private ScheduledExecutorService executor; + private BatchProcessor batchProcessor; + private Schema schema; + private Handler mockHandler; + + @Before + public void setUp() { + executor = Executors.newScheduledThreadPool(1); + batchProcessor = new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor); + schema = BigQuerySchema.getArrowSchema(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @After + public void tearDown() { + batchProcessor.close(); + executor.shutdown(); + } + + @Test + public void flush_populatesTimestampFieldCorrectly() throws Exception { + Instant now = Instant.parse("2026-03-02T19:11:49.631Z"); + Map row = new HashMap<>(); + row.put("timestamp", now); + row.put("event_type", "TEST_EVENT"); + + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + + var timestampVector = root.getVector("timestamp"); + if (!(timestampVector instanceof TimeStampMicroTZVector tzVector)) { + failureMessage[0] = "Vector should be an instance of TimeStampMicroTZVector"; + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + if (tzVector.isNull(0)) { + failureMessage[0] = "Timestamp should NOT be null"; + } else if (tzVector.get(0) != now.toEpochMilli() * 1000) { + failureMessage[0] = + "Expected " + (now.toEpochMilli() * 1000) + ", got " + tzVector.get(0); + } else { + checksPassed[0] = true; + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during check: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void flush_populatesAllBasicFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", "BASIC_EVENT"); + row.put("is_truncated", true); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals("BASIC_EVENT", root.getVector("event_type").getObject(0).toString()); + assertEquals(1, ((BitVector) root.getVector("is_truncated")).get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesJsonFields() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("content", "{\"key\": \"value\"}"); + row.put("attributes", "{\"attr\": 123}"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertEquals( + "{\"key\": \"value\"}", root.getVector("content").getObject(0).toString()); + assertEquals( + "{\"attr\": 123}", root.getVector("attributes").getObject(0).toString()); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_populatesNestedStructs() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + + List> contentParts = new ArrayList<>(); + Map part = new HashMap<>(); + part.put("mime_type", "text/plain"); + part.put("text", "hello world"); + part.put("part_index", 0L); + contentParts.add(part); + row.put("content_parts", contentParts); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + StructVector structVector = (StructVector) contentPartsVector.getDataVector(); + + assertEquals(1, ((List) contentPartsVector.getObject(0)).size()); + VarCharVector mimeTypeVector = (VarCharVector) structVector.getChild("mime_type"); + assertEquals("text/plain", mimeTypeVector.getObject(0).toString()); + + VarCharVector textVector = (VarCharVector) structVector.getChild("text"); + assertEquals("hello world", textVector.getObject(0).toString()); + + BigIntVector partIndexVector = (BigIntVector) structVector.getChild("part_index"); + assertEquals(0L, partIndexVector.get(0)); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesBigQueryErrorResponse() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ERROR_EVENT"); + + AppendRowsResponse responseWithError = + AppendRowsResponse.newBuilder() + .setError(Status.newBuilder().setMessage("Global error").build()) + .addRowErrors(RowError.newBuilder().setIndex(0).setMessage("Row error").build()) + .build(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(responseWithError)); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesGenericExceptionDuringAppend() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "EXCEPTION_EVENT"); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenThrow(new RuntimeException("Simulated failure")); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + } + + @Test + public void append_triggersFlushWhenBatchSizeReached() { + ScheduledExecutorService mockExecutor = mock(ScheduledExecutorService.class); + BatchProcessor bp = new BatchProcessor(mockWriter, 2, Duration.ofMinutes(1), 10, mockExecutor); + + Map row = new HashMap<>(); + bp.append(row); + verify(mockExecutor, never()).execute(any(Runnable.class)); + + bp.append(row); + verify(mockExecutor).execute(any(Runnable.class)); + } + + @Test + public void flush_doesNothingWhenQueueIsEmpty() throws Exception { + batchProcessor.flush(); + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void flush_handlesNullValues() throws Exception { + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", null); + row.put("is_truncated", null); + + final boolean[] checksPassed = {false}; + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + assertTrue(root.getVector("event_type").isNull(0)); + assertTrue(root.getVector("is_truncated").isNull(0)); + checksPassed[0] = true; + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue("Null checks failed", checksPassed[0]); + } + + @Test + public void flush_handlesAllocationFailure() throws Exception { + Map row = new HashMap<>(); + row.put("event_type", "ALLOC_FAIL_EVENT"); + batchProcessor.append(row); + batchProcessor.allocator.setLimit(1); + + batchProcessor.flush(); + + verify(mockWriter, never()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + boolean foundError = false; + for (LogRecord record : captor.getAllValues()) { + if (record.getLevel().equals(Level.SEVERE) + && record.getMessage().contains("Failed to write batch to BigQuery")) { + foundError = true; + break; + } + } + assertTrue("Expected SEVERE error log not found", foundError); + } + + @Test + public void close_flushesAndClosesResources() throws Exception { + try (BatchProcessor bp = + new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor)) { + Map row = new HashMap<>(); + row.put("event_type", "CLOSE_EVENT"); + bp.append(row); + } + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + verify(mockWriter).close(); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java new file mode 100644 index 000000000..8147c5cc6 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -0,0 +1,457 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BigQueryAgentAnalyticsPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private BigQuery mockBigQuery; + @Mock private StreamWriter mockWriter; + @Mock private BigQueryWriteClient mockWriteClient; + @Mock private InvocationContext mockInvocationContext; + private BaseAgent fakeAgent; + + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + private Handler mockHandler; + + @Before + public void setUp() throws Exception { + fakeAgent = new FakeAgent("agent_name"); + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setProjectId("project") + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setAutoSchemaUpgrade(false) + .setCredentials(mock(Credentials.class)) + .setCustomTags(ImmutableMap.of("global_tag", "global_value")) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + Session session = Session.builder("session_id").build(); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); + when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.userId()).thenReturn("user_id"); + + Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + } + + @Test + public void onUserMessageCallback_appendsToWriter() throws Exception { + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void beforeRunCallback_appendsToWriter() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void afterRunCallback_flushesAndAppends() throws Exception { + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void getStreamName_returnsCorrectFormat() { + BigQueryLoggerConfig config = + BigQueryLoggerConfig.builder() + .setProjectId("test-project") + .setDatasetId("test-dataset") + .setTableName("test-table") + .build(); + + String streamName = plugin.getStreamName(config); + + assertEquals( + "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", + streamName); + } + + @Test + public void formatContentParts_populatesCorrectFields() { + Content content = Content.fromParts(Part.fromText("hello")); + ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals(0, node.get("part_index").asInt()); + assertEquals("INLINE", node.get("storage_mode").asText()); + assertEquals("hello", node.get("text").asText()); + assertEquals("text/plain", node.get("mime_type").asText()); + } + + @Test + public void arrowSchema_hasJsonMetadata() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentField = schema.findField("content"); + assertNotNull(contentField); + assertEquals("google:sqlType:json", contentField.getMetadata().get("ARROW:extension:name")); + } + + @Test + public void onUserMessageCallback_handlesTableCreationFailure() throws Exception { + Logger logger = Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + Handler mockHandler = mock(Handler.class); + logger.addHandler(mockHandler); + try { + when(mockBigQuery.getTable(any(TableId.class))) + .thenThrow(new RuntimeException("Table check failed")); + Content content = Content.builder().build(); + + // Should not throw exception + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue( + captor + .getValue() + .getMessage() + .contains("Failed to check or create/upgrade BigQuery table")); + assertEquals(Level.WARNING, captor.getValue().getLevel()); + } finally { + logger.removeHandler(mockHandler); + } + } + + @Test + public void onUserMessageCallback_handlesAppendFailure() throws Exception { + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFailedFuture(new RuntimeException("Append failed"))); + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + // Flush should handle the failed future from writer.append() + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + assertTrue(captor.getValue().getMessage().contains("Failed to write batch to BigQuery")); + assertEquals(Level.SEVERE, captor.getValue().getLevel()); + } + + @Test + public void ensureTableExists_calledOnlyOnce() throws Exception { + Content content = Content.builder().build(); + + // Multiple calls to logEvent via different callbacks + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + // Verify getting table was only done once. Using fully qualified name to avoid ambiguity. + verify(mockBigQuery).getTable(any(TableId.class)); + } + + @Test + public void arrowSchema_handlesNestedFields() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentPartsField = schema.findField("content_parts"); + assertNotNull(contentPartsField); + // Repeated struct becomes a List of Structs + assertTrue(contentPartsField.getType() instanceof ArrowType.List); + + Field element = contentPartsField.getChildren().get(0); + assertEquals("element", element.getName()); + + // Check object_ref which is a nested STRUCT + Field objectRef = + element.getChildren().stream() + .filter(f -> f.getName().equals("object_ref")) + .findFirst() + .orElse(null); + assertNotNull(objectRef); + assertTrue(objectRef.getType() instanceof ArrowType.Struct); + assertFalse(objectRef.getChildren().isEmpty()); + } + + @Test + public void arrowSchema_handlesFieldNullability() { + Schema schema = BigQuerySchema.getArrowSchema(); + + // timestamp is REQUIRED in BigQuerySchema.getEventsSchema() + Field timestampField = schema.findField("timestamp"); + assertNotNull(timestampField); + assertFalse(timestampField.isNullable()); + + // event_type is NULLABLE in BigQuerySchema.getEventsSchema() + Field eventTypeField = schema.findField("event_type"); + assertNotNull(eventTypeField); + assertTrue(eventTypeField.isNullable()); + } + + @Test + public void logEvent_populatesCommonFields() throws Exception { + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + Schema schema = BigQuerySchema.getArrowSchema(); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, plugin.batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + } else if (!Objects.equals( + root.getVector("event_type").getObject(0).toString(), "USER_MESSAGE")) { + failureMessage[0] = + "Wrong event_type: " + root.getVector("event_type").getObject(0); + } else if (!root.getVector("agent").getObject(0).toString().equals("agent_name")) { + failureMessage[0] = "Wrong agent: " + root.getVector("agent").getObject(0); + } else if (!root.getVector("session_id") + .getObject(0) + .toString() + .equals("session_id")) { + failureMessage[0] = + "Wrong session_id: " + root.getVector("session_id").getObject(0); + } else if (!root.getVector("invocation_id") + .getObject(0) + .toString() + .equals("invocation_id")) { + failureMessage[0] = + "Wrong invocation_id: " + root.getVector("invocation_id").getObject(0); + } else if (!root.getVector("user_id").getObject(0).toString().equals("user_id")) { + failureMessage[0] = "Wrong user_id: " + root.getVector("user_id").getObject(0); + } else if (((TimeStampMicroTZVector) root.getVector("timestamp")).get(0) <= 0) { + failureMessage[0] = "Timestamp not populated"; + } else { + // Check content and content_parts + String contentJson = root.getVector("content").getObject(0).toString(); + if (!contentJson.contains("test message")) { + failureMessage[0] = "Wrong content: " + contentJson; + } else { + ListVector contentPartsVector = (ListVector) root.getVector("content_parts"); + if (((List) contentPartsVector.getObject(0)).isEmpty()) { + failureMessage[0] = "content_parts is empty"; + } else { + // Check attributes + String attributesJson = root.getVector("attributes").getObject(0).toString(); + if (!attributesJson.contains("global_tag") + || !attributesJson.contains("global_value")) { + failureMessage[0] = "Wrong attributes: " + attributesJson; + } else { + checksPassed[0] = true; + } + } + } + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during inspection: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + Content content = Content.fromParts(Part.fromText("test message")); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + plugin.batchProcessor.flush(); + + assertTrue(failureMessage[0], checksPassed[0]); + } + + @Test + public void logEvent_populatesTraceDetails() throws Exception { + String traceId = "4bf92f3577b34da6a3ce929d0e0e4736"; + String spanId = "00f067aa0ba902b7"; + + SpanContext mockSpanContext = mock(SpanContext.class); + when(mockSpanContext.isValid()).thenReturn(true); + when(mockSpanContext.getTraceId()).thenReturn(traceId); + when(mockSpanContext.getSpanId()).thenReturn(spanId); + + Span mockSpan = Span.wrap(mockSpanContext); + + try (Scope scope = mockSpan.makeCurrent()) { + Content content = Content.builder().build(); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals(traceId, row.get("trace_id")); + assertEquals(spanId, row.get("span_id")); + } + } + + @Test + public void complexType_appendsToWriter() throws Exception { + Part part = Part.fromText("test text"); + Content content = Content.fromParts(part); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void onEventCallback_populatesCorrectFields() throws Exception { + Event event = + Event.builder() + .author("agent_author") + .content(Content.fromParts(Part.fromText("event content"))) + .build(); + + plugin.onEventCallback(mockInvocationContext, event).blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("EVENT", row.get("event_type")); + assertEquals("agent_name", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("agent_author")); + assertTrue(row.get("content").toString().contains("event content")); + } + + @Test + public void onModelErrorCallback_populatesCorrectFields() throws Exception { + CallbackContext mockCallbackContext = mock(CallbackContext.class); + when(mockCallbackContext.invocationContext()).thenReturn(mockInvocationContext); + when(mockCallbackContext.agentName()).thenReturn("agent_in_context"); + LlmRequest.Builder mockLlmRequestBuilder = mock(LlmRequest.Builder.class); + Throwable error = new RuntimeException("model error message"); + + plugin + .onModelErrorCallback(mockCallbackContext, mockLlmRequestBuilder, error) + .blockingSubscribe(); + + Map row = plugin.batchProcessor.queue.poll(); + assertNotNull("Row not found in queue", row); + assertEquals("MODEL_ERROR", row.get("event_type")); + assertEquals("agent_in_context", row.get("agent")); + assertTrue(row.get("attributes").toString().contains("model error message")); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 8a0a84b08..a3e21cb73 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -24,6 +24,8 @@ import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.stream; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.CALLS_REAL_METHODS; import static org.mockito.Mockito.mock; @@ -36,6 +38,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.Functions; import com.google.adk.models.LlmResponse; @@ -57,17 +60,22 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; +import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -75,6 +83,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; @RunWith(JUnit4.class) public final class RunnerTest { @@ -846,6 +855,19 @@ private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } + private static Content createInlineDataContent(byte[]... data) { + return Content.builder() + .parts( + stream(data) + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) + .toArray(Part[]::new)) + .build(); + } + + private static Content createInlineDataContent(String... data) { + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); + } + @Test public void runAsync_createsInvocationSpan() { var unused = @@ -977,6 +999,84 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan.get().hasEnded()).isTrue(); } + @Test + public void runAsync_createsToolSpansWithCorrectParent() { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + + var unused = + runnerWithTool + .runAsync( + sessionWithTool.sessionKey(), + createContent("from user"), + RunConfig.builder().build()) + .toList() + .blockingGet(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + assertThat(llmSpans).hasSize(2); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + + @Test + public void runLive_createsToolSpansWithCorrectParent() throws Exception { + LlmAgent agentWithTool = + createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); + Runner runnerWithTool = + Runner.builder().app(App.builder().name("test").rootAgent(agentWithTool).build()).build(); + Session sessionWithTool = + runnerWithTool.sessionService().createSession("test", "user").blockingGet(); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runnerWithTool + .runLive(sessionWithTool.sessionKey(), liveRequestQueue, RunConfig.builder().build()) + .test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + + List spans = openTelemetryRule.getSpans(); + List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); + List toolCallSpans = + spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); + List toolResponseSpans = + spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + + // In runLive, there is one call_llm span for the execution + assertThat(llmSpans).hasSize(1); + assertThat(toolCallSpans).hasSize(1); + assertThat(toolResponseSpans).hasSize(1); + + List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); + String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); + String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + + assertThat(toolCallParentId).isEqualTo(toolResponseParentId); + assertThat(llmSpanIds).contains(toolCallParentId); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -1188,6 +1288,53 @@ public void close_closesPluginsAndCodeExecutors() { verify(plugin).close(); } + @Test + public void runAsync_contextPropagation() { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + + List events; + try (Scope scope = testContext.makeCurrent()) { + events = + runner + .runAsync("user", session.id(), createContent("test message")) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .toList() + .blockingGet(); + } + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + + @Test + public void runLive_contextPropagation() throws Exception { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.current().with(testKey, "test-value"); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber; + try (Scope scope = testContext.makeCurrent()) { + testSubscriber = + runner + .runLive(session, liveRequestQueue, RunConfig.builder().build()) + .doOnNext( + event -> { + assertThat(Context.current().get(testKey)).isEqualTo("test-value"); + }) + .test(); + } + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void buildRunnerWithPlugins_success() { BasePlugin plugin1 = mockPlugin("test1"); @@ -1203,4 +1350,40 @@ public static ImmutableMap echoTool(String message) { return ImmutableMap.of("message", message); } } + + @Test + public void runner_executesSaveArtifactFlow() { + // arrange + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) + .thenReturn( + Single.defer( + () -> { + // we want to assert not only that the saveArtifact method was + // called, but also that the flow that it returned was run, so + // we need to record the call in a counter + artifactsSavedCounter.incrementAndGet(); + return Single.just(42); + })); + Runner runner = + Runner.builder() + .app(App.builder().name("test").rootAgent(agent).build()) + .artifactService(mockArtifactService) + .build(); + session = runner.sessionService().createSession("test", "user").blockingGet(); + // each inline data will be saved using our mock artifact service + Content content = createInlineDataContent("test data", "test data 2"); + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); + + // act + var events = runner.runAsync("user", session.id(), content, runConfig).test(); + + // assert + events.assertComplete(); + // artifacts were saved + assertThat(artifactsSavedCounter.get()).isEqualTo(2); + // agent was run + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); + } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 41e156ffd..0d9235b1b 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.time.Instant; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -214,6 +215,24 @@ public void appendEvent_removesState() { assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey"); } + @Test + public void appendEvent_updatesSessionTimestampWithFractionalSeconds() { + InMemorySessionService sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); + + // Add an event with a timestamp that contains a fractional second + Event eventAdd = Event.builder().timestamp(5500).build(); + var unused = sessionService.appendEvent(session, eventAdd).blockingGet(); + + // Verify the last modified timestamp contains a fractional second + Session retrievedSession = + sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .blockingGet(); + assertThat(retrievedSession.lastUpdateTime()).isEqualTo(Instant.ofEpochSecond(5, 500000000L)); + } + @Test public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index e5795d61f..1ee018848 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,6 +16,7 @@ package com.google.adk.telemetry; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -32,10 +33,15 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.SessionKey; +import com.google.adk.testing.TestLlm; +import com.google.adk.testing.TestUtils; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GenerateContentResponseUsageMetadata; @@ -54,6 +60,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -123,10 +130,7 @@ public void testToolCallSpanLinksToParent() { parentSpanData.getSpanContext().getTraceId(), toolCallSpanData.getSpanContext().getTraceId()); - assertEquals( - "Tool call's parent should be the parent span", - parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, toolCallSpanData); } @Test @@ -146,7 +150,7 @@ public void testToolCallWithoutParentCreatesRootSpan() { // Then: Should create root span (backward compatible) List spans = openTelemetryRule.getSpans(); - assertEquals("Should have exactly 1 span", 1, spans.size()); + assertThat(spans).hasSize(1); SpanData toolCallSpanData = spans.get(0); assertFalse( @@ -193,7 +197,7 @@ public void testNestedSpanHierarchy() { List spans = openTelemetryRule.getSpans(); // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response // [testTool]". - assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); + assertThat(spans).hasSize(4); SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -210,22 +214,13 @@ public void testNestedSpanHierarchy() { SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); // invocation should be child of parent - assertEquals( - "Invocation should be child of parent", - parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, invocationSpanData); // tool_call should be child of invocation - assertEquals( - "Tool call should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, toolCallSpanData); // tool_response should be child of tool_call - assertEquals( - "Tool response should be child of tool call", - toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId()); + assertParent(toolCallSpanData, toolResponseSpanData); } @Test @@ -253,7 +248,6 @@ public void testMultipleSpansInParallel() { // Verify all tool calls link to same parent SpanData parentSpanData = findSpanByName("parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - String parentSpanId = parentSpanData.getSpanContext().getSpanId(); // All tool calls should have same trace ID and parent span ID List toolCallSpans = @@ -261,7 +255,7 @@ public void testMultipleSpansInParallel() { .filter(s -> s.getName().startsWith("tool_call")) .toList(); - assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); + assertThat(toolCallSpans).hasSize(3); toolCallSpans.forEach( span -> { @@ -269,10 +263,7 @@ public void testMultipleSpansInParallel() { "Tool call should have same trace ID as parent", parentTraceId, span.getSpanContext().getTraceId()); - assertEquals( - "Tool call should have parent as parent span", - parentSpanId, - span.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, span); }); } @@ -298,10 +289,7 @@ public void testInvokeAgentSpanLinksToInvocation() { SpanData invocationSpanData = findSpanByName("invocation"); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - assertEquals( - "Agent run should be child of invocation", - invocationSpanData.getSpanContext().getSpanId(), - invokeAgentSpanData.getParentSpanContext().getSpanId()); + assertParent(invocationSpanData, invokeAgentSpanData); } @Test @@ -323,15 +311,12 @@ public void testCallLlmSpanLinksToAgentRun() { } List spans = openTelemetryRule.getSpans(); - assertEquals("Should have 2 spans", 2, spans.size()); + assertThat(spans).hasSize(2); SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName("call_llm"); - assertEquals( - "Call LLM should be child of agent run", - invokeAgentSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId()); + assertParent(invokeAgentSpanData, callLlmSpanData); } @Test @@ -349,10 +334,7 @@ public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { SpanData parentSpanData = findSpanByName("invocation"); SpanData agentSpanData = findSpanByName("invoke_agent"); - assertEquals( - "Agent span should be a child of the invocation span", - parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, agentSpanData); } @Test @@ -380,9 +362,7 @@ public void testTraceFlowable() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData flowableSpanData = findSpanByName("flowable"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - flowableSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, flowableSpanData); assertTrue(flowableSpanData.hasEnded()); } @@ -469,9 +449,7 @@ public void testTraceTransformer() throws InterruptedException { SpanData parentSpanData = findSpanByName("parent"); SpanData transformerSpanData = findSpanByName("transformer"); - assertEquals( - parentSpanData.getSpanContext().getSpanId(), - transformerSpanData.getParentSpanContext().getSpanId()); + assertParent(parentSpanData, transformerSpanData); assertTrue(transformerSpanData.hasEnded()); } @@ -485,7 +463,7 @@ public void testTraceAgentInvocation() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -504,7 +482,7 @@ public void testTraceToolCall() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -541,7 +519,7 @@ public void testTraceToolResponse() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); @@ -578,7 +556,7 @@ public void testTraceCallLlm() { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); @@ -606,12 +584,12 @@ public void testTraceSendData() { Tracing.traceSendData( buildInvocationContext(), "event-1", - ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + ImmutableList.of(Content.fromParts(Part.fromText("hello")))); } finally { span.end(); } List spans = openTelemetryRule.getSpans(); - assertEquals(1, spans.size()); + assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals( @@ -653,37 +631,23 @@ public void baseAgentRunAsync_propagatesContext() throws InterruptedException { } SpanData parent = findSpanByName("parent"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals(parent.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, agentSpan); } @Test public void runnerRunAsync_propagatesContext() throws InterruptedException { BaseAgent agent = new TestAgent(); - Runner runner = Runner.builder().agent(agent).appName("test-app").build(); Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { - Session session = - runner - .sessionService() - .createSession(new SessionKey("test-app", "test-user", "test-session")) - .blockingGet(); - Content newMessage = Content.fromParts(Part.fromText("hi")); - RunConfig runConfig = RunConfig.builder().build(); - runner - .runAsync(session.userId(), session.id(), newMessage, runConfig, null) - .test() - .await() - .assertComplete(); + runAgent(agent); } finally { parentSpan.end(); } SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); } @Test @@ -713,10 +677,173 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { SpanData parent = findSpanByName("parent"); SpanData invocation = findSpanByName("invocation"); SpanData agentSpan = findSpanByName("invoke_agent test-agent"); - assertEquals( - parent.getSpanContext().getSpanId(), invocation.getParentSpanContext().getSpanId()); - assertEquals( - invocation.getSpanContext().getSpanId(), agentSpan.getParentSpanContext().getSpanId()); + assertParent(parent, invocation); + assertParent(invocation, agentSpan); + } + + @Test + public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when an agent calls an LLM, + // which then invokes a tool. The expected hierarchy is: + // invocation + // └── invoke_agent test_agent + // ├── call_llm + // │ ├── tool_call [search_flights] + // │ └── tool_response [search_flights] + // └── call_llm + + SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); + + TestLlm testLlm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + searchFlightsTool.name(), ImmutableMap.of("destination", "SFO"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("done")))); + + LlmAgent agentWithTool = + LlmAgent.builder() + .name("test_agent") + .description("description") + .model(testLlm) + .tools(ImmutableList.of(searchFlightsTool)) + .build(); + + runAgent(agentWithTool); + + SpanData invocation = findSpanByName("invocation"); + SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); + SpanData toolCall = findSpanByName("tool_call [search_flights]"); + SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + SpanData callLlm1 = callLlmSpans.get(0); + SpanData callLlm2 = callLlmSpans.get(1); + + // Assert hierarchy: + // invocation + // └── invoke_agent test_agent + assertParent(invocation, invokeAgent); + // ├── call_llm 1 + assertParent(invokeAgent, callLlm1); + // │ ├── tool_call [search_flights] + assertParent(callLlm1, toolCall); + // │ └── tool_response [search_flights] + assertParent(callLlm1, toolResponse); + // └── call_llm 2 + assertParent(invokeAgent, callLlm2); + } + + @Test + public void testNestedAgentTraceHierarchy() throws InterruptedException { + // This test verifies the trace hierarchy created when AgentA transfers to AgentB. + // The expected hierarchy is: + // invocation + // └── invoke_agent AgentA + // ├── call_llm + // │ ├── tool_call [transfer_to_agent] + // │ └── tool_response [transfer_to_agent] + // └── invoke_agent AgentB + // └── call_llm + TestLlm llm = + TestUtils.createTestLlm( + TestUtils.createLlmResponse( + Content.builder() + .role("model") + .parts( + Part.fromFunctionCall( + "transfer_to_agent", ImmutableMap.of("agent_name", "AgentB"))) + .build()), + TestUtils.createLlmResponse(Content.fromParts(Part.fromText("agent b response")))); + LlmAgent agentB = LlmAgent.builder().name("AgentB").description("Agent B").model(llm).build(); + + LlmAgent agentA = + LlmAgent.builder() + .name("AgentA") + .description("Agent A") + .model(llm) + .subAgents(ImmutableList.of(agentB)) + .build(); + + runAgent(agentA); + + SpanData invocation = findSpanByName("invocation"); + SpanData agentASpan = findSpanByName("invoke_agent AgentA"); + SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); + SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); + + List callLlmSpans = + openTelemetryRule.getSpans().stream() + .filter(s -> s.getName().equals("call_llm")) + .sorted(Comparator.comparing(SpanData::getStartEpochNanos)) + .toList(); + assertThat(callLlmSpans).hasSize(2); + + SpanData agentACallLlm1 = callLlmSpans.get(0); + SpanData agentBCallLlm = callLlmSpans.get(1); + + assertParent(invocation, agentASpan); + assertParent(agentASpan, agentACallLlm1); + assertParent(agentACallLlm1, toolCall); + assertParent(agentACallLlm1, toolResponse); + assertParent(agentASpan, agentBSpan); + assertParent(agentBSpan, agentBCallLlm); + } + + private void runAgent(BaseAgent agent) throws InterruptedException { + Runner runner = Runner.builder().agent(agent).appName("test-app").build(); + Session session = + runner + .sessionService() + .createSession(new SessionKey("test-app", "test-user", "test-session")) + .blockingGet(); + Content newMessage = Content.fromParts(Part.fromText("hi")); + RunConfig runConfig = RunConfig.builder().build(); + runner + .runAsync(session.sessionKey(), newMessage, runConfig, null) + .test() + .await() + .assertComplete(); + } + + /** Tool for testing. */ + public static class SearchFlightsTool extends BaseTool { + public SearchFlightsTool() { + super("search_flights", "Search for flights tool"); + } + + @Override + public Single> runAsync(Map args, ToolContext context) { + return Single.just(ImmutableMap.of("result", args)); + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name("search_flights") + .description("Search for flights tool") + .build()); + } + } + + /** + * Asserts that the parent span is the parent of the child span. + * + * @param parent The parent span. + * @param child The child span. + */ + private void assertParent(SpanData parent, SpanData child) { + assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId()); } /** diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index 2a07e7a44..d3c8da5aa 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,12 +2,15 @@ import static com.google.common.truth.Truth.assertThat; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GoogleMaps; @@ -17,6 +20,7 @@ import com.google.genai.types.UrlContext; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.observers.TestObserver; import java.util.Map; import java.util.Optional; import org.junit.Test; @@ -27,6 +31,20 @@ @RunWith(JUnit4.class) public final class BaseToolTest { + private final BaseTool doublingBaseTool = + new BaseTool("doubling-test-tool", "returns doubled args") { + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + String sArg = (String) args.get("s"); + Integer iArg = (Integer) args.get("i"); + return Single.just( + ImmutableMap.of( + "s", sArg + sArg, + "i", iArg + iArg)); + } + }; + @Test public void processLlmRequestNoDeclarationReturnsSameRequest() { BaseTool tool = @@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); } + + @Test + public void runAsync_withTypeReference_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, new TypeReference() {}); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClass_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + + Single out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single> out = + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper); + TestObserver> testObserver = out.test(); + + testObserver.assertComplete(); + ImmutableMap expected = ImmutableMap.of("i", 22, "s", "bazbaz"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, + /* toolContext= */ null, + objectMapper, + new TypeReference() {}); + + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(84, "foofoo"); + testObserver.assertValue(expected); + } + + @Test + public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception { + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); + ObjectMapper objectMapper = new ObjectMapper(); + + Single out = + doublingBaseTool.runAsync( + testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class); + TestObserver testObserver = out.test(); + + testObserver.assertComplete(); + TestToolArgs expected = new TestToolArgs(42, "barbar"); + testObserver.assertValue(expected); + } + + public record TestToolArgs(int i, String s) {} } diff --git a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java index 4e80ed0ff..55e5d8f93 100644 --- a/core/src/test/java/com/google/adk/tools/ExampleToolTest.java +++ b/core/src/test/java/com/google/adk/tools/ExampleToolTest.java @@ -305,4 +305,30 @@ static final class WrongTypeProviderHolder { private WrongTypeProviderHolder() {} } + + @Test + public void declaration_isEmpty() { + ExampleTool tool = ExampleTool.builder().build(); + assertThat(tool.declaration().isPresent()).isFalse(); + } + + @Test + public void processLlmRequest_doesNotAddFunctionDeclarations() { + ExampleTool tool = ExampleTool.builder().addExample(makeExample("qin", "qout")).build(); + InvocationContext ctx = buildInvocationContext(); + LlmRequest.Builder builder = LlmRequest.builder().model("gemini-2.0-flash"); + + tool.processLlmRequest(builder, ToolContext.builder(ctx).build()).blockingAwait(); + LlmRequest updated = builder.build(); + + if (updated.config().isPresent()) { + var config = updated.config().get(); + if (config.tools().isPresent()) { + var tools = config.tools().get(); + boolean hasFunctionDeclarations = + tools.stream().anyMatch(t -> t.functionDeclarations().isPresent()); + assertThat(hasFunctionDeclarations).isFalse(); + } + } + } } diff --git a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java index 3d322b73f..0db218347 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java @@ -34,7 +34,6 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.List; -import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -324,7 +323,7 @@ public void getTools_withToolFilter_returnsFilteredTools() { when(mockMcpSyncClient.listTools()).thenReturn(mockResult); McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter)); + new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), toolFilter); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); @@ -340,8 +339,7 @@ public void getTools_retriesAndFailsAfterMaxRetries() { when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient); when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception")); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); toolset .getTools(mockReadonlyContext) @@ -362,8 +360,7 @@ public void getTools_succeedsOnLastRetryAttempt() { .thenThrow(new RuntimeException("Attempt 2 failed")) .thenReturn(mockResult); - McpToolset toolset = - new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty()); + McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper()); List tools = toolset.getTools(mockReadonlyContext).toList().blockingGet(); diff --git a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java index 7d70a0efb..eab293de0 100644 --- a/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java +++ b/dev/src/main/java/com/google/adk/plugins/InvocationReplayState.java @@ -16,8 +16,8 @@ package com.google.adk.plugins; import com.google.adk.plugins.recordings.Recordings; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** Per-invocation replay state to isolate concurrent runs. */ class InvocationReplayState { @@ -33,7 +33,7 @@ public InvocationReplayState(String testCasePath, int userMessageIndex, Recordin this.testCasePath = testCasePath; this.userMessageIndex = userMessageIndex; this.recordings = recordings; - this.agentReplayIndices = new HashMap<>(); + this.agentReplayIndices = new ConcurrentHashMap<>(); } public String getTestCasePath() { @@ -57,7 +57,6 @@ public void setAgentReplayIndex(String agentName, int index) { } public void incrementAgentReplayIndex(String agentName) { - int currentIndex = getAgentReplayIndex(agentName); - setAgentReplayIndex(agentName, currentIndex + 1); + agentReplayIndices.merge(agentName, 1, Integer::sum); } } diff --git a/pom.xml b/pom.xml index 11696db73..0be05a629 100644 --- a/pom.xml +++ b/pom.xml @@ -49,9 +49,9 @@ cloud libraries. Once they update their otel dependencies we can consider updating ours here as well --> 1.51.0 - 0.14.0 + 0.17.2 2.47.0 - 1.41.0 + 1.43.0 4.33.5 5.11.4 5.20.0 @@ -73,6 +73,8 @@ 2.15.0 3.9.0 5.6 + 4.1.118.Final + @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -85,6 +87,13 @@ pom import + + io.netty + netty-bom + ${netty.version} + pom + import + com.google.cloud libraries-bom @@ -338,6 +347,8 @@ + + ${surefire.argLine} plain