From 46f51d1559fb13d0d48a4f4502e52ae547bd84d9 Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 17:44:13 +0100 Subject: [PATCH 1/6] fix(bug): properly handle RST_STREAM responses by opening new stream Signed-off-by: Casper Nielsen --- durabletask/internal/shared.py | 24 ++++++- durabletask/worker.py | 118 +++++++++++++-------------------- tests/test_shared.py | 37 +++++++++++ 3 files changed, 105 insertions(+), 74 deletions(-) create mode 100644 tests/test_shared.py diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 09645ed..b73f95f 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -50,6 +50,25 @@ def get_default_host_address() -> str: return "localhost:4001" +DEFAULT_GRPC_KEEPALIVE_OPTIONS: tuple[tuple[str, int], ...] = ( + ("grpc.keepalive_time_ms", 30_000), + ("grpc.keepalive_timeout_ms", 10_000), + ("grpc.http2.max_pings_without_data", 0), + ("grpc.keepalive_permit_without_calls", 1), +) + + +def _merge_grpc_options( + user_options: Optional[Sequence[tuple[str, Any]]], + defaults: Sequence[tuple[str, Any]] = DEFAULT_GRPC_KEEPALIVE_OPTIONS, +) -> list[tuple[str, Any]]: + """Merge user gRPC options with defaults. User options take precedence.""" + merged = dict(defaults) + if user_options: + merged.update(dict(user_options)) + return list(merged.items()) + + def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, @@ -81,10 +100,11 @@ def get_grpc_channel( host_address = host_address[len(protocol) :] break + merged_options = _merge_grpc_options(options) if secure_channel: - channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options) + channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=merged_options) else: - channel = grpc.insecure_channel(host_address, options=options) + channel = grpc.insecure_channel(host_address, options=merged_options) # Apply interceptors ONLY if they exist if interceptors: diff --git a/durabletask/worker.py b/durabletask/worker.py index 13f13d8..e549237 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -384,15 +384,13 @@ async def _async_run_loop(self): current_stub = None current_reader_thread = None conn_retry_count = 0 - conn_max_retry_delay = 60 + conn_max_retry_delay = 15 def create_fresh_connection(): nonlocal current_channel, current_stub, conn_retry_count - if current_channel: - try: - current_channel.close() - except Exception: - pass + # Don't call channel.close() — in-flight activity RPCs on the + # old stub may still reference the channel from another thread. + # The old channel is GC'd once all references are released. current_channel = None current_stub = None try: @@ -409,7 +407,8 @@ def create_fresh_connection(): conn_retry_count = 0 self._logger.info(f"Created fresh connection to {self._host_address}") except Exception as e: - self._logger.warning(f"Failed to create connection: {e}") + detail = getattr(e, "details", lambda: str(e))() + self._logger.warning("Failed to create connection: %s", detail) current_channel = None self._current_channel = None current_stub = None @@ -417,31 +416,20 @@ def create_fresh_connection(): def invalidate_connection(): nonlocal current_channel, current_stub, current_reader_thread - # Cancel the response stream first to signal the reader thread to stop - if self._response_stream is not None: - try: - if hasattr(self._response_stream, "call"): - self._response_stream.call.cancel() # type: ignore - else: - self._response_stream.cancel() # type: ignore - except Exception as e: - self._logger.warning(f"Error cancelling response stream: {e}") - self._response_stream = None - - # Wait for the reader thread to finish - if current_reader_thread is not None: - current_reader_thread.join(timeout=1) - current_reader_thread = None - - # Close the channel - if current_channel: - try: - current_channel.close() - except Exception: - pass + # Null out references so the next iteration creates a fresh connection. + # Do NOT call channel.close() here — in-flight activity RPCs + # (CompleteActivityTask) may still be using the stub on another + # thread. Closing the channel concurrently causes segfaults in the + # gRPC C extension. The old channel is GC'd once all references + # (including captured stub refs in activity threads) are released. current_channel = None self._current_channel = None current_stub = None + self._response_stream = None + + if current_reader_thread is not None: + current_reader_thread.join(timeout=5) + current_reader_thread = None def should_invalidate_connection(rpc_error): error_code = rpc_error.code() # type: ignore @@ -451,6 +439,7 @@ def should_invalidate_connection(rpc_error): grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAUTHENTICATED, grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, # RST_STREAM from proxy means connection is dead } return error_code in connection_level_errors @@ -532,7 +521,9 @@ def stream_reader(): break # Other RPC errors - put in queue for async loop to handle self._logger.warning( - f"Stream reader: RPC error (code={rpc_error.code()}): {rpc_error}" + "Stream reader: RPC error (code=%s): %s", + rpc_error.code(), + rpc_error.details() if hasattr(rpc_error, "details") else rpc_error, ) break except Exception as stream_error: @@ -654,31 +645,22 @@ def stream_reader(): if should_invalidate: invalidate_connection() error_code = rpc_error.code() # type: ignore - error_details = str(rpc_error) + error_detail = rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error) if error_code == grpc.StatusCode.CANCELLED: self._logger.info(f"Disconnected from {self._host_address}") break - elif error_code == grpc.StatusCode.UNAVAILABLE: - # Check if this is a connection timeout scenario - if ( - "Timeout occurred" in error_details - or "Failed to connect to remote host" in error_details - ): - self._logger.warning( - f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection" - ) - else: - self._logger.warning( - f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying" - ) elif should_invalidate: self._logger.warning( - f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection" + "Connection error (%s): %s — resetting connection", + error_code, + error_detail, ) else: self._logger.warning( - f"Application-level gRPC error ({error_code}): {rpc_error}" + "gRPC error (%s): %s", + error_code, + error_detail, ) except RuntimeError as ex: # RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown") @@ -744,16 +726,8 @@ def stop(self): return self._logger.info("Stopping gRPC worker...") - if self._response_stream is not None: - try: - if hasattr(self._response_stream, "call"): - self._response_stream.call.cancel() # type: ignore - else: - self._response_stream.cancel() # type: ignore - except Exception as e: - self._logger.warning(f"Error cancelling response stream: {e}") self._shutdown.set() - # Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up + # Close the channel — propagates cancellation to all streams and cleans up resources if self._current_channel is not None: try: self._current_channel.close() @@ -778,31 +752,31 @@ def stop(self): # TODO: This should be removed in the future as we do handle grpc errs def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str): - """Handle a gRPC execution error during shutdown or benign condition.""" - # During shutdown or if the instance was terminated, the channel may be close - # or the instance may no longer be recognized by the sidecar. Treat these as benign - # to reduce noisy logging when shutting down. + """Handle a gRPC execution error during shutdown or connection reset.""" details = str(rpc_error).lower() - benign_errors = { + # These errors are transient — the sidecar will re-dispatch the work item. + transient_errors = { grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN, + grpc.StatusCode.INTERNAL, } - if ( - self._shutdown.is_set() - and rpc_error.code() in benign_errors - or ( - "unknown instance id/task id combo" in details - or "channel closed" in details - or "locally cancelled by application" in details - ) - ): - self._logger.debug( - f"Ignoring gRPC {request_type} execution error during shutdown/benign condition: {rpc_error}" + is_transient = rpc_error.code() in transient_errors + is_benign = ( + "unknown instance id/task id combo" in details + or "channel closed" in details + or "locally cancelled by application" in details + ) + if is_transient or is_benign or self._shutdown.is_set(): + self._logger.warning( + "Could not deliver %s result (%s): %s — sidecar will re-dispatch", + request_type, + rpc_error.code(), + rpc_error.details() if hasattr(rpc_error, "details") else rpc_error, ) else: self._logger.exception( - f"Failed to execute gRPC {request_type} execution error: {rpc_error}" + f"Failed to deliver {request_type} result: {rpc_error}" ) def _execute_orchestrator( diff --git a/tests/test_shared.py b/tests/test_shared.py new file mode 100644 index 0000000..73bc9f6 --- /dev/null +++ b/tests/test_shared.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.internal.shared import ( + DEFAULT_GRPC_KEEPALIVE_OPTIONS, + _merge_grpc_options, +) + + +class TestMergeGrpcOptions: + def test_user_options_take_precedence(self): + """User-supplied options override defaults with the same key.""" + user_options = [ + ("grpc.keepalive_time_ms", 60_000), + ("grpc.custom_option", 42), + ] + result = _merge_grpc_options(user_options) + result_dict = dict(result) + + # User override should win + assert result_dict["grpc.keepalive_time_ms"] == 60_000 + # User-only option should be present + assert result_dict["grpc.custom_option"] == 42 + # Non-overridden defaults should still be present + assert result_dict["grpc.keepalive_timeout_ms"] == 10_000 + assert result_dict["grpc.http2.max_pings_without_data"] == 0 + assert result_dict["grpc.keepalive_permit_without_calls"] == 1 + + def test_defaults_used_when_no_user_options(self): + """When user passes an empty sequence, all defaults are returned.""" + result = _merge_grpc_options([]) + assert result == list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) + + def test_none_user_options(self): + """When user passes None, all defaults are returned.""" + result = _merge_grpc_options(None) + assert result == list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) From d32dc65d435b671eac2f8fed8f0b883d83f3fc28 Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 17:50:03 +0100 Subject: [PATCH 2/6] chore(format): ruff Signed-off-by: Casper Nielsen --- durabletask/internal/shared.py | 4 +++- durabletask/worker.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index b73f95f..23661ba 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -102,7 +102,9 @@ def get_grpc_channel( merged_options = _merge_grpc_options(options) if secure_channel: - channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=merged_options) + channel = grpc.secure_channel( + host_address, grpc.ssl_channel_credentials(), options=merged_options + ) else: channel = grpc.insecure_channel(host_address, options=merged_options) diff --git a/durabletask/worker.py b/durabletask/worker.py index e549237..487d80a 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -407,8 +407,7 @@ def create_fresh_connection(): conn_retry_count = 0 self._logger.info(f"Created fresh connection to {self._host_address}") except Exception as e: - detail = getattr(e, "details", lambda: str(e))() - self._logger.warning("Failed to create connection: %s", detail) + self._logger.warning(f"Failed to create connection: {e}") current_channel = None self._current_channel = None current_stub = None @@ -523,7 +522,9 @@ def stream_reader(): self._logger.warning( "Stream reader: RPC error (code=%s): %s", rpc_error.code(), - rpc_error.details() if hasattr(rpc_error, "details") else rpc_error, + rpc_error.details() + if hasattr(rpc_error, "details") + else rpc_error, ) break except Exception as stream_error: @@ -645,7 +646,9 @@ def stream_reader(): if should_invalidate: invalidate_connection() error_code = rpc_error.code() # type: ignore - error_detail = rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error) + error_detail = ( + rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error) + ) if error_code == grpc.StatusCode.CANCELLED: self._logger.info(f"Disconnected from {self._host_address}") @@ -775,9 +778,7 @@ def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: s rpc_error.details() if hasattr(rpc_error, "details") else rpc_error, ) else: - self._logger.exception( - f"Failed to deliver {request_type} result: {rpc_error}" - ) + self._logger.exception(f"Failed to deliver {request_type} result: {rpc_error}") def _execute_orchestrator( self, From c21bebc3ff9f741e4c8abae2fa19d54348e63ab0 Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 18:00:36 +0100 Subject: [PATCH 3/6] fix: correct license header & move file Signed-off-by: Casper Nielsen --- tests/{ => durabletask}/test_shared.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) rename tests/{ => durabletask}/test_shared.py (70%) diff --git a/tests/test_shared.py b/tests/durabletask/test_shared.py similarity index 70% rename from tests/test_shared.py rename to tests/durabletask/test_shared.py index 73bc9f6..eecd84e 100644 --- a/tests/test_shared.py +++ b/tests/durabletask/test_shared.py @@ -1,5 +1,16 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + from durabletask.internal.shared import ( DEFAULT_GRPC_KEEPALIVE_OPTIONS, From 29818bb95c2b3fe5e6fdd8faa909fdb624e5604e Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 18:01:21 +0100 Subject: [PATCH 4/6] chore(format): f-strings Signed-off-by: Casper Nielsen --- durabletask/worker.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index 487d80a..d0f5654 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -655,16 +655,10 @@ def stream_reader(): break elif should_invalidate: self._logger.warning( - "Connection error (%s): %s — resetting connection", - error_code, - error_detail, + f"Connection error ({error_code}): {error_detail} — resetting connection" ) else: - self._logger.warning( - "gRPC error (%s): %s", - error_code, - error_detail, - ) + self._logger.warning(f"gRPC error ({error_code}): {error_detail}") except RuntimeError as ex: # RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown") # Check shutdown state first @@ -772,10 +766,8 @@ def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: s ) if is_transient or is_benign or self._shutdown.is_set(): self._logger.warning( - "Could not deliver %s result (%s): %s — sidecar will re-dispatch", - request_type, - rpc_error.code(), - rpc_error.details() if hasattr(rpc_error, "details") else rpc_error, + f"Could not deliver {request_type} result ({rpc_error.code()}): " + f"{rpc_error.details() if hasattr(rpc_error, 'details') else rpc_error} — sidecar will re-dispatch" ) else: self._logger.exception(f"Failed to deliver {request_type} result: {rpc_error}") From bdb1a444dec14717225231bd750c8c7fbd6acd41 Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 18:01:44 +0100 Subject: [PATCH 5/6] chore(format): ruff Signed-off-by: Casper Nielsen --- tests/durabletask/test_shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/durabletask/test_shared.py b/tests/durabletask/test_shared.py index eecd84e..4761040 100644 --- a/tests/durabletask/test_shared.py +++ b/tests/durabletask/test_shared.py @@ -11,7 +11,6 @@ limitations under the License. """ - from durabletask.internal.shared import ( DEFAULT_GRPC_KEEPALIVE_OPTIONS, _merge_grpc_options, From ac301e61e93786feaf2d901de0c63b8fd5bf092e Mon Sep 17 00:00:00 2001 From: Casper Nielsen Date: Wed, 11 Mar 2026 18:33:18 +0100 Subject: [PATCH 6/6] fix: address pr comments Signed-off-by: Casper Nielsen --- durabletask/internal/shared.py | 16 +--- durabletask/worker.py | 59 +++++++++++-- tests/durabletask/test_client.py | 36 ++++---- tests/durabletask/test_shared.py | 47 ----------- tests/durabletask/test_worker_stop.py | 114 +++++++++++++++++++++++--- 5 files changed, 178 insertions(+), 94 deletions(-) delete mode 100644 tests/durabletask/test_shared.py diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 23661ba..f676332 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -58,17 +58,6 @@ def get_default_host_address() -> str: ) -def _merge_grpc_options( - user_options: Optional[Sequence[tuple[str, Any]]], - defaults: Sequence[tuple[str, Any]] = DEFAULT_GRPC_KEEPALIVE_OPTIONS, -) -> list[tuple[str, Any]]: - """Merge user gRPC options with defaults. User options take precedence.""" - merged = dict(defaults) - if user_options: - merged.update(dict(user_options)) - return list(merged.items()) - - def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, @@ -100,7 +89,10 @@ def get_grpc_channel( host_address = host_address[len(protocol) :] break - merged_options = _merge_grpc_options(options) + merged = dict(DEFAULT_GRPC_KEEPALIVE_OPTIONS) + if options: + merged.update(dict(options)) + merged_options = list(merged.items()) if secure_channel: channel = grpc.secure_channel( host_address, grpc.ssl_channel_credentials(), options=merged_options diff --git a/durabletask/worker.py b/durabletask/worker.py index d0f5654..cc54c6a 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -307,6 +307,7 @@ def __init__( self._channel_options = channel_options self._stop_timeout = stop_timeout self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup + self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads self._stream_ready = threading.Event() # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -388,9 +389,12 @@ async def _async_run_loop(self): def create_fresh_connection(): nonlocal current_channel, current_stub, conn_retry_count - # Don't call channel.close() — in-flight activity RPCs on the - # old stub may still reference the channel from another thread. - # The old channel is GC'd once all references are released. + # Schedule deferred close of old channel to avoid orphaned TCP + # connections. In-flight RPCs on the old stub may still reference + # the channel from another thread, so we wait a grace period + # before closing instead of closing immediately. + if current_channel is not None: + self._schedule_deferred_channel_close(current_channel) current_channel = None current_stub = None try: @@ -415,12 +419,12 @@ def create_fresh_connection(): def invalidate_connection(): nonlocal current_channel, current_stub, current_reader_thread - # Null out references so the next iteration creates a fresh connection. - # Do NOT call channel.close() here — in-flight activity RPCs - # (CompleteActivityTask) may still be using the stub on another - # thread. Closing the channel concurrently causes segfaults in the - # gRPC C extension. The old channel is GC'd once all references - # (including captured stub refs in activity threads) are released. + # Schedule deferred close of old channel to avoid orphaned TCP + # connections. In-flight RPCs (e.g. CompleteActivityTask) may still + # be using the stub on another thread, so we defer the close by a + # grace period instead of closing immediately. + if current_channel is not None: + self._schedule_deferred_channel_close(current_channel) current_channel = None self._current_channel = None current_stub = None @@ -717,6 +721,38 @@ def stream_reader(): except Exception as e: self._logger.warning(f"Error while waiting for worker task shutdown: {e}") + def _schedule_deferred_channel_close( + self, old_channel: grpc.Channel, grace_timeout: float = 10.0 + ): + """Schedule a deferred close of an old gRPC channel. + + Waits up to *grace_timeout* seconds for in-flight RPCs to complete + before closing the channel. This prevents orphaned TCP connections + while still allowing in-flight work (e.g. ``CompleteActivityTask`` + calls on another thread) to finish gracefully. + + During ``stop()``, ``_shutdown`` is already set so the wait returns + immediately and the channel is closed at once. + """ + # Prune already-finished cleanup threads to avoid unbounded growth + self._channel_cleanup_threads = [t for t in self._channel_cleanup_threads if t.is_alive()] + + def _deferred_close(): + try: + # Normal reconnect: wait grace period for RPCs to drain. + # Shutdown: _shutdown is already set, returns immediately. + self._shutdown.wait(timeout=grace_timeout) + finally: + try: + old_channel.close() + self._logger.debug("Deferred channel close completed") + except Exception as e: + self._logger.debug(f"Error during deferred channel close: {e}") + + thread = threading.Thread(target=_deferred_close, daemon=True, name="ChannelCleanup") + thread.start() + self._channel_cleanup_threads.append(thread) + def stop(self): """Stops the worker and waits for any pending work items to complete.""" if not self._is_running: @@ -743,6 +779,11 @@ def stop(self): else: self._logger.debug("Worker thread completed successfully") + # Wait for any deferred channel-cleanup threads to finish + for t in self._channel_cleanup_threads: + t.join(timeout=5) + self._channel_cleanup_threads.clear() + self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index c74ba17..5428cc0 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -2,7 +2,13 @@ from durabletask import client from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import get_default_host_address, get_grpc_channel +from durabletask.internal.shared import ( + DEFAULT_GRPC_KEEPALIVE_OPTIONS, + get_default_host_address, + get_grpc_channel, +) + +EXPECTED_DEFAULT_OPTIONS = list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) HOST_ADDRESS = "localhost:50051" METADATA = [("key1", "value1"), ("key2", "value2")] @@ -14,7 +20,7 @@ def test_get_grpc_channel_insecure(): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_secure(): @@ -26,7 +32,7 @@ def test_get_grpc_channel_secure(): args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS assert args[1] == mock_credentials.return_value - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_default_host_address(): @@ -34,7 +40,7 @@ def test_get_grpc_channel_default_host_address(): get_grpc_channel(None, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == get_default_host_address() - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_get_grpc_channel_with_metadata(): @@ -45,7 +51,7 @@ def test_get_grpc_channel_with_metadata(): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) args, kwargs = mock_channel.call_args assert args[0] == HOST_ADDRESS - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS mock_intercept_channel.assert_called_once() # Capture and check the arguments passed to intercept_channel() @@ -66,61 +72,61 @@ def test_grpc_channel_with_host_name_protocol_stripping(): get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "http://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "HTTP://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "GRPC://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_insecure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "grpcs://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "https://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "HTTPS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "GRPCS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) args, kwargs = mock_secure_channel.call_args assert args[0] == host_name - assert "options" in kwargs and kwargs["options"] is None + assert "options" in kwargs and kwargs["options"] == EXPECTED_DEFAULT_OPTIONS def test_sync_channel_passes_base_options_and_max_lengths(): diff --git a/tests/durabletask/test_shared.py b/tests/durabletask/test_shared.py deleted file mode 100644 index 4761040..0000000 --- a/tests/durabletask/test_shared.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Copyright 2025 The Dapr Authors -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from durabletask.internal.shared import ( - DEFAULT_GRPC_KEEPALIVE_OPTIONS, - _merge_grpc_options, -) - - -class TestMergeGrpcOptions: - def test_user_options_take_precedence(self): - """User-supplied options override defaults with the same key.""" - user_options = [ - ("grpc.keepalive_time_ms", 60_000), - ("grpc.custom_option", 42), - ] - result = _merge_grpc_options(user_options) - result_dict = dict(result) - - # User override should win - assert result_dict["grpc.keepalive_time_ms"] == 60_000 - # User-only option should be present - assert result_dict["grpc.custom_option"] == 42 - # Non-overridden defaults should still be present - assert result_dict["grpc.keepalive_timeout_ms"] == 10_000 - assert result_dict["grpc.http2.max_pings_without_data"] == 0 - assert result_dict["grpc.keepalive_permit_without_calls"] == 1 - - def test_defaults_used_when_no_user_options(self): - """When user passes an empty sequence, all defaults are returned.""" - result = _merge_grpc_options([]) - assert result == list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) - - def test_none_user_options(self): - """When user passes None, all defaults are returned.""" - result = _merge_grpc_options(None) - assert result == list(DEFAULT_GRPC_KEEPALIVE_OPTIONS) diff --git a/tests/durabletask/test_worker_stop.py b/tests/durabletask/test_worker_stop.py index c9cb70f..d618744 100644 --- a/tests/durabletask/test_worker_stop.py +++ b/tests/durabletask/test_worker_stop.py @@ -15,35 +15,41 @@ def _make_running_worker(): def test_stop_with_grpc_future(): + """stop() closes the channel, which propagates cancellation to all streams.""" worker = _make_running_worker() - mock_future = MagicMock(spec=grpc.Future) - worker._response_stream = mock_future + mock_channel = MagicMock() + worker._current_channel = mock_channel + worker._response_stream = MagicMock(spec=grpc.Future) worker.stop() - mock_future.cancel.assert_called_once() + mock_channel.close.assert_called_once() def test_stop_with_generator_call(): + """stop() closes the channel even when response stream has a .call attribute.""" worker = _make_running_worker() - mock_call = MagicMock() + mock_channel = MagicMock() + worker._current_channel = mock_channel mock_stream = MagicMock() - mock_stream.call = mock_call + mock_stream.call = MagicMock() worker._response_stream = mock_stream worker.stop() - mock_call.cancel.assert_called_once() + mock_channel.close.assert_called_once() -def test_stop_with_unknown_stream_type(caplog): +def test_stop_with_unknown_stream_type(): + """stop() closes the channel regardless of response stream type.""" worker = _make_running_worker() - # Not a grpc.Future, no 'call' attribute + mock_channel = MagicMock() + worker._current_channel = mock_channel worker._response_stream = object() - with caplog.at_level("WARNING"): - worker.stop() - assert any("Error cancelling response stream: " in m for m in caplog.text.splitlines()) + worker.stop() + mock_channel.close.assert_called_once() def test_stop_with_none_stream(): worker = _make_running_worker() worker._response_stream = None + worker._current_channel = None # Should not raise worker.stop() @@ -55,3 +61,89 @@ def test_stop_when_not_running(): with patch.object(worker._shutdown, "set") as shutdown_set: worker.stop() shutdown_set.assert_not_called() + + +def test_stop_channel_close_handles_exception(caplog): + """stop() handles exceptions from channel.close() gracefully.""" + worker = _make_running_worker() + mock_channel = MagicMock() + mock_channel.close.side_effect = Exception("close failed") + worker._current_channel = mock_channel + # Should not raise + worker.stop() + assert worker._current_channel is None + + +def test_deferred_channel_close_waits_then_closes(): + """_schedule_deferred_channel_close waits grace period, then closes old channel.""" + worker = TaskHubGrpcWorker() + old_channel = MagicMock() + + worker._schedule_deferred_channel_close(old_channel, grace_timeout=0.1) + # Thread should be tracked + assert len(worker._channel_cleanup_threads) == 1 + + # Wait for the grace period to expire and the thread to finish + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_deferred_channel_close_fires_immediately_on_shutdown(): + """Deferred close returns immediately when shutdown is already set.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() + old_channel = MagicMock() + + worker._schedule_deferred_channel_close(old_channel, grace_timeout=60) + # Even with a 60s grace, shutdown makes it return immediately + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_deferred_channel_close_handles_close_exception(): + """Deferred close handles exceptions from channel.close() gracefully.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() + old_channel = MagicMock() + old_channel.close.side_effect = Exception("already closed") + + # Should not raise + worker._schedule_deferred_channel_close(old_channel, grace_timeout=0) + worker._channel_cleanup_threads[0].join(timeout=2) + old_channel.close.assert_called_once() + + +def test_stop_joins_deferred_cleanup_threads(): + """stop() joins all deferred channel cleanup threads.""" + worker = _make_running_worker() + mock_channel = MagicMock() + worker._current_channel = mock_channel + + # Pre-populate a cleanup thread (simulating a prior reconnection) + old_channel = MagicMock() + worker._schedule_deferred_channel_close(old_channel, grace_timeout=60) + assert len(worker._channel_cleanup_threads) == 1 + + worker.stop() + # stop() sets shutdown, which unblocks deferred close threads + # After stop(), cleanup threads list should be cleared + assert len(worker._channel_cleanup_threads) == 0 + old_channel.close.assert_called_once() + mock_channel.close.assert_called_once() + + +def test_deferred_close_prunes_finished_threads(): + """_schedule_deferred_channel_close prunes already-finished threads.""" + worker = TaskHubGrpcWorker() + worker._shutdown.set() # Make threads complete immediately + + ch1 = MagicMock() + ch2 = MagicMock() + worker._schedule_deferred_channel_close(ch1, grace_timeout=0) + worker._channel_cleanup_threads[0].join(timeout=2) + + # ch1's thread is finished; scheduling ch2 should prune it + worker._schedule_deferred_channel_close(ch2, grace_timeout=0) + worker._channel_cleanup_threads[-1].join(timeout=2) + # Only the still-alive (or just-finished ch2) thread remains; ch1's was pruned + assert len(worker._channel_cleanup_threads) <= 1