From d21ee82048ff184a01547558bd24c6eb2de5addf Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Wed, 18 Mar 2026 21:46:07 +0100 Subject: [PATCH] chore: clean PR #623 * feat(langchain4j): add UsageMetadata extraction from TokenCountEstimator or TokenUsage Signed-off-by: Rhuan Rocha feat(langchain4j): fixing exception treatment Signed-off-by: Rhuan Rocha feat(langchain4j): fixing exception treatment Signed-off-by: Rhuan Rocha feat(langchain4j): refactoring constructor Signed-off-by: Rhuan Rocha * Delete contrib/samples/a2a_basic/bin/.settings/org.eclipse.core.resources.prefs * Delete contrib/samples/a2a_basic/bin/.settings/org.eclipse.m2e.core.prefs * Delete contrib/samples/a2a_basic/bin/.project * Delete contrib/samples/a2a_basic/bin/pom.xml * Delete contrib/samples/a2a_basic/bin/README.md * chore: format --------- Signed-off-by: Rhuan Rocha Co-authored-by: Rhuan Rocha --- .../adk/models/langchain4j/LangChain4j.java | 138 ++++++++++++++++-- .../models/langchain4j/LangChain4jTest.java | 137 +++++++++++++++++ 2 files changed, 264 insertions(+), 11 deletions(-) diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 3ccb1e029..9f7b8259f 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -30,6 +30,7 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; @@ -52,6 +53,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -65,6 +67,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -84,24 +87,109 @@ public class LangChain4j extends BaseLlm { private final ChatModel chatModel; private final StreamingChatModel streamingChatModel; private final ObjectMapper objectMapper; + private final TokenCountEstimator tokenCountEstimator; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private ChatModel chatModel; + private StreamingChatModel streamingChatModel; + private String modelName; + private TokenCountEstimator tokenCountEstimator; + + private Builder() {} + + public Builder chatModel(ChatModel chatModel) { + this.chatModel = chatModel; + return this; + } + + public Builder streamingChatModel(StreamingChatModel streamingChatModel) { + this.streamingChatModel = streamingChatModel; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) { + this.tokenCountEstimator = tokenCountEstimator; + return this; + } + + public LangChain4j build() { + if (chatModel == null && streamingChatModel == null) { + throw new IllegalStateException( + "At least one of chatModel or streamingChatModel must be provided"); + } + + String effectiveModelName = modelName; + if (effectiveModelName == null) { + if (chatModel != null) { + effectiveModelName = chatModel.defaultRequestParameters().modelName(); + } else { + effectiveModelName = streamingChatModel.defaultRequestParameters().modelName(); + } + } + + if (effectiveModelName == null) { + throw new IllegalStateException("Model name cannot be null"); + } + + return new LangChain4j( + chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator); + } + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + String modelName, + TokenCountEstimator tokenCountEstimator) { + super(Objects.requireNonNull(modelName, "model name cannot be null")); + this.chatModel = chatModel; + this.streamingChatModel = streamingChatModel; + this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; + } public LangChain4j(ChatModel chatModel) { + this(chatModel, (TokenCountEstimator) null); + } + + public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) { super( Objects.requireNonNull( chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(ChatModel chatModel, String modelName) { + this(chatModel, modelName, (TokenCountEstimator) null); + } + + public LangChain4j( + ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) { super(Objects.requireNonNull(modelName, "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(StreamingChatModel streamingChatModel) { + this(streamingChatModel, (TokenCountEstimator) null); + } + + public LangChain4j( + StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) { super( Objects.requireNonNull( streamingChatModel.defaultRequestParameters().modelName(), @@ -110,22 +198,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) { this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(streamingChatModel, modelName, (TokenCountEstimator) null); } - public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + public LangChain4j( + StreamingChatModel streamingChatModel, + String modelName, + TokenCountEstimator tokenCountEstimator) { + super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); + this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } @Override @@ -186,7 +275,7 @@ public void onError(Throwable throwable) { ChatRequest chatRequest = toChatRequest(llmRequest); ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -511,11 +600,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator != null) { + try { + int estimatedInput = + tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 076bb79a3..8d0c3533a 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -26,6 +26,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +34,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -690,6 +692,141 @@ void testGenerateContentWithStructuredResponseJsonSchema() { } @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = new LangChain4j(chatModel, MODEL_NAME); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } + @DisplayName("Should handle MCP tools with parametersJsonSchema") void testGenerateContentWithMcpToolParametersJsonSchema() { // Given