Skip to content

Commit 36910b2

Browse files
committed
Add cancellable tasks
1 parent 9a909d2 commit 36910b2

File tree

3 files changed

+167
-10
lines changed

3 files changed

+167
-10
lines changed

durabletask/task.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def set_custom_status(self, custom_status: Any) -> None:
9898
pass
9999

100100
@abstractmethod
101-
def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
101+
def create_timer(self, fire_at: Union[datetime, timedelta]) -> CancellableTask:
102102
"""Create a Timer Task to fire after at the specified deadline.
103103
104104
Parameters
@@ -231,7 +231,7 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput
231231
# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
232232
# not received within the specified timeout. This requires support for task cancellation.
233233
@abstractmethod
234-
def wait_for_external_event(self, name: str) -> CompletableTask:
234+
def wait_for_external_event(self, name: str) -> CancellableTask:
235235
"""Wait asynchronously for an event to be raised with the name `name`.
236236
237237
Parameters
@@ -324,6 +324,10 @@ class OrchestrationStateError(Exception):
324324
pass
325325

326326

327+
class TaskCanceledError(Exception):
328+
"""Exception type for canceled orchestration tasks."""
329+
330+
327331
class Task(ABC, Generic[T]):
328332
"""Abstract base class for asynchronous tasks in a durable orchestration."""
329333
_result: T
@@ -435,6 +439,48 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]):
435439
self._parent.on_child_completed(self)
436440

437441

442+
class CancellableTask(CompletableTask[T]):
443+
"""A completable task that can be canceled before it finishes."""
444+
445+
def __init__(self) -> None:
446+
super().__init__()
447+
self._is_cancelled = False
448+
self._cancel_handler: Optional[Callable[[], None]] = None
449+
450+
@property
451+
def is_cancelled(self) -> bool:
452+
"""Returns True if the task was canceled, False otherwise."""
453+
return self._is_cancelled
454+
455+
def get_result(self) -> T:
456+
if self._is_cancelled:
457+
raise TaskCanceledError('The task was canceled.')
458+
return super().get_result()
459+
460+
def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None:
461+
self._cancel_handler = cancel_handler
462+
463+
def cancel(self) -> bool:
464+
"""Attempts to cancel this task.
465+
466+
Returns
467+
-------
468+
bool
469+
True if cancellation was applied, False if the task had already completed.
470+
"""
471+
if self._is_complete:
472+
return False
473+
474+
if self._cancel_handler is not None:
475+
self._cancel_handler()
476+
477+
self._is_cancelled = True
478+
self._is_complete = True
479+
if self._parent is not None:
480+
self._parent.on_child_completed(self)
481+
return True
482+
483+
438484
class RetryableTask(CompletableTask[T]):
439485
"""A task that can be retried according to a retry policy."""
440486

@@ -474,7 +520,7 @@ def compute_next_delay(self) -> Optional[timedelta]:
474520
return None
475521

476522

477-
class TimerTask(CompletableTask[T]):
523+
class TimerTask(CancellableTask[T]):
478524

479525
def __init__(self) -> None:
480526
super().__init__()

durabletask/worker.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class TaskHubGrpcWorker:
307307
activity function.
308308
"""
309309

310-
_response_stream: Optional[grpc.Future] = None
310+
_response_stream: Optional[Any] = None
311311
_interceptors: Optional[list[shared.ClientInterceptor]] = None
312312

313313
def __init__(
@@ -512,7 +512,11 @@ def should_invalidate_connection(rpc_error):
512512

513513
def stream_reader():
514514
try:
515-
for work_item in self._response_stream:
515+
response_stream = self._response_stream
516+
if response_stream is None:
517+
return
518+
519+
for work_item in response_stream:
516520
work_item_queue.put(work_item)
517521
except Exception as e:
518522
work_item_queue.put(e)
@@ -843,7 +847,7 @@ def __init__(self, instance_id: str, registry: _Registry):
843847
self._version: Optional[str] = None
844848
self._completion_status: Optional[pb.OrchestrationStatus] = None
845849
self._received_events: dict[str, list[Any]] = {}
846-
self._pending_events: dict[str, list[task.CompletableTask]] = {}
850+
self._pending_events: dict[str, list[task.CancellableTask]] = {}
847851
self._new_input: Optional[Any] = None
848852
self._save_events = False
849853
self._encoded_custom_status: Optional[str] = None
@@ -1026,7 +1030,13 @@ def create_timer_internal(
10261030
action = ph.new_create_timer_action(id, fire_at)
10271031
self._pending_actions[id] = action
10281032

1029-
timer_task: task.TimerTask = task.TimerTask()
1033+
timer_task = task.TimerTask()
1034+
1035+
def _cancel_timer() -> None:
1036+
self._pending_actions.pop(id, None)
1037+
self._pending_tasks.pop(id, None)
1038+
1039+
timer_task.set_cancel_handler(_cancel_timer)
10301040
if retryable_task is not None:
10311041
timer_task.set_retryable_parent(retryable_task)
10321042
self._pending_tasks[id] = timer_task
@@ -1234,13 +1244,13 @@ def _exit_critical_section(self) -> None:
12341244
action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
12351245
self._pending_actions[task_id] = action
12361246

1237-
def wait_for_external_event(self, name: str) -> task.CompletableTask:
1247+
def wait_for_external_event(self, name: str) -> task.CancellableTask:
12381248
# Check to see if this event has already been received, in which case we
12391249
# can return it immediately. Otherwise, record out intent to receive an
12401250
# event with the given name so that we can resume the generator when it
12411251
# arrives. If there are multiple events with the same name, we return
12421252
# them in the order they were received.
1243-
external_event_task: task.CompletableTask = task.CompletableTask()
1253+
external_event_task: task.CancellableTask = task.CancellableTask()
12441254
event_name = name.casefold()
12451255
event_list = self._received_events.get(event_name, None)
12461256
if event_list:
@@ -1254,6 +1264,19 @@ def wait_for_external_event(self, name: str) -> task.CompletableTask:
12541264
task_list = []
12551265
self._pending_events[event_name] = task_list
12561266
task_list.append(external_event_task)
1267+
1268+
def _cancel_wait() -> None:
1269+
waiting_tasks = self._pending_events.get(event_name)
1270+
if waiting_tasks is None:
1271+
return
1272+
try:
1273+
waiting_tasks.remove(external_event_task)
1274+
except ValueError:
1275+
return
1276+
if not waiting_tasks:
1277+
del self._pending_events[event_name]
1278+
1279+
external_event_task.set_cancel_handler(_cancel_wait)
12571280
return external_event_task
12581281

12591282
def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
@@ -1450,6 +1473,13 @@ def process_event(
14501473
f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
14511474
)
14521475
return
1476+
if not isinstance(timer_task, task.TimerTask):
1477+
if not ctx._is_replaying:
1478+
self._logger.warning(
1479+
f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}."
1480+
)
1481+
return
1482+
14531483
timer_task.complete(None)
14541484
if timer_task._retryable_parent is not None:
14551485
activity_action = timer_task._retryable_parent._action

tests/durabletask/test_orchestration_executor.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,87 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _):
143143
assert complete_action.result.value == '"done"' # results are JSON-encoded
144144

145145

146+
def test_timer_can_be_cancelled_after_when_any_winner():
147+
"""Tests cancellation of an outstanding timer task after another task wins when_any."""
148+
149+
def orchestrator(ctx: task.OrchestrationContext, _):
150+
approval = ctx.wait_for_external_event("approval")
151+
timeout = ctx.create_timer(timedelta(hours=1))
152+
winner = yield task.when_any([approval, timeout])
153+
if winner == approval:
154+
timeout.cancel()
155+
return "approved"
156+
return "timed out"
157+
158+
registry = worker._Registry()
159+
name = registry.add_orchestrator(orchestrator)
160+
executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
161+
162+
start_time = datetime(2020, 1, 1, 12, 0, 0)
163+
timeout_fire_at = start_time + timedelta(hours=1)
164+
165+
result = executor.execute(
166+
TEST_INSTANCE_ID,
167+
[],
168+
[
169+
helpers.new_orchestrator_started_event(start_time),
170+
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None),
171+
],
172+
)
173+
assert len(result.actions) == 1
174+
assert result.actions[0].HasField("createTimer")
175+
assert result.actions[0].createTimer.fireAt.ToDatetime() == timeout_fire_at
176+
177+
old_events = [
178+
helpers.new_orchestrator_started_event(start_time),
179+
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None),
180+
helpers.new_timer_created_event(1, timeout_fire_at),
181+
]
182+
result = executor.execute(
183+
TEST_INSTANCE_ID,
184+
old_events,
185+
[helpers.new_event_raised_event("approval", json.dumps(True))],
186+
)
187+
complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions)
188+
assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED
189+
assert complete_action.result.value == '"approved"'
190+
191+
192+
def test_only_cancellable_tasks_expose_cancel():
193+
"""Tests that only timer and external-event tasks expose cancellation state and operations."""
194+
195+
def dummy_activity(ctx, _):
196+
pass
197+
198+
ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry())
199+
200+
timer_task = ctx.create_timer(timedelta(minutes=5))
201+
external_event_task = ctx.wait_for_external_event("approval")
202+
activity_task = ctx.call_activity(dummy_activity)
203+
204+
assert isinstance(timer_task, task.CancellableTask)
205+
assert isinstance(external_event_task, task.CancellableTask)
206+
assert not isinstance(activity_task, task.CancellableTask)
207+
assert hasattr(timer_task, "cancel")
208+
assert hasattr(external_event_task, "cancel")
209+
assert not hasattr(activity_task, "cancel")
210+
assert hasattr(timer_task, "is_cancelled")
211+
assert hasattr(external_event_task, "is_cancelled")
212+
assert not hasattr(activity_task, "is_cancelled")
213+
214+
215+
def test_cancelled_task_get_result_raises_task_canceled_error():
216+
"""Tests that canceled cancellable tasks raise TaskCanceledError from get_result."""
217+
218+
cancellable_task = task.CancellableTask()
219+
220+
assert cancellable_task.cancel() is True
221+
assert cancellable_task.is_cancelled is True
222+
223+
with pytest.raises(task.TaskCanceledError):
224+
cancellable_task.get_result()
225+
226+
146227
def test_schedule_activity_actions():
147228
"""Test the actions output for the call_activity orchestrator method"""
148229
def dummy_activity(ctx, _):
@@ -1313,7 +1394,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
13131394
encoded_output = json.dumps(dummy_activity(None, "Seattle"))
13141395
old_events = old_events + new_events
13151396
new_events = [helpers.new_task_completed_event(2, encoded_output),
1316-
helpers.new_timer_fired_event(4, current_timestamp)]
1397+
helpers.new_timer_fired_event(4, expected_fire_at)]
13171398
executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
13181399
result = executor.execute(TEST_INSTANCE_ID, old_events, new_events)
13191400
actions = result.actions

0 commit comments

Comments
 (0)