diff --git a/durabletask/task.py b/durabletask/task.py index 83750ff..59f0c8f 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -70,13 +70,13 @@ def is_replaying(self) -> bool: pass @abstractmethod - def set_custom_status(self, custom_status: Any) -> None: + def set_custom_status(self, custom_status: str) -> None: """Set the orchestration instance's custom status. Parameters ---------- - custom_status: Any - A JSON-serializable custom status value to set. + custom_status: str + A custom status string to set. """ pass diff --git a/durabletask/worker.py b/durabletask/worker.py index 13f13d8..b7e997d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -8,6 +8,7 @@ import os import random import threading +import warnings from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from threading import Event, Thread @@ -1093,10 +1094,17 @@ def current_utc_datetime(self, value: datetime): def is_replaying(self) -> bool: return self._is_replaying - def set_custom_status(self, custom_status: Any) -> None: - self._encoded_custom_status = ( - shared.to_json(custom_status) if custom_status is not None else None - ) + def set_custom_status(self, custom_status: str) -> None: + if custom_status is not None and not isinstance(custom_status, str): + warnings.warn( + "Passing a non-str value to set_custom_status is deprecated and will be " + "removed in a future version. Serialize your value to a JSON string before calling.", + DeprecationWarning, + stacklevel=2, + ) + self._encoded_custom_status = shared.to_json(custom_status) + else: + self._encoded_custom_status = custom_status def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index f3cd56c..8711441 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -594,7 +594,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.runtime_status == client.OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None - assert state.serialized_custom_status == '"foobaz"' + assert state.serialized_custom_status == 'foobaz' def test_now_with_sequence_ordering(): diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index b2f3003..fa09a47 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -481,4 +481,4 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None - assert state.serialized_custom_status == '"foobaz"' + assert state.serialized_custom_status == 'foobaz'