-
Notifications
You must be signed in to change notification settings - Fork 300
Updating and cleaning PR #623 #1054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
glaforge
wants to merge
20
commits into
google:main
Choose a base branch
from
glaforge:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+264
−11
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c4fab4d
Merge pull request #2 from google/main
glaforge 422437a
Merge branch 'google:main' into main
glaforge a0d4a2a
Merge branch 'google:main' into main
glaforge 3821edb
Merge branch 'google:main' into main
glaforge cee5bb3
Merge branch 'google:main' into main
glaforge fd3bf1c
Merge branch 'google:main' into main
glaforge e786cb6
Merge branch 'google:main' into main
glaforge b2ea63e
Merge branch 'google:main' into main
glaforge 327fe5d
Merge branch 'google:main' into main
glaforge b76f4b9
Merge branch 'google:main' into main
glaforge 0e30c6b
Merge branch 'google:main' into main
glaforge 8afb64b
Merge branch 'google:main' into main
glaforge 032d69e
Merge branch 'google:main' into main
glaforge 23b6827
Merge branch 'google:main' into main
glaforge 23965a6
Merge branch 'google:main' into main
glaforge 5e0e2b2
Merge branch 'google:main' into main
glaforge 99d1e07
Merge branch 'google:main' into main
glaforge 8c93428
Merge branch 'google:main' into main
glaforge d21ee82
chore: clean PR #623
glaforge f343ee8
Merge branch 'google:main' into main
glaforge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| import com.google.genai.types.FunctionDeclaration; | ||
| import com.google.genai.types.FunctionResponse; | ||
| import com.google.genai.types.GenerateContentConfig; | ||
| import com.google.genai.types.GenerateContentResponseUsageMetadata; | ||
| import com.google.genai.types.Part; | ||
| import com.google.genai.types.Schema; | ||
| import com.google.genai.types.ToolConfig; | ||
|
|
@@ -52,6 +53,7 @@ | |
| import dev.langchain4j.data.pdf.PdfFile; | ||
| import dev.langchain4j.data.video.Video; | ||
| import dev.langchain4j.exception.UnsupportedFeatureException; | ||
| import dev.langchain4j.model.TokenCountEstimator; | ||
| import dev.langchain4j.model.chat.ChatModel; | ||
| import dev.langchain4j.model.chat.StreamingChatModel; | ||
| import dev.langchain4j.model.chat.request.ChatRequest; | ||
|
|
@@ -65,6 +67,7 @@ | |
| import dev.langchain4j.model.chat.request.json.JsonStringSchema; | ||
| import dev.langchain4j.model.chat.response.ChatResponse; | ||
| import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; | ||
| import dev.langchain4j.model.output.TokenUsage; | ||
| import io.reactivex.rxjava3.core.BackpressureStrategy; | ||
| import io.reactivex.rxjava3.core.Flowable; | ||
| import java.util.ArrayList; | ||
|
|
@@ -84,24 +87,109 @@ public class LangChain4j extends BaseLlm { | |
| private final ChatModel chatModel; | ||
| private final StreamingChatModel streamingChatModel; | ||
| private final ObjectMapper objectMapper; | ||
| private final TokenCountEstimator tokenCountEstimator; | ||
|
|
||
| public static Builder builder() { | ||
| return new Builder(); | ||
| } | ||
|
|
||
| public static class Builder { | ||
| private ChatModel chatModel; | ||
| private StreamingChatModel streamingChatModel; | ||
| private String modelName; | ||
| private TokenCountEstimator tokenCountEstimator; | ||
|
|
||
| private Builder() {} | ||
|
|
||
| public Builder chatModel(ChatModel chatModel) { | ||
| this.chatModel = chatModel; | ||
| return this; | ||
| } | ||
|
|
||
| public Builder streamingChatModel(StreamingChatModel streamingChatModel) { | ||
| this.streamingChatModel = streamingChatModel; | ||
| return this; | ||
| } | ||
|
|
||
| public Builder modelName(String modelName) { | ||
| this.modelName = modelName; | ||
| return this; | ||
| } | ||
|
|
||
| public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) { | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| return this; | ||
| } | ||
|
|
||
| public LangChain4j build() { | ||
| if (chatModel == null && streamingChatModel == null) { | ||
| throw new IllegalStateException( | ||
| "At least one of chatModel or streamingChatModel must be provided"); | ||
| } | ||
|
|
||
| String effectiveModelName = modelName; | ||
| if (effectiveModelName == null) { | ||
| if (chatModel != null) { | ||
| effectiveModelName = chatModel.defaultRequestParameters().modelName(); | ||
| } else { | ||
| effectiveModelName = streamingChatModel.defaultRequestParameters().modelName(); | ||
| } | ||
| } | ||
|
|
||
| if (effectiveModelName == null) { | ||
| throw new IllegalStateException("Model name cannot be null"); | ||
| } | ||
|
|
||
| return new LangChain4j( | ||
| chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator); | ||
| } | ||
| } | ||
|
|
||
| private LangChain4j( | ||
| ChatModel chatModel, | ||
| StreamingChatModel streamingChatModel, | ||
| String modelName, | ||
| TokenCountEstimator tokenCountEstimator) { | ||
| super(Objects.requireNonNull(modelName, "model name cannot be null")); | ||
| this.chatModel = chatModel; | ||
| this.streamingChatModel = streamingChatModel; | ||
| this.objectMapper = new ObjectMapper(); | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| } | ||
|
|
||
| public LangChain4j(ChatModel chatModel) { | ||
| this(chatModel, (TokenCountEstimator) null); | ||
| } | ||
|
|
||
| public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) { | ||
| super( | ||
| Objects.requireNonNull( | ||
| chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); | ||
| this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); | ||
| this.streamingChatModel = null; | ||
| this.objectMapper = new ObjectMapper(); | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| } | ||
|
|
||
| public LangChain4j(ChatModel chatModel, String modelName) { | ||
| this(chatModel, modelName, (TokenCountEstimator) null); | ||
| } | ||
|
|
||
| public LangChain4j( | ||
| ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) { | ||
| super(Objects.requireNonNull(modelName, "chat model name cannot be null")); | ||
| this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); | ||
| this.streamingChatModel = null; | ||
| this.objectMapper = new ObjectMapper(); | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| } | ||
|
|
||
| public LangChain4j(StreamingChatModel streamingChatModel) { | ||
| this(streamingChatModel, (TokenCountEstimator) null); | ||
| } | ||
|
|
||
| public LangChain4j( | ||
| StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) { | ||
| super( | ||
| Objects.requireNonNull( | ||
| streamingChatModel.defaultRequestParameters().modelName(), | ||
|
|
@@ -110,22 +198,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) { | |
| this.streamingChatModel = | ||
| Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); | ||
| this.objectMapper = new ObjectMapper(); | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| } | ||
|
|
||
| public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { | ||
| super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); | ||
| this.chatModel = null; | ||
| this.streamingChatModel = | ||
| Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); | ||
| this.objectMapper = new ObjectMapper(); | ||
| this(streamingChatModel, modelName, (TokenCountEstimator) null); | ||
| } | ||
|
|
||
| public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { | ||
| super(Objects.requireNonNull(modelName, "model name cannot be null")); | ||
| this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); | ||
| public LangChain4j( | ||
| StreamingChatModel streamingChatModel, | ||
| String modelName, | ||
| TokenCountEstimator tokenCountEstimator) { | ||
| super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); | ||
| this.chatModel = null; | ||
| this.streamingChatModel = | ||
| Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); | ||
| this.objectMapper = new ObjectMapper(); | ||
| this.tokenCountEstimator = tokenCountEstimator; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -186,7 +275,7 @@ public void onError(Throwable throwable) { | |
|
|
||
| ChatRequest chatRequest = toChatRequest(llmRequest); | ||
| ChatResponse chatResponse = chatModel.chat(chatRequest); | ||
| LlmResponse llmResponse = toLlmResponse(chatResponse); | ||
| LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); | ||
|
|
||
| return Flowable.just(llmResponse); | ||
| } | ||
|
|
@@ -511,11 +600,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { | |
| } | ||
| } | ||
|
|
||
| private LlmResponse toLlmResponse(ChatResponse chatResponse) { | ||
| private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { | ||
| Content content = | ||
| Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); | ||
|
|
||
| return LlmResponse.builder().content(content).build(); | ||
| LlmResponse.Builder builder = LlmResponse.builder().content(content); | ||
| TokenUsage tokenUsage = chatResponse.tokenUsage(); | ||
| if (tokenCountEstimator != null) { | ||
| try { | ||
| int estimatedInput = | ||
| tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages()); | ||
| int estimatedOutput = | ||
| tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text()); | ||
| int estimatedTotal = estimatedInput + estimatedOutput; | ||
| builder.usageMetadata( | ||
| GenerateContentResponseUsageMetadata.builder() | ||
| .promptTokenCount(estimatedInput) | ||
| .candidatesTokenCount(estimatedOutput) | ||
| .totalTokenCount(estimatedTotal) | ||
| .build()); | ||
| } catch (Exception e) { | ||
| e.printStackTrace(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably use a logger or rethrow the exception higher up |
||
| } | ||
| } else if (tokenUsage != null) { | ||
| builder.usageMetadata( | ||
| GenerateContentResponseUsageMetadata.builder() | ||
| .promptTokenCount(tokenUsage.inputTokenCount()) | ||
| .candidatesTokenCount(tokenUsage.outputTokenCount()) | ||
| .totalTokenCount(tokenUsage.totalTokenCount()) | ||
| .build()); | ||
| } | ||
|
|
||
| return builder.build(); | ||
| } | ||
|
|
||
| private List<Part> toParts(AiMessage aiMessage) { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about using AutoValue for that?