Skip to content

Commit 21382a6

Browse files
ferponseFerran Pons Serra
authored andcommitted
feat: add stop_event to cancel SSE streaming mid-generation
Adds an optional asyncio.Event-based cancellation mechanism that allows consumers to stop SSE streaming mid-generation. When stop_event.set() is called, the flow stops yielding new chunks and returns cleanly. This enables "stop generating" buttons in chat UIs: the consumer passes an asyncio.Event to runner.run_async(), and sets it when the user clicks stop.
1 parent 6f0dcb3 commit 21382a6

4 files changed

Lines changed: 247 additions & 0 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
from typing import Any
1819
from typing import cast
1920
from typing import Optional
@@ -178,6 +179,14 @@ class InvocationContext(BaseModel):
178179
179180
Set to True in callbacks or tools to terminate this invocation."""
180181

182+
stop_event: Optional[asyncio.Event] = None
183+
"""An optional event that consumers can set to stop generation mid-stream.
184+
185+
When set (``stop_event.set()``), the SSE streaming flow will stop yielding
186+
new chunks and return cleanly. This is useful for implementing a "stop
187+
generating" button in chat UIs.
188+
"""
189+
181190
live_request_queue: Optional[LiveRequestQueue] = None
182191
"""The queue to receive live requests."""
183192

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,11 @@ async def run_async(
749749
) -> AsyncGenerator[Event, None]:
750750
"""Runs the flow."""
751751
while True:
752+
if (
753+
invocation_context.stop_event
754+
and invocation_context.stop_event.is_set()
755+
):
756+
break
752757
last_event = None
753758
async with Aclosing(self._run_one_step_async(invocation_context)) as agen:
754759
async for event in agen:
@@ -829,6 +834,11 @@ async def _run_one_step_async(
829834
)
830835
) as agen:
831836
async for llm_response in agen:
837+
if (
838+
invocation_context.stop_event
839+
and invocation_context.stop_event.is_set()
840+
):
841+
return
832842
# Postprocess after calling the LLM.
833843
async with Aclosing(
834844
self._postprocess_async(

src/google/adk/runners.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ async def run_async(
509509
new_message: Optional[types.Content] = None,
510510
state_delta: Optional[dict[str, Any]] = None,
511511
run_config: Optional[RunConfig] = None,
512+
stop_event: Optional[asyncio.Event] = None,
512513
) -> AsyncGenerator[Event, None]:
513514
"""Main entry method to run the agent in this runner.
514515
@@ -526,6 +527,9 @@ async def run_async(
526527
new_message: A new message to append to the session.
527528
state_delta: Optional state changes to apply to the session.
528529
run_config: The run config for the agent.
530+
stop_event: An optional ``asyncio.Event`` that, when set, causes the
531+
streaming flow to stop yielding new chunks and return cleanly. This
532+
is useful for implementing a "stop generating" button in chat UIs.
529533
530534
Yields:
531535
The events generated by the agent.
@@ -601,6 +605,9 @@ async def _run_with_trace(
601605
# already final.
602606
return
603607

608+
if stop_event is not None:
609+
invocation_context.stop_event = stop_event
610+
604611
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
605612
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
606613
async for event in agen:
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the stop_event cancellation mechanism in SSE streaming."""
16+
17+
import asyncio
18+
from typing import AsyncGenerator
19+
from typing import override
20+
21+
from google.adk.agents.llm_agent import Agent
22+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
23+
from google.adk.models.llm_request import LlmRequest
24+
from google.adk.models.llm_response import LlmResponse
25+
from google.genai import types
26+
from pydantic import Field
27+
import pytest
28+
29+
from ... import testing_utils
30+
31+
32+
class BaseLlmFlowForTesting(BaseLlmFlow):
33+
"""Test implementation of BaseLlmFlow for testing purposes."""
34+
35+
pass
36+
37+
38+
class StreamingMockModel(testing_utils.MockModel):
39+
"""MockModel that yields multiple chunks per generate_content_async call."""
40+
41+
chunks_per_call: list[list[LlmResponse]] = Field(default_factory=list)
42+
call_index: int = -1
43+
44+
@override
45+
async def generate_content_async(
46+
self, llm_request: LlmRequest, stream: bool = False
47+
) -> AsyncGenerator[LlmResponse, None]:
48+
self.call_index += 1
49+
self.requests.append(llm_request)
50+
for chunk in self.chunks_per_call[self.call_index]:
51+
yield chunk
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_stop_event_stops_streaming_mid_chunks():
56+
"""Setting stop_event mid-stream should prevent further chunks from being yielded."""
57+
chunks = [
58+
LlmResponse(
59+
content=types.Content(
60+
role='model', parts=[types.Part.from_text(text='Hello')]
61+
),
62+
partial=True,
63+
),
64+
LlmResponse(
65+
content=types.Content(
66+
role='model', parts=[types.Part.from_text(text=' world')]
67+
),
68+
partial=True,
69+
),
70+
LlmResponse(
71+
content=types.Content(
72+
role='model', parts=[types.Part.from_text(text=' foo')]
73+
),
74+
partial=True,
75+
),
76+
LlmResponse(
77+
content=types.Content(
78+
role='model', parts=[types.Part.from_text(text=' bar')]
79+
),
80+
partial=True,
81+
),
82+
]
83+
84+
mock_model = StreamingMockModel(responses=[], chunks_per_call=[chunks])
85+
86+
agent = Agent(name='test_agent', model=mock_model)
87+
stop_event = asyncio.Event()
88+
invocation_context = await testing_utils.create_invocation_context(
89+
agent=agent, user_content='test message'
90+
)
91+
invocation_context.stop_event = stop_event
92+
93+
flow = BaseLlmFlowForTesting()
94+
events = []
95+
async for event in flow.run_async(invocation_context):
96+
events.append(event)
97+
if len(events) == 2:
98+
# Signal stop after receiving 2 chunks
99+
stop_event.set()
100+
101+
# Should have received exactly 2 chunks (stop was signalled after the 2nd)
102+
assert len(events) == 2
103+
104+
105+
@pytest.mark.asyncio
106+
async def test_stop_event_not_set_yields_all_chunks():
107+
"""When stop_event is provided but never set, all chunks should be yielded."""
108+
chunks = [
109+
LlmResponse(
110+
content=types.Content(
111+
role='model', parts=[types.Part.from_text(text='Hello')]
112+
),
113+
partial=True,
114+
),
115+
LlmResponse(
116+
content=types.Content(
117+
role='model', parts=[types.Part.from_text(text=' world')]
118+
),
119+
partial=True,
120+
),
121+
]
122+
123+
mock_model = StreamingMockModel(responses=[], chunks_per_call=[chunks])
124+
125+
agent = Agent(name='test_agent', model=mock_model)
126+
stop_event = asyncio.Event()
127+
invocation_context = await testing_utils.create_invocation_context(
128+
agent=agent, user_content='test message'
129+
)
130+
invocation_context.stop_event = stop_event
131+
132+
flow = BaseLlmFlowForTesting()
133+
events = []
134+
async for event in flow.run_async(invocation_context):
135+
events.append(event)
136+
137+
assert len(events) == 2
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_stop_event_prevents_next_llm_call():
142+
"""Setting stop_event between LLM calls should prevent the next call."""
143+
# First LLM call: returns a function call
144+
fc_response = LlmResponse(
145+
content=types.Content(
146+
role='model',
147+
parts=[
148+
types.Part.from_function_call(
149+
name='my_tool', args={'x': '1'}
150+
)
151+
],
152+
),
153+
partial=False,
154+
)
155+
# Second LLM call: should NOT be reached
156+
text_response = LlmResponse(
157+
content=types.Content(
158+
role='model', parts=[types.Part.from_text(text='Result')]
159+
),
160+
partial=False,
161+
error_code=types.FinishReason.STOP,
162+
)
163+
164+
mock_model = testing_utils.MockModel.create(
165+
responses=[fc_response, text_response]
166+
)
167+
168+
def my_tool(x: str) -> str:
169+
return f'result_{x}'
170+
171+
agent = Agent(name='test_agent', model=mock_model, tools=[my_tool])
172+
stop_event = asyncio.Event()
173+
invocation_context = await testing_utils.create_invocation_context(
174+
agent=agent, user_content='test message'
175+
)
176+
invocation_context.stop_event = stop_event
177+
178+
flow = BaseLlmFlowForTesting()
179+
events = []
180+
async for event in flow.run_async(invocation_context):
181+
events.append(event)
182+
# Stop after first LLM call yields its events
183+
if event.get_function_calls():
184+
stop_event.set()
185+
186+
# Should have events from the first LLM call only
187+
# The second LLM call (text_response) should NOT have happened
188+
texts = [
189+
e.content.parts[0].text
190+
for e in events
191+
if e.content and e.content.parts and e.content.parts[0].text
192+
]
193+
assert 'Result' not in texts
194+
195+
196+
@pytest.mark.asyncio
197+
async def test_no_stop_event_works_normally():
198+
"""When no stop_event is provided, everything works as before."""
199+
response = LlmResponse(
200+
content=types.Content(
201+
role='model', parts=[types.Part.from_text(text='Done')]
202+
),
203+
partial=False,
204+
error_code=types.FinishReason.STOP,
205+
)
206+
207+
mock_model = testing_utils.MockModel.create(responses=[response])
208+
209+
agent = Agent(name='test_agent', model=mock_model)
210+
invocation_context = await testing_utils.create_invocation_context(
211+
agent=agent, user_content='test message'
212+
)
213+
# No stop_event set (default None)
214+
215+
flow = BaseLlmFlowForTesting()
216+
events = []
217+
async for event in flow.run_async(invocation_context):
218+
events.append(event)
219+
220+
assert len(events) == 1
221+
assert events[0].content.parts[0].text == 'Done'

0 commit comments

Comments
 (0)