Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.ToolConfig;
Expand All @@ -52,6 +53,7 @@
import dev.langchain4j.data.pdf.PdfFile;
import dev.langchain4j.data.video.Video;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
Expand All @@ -65,6 +67,7 @@
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
Expand All @@ -84,24 +87,109 @@ public class LangChain4j extends BaseLlm {
private final ChatModel chatModel;
private final StreamingChatModel streamingChatModel;
private final ObjectMapper objectMapper;
private final TokenCountEstimator tokenCountEstimator;

public static Builder builder() {
return new Builder();
}

public static class Builder {
private ChatModel chatModel;
private StreamingChatModel streamingChatModel;
private String modelName;
private TokenCountEstimator tokenCountEstimator;

private Builder() {}

public Builder chatModel(ChatModel chatModel) {
this.chatModel = chatModel;
return this;
}

public Builder streamingChatModel(StreamingChatModel streamingChatModel) {
this.streamingChatModel = streamingChatModel;
return this;
}

public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}

public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) {
this.tokenCountEstimator = tokenCountEstimator;
return this;
}

public LangChain4j build() {
if (chatModel == null && streamingChatModel == null) {
throw new IllegalStateException(
"At least one of chatModel or streamingChatModel must be provided");
}

String effectiveModelName = modelName;
if (effectiveModelName == null) {
if (chatModel != null) {
effectiveModelName = chatModel.defaultRequestParameters().modelName();
} else {
effectiveModelName = streamingChatModel.defaultRequestParameters().modelName();
}
}

if (effectiveModelName == null) {
throw new IllegalStateException("Model name cannot be null");
}

return new LangChain4j(
chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator);
}
}

private LangChain4j(
ChatModel chatModel,
StreamingChatModel streamingChatModel,
String modelName,
TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "model name cannot be null"));
this.chatModel = chatModel;
this.streamingChatModel = streamingChatModel;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(ChatModel chatModel) {
this(chatModel, (TokenCountEstimator) null);
}

public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) {
super(
Objects.requireNonNull(
chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(ChatModel chatModel, String modelName) {
this(chatModel, modelName, (TokenCountEstimator) null);
}

public LangChain4j(
ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "chat model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
this.streamingChatModel = null;
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(StreamingChatModel streamingChatModel) {
this(streamingChatModel, (TokenCountEstimator) null);
}

public LangChain4j(
StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) {
super(
Objects.requireNonNull(
streamingChatModel.defaultRequestParameters().modelName(),
Expand All @@ -110,22 +198,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) {
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

public LangChain4j(StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this(streamingChatModel, modelName, (TokenCountEstimator) null);
}

public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) {
super(Objects.requireNonNull(modelName, "model name cannot be null"));
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
public LangChain4j(
StreamingChatModel streamingChatModel,
String modelName,
TokenCountEstimator tokenCountEstimator) {
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
this.chatModel = null;
this.streamingChatModel =
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
this.objectMapper = new ObjectMapper();
this.tokenCountEstimator = tokenCountEstimator;
}

@Override
Expand Down Expand Up @@ -186,7 +275,7 @@ public void onError(Throwable throwable) {

ChatRequest chatRequest = toChatRequest(llmRequest);
ChatResponse chatResponse = chatModel.chat(chatRequest);
LlmResponse llmResponse = toLlmResponse(chatResponse);
LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest);

return Flowable.just(llmResponse);
}
Expand Down Expand Up @@ -511,11 +600,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
}
}

private LlmResponse toLlmResponse(ChatResponse chatResponse) {
private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) {
Content content =
Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build();

return LlmResponse.builder().content(content).build();
LlmResponse.Builder builder = LlmResponse.builder().content(content);
TokenUsage tokenUsage = chatResponse.tokenUsage();
if (tokenCountEstimator != null) {
try {
int estimatedInput =
tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages());
int estimatedOutput =
tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text());
int estimatedTotal = estimatedInput + estimatedOutput;
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(estimatedInput)
.candidatesTokenCount(estimatedOutput)
.totalTokenCount(estimatedTotal)
.build());
} catch (Exception e) {
e.printStackTrace();
}
} else if (tokenUsage != null) {
builder.usageMetadata(
GenerateContentResponseUsageMetadata.builder()
.promptTokenCount(tokenUsage.inputTokenCount())
.candidatesTokenCount(tokenUsage.outputTokenCount())
.totalTokenCount(tokenUsage.totalTokenCount())
.build());
}

return builder.build();
}

private List<Part> toParts(AiMessage aiMessage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.TokenUsage;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -690,6 +692,141 @@ void testGenerateContentWithStructuredResponseJsonSchema() {
}

@Test
@DisplayName(
"Should use TokenCountEstimator to estimate token usage when TokenUsage is not available")
void testTokenCountEstimatorFallback() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts)
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response has usage metadata estimated by TokenCountEstimator
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("The weather is sunny today.");

// IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20

// Verify the estimator was actually called
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided")
void testTokenCountEstimatorPriority() {
// Given
// Create a mock TokenCountEstimator
final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class);
when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator
when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator

// Create LangChain4j with the TokenCountEstimator using Builder
final LangChain4j langChain4jWithEstimator =
LangChain4j.builder()
.chatModel(chatModel)
.modelName(MODEL_NAME)
.tokenCountEstimator(tokenCountEstimator)
.build();

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("What is the weather today?"))))
.build();

// Mock ChatResponse WITH actual TokenUsage from the LLM
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("The weather is sunny today.");
final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage
assertThat(response).isNotNull();
assertThat(response.usageMetadata()).isPresent();
final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get();
assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator
assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator
assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50

// Verify the estimator was called (it takes priority)
verify(tokenCountEstimator).estimateTokenCountInMessages(any());
verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today.");
}

@Test
@DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided")
void testNoUsageMetadataWithoutEstimator() {
// Given
// Create LangChain4j WITHOUT TokenCountEstimator (default behavior)
final LangChain4j langChain4jNoEstimator = new LangChain4j(chatModel, MODEL_NAME);

// Create a LlmRequest
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Hello, world!"))))
.build();

// Mock ChatResponse WITHOUT TokenUsage
final ChatResponse chatResponse = mock(ChatResponse.class);
final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?");
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response =
langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response does NOT have usage metadata
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?");

// IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator
assertThat(response.usageMetadata()).isEmpty();
}

@DisplayName("Should handle MCP tools with parametersJsonSchema")
void testGenerateContentWithMcpToolParametersJsonSchema() {
// Given
Expand Down
Loading