Skip to content

Commit 5ea48cd

Browse files
committed
ISSUE-777: Ensure token usage metadata included with streaming responses
1 parent aabf15a commit 5ea48cd

2 files changed

Lines changed: 211 additions & 5 deletions

File tree

core/src/main/java/com/google/adk/models/Gemini.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
239239
p ->
240240
p.functionCall().isPresent()
241241
|| p.functionResponse().isPresent()
242-
|| p.text().map(t -> !t.isBlank()).orElse(false)))
242+
|| p.text().isPresent()))
243243
.orElse(false));
244244
} else {
245245
logger.debug("Sending generateContent request to model {}", effectiveModelName);
@@ -272,11 +272,17 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons
272272
if (part.get().thought().orElse(false)) {
273273
accumulatedThoughtText.append(currentTextChunk);
274274
responsesToEmit.add(
275-
thinkingResponseFromText(currentTextChunk).toBuilder().partial(true).build());
275+
thinkingResponseFromText(currentTextChunk).toBuilder()
276+
.usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null))
277+
.partial(true)
278+
.build());
276279
} else {
277280
accumulatedText.append(currentTextChunk);
278281
responsesToEmit.add(
279-
responseFromText(currentTextChunk).toBuilder().partial(true).build());
282+
responseFromText(currentTextChunk).toBuilder()
283+
.usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null))
284+
.partial(true)
285+
.build());
280286
}
281287
} else {
282288
if (accumulatedThoughtText.length() > 0
@@ -316,11 +322,20 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons
316322
List<LlmResponse> finalResponses = new ArrayList<>();
317323
if (accumulatedThoughtText.length() > 0) {
318324
finalResponses.add(
319-
thinkingResponseFromText(accumulatedThoughtText.toString()));
325+
thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder()
326+
.usageMetadata(
327+
accumulatedText.length() > 0
328+
? null
329+
: finalRawResp.usageMetadata().orElse(null))
330+
.build());
320331
}
321332
if (accumulatedText.length() > 0) {
322-
finalResponses.add(responseFromText(accumulatedText.toString()));
333+
finalResponses.add(
334+
responseFromText(accumulatedText.toString()).toBuilder()
335+
.usageMetadata(finalRawResp.usageMetadata().orElse(null))
336+
.build());
323337
}
338+
324339
return Flowable.fromIterable(finalResponses);
325340
}
326341
return Flowable.empty();

core/src/test/java/com/google/adk/models/GeminiTest.java

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.genai.types.Content;
2323
import com.google.genai.types.FinishReason;
2424
import com.google.genai.types.GenerateContentResponse;
25+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
2526
import com.google.genai.types.Part;
2627
import io.reactivex.rxjava3.core.Flowable;
2728
import io.reactivex.rxjava3.functions.Predicate;
@@ -123,6 +124,76 @@ public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmp
123124
isEmptyResponse());
124125
}
125126

127+
@Test
128+
public void processRawResponses_withTextChunks_partialResponsesIncludeUsageMetadata() {
129+
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15);
130+
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25);
131+
Flowable<GenerateContentResponse> rawResponses =
132+
Flowable.just(
133+
toResponseWithText("Hello", metadata1), toResponseWithText(" world", metadata2));
134+
135+
Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);
136+
137+
assertLlmResponses(
138+
llmResponses,
139+
isPartialTextResponseWithUsageMetadata("Hello", metadata1),
140+
isPartialTextResponseWithUsageMetadata(" world", metadata2));
141+
}
142+
143+
@Test
144+
public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMetadata() {
145+
GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30);
146+
Flowable<GenerateContentResponse> rawResponses =
147+
Flowable.just(
148+
toResponseWithText("Hello"),
149+
toResponseWithText(" world", FinishReason.Known.STOP, metadata));
150+
151+
Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);
152+
153+
assertLlmResponses(
154+
llmResponses,
155+
isPartialTextResponse("Hello"),
156+
isPartialTextResponseWithUsageMetadata(" world", metadata),
157+
isFinalTextResponseWithUsageMetadata("Hello world", metadata));
158+
}
159+
160+
@Test
161+
public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() {
162+
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15);
163+
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25);
164+
Flowable<GenerateContentResponse> rawResponses =
165+
Flowable.just(
166+
toResponseWithThoughtText("Thinking", metadata1),
167+
toResponseWithThoughtText(" deeply", FinishReason.Known.STOP, metadata2));
168+
169+
Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);
170+
171+
assertLlmResponses(
172+
llmResponses,
173+
isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1),
174+
isPartialThoughtResponseWithUsageMetadata(" deeply", metadata2),
175+
isFinalThoughtResponseWithUsageMetadata("Thinking deeply", metadata2));
176+
}
177+
178+
@Test
179+
public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsageMetadata() {
180+
GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10);
181+
GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30);
182+
Flowable<GenerateContentResponse> rawResponses =
183+
Flowable.just(
184+
toResponseWithThoughtText("Thinking", metadata1),
185+
toResponseWithText("Answer", FinishReason.Known.STOP, metadata2));
186+
187+
Flowable<LlmResponse> llmResponses = Gemini.processRawResponses(rawResponses);
188+
189+
assertLlmResponses(
190+
llmResponses,
191+
isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1),
192+
isPartialTextResponseWithUsageMetadata("Answer", metadata2),
193+
isFinalThoughtResponseWithNoUsageMetadata("Thinking"),
194+
isFinalTextResponseWithUsageMetadata("Answer", metadata2));
195+
}
196+
126197
// Helper methods for assertions
127198

128199
private void assertLlmResponses(
@@ -170,6 +241,67 @@ private static Predicate<LlmResponse> isEmptyResponse() {
170241
};
171242
}
172243

244+
private static Predicate<LlmResponse> isPartialTextResponseWithUsageMetadata(
245+
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
246+
return response -> {
247+
assertThat(response.partial()).hasValue(true);
248+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
249+
.isEqualTo(expectedText);
250+
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
251+
return true;
252+
};
253+
}
254+
255+
private static Predicate<LlmResponse> isPartialThoughtResponseWithUsageMetadata(
256+
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
257+
return response -> {
258+
assertThat(response.partial()).hasValue(true);
259+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
260+
.isEqualTo(expectedText);
261+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
262+
.isTrue();
263+
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
264+
return true;
265+
};
266+
}
267+
268+
private static Predicate<LlmResponse> isFinalTextResponseWithUsageMetadata(
269+
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
270+
return response -> {
271+
assertThat(response.partial()).isEmpty();
272+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
273+
.isEqualTo(expectedText);
274+
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
275+
return true;
276+
};
277+
}
278+
279+
private static Predicate<LlmResponse> isFinalThoughtResponseWithUsageMetadata(
280+
String expectedText, GenerateContentResponseUsageMetadata expectedMetadata) {
281+
return response -> {
282+
assertThat(response.partial()).isEmpty();
283+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
284+
.isEqualTo(expectedText);
285+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
286+
.isTrue();
287+
assertThat(response.usageMetadata()).hasValue(expectedMetadata);
288+
return true;
289+
};
290+
}
291+
292+
private static Predicate<LlmResponse> isFinalThoughtResponseWithNoUsageMetadata(
293+
String expectedText) {
294+
return response -> {
295+
assertThat(response.partial()).isEmpty();
296+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
297+
.isEqualTo(expectedText);
298+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false))
299+
.isTrue();
300+
assertThat(response.usageMetadata()).isEmpty();
301+
return true;
302+
};
303+
}
304+
173305
// Helper methods to create responses for testing
174306

175307
private GenerateContentResponse toResponseWithText(String text) {
@@ -191,4 +323,63 @@ private GenerateContentResponse toResponse(Part part) {
191323
private GenerateContentResponse toResponse(Candidate candidate) {
192324
return GenerateContentResponse.builder().candidates(candidate).build();
193325
}
326+
327+
private GenerateContentResponse toResponseWithText(
328+
String text, GenerateContentResponseUsageMetadata usageMetadata) {
329+
return GenerateContentResponse.builder()
330+
.candidates(
331+
Candidate.builder()
332+
.content(Content.builder().parts(Part.fromText(text)).build())
333+
.build())
334+
.usageMetadata(usageMetadata)
335+
.build();
336+
}
337+
338+
private GenerateContentResponse toResponseWithText(
339+
String text,
340+
FinishReason.Known finishReason,
341+
GenerateContentResponseUsageMetadata usageMetadata) {
342+
return GenerateContentResponse.builder()
343+
.candidates(
344+
Candidate.builder()
345+
.content(Content.builder().parts(Part.fromText(text)).build())
346+
.finishReason(new FinishReason(finishReason))
347+
.build())
348+
.usageMetadata(usageMetadata)
349+
.build();
350+
}
351+
352+
private GenerateContentResponse toResponseWithThoughtText(
353+
String text, GenerateContentResponseUsageMetadata usageMetadata) {
354+
Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build();
355+
return GenerateContentResponse.builder()
356+
.candidates(
357+
Candidate.builder().content(Content.builder().parts(thoughtPart).build()).build())
358+
.usageMetadata(usageMetadata)
359+
.build();
360+
}
361+
362+
private GenerateContentResponse toResponseWithThoughtText(
363+
String text,
364+
FinishReason.Known finishReason,
365+
GenerateContentResponseUsageMetadata usageMetadata) {
366+
Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build();
367+
return GenerateContentResponse.builder()
368+
.candidates(
369+
Candidate.builder()
370+
.content(Content.builder().parts(thoughtPart).build())
371+
.finishReason(new FinishReason(finishReason))
372+
.build())
373+
.usageMetadata(usageMetadata)
374+
.build();
375+
}
376+
377+
private static GenerateContentResponseUsageMetadata createUsageMetadata(
378+
int promptTokens, int candidateTokens, int totalTokens) {
379+
return GenerateContentResponseUsageMetadata.builder()
380+
.promptTokenCount(promptTokens)
381+
.candidatesTokenCount(candidateTokens)
382+
.totalTokenCount(totalTokens)
383+
.build();
384+
}
194385
}

0 commit comments

Comments
 (0)