Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class TaskHubGrpcWorker:
controlling worker concurrency limits. If None, default settings are used.
stop_timeout (float, optional): Maximum time in seconds to wait for the worker thread
to stop when calling stop(). Defaults to 30.0. Useful to set lower values in tests.
keepalive_interval (float, optional): Interval in seconds between application-level
keepalive Hello RPCs sent to prevent L7 load balancers (e.g. AWS ALB) from closing
idle HTTP/2 connections. Set to 0 or negative to disable. Defaults to 30.0.

Attributes:
concurrency_options (ConcurrencyOptions): The current concurrency configuration.
Expand Down Expand Up @@ -297,6 +300,7 @@ def __init__(
concurrency_options: Optional[ConcurrencyOptions] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
stop_timeout: float = 30.0,
keepalive_interval: float = 30.0,
):
self._registry = _Registry()
self._host_address = host_address if host_address else shared.get_default_host_address()
Expand All @@ -306,6 +310,7 @@ def __init__(
self._secure_channel = secure_channel
self._channel_options = channel_options
self._stop_timeout = stop_timeout
self._keepalive_interval = keepalive_interval
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
self._stream_ready = threading.Event()
# Use provided concurrency options or create default ones
Expand Down Expand Up @@ -368,6 +373,26 @@ def run_loop():
raise RuntimeError("Failed to establish work item stream connection within 10 seconds")
self._is_running = True

async def _keepalive_loop(self, stub):
"""Background keepalive loop to keep the gRPC connection alive through L7 load balancers."""
loop = asyncio.get_running_loop()
while not self._shutdown.is_set():
await asyncio.sleep(self._keepalive_interval)
if self._shutdown.is_set():
return
try:
await loop.run_in_executor(None, lambda: stub.Hello(empty_pb2.Empty(), timeout=10))
except Exception as e:
self._logger.debug(f"keepalive failed: {e}")

@staticmethod
async def _cancel_keepalive(keepalive_task):
"""Cancel and await the keepalive task if it exists."""
if keepalive_task is not None:
keepalive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await keepalive_task

# TODO: refactor this to be more readable and maintainable.
async def _async_run_loop(self):
"""
Expand Down Expand Up @@ -472,6 +497,7 @@ def should_invalidate_connection(rpc_error):
if self._shutdown.wait(delay):
break
continue
keepalive_task = None
try:
assert current_stub is not None
stub = current_stub
Expand Down Expand Up @@ -584,6 +610,8 @@ def stream_reader():
raise

loop = asyncio.get_running_loop()
if self._keepalive_interval > 0:
keepalive_task = asyncio.ensure_future(self._keepalive_loop(stub))

# NOTE: This is a blocking call that will wait for a work item to become available or the shutdown sentinel
while not self._shutdown.is_set():
Expand Down Expand Up @@ -633,6 +661,7 @@ def stream_reader():
invalidate_connection()
raise e
current_reader_thread.join(timeout=1)
await self._cancel_keepalive(keepalive_task)

if self._shutdown.is_set():
self._logger.info(f"Disconnected from {self._host_address}")
Expand All @@ -646,6 +675,7 @@ def stream_reader():
# Fall through to the top of the outer loop, which will
# create a fresh connection (with retry/backoff if needed)
except grpc.RpcError as rpc_error:
await self._cancel_keepalive(keepalive_task)
# Check shutdown first - if shutting down, exit immediately
if self._shutdown.is_set():
self._logger.debug("Shutdown detected during RPC error handling, exiting")
Expand Down Expand Up @@ -681,6 +711,7 @@ def stream_reader():
f"Application-level gRPC error ({error_code}): {rpc_error}"
)
except RuntimeError as ex:
await self._cancel_keepalive(keepalive_task)
# RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown")
# Check shutdown state first
if self._shutdown.is_set():
Expand All @@ -704,6 +735,7 @@ def stream_reader():
# it's likely shutdown-related. Break to prevent infinite retries.
break
except Exception as ex:
await self._cancel_keepalive(keepalive_task)
if self._shutdown.is_set():
self._logger.debug(
f"Shutdown detected during exception handling, exiting: {ex}"
Expand Down
Loading