Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions durabletask/internal/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def get_default_host_address() -> str:
return "localhost:4001"


DEFAULT_GRPC_KEEPALIVE_OPTIONS: tuple[tuple[str, int], ...] = (
("grpc.keepalive_time_ms", 30_000),
("grpc.keepalive_timeout_ms", 10_000),
("grpc.http2.max_pings_without_data", 0),
("grpc.keepalive_permit_without_calls", 1),
)


def get_grpc_channel(
host_address: Optional[str],
secure_channel: bool = False,
Expand Down Expand Up @@ -81,10 +89,16 @@ def get_grpc_channel(
host_address = host_address[len(protocol) :]
break

merged = dict(DEFAULT_GRPC_KEEPALIVE_OPTIONS)
if options:
merged.update(dict(options))
merged_options = list(merged.items())
if secure_channel:
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
channel = grpc.secure_channel(
host_address, grpc.ssl_channel_credentials(), options=merged_options
)
else:
channel = grpc.insecure_channel(host_address, options=options)
channel = grpc.insecure_channel(host_address, options=merged_options)

# Apply interceptors ONLY if they exist
if interceptors:
Expand Down
158 changes: 83 additions & 75 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def __init__(
self._channel_options = channel_options
self._stop_timeout = stop_timeout
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads
self._stream_ready = threading.Event()
# Use provided concurrency options or create default ones
self._concurrency_options = (
Expand Down Expand Up @@ -384,15 +385,16 @@ async def _async_run_loop(self):
current_stub = None
current_reader_thread = None
conn_retry_count = 0
conn_max_retry_delay = 60
conn_max_retry_delay = 15

def create_fresh_connection():
nonlocal current_channel, current_stub, conn_retry_count
if current_channel:
try:
current_channel.close()
except Exception:
pass
# Schedule deferred close of old channel to avoid orphaned TCP
# connections. In-flight RPCs on the old stub may still reference
# the channel from another thread, so we wait a grace period
# before closing instead of closing immediately.
if current_channel is not None:
self._schedule_deferred_channel_close(current_channel)
current_channel = None
current_stub = None
try:
Expand All @@ -417,31 +419,20 @@ def create_fresh_connection():

def invalidate_connection():
nonlocal current_channel, current_stub, current_reader_thread
# Cancel the response stream first to signal the reader thread to stop
if self._response_stream is not None:
try:
if hasattr(self._response_stream, "call"):
self._response_stream.call.cancel() # type: ignore
else:
self._response_stream.cancel() # type: ignore
except Exception as e:
self._logger.warning(f"Error cancelling response stream: {e}")
self._response_stream = None

# Wait for the reader thread to finish
if current_reader_thread is not None:
current_reader_thread.join(timeout=1)
current_reader_thread = None

# Close the channel
if current_channel:
try:
current_channel.close()
except Exception:
pass
# Schedule deferred close of old channel to avoid orphaned TCP
# connections. In-flight RPCs (e.g. CompleteActivityTask) may still
# be using the stub on another thread, so we defer the close by a
# grace period instead of closing immediately.
if current_channel is not None:
self._schedule_deferred_channel_close(current_channel)
current_channel = None
self._current_channel = None
current_stub = None
self._response_stream = None

if current_reader_thread is not None:
current_reader_thread.join(timeout=5)
current_reader_thread = None

def should_invalidate_connection(rpc_error):
error_code = rpc_error.code() # type: ignore
Expand All @@ -451,6 +442,7 @@ def should_invalidate_connection(rpc_error):
grpc.StatusCode.CANCELLED,
grpc.StatusCode.UNAUTHENTICATED,
grpc.StatusCode.ABORTED,
grpc.StatusCode.INTERNAL, # RST_STREAM from proxy means connection is dead
}
return error_code in connection_level_errors

Expand Down Expand Up @@ -532,7 +524,11 @@ def stream_reader():
break
# Other RPC errors - put in queue for async loop to handle
self._logger.warning(
f"Stream reader: RPC error (code={rpc_error.code()}): {rpc_error}"
"Stream reader: RPC error (code=%s): %s",
rpc_error.code(),
rpc_error.details()
if hasattr(rpc_error, "details")
else rpc_error,
)
break
except Exception as stream_error:
Expand Down Expand Up @@ -654,32 +650,19 @@ def stream_reader():
if should_invalidate:
invalidate_connection()
error_code = rpc_error.code() # type: ignore
error_details = str(rpc_error)
error_detail = (
rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error)
)

if error_code == grpc.StatusCode.CANCELLED:
self._logger.info(f"Disconnected from {self._host_address}")
break
elif error_code == grpc.StatusCode.UNAVAILABLE:
# Check if this is a connection timeout scenario
if (
"Timeout occurred" in error_details
or "Failed to connect to remote host" in error_details
):
self._logger.warning(
f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
)
else:
self._logger.warning(
f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
)
elif should_invalidate:
self._logger.warning(
f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
f"Connection error ({error_code}): {error_detail} — resetting connection"
)
else:
self._logger.warning(
f"Application-level gRPC error ({error_code}): {rpc_error}"
)
self._logger.warning(f"gRPC error ({error_code}): {error_detail}")
except RuntimeError as ex:
# RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown")
# Check shutdown state first
Expand Down Expand Up @@ -738,22 +721,46 @@ def stream_reader():
except Exception as e:
self._logger.warning(f"Error while waiting for worker task shutdown: {e}")

def _schedule_deferred_channel_close(
self, old_channel: grpc.Channel, grace_timeout: float = 10.0
):
"""Schedule a deferred close of an old gRPC channel.

Waits up to *grace_timeout* seconds for in-flight RPCs to complete
before closing the channel. This prevents orphaned TCP connections
while still allowing in-flight work (e.g. ``CompleteActivityTask``
calls on another thread) to finish gracefully.

During ``stop()``, ``_shutdown`` is already set so the wait returns
immediately and the channel is closed at once.
"""
# Prune already-finished cleanup threads to avoid unbounded growth
self._channel_cleanup_threads = [t for t in self._channel_cleanup_threads if t.is_alive()]

def _deferred_close():
try:
# Normal reconnect: wait grace period for RPCs to drain.
# Shutdown: _shutdown is already set, returns immediately.
self._shutdown.wait(timeout=grace_timeout)
finally:
try:
old_channel.close()
self._logger.debug("Deferred channel close completed")
except Exception as e:
self._logger.debug(f"Error during deferred channel close: {e}")

thread = threading.Thread(target=_deferred_close, daemon=True, name="ChannelCleanup")
thread.start()
self._channel_cleanup_threads.append(thread)

def stop(self):
"""Stops the worker and waits for any pending work items to complete."""
if not self._is_running:
return

self._logger.info("Stopping gRPC worker...")
if self._response_stream is not None:
try:
if hasattr(self._response_stream, "call"):
self._response_stream.call.cancel() # type: ignore
else:
self._response_stream.cancel() # type: ignore
except Exception as e:
self._logger.warning(f"Error cancelling response stream: {e}")
self._shutdown.set()
# Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up
# Close the channel — propagates cancellation to all streams and cleans up resources
if self._current_channel is not None:
try:
self._current_channel.close()
Expand All @@ -772,38 +779,39 @@ def stop(self):
else:
self._logger.debug("Worker thread completed successfully")

# Wait for any deferred channel-cleanup threads to finish
for t in self._channel_cleanup_threads:
t.join(timeout=5)
self._channel_cleanup_threads.clear()

self._async_worker_manager.shutdown()
self._logger.info("Worker shutdown completed")
self._is_running = False

# TODO: This should be removed in the future as we do handle grpc errs
def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str):
"""Handle a gRPC execution error during shutdown or benign condition."""
# During shutdown or if the instance was terminated, the channel may be close
# or the instance may no longer be recognized by the sidecar. Treat these as benign
# to reduce noisy logging when shutting down.
"""Handle a gRPC execution error during shutdown or connection reset."""
details = str(rpc_error).lower()
benign_errors = {
# These errors are transient — the sidecar will re-dispatch the work item.
transient_errors = {
grpc.StatusCode.CANCELLED,
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
grpc.StatusCode.INTERNAL,
}
if (
self._shutdown.is_set()
and rpc_error.code() in benign_errors
or (
"unknown instance id/task id combo" in details
or "channel closed" in details
or "locally cancelled by application" in details
)
):
self._logger.debug(
f"Ignoring gRPC {request_type} execution error during shutdown/benign condition: {rpc_error}"
is_transient = rpc_error.code() in transient_errors
is_benign = (
"unknown instance id/task id combo" in details
or "channel closed" in details
or "locally cancelled by application" in details
)
if is_transient or is_benign or self._shutdown.is_set():
self._logger.warning(
f"Could not deliver {request_type} result ({rpc_error.code()}): "
f"{rpc_error.details() if hasattr(rpc_error, 'details') else rpc_error} — sidecar will re-dispatch"
)
else:
self._logger.exception(
f"Failed to execute gRPC {request_type} execution error: {rpc_error}"
)
self._logger.exception(f"Failed to deliver {request_type} result: {rpc_error}")

def _execute_orchestrator(
self,
Expand Down
Loading
Loading