diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 09645ed..f676332 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -50,6 +50,14 @@ def get_default_host_address() -> str: return "localhost:4001" +DEFAULT_GRPC_KEEPALIVE_OPTIONS: tuple[tuple[str, int], ...] = ( + ("grpc.keepalive_time_ms", 30_000), + ("grpc.keepalive_timeout_ms", 10_000), + ("grpc.http2.max_pings_without_data", 0), + ("grpc.keepalive_permit_without_calls", 1), +) + + def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, @@ -81,10 +89,16 @@ def get_grpc_channel( host_address = host_address[len(protocol) :] break + merged = dict(DEFAULT_GRPC_KEEPALIVE_OPTIONS) + if options: + merged.update(dict(options)) + merged_options = list(merged.items()) if secure_channel: - channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options) + channel = grpc.secure_channel( + host_address, grpc.ssl_channel_credentials(), options=merged_options + ) else: - channel = grpc.insecure_channel(host_address, options=options) + channel = grpc.insecure_channel(host_address, options=merged_options) # Apply interceptors ONLY if they exist if interceptors: diff --git a/durabletask/worker.py b/durabletask/worker.py index 13f13d8..cc54c6a 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -307,6 +307,7 @@ def __init__( self._channel_options = channel_options self._stop_timeout = stop_timeout self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup + self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads self._stream_ready = threading.Event() # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -384,15 +385,16 @@ async def _async_run_loop(self): current_stub = None current_reader_thread = None conn_retry_count = 0 - conn_max_retry_delay = 60 + conn_max_retry_delay = 15 def create_fresh_connection(): nonlocal current_channel, current_stub, conn_retry_count - if current_channel: - try: - current_channel.close() - except Exception: - pass + # Schedule deferred close of old channel to avoid orphaned TCP + # connections. In-flight RPCs on the old stub may still reference + # the channel from another thread, so we wait a grace period + # before closing instead of closing immediately. + if current_channel is not None: + self._schedule_deferred_channel_close(current_channel) current_channel = None current_stub = None try: @@ -417,31 +419,20 @@ def create_fresh_connection(): def invalidate_connection(): nonlocal current_channel, current_stub, current_reader_thread - # Cancel the response stream first to signal the reader thread to stop - if self._response_stream is not None: - try: - if hasattr(self._response_stream, "call"): - self._response_stream.call.cancel() # type: ignore - else: - self._response_stream.cancel() # type: ignore - except Exception as e: - self._logger.warning(f"Error cancelling response stream: {e}") - self._response_stream = None - - # Wait for the reader thread to finish - if current_reader_thread is not None: - current_reader_thread.join(timeout=1) - current_reader_thread = None - - # Close the channel - if current_channel: - try: - current_channel.close() - except Exception: - pass + # Schedule deferred close of old channel to avoid orphaned TCP + # connections. In-flight RPCs (e.g. CompleteActivityTask) may still + # be using the stub on another thread, so we defer the close by a + # grace period instead of closing immediately. + if current_channel is not None: + self._schedule_deferred_channel_close(current_channel) current_channel = None self._current_channel = None current_stub = None + self._response_stream = None + + if current_reader_thread is not None: + current_reader_thread.join(timeout=5) + current_reader_thread = None def should_invalidate_connection(rpc_error): error_code = rpc_error.code() # type: ignore @@ -451,6 +442,7 @@ def should_invalidate_connection(rpc_error): grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAUTHENTICATED, grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, # RST_STREAM from proxy means connection is dead } return error_code in connection_level_errors @@ -532,7 +524,11 @@ def stream_reader(): break # Other RPC errors - put in queue for async loop to handle self._logger.warning( - f"Stream reader: RPC error (code={rpc_error.code()}): {rpc_error}" + "Stream reader: RPC error (code=%s): %s", + rpc_error.code(), + rpc_error.details() + if hasattr(rpc_error, "details") + else rpc_error, ) break except Exception as stream_error: @@ -654,32 +650,19 @@ def stream_reader(): if should_invalidate: invalidate_connection() error_code = rpc_error.code() # type: ignore - error_details = str(rpc_error) + error_detail = ( + rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error) + ) if error_code == grpc.StatusCode.CANCELLED: self._logger.info(f"Disconnected from {self._host_address}") break - elif error_code == grpc.StatusCode.UNAVAILABLE: - # Check if this is a connection timeout scenario - if ( - "Timeout occurred" in error_details - or "Failed to connect to remote host" in error_details - ): - self._logger.warning( - f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection" - ) - else: - self._logger.warning( - f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying" - ) elif should_invalidate: self._logger.warning( - f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection" + f"Connection error ({error_code}): {error_detail} — resetting connection" ) else: - self._logger.warning( - f"Application-level gRPC error ({error_code}): {rpc_error}" - ) + self._logger.warning(f"gRPC error ({error_code}): {error_detail}") except RuntimeError as ex: # RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown") # Check shutdown state first @@ -738,22 +721,46 @@ def stream_reader(): except Exception as e: self._logger.warning(f"Error while waiting for worker task shutdown: {e}") + def _schedule_deferred_channel_close( + self, old_channel: grpc.Channel, grace_timeout: float = 10.0 + ): + """Schedule a deferred close of an old gRPC channel. + + Waits up to *grace_timeout* seconds for in-flight RPCs to complete + before closing the channel. This prevents orphaned TCP connections + while still allowing in-flight work (e.g. ``CompleteActivityTask`` + calls on another thread) to finish gracefully. + + During ``stop()``, ``_shutdown`` is already set so the wait returns + immediately and the channel is closed at once. + """ + # Prune already-finished cleanup threads to avoid unbounded growth + self._channel_cleanup_threads = [t for t in self._channel_cleanup_threads if t.is_alive()] + + def _deferred_close(): + try: + # Normal reconnect: wait grace period for RPCs to drain. + # Shutdown: _shutdown is already set, returns immediately. + self._shutdown.wait(timeout=grace_timeout) + finally: + try: + old_channel.close() + self._logger.debug("Deferred channel close completed") + except Exception as e: + self._logger.debug(f"Error during deferred channel close: {e}") + + thread = threading.Thread(target=_deferred_close, daemon=True, name="ChannelCleanup") + thread.start() + self._channel_cleanup_threads.append(thread) + def stop(self): """Stops the worker and waits for any pending work items to complete.""" if not self._is_running: return self._logger.info("Stopping gRPC worker...") - if self._response_stream is not None: - try: - if hasattr(self._response_stream, "call"): - self._response_stream.call.cancel() # type: ignore - else: - self._response_stream.cancel() # type: ignore - except Exception as e: - self._logger.warning(f"Error cancelling response stream: {e}") self._shutdown.set() - # Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up + # Close the channel — propagates cancellation to all streams and cleans up resources if self._current_channel is not None: try: self._current_channel.close() @@ -772,38 +779,39 @@ def stop(self): else: self._logger.debug("Worker thread completed successfully") + # Wait for any deferred channel-cleanup threads to finish + for t in self._channel_cleanup_threads: + t.join(timeout=5) + self._channel_cleanup_threads.clear() + self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False # TODO: This should be removed in the future as we do handle grpc errs def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str): - """Handle a gRPC execution error during shutdown or benign condition.""" - # During shutdown or if the instance was terminated, the channel may be close - # or the instance may no longer be recognized by the sidecar. Treat these as benign - # to reduce noisy logging when shutting down. + """Handle a gRPC execution error during shutdown or connection reset.""" details = str(rpc_error).lower() - benign_errors = { + # These errors are transient — the sidecar will re-dispatch the work item. + transient_errors = { grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN, + grpc.StatusCode.INTERNAL, } - if ( - self._shutdown.is_set() - and rpc_error.code() in benign_errors - or ( - "unknown instance id/task id combo" in details - or "channel closed" in details - or "locally cancelled by application" in details - ) - ): - self._logger.debug( - f"Ignoring gRPC {request_type} execution error during shutdown/benign condition: {rpc_error}" + is_transient = rpc_error.code() in transient_errors + is_benign = ( + "unknown instance id/task id combo" in details + or "channel closed" in details + or "locally cancelled by application" in details + ) + if is_transient or is_benign or self._shutdown.is_set(): + self._logger.warning( + f"Could not deliver {request_type} result ({rpc_error.code()}): " + f"{rpc_error.details() if hasattr(rpc_error, 'details') else rpc_error} — sidecar will re-dispatch" ) else: - self._logger.exception( - f"Failed to execute gRPC {request_type} execution error: {rpc_error}" - ) + self._logger.exception(f"Failed to deliver {request_type} result: {rpc_error}") def _execute_orchestrator( self, diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index c74ba17..5428cc0 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -2,7 +2,13 @@ from durabletask import client from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import get_default_host_address, get_grpc_channel +from durabletask.internal.shared import ( + DEFAULT_GRPC_KEEPALIVE_OPTIONS, + get_default_host_address, + get_grpc_channel, +) + +EXPECTED_DEFAULT_OPTIONS = list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) HOST_ADDRESS = "localhost:50051" METADATA = [("key1", "value1"), ("key2", "value2")] @@ -14,7 +20,7 @@ def test_get_grpc_channel_insecure(): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_secure(): @@ -26,7 +32,7 @@ def test_get_grpc_channel_secure(): args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS assert args[1] == mock_credentials.return_value - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_default_host_address(): @@ -34,7 +40,7 @@ def test_get_grpc_channel_default_host_address(): get_grpc_channel(None, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == get_default_host_address() - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_with_metadata(): @@ -45,7 +51,7 @@ def test_get_grpc_channel_with_metadata(): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS mock_intercept_channel.assert_called_once() # Capture and check the arguments passed to intercept_channel() @@ -66,61 +72,61 @@ def test_grpc_channel_with_host_name_protocol_stripping(): get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "http://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "HTTP://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "GRPC://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "grpcs://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "https://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "HTTPS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "GRPCS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_sync_channel_passes_base_options_and_max_lengths(): diff --git a/tests/durabletask/test_worker_stop.py b/tests/durabletask/test_worker_stop.py index c9cb70f..d618744 100644 --- a/tests/durabletask/test_worker_stop.py +++ b/tests/durabletask/test_worker_stop.py @@ -15,35 +15,41 @@ def _make_running_worker(): def test_stop_with_grpc_future(): + """stop() closes the channel, which propagates cancellation to all streams.""" worker = _make_running_worker() - mock_future = MagicMock(spec=grpc.Future) - worker._response_stream = mock_future + mock_channel = MagicMock() + worker._current_channel = mock_channel + worker._response_stream = MagicMock(spec=grpc.Future) worker.stop() - mock_future.cancel.assert_called_once() + mock_channel.close.assert_called_once() def test_stop_with_generator_call(): + """stop() closes the channel even when response stream has a .call attribute.""" worker = _make_running_worker() - mock_call = MagicMock() + mock_channel = MagicMock() + worker._current_channel = mock_channel mock_stream = MagicMock() - mock_stream.call = mock_call + mock_stream.call = MagicMock() worker._response_stream = mock_stream worker.stop() - mock_call.cancel.assert_called_once() + mock_channel.close.assert_called_once() -def test_stop_with_unknown_stream_type(caplog): +def test_stop_with_unknown_stream_type(): + """stop() closes the channel regardless of response stream type.""" worker = _make_running_worker() - # Not a grpc.Future, no 'call' attribute + mock_channel = MagicMock() + worker._current_channel = mock_channel worker._response_stream = object() - with caplog.at_level("WARNING"): - worker.stop() - assert any("Error cancelling response stream: " in m for m in caplog.text.splitlines()) + worker.stop() + mock_channel.close.assert_called_once() def test_stop_with_none_stream(): worker = _make_running_worker() worker._response_stream = None + worker._current_channel = None # Should not raise worker.stop() @@ -55,3 +61,89 @@ def test_stop_when_not_running(): with patch.object(worker._shutdown, "set") as shutdown_set: worker.stop() shutdown_set.assert_not_called() + + +def test_stop_channel_close_handles_exception(caplog): + """stop() handles exceptions from channel.close() gracefully.""" + worker = _make_running_worker() + mock_channel = MagicMock() + mock_channel.close.side_effect = Exception("close failed") + worker._current_channel = mock_channel + # Should not raise + worker.stop() + assert worker._current_channel is None + + +def test_deferred_channel_close_waits_then_closes(): + """_schedule_deferred_channel_close waits grace period, then closes old channel.""" + worker = TaskHubGrpcWorker() + old_channel = MagicMock() + + worker._schedule_deferred_channel_close(old_channel, grace_timeout=0.1) + # Thread should be tracked + assert len(worker._channel_cleanup_threads) == 1 + + # Wait for the grace period to expire and the thread to finish + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_deferred_channel_close_fires_immediately_on_shutdown(): + """Deferred close returns immediately when shutdown is already set.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() + old_channel = MagicMock() + + worker._schedule_deferred_channel_close(old_channel, grace_timeout=60) + # Even with a 60s grace, shutdown makes it return immediately + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_deferred_channel_close_handles_close_exception(): + """Deferred close handles exceptions from channel.close() gracefully.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() + old_channel = MagicMock() + old_channel.close.side_effect = Exception("already closed") + + # Should not raise + worker._schedule_deferred_channel_close(old_channel, grace_timeout=0) + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_stop_joins_deferred_cleanup_threads(): + """stop() joins all deferred channel cleanup threads.""" + worker = _make_running_worker() + mock_channel = MagicMock() + worker._current_channel = mock_channel + + # Pre-populate a cleanup thread (simulating a prior reconnection) + old_channel = MagicMock() + worker._schedule_deferred_channel_close(old_channel, grace_timeout=60) + assert len(worker._channel_cleanup_threads) == 1 + + worker.stop() + # stop() sets shutdown, which unblocks deferred close threads + # After stop(), cleanup threads list should be cleared + assert len(worker._channel_cleanup_threads) == 0 + old_channel.close.assert_called_once() + mock_channel.close.assert_called_once() + + +def test_deferred_close_prunes_finished_threads(): + """_schedule_deferred_channel_close prunes already-finished threads.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() # Make threads complete immediately + + ch1 = MagicMock() + ch2 = MagicMock() + worker._schedule_deferred_channel_close(ch1, grace_timeout=0) + worker._channel_cleanup_threads[0].join(timeout=2) + + # ch1's thread is finished; scheduling ch2 should prune it + worker._schedule_deferred_channel_close(ch2, grace_timeout=0) + worker._channel_cleanup_threads[-1].join(timeout=2) + # Only the still-alive (or just-finished ch2) thread remains; ch1's was pruned + assert len(worker._channel_cleanup_threads) <= 1