diff --git a/README.md b/README.md index 1a4472c72..e212cbd43 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,14 @@ git clone https://github.com/dapr/python-sdk.git cd python-sdk ``` -2. Install a project in a editable mode +2. Create and activate a virtual environment + +```bash +python3 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +3. Install a project in editable mode ```bash pip3 install -e . @@ -90,31 +97,31 @@ pip3 install -e ./ext/dapr-ext-langgraph/ pip3 install -e ./ext/dapr-ext-strands/ ``` -3. Install required packages +4. Install required packages ```bash pip3 install -r dev-requirements.txt ``` -4. Run linter and autofix +5. Run linter and autofix ```bash tox -e ruff ``` -5. Run unit-test +6. Run unit-test ```bash tox -e py311 ``` -6. Run type check +7. Run type check ```bash tox -e type ``` -7. Run examples +8. Run examples ```bash tox -e examples diff --git a/dapr/clients/grpc/client.py b/dapr/clients/grpc/client.py index ae1206c4c..fa69c359a 100644 --- a/dapr/clients/grpc/client.py +++ b/dapr/clients/grpc/client.py @@ -25,8 +25,10 @@ import grpc # type: ignore from google.protobuf.any_pb2 import Any as GrpcAny +from google.protobuf.duration_pb2 import Duration as GrpcDuration from google.protobuf.empty_pb2 import Empty as GrpcEmpty from google.protobuf.message import Message as GrpcMessage +from google.protobuf.struct_pb2 import Struct as GrpcStruct from grpc import ( # type: ignore RpcError, StatusCode, @@ -1880,6 +1882,8 @@ def converse_alpha2( temperature: Optional[float] = None, tools: Optional[List[conversation.ConversationTools]] = None, tool_choice: Optional[str] = None, + response_format: Optional[GrpcStruct] = None, + prompt_cache_retention: Optional[GrpcDuration] = None, ) -> conversation.ConversationResponseAlpha2: """Invoke an LLM using the conversation API (Alpha2) with tool calling support. @@ -1893,6 +1897,8 @@ def converse_alpha2( temperature: Optional temperature setting for the LLM to optimize for creativity or predictability tools: Optional list of tools available for the LLM to call tool_choice: Optional control over which tools can be called ('none', 'auto', 'required', or specific tool name) + response_format: Optional response format (google.protobuf.struct_pb2.Struct, ex: json_schema for structured output) + prompt_cache_retention: Optional retention for prompt cache (google.protobuf.duration_pb2.Duration) Returns: ConversationResponseAlpha2 containing the conversation results with choices and tool calls @@ -1949,6 +1955,10 @@ def converse_alpha2( request.temperature = temperature if tool_choice is not None: request.tool_choice = tool_choice + if response_format is not None and hasattr(request, 'response_format'): + request.response_format.CopyFrom(response_format) + if prompt_cache_retention is not None and hasattr(request, 'prompt_cache_retention'): + request.prompt_cache_retention.CopyFrom(prompt_cache_retention) try: response, call = self.retry_policy.run_rpc(self._stub.ConverseAlpha2.with_call, request) diff --git a/dapr/clients/grpc/conversation.py b/dapr/clients/grpc/conversation.py index d11c41979..8fc3db067 100644 --- a/dapr/clients/grpc/conversation.py +++ b/dapr/clients/grpc/conversation.py @@ -338,11 +338,46 @@ class ConversationResultAlpha2Choices: message: ConversationResultAlpha2Message +@dataclass +class ConversationResultAlpha2CompletionUsageCompletionTokensDetails: + """Breakdown of tokens used in the completion.""" + + accepted_prediction_tokens: int = 0 + audio_tokens: int = 0 + reasoning_tokens: int = 0 + rejected_prediction_tokens: int = 0 + + +@dataclass +class ConversationResultAlpha2CompletionUsagePromptTokensDetails: + """Breakdown of tokens used in the prompt.""" + + audio_tokens: int = 0 + cached_tokens: int = 0 + + +@dataclass +class ConversationResultAlpha2CompletionUsage: + """Token usage for one Alpha2 conversation result.""" + + completion_tokens: int = 0 + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens_details: Optional[ + ConversationResultAlpha2CompletionUsageCompletionTokensDetails + ] = None + prompt_tokens_details: Optional[ConversationResultAlpha2CompletionUsagePromptTokensDetails] = ( + None + ) + + @dataclass class ConversationResultAlpha2: """One of the outputs in Alpha2 response from conversation input.""" choices: List[ConversationResultAlpha2Choices] = field(default_factory=list) + model: Optional[str] = None + usage: Optional[ConversationResultAlpha2CompletionUsage] = None @dataclass @@ -657,5 +692,38 @@ def _get_outputs_from_grpc_response( ) ) - outputs.append(ConversationResultAlpha2(choices=choices)) + model: Optional[str] = None + usage: Optional[ConversationResultAlpha2CompletionUsage] = None + if hasattr(output, 'model') and getattr(output, 'model', None): + model = output.model + if hasattr(output, 'usage') and output.usage: + u = output.usage + completion_details: Optional[ + ConversationResultAlpha2CompletionUsageCompletionTokensDetails + ] = None + prompt_details: Optional[ConversationResultAlpha2CompletionUsagePromptTokensDetails] = ( + None + ) + if hasattr(u, 'completion_tokens_details') and u.completion_tokens_details: + cd = u.completion_tokens_details + completion_details = ConversationResultAlpha2CompletionUsageCompletionTokensDetails( + accepted_prediction_tokens=getattr(cd, 'accepted_prediction_tokens', 0) or 0, + audio_tokens=getattr(cd, 'audio_tokens', 0) or 0, + reasoning_tokens=getattr(cd, 'reasoning_tokens', 0) or 0, + rejected_prediction_tokens=getattr(cd, 'rejected_prediction_tokens', 0) or 0, + ) + if hasattr(u, 'prompt_tokens_details') and u.prompt_tokens_details: + pd = u.prompt_tokens_details + prompt_details = ConversationResultAlpha2CompletionUsagePromptTokensDetails( + audio_tokens=getattr(pd, 'audio_tokens', 0) or 0, + cached_tokens=getattr(pd, 'cached_tokens', 0) or 0, + ) + usage = ConversationResultAlpha2CompletionUsage( + completion_tokens=getattr(u, 'completion_tokens', 0) or 0, + prompt_tokens=getattr(u, 'prompt_tokens', 0) or 0, + total_tokens=getattr(u, 'total_tokens', 0) or 0, + completion_tokens_details=completion_details, + prompt_tokens_details=prompt_details, + ) + outputs.append(ConversationResultAlpha2(choices=choices, model=model, usage=usage)) return outputs diff --git a/dev-requirements.txt b/dev-requirements.txt index acf05f8bd..18cdd4342 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -14,7 +14,7 @@ Flask>=1.1 # needed for auto fix ruff===0.14.1 # needed for dapr-ext-workflow -durabletask-dapr >= 0.2.0a19 +durabletask-dapr >= 0.17.0 # needed for .env file loading in examples python-dotenv>=1.0.0 # needed for enhanced schema generation from function features diff --git a/examples/workflow/README.md b/examples/workflow/README.md index 22a55e868..12829f0bd 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -338,7 +338,7 @@ When you run the example, you will see output like this: ``` -### Cross-app Workflow +### Multi-app Workflows This example demonstrates how to call child workflows and activities in different apps. The multiple Dapr CLI instances can be started using the following commands: @@ -361,9 +361,9 @@ sleep: 20 --> ```sh -dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 cross-app3.py & -dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 cross-app2.py & -dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 cross-app1.py +dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 multi-app3.py & +dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 multi-app2.py & +dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 multi-app1.py ``` @@ -379,9 +379,9 @@ among others. This shows that the workflow calls are working as expected. #### Error handling on activity calls -This example demonstrates how the error handling works on activity calls across apps. +This example demonstrates how the error handling works on activity calls in multi-app workflows. -Error handling on activity calls across apps works as normal workflow activity calls. +Error handling on activity calls in multi-app workflows works as normal workflow activity calls. In this example we run `app3` in failing mode, which makes the activity call return error constantly. The activity call from `app2` will fail after the retry policy is exhausted. @@ -404,9 +404,9 @@ sleep: 20 ```sh export ERROR_ACTIVITY_MODE=true -dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 cross-app3.py & -dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 cross-app2.py & -dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 cross-app1.py +dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 multi-app3.py & +dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 multi-app2.py & +dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 multi-app1.py ``` @@ -424,9 +424,9 @@ among others. This shows that the activity calls are failing as expected, and th #### Error handling on workflow calls -This example demonstrates how the error handling works on workflow calls across apps. +This example demonstrates how the error handling works on workflow calls in multi-app workflows. -Error handling on workflow calls across apps works as normal workflow calls. +Error handling on workflow calls in multi-app workflows works as normal workflow calls. In this example we run `app2` in failing mode, which makes the workflow call return error constantly. The workflow call from `app1` will fail after the retry policy is exhausted. @@ -445,9 +445,9 @@ sleep: 20 ```sh export ERROR_WORKFLOW_MODE=true -dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 cross-app3.py & -dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 cross-app2.py & -dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 cross-app1.py +dapr run --app-id wfexample3 --dapr-http-port 3503 --dapr-grpc-port 50103 -- python3 multi-app3.py & +dapr run --app-id wfexample2 --dapr-http-port 3502 --dapr-grpc-port 50102 -- python3 multi-app2.py & +dapr run --app-id wfexample1 --dapr-http-port 3501 --dapr-grpc-port 50101 -- python3 multi-app1.py ``` diff --git a/examples/workflow/child_workflow.py b/examples/workflow/child_workflow.py index 57ab2fc3e..20b675ea0 100644 --- a/examples/workflow/child_workflow.py +++ b/examples/workflow/child_workflow.py @@ -10,7 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import dapr.ext.workflow as wf @@ -40,12 +39,10 @@ def child_workflow(ctx: wf.DaprWorkflowContext): if __name__ == '__main__': wfr.start() - time.sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() instance_id = wf_client.schedule_new_workflow(workflow=main_workflow) - # Wait for the workflow to complete - time.sleep(5) + wf_client.wait_for_workflow_completion(instance_id) wfr.shutdown() diff --git a/examples/workflow/fan_out_fan_in.py b/examples/workflow/fan_out_fan_in.py index f625ea287..9cd1ff6cb 100644 --- a/examples/workflow/fan_out_fan_in.py +++ b/examples/workflow/fan_out_fan_in.py @@ -55,7 +55,6 @@ def process_results(ctx, final_result: int): if __name__ == '__main__': wfr.start() - time.sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() instance_id = wf_client.schedule_new_workflow(workflow=batch_processing_workflow, input=10) diff --git a/examples/workflow/cross-app1.py b/examples/workflow/multi-app1.py similarity index 93% rename from examples/workflow/cross-app1.py rename to examples/workflow/multi-app1.py index 1ef7b48da..9b968def3 100644 --- a/examples/workflow/cross-app1.py +++ b/examples/workflow/multi-app1.py @@ -10,7 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from datetime import timedelta import dapr.ext.workflow as wf @@ -46,13 +45,11 @@ def app1_workflow(ctx: wf.DaprWorkflowContext): if __name__ == '__main__': wfr.start() - time.sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() print('app1 - triggering app1 workflow', flush=True) instance_id = wf_client.schedule_new_workflow(workflow=app1_workflow) - # Wait for the workflow to complete - time.sleep(7) + wf_client.wait_for_workflow_completion(instance_id) wfr.shutdown() diff --git a/examples/workflow/cross-app2.py b/examples/workflow/multi-app2.py similarity index 95% rename from examples/workflow/cross-app2.py rename to examples/workflow/multi-app2.py index 2af65912c..7e97b58c0 100644 --- a/examples/workflow/cross-app2.py +++ b/examples/workflow/multi-app2.py @@ -46,5 +46,5 @@ def app2_workflow(ctx: wf.DaprWorkflowContext): if __name__ == '__main__': wfr.start() - time.sleep(15) # wait for workflow runtime to start + time.sleep(15) # Keep the workflow runtime running for a while to process workflows wfr.shutdown() diff --git a/examples/workflow/cross-app3.py b/examples/workflow/multi-app3.py similarity index 93% rename from examples/workflow/cross-app3.py rename to examples/workflow/multi-app3.py index 4bcc158a0..6b72de7e4 100644 --- a/examples/workflow/cross-app3.py +++ b/examples/workflow/multi-app3.py @@ -29,5 +29,5 @@ def app3_activity(ctx: wf.DaprWorkflowContext) -> int: if __name__ == '__main__': wfr.start() - time.sleep(15) # wait for workflow runtime to start + time.sleep(15) # Keep the workflow runtime alive for a while to process requests wfr.shutdown() diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index cc0cfe8ba..d90c72dc2 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -68,12 +68,12 @@ def call_activity( retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, ) -> task.Task[TOutput]: - # Handle string activity names for cross-app scenarios + # Handle string activity names for multi-app workflow scenarios if isinstance(activity, str): activity_name = activity if app_id is not None: self._logger.debug( - f'{self.instance_id}: Creating cross-app activity {activity_name} for app {app_id}' + f'{self.instance_id}: Creating multi-app workflow activity {activity_name} for app {app_id}' ) else: self._logger.debug(f'{self.instance_id}: Creating activity {activity_name}') @@ -106,7 +106,7 @@ def call_child_workflow( retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, ) -> task.Task[TOutput]: - # Handle string workflow names for cross-app scenarios + # Handle string workflow names for multi-app workflow scenarios if isinstance(workflow, str): workflow_name = workflow self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow_name}') diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py index b93e7074f..dd33cab86 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py @@ -32,5 +32,8 @@ def warning(self, msg, *args, **kwargs): def error(self, msg, *args, **kwargs): self._logger.error(msg, *args, **kwargs) + def exception(self, msg, *args, **kwargs): + self._logger.exception(msg, *args, **kwargs) + def critical(self, msg, *args, **kwargs): self._logger.critical(msg, *args, **kwargs) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index 8453e16ef..d41841472 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -118,7 +118,7 @@ def call_activity( Parameters ---------- activity: Activity[TInput, TOutput] | str - A reference to the activity function to call, or a string name for cross-app activities. + A reference to the activity function to call, or a string name for multi-app workflow activities. input: TInput | None The JSON-serializable input (or None) to pass to the activity. app_id: str | None @@ -145,7 +145,7 @@ def call_child_workflow( Parameters ---------- orchestrator: Orchestrator[TInput, TOutput] | str - A reference to the orchestrator function to call, or a string name for cross-app workflows. + A reference to the orchestrator function to call, or a string name for multi-app workflows. input: TInput The optional JSON-serializable input to pass to the orchestrator function. instance_id: str diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 58b0912a0..9f5edb2b4 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -14,6 +14,7 @@ """ import inspect +import time from functools import wraps from typing import Optional, Sequence, TypeVar, Union @@ -54,8 +55,11 @@ def __init__( maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, + worker_ready_timeout: Optional[float] = None, ): self._logger = Logger('WorkflowRuntime', logger_options) + self._worker_ready_timeout = 30.0 if worker_ready_timeout is None else worker_ready_timeout + metadata = tuple() if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) @@ -86,10 +90,20 @@ def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): """Responsible to call Workflow function in orchestrationWrapper""" - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: - return fn(daprWfContext) - return fn(daprWfContext, inp) + instance_id = getattr(ctx, 'instance_id', 'unknown') + + try: + daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) + if inp is None: + result = fn(daprWfContext) + else: + result = fn(daprWfContext, inp) + return result + except Exception as e: + self._logger.exception( + f'Workflow execution failed - instance_id: {instance_id}, error: {e}' + ) + raise if hasattr(fn, '_workflow_registered'): # whenever a workflow is registered, it has a _dapr_alternate_name attribute @@ -152,10 +166,20 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): """Responsible to call Activity function in activityWrapper""" - wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: - return fn(wfActivityContext) - return fn(wfActivityContext, inp) + activity_id = getattr(ctx, 'task_id', 'unknown') + + try: + wfActivityContext = WorkflowActivityContext(ctx) + if inp is None: + result = fn(wfActivityContext) + else: + result = fn(wfActivityContext, inp) + return result + except Exception as e: + self._logger.exception( + f'Activity execution failed - task_id: {activity_id}, error: {e}' + ) + raise if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -174,13 +198,77 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): ) fn.__dict__['_activity_registered'] = True + def wait_for_worker_ready(self, timeout: float = 30.0) -> bool: + """ + Wait for the worker's gRPC stream to become ready to receive work items. + This method polls the worker's is_worker_ready() method until it returns True + or the timeout is reached. + + Args: + timeout: Maximum time in seconds to wait for the worker to be ready. + Defaults to 30 seconds. + + Returns: + True if the worker's gRPC stream is ready to receive work items, False if timeout. + """ + if not hasattr(self.__worker, 'is_worker_ready'): + return False + + elapsed = 0.0 + poll_interval = 0.1 # 100ms + + while elapsed < timeout: + if self.__worker.is_worker_ready(): + return True + time.sleep(poll_interval) + elapsed += poll_interval + + self._logger.warning( + f'WorkflowRuntime worker readiness check timed out after {timeout} seconds' + ) + return False + def start(self): - """Starts the listening for work items on a background thread.""" - self.__worker.start() + """Starts the listening for work items on a background thread. + This method waits for the worker's gRPC stream to be fully initialized + before returning, ensuring that workflows can be scheduled immediately + after start() completes. + """ + try: + try: + self.__worker.start() + except Exception as start_error: + self._logger.exception(f'WorkflowRuntime worker did not start: {start_error}') + raise + + # Verify the worker and its stream reader are ready + if hasattr(self.__worker, 'is_worker_ready'): + try: + is_ready = self.wait_for_worker_ready(timeout=self._worker_ready_timeout) + if not is_ready: + raise RuntimeError('WorkflowRuntime worker and its stream are not ready') + else: + self._logger.debug( + 'WorkflowRuntime worker is ready and its stream can receive work items' + ) + except Exception as ready_error: + self._logger.exception( + f'WorkflowRuntime wait_for_worker_ready() raised exception: {ready_error}' + ) + raise ready_error + else: + self._logger.warning( + 'Unable to verify stream readiness. Workflows scheduled immediately may not be received.' + ) + except Exception: + raise def shutdown(self): """Stops the listening for work items on a background thread.""" - self.__worker.stop() + try: + self.__worker.stop() + except Exception: + raise def versioned_workflow( self, diff --git a/ext/dapr-ext-workflow/setup.cfg b/ext/dapr-ext-workflow/setup.cfg index bd5a41536..f0f076bcd 100644 --- a/ext/dapr-ext-workflow/setup.cfg +++ b/ext/dapr-ext-workflow/setup.cfg @@ -25,7 +25,7 @@ packages = find_namespace: include_package_data = True install_requires = dapr >= 1.17.0 - durabletask-dapr >= 0.2.0a19 + durabletask-dapr >= 0.17.0 [options.packages.find] include = diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index bf18cd689..16eb4946f 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -26,11 +26,17 @@ class FakeTaskHubGrpcWorker: + def __init__(self): + self._orchestrator_fns = {} + self._activity_fns = {} + def add_named_orchestrator(self, name: str, fn): listOrchestrators.append(name) + self._orchestrator_fns[name] = fn def add_named_activity(self, name: str, fn): listActivities.append(name) + self._activity_fns[name] = fn class WorkflowRuntimeTest(unittest.TestCase): @@ -171,3 +177,124 @@ def test_decorator_register_optinal_name(self): wanted_activity = ['test_act'] assert listActivities == wanted_activity assert client_act._dapr_alternate_name == 'test_act' + + +class WorkflowRuntimeWorkerReadyTest(unittest.TestCase): + """Tests for wait_for_worker_ready() and start() stream readiness.""" + + def setUp(self): + listActivities.clear() + listOrchestrators.clear() + mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + self.runtime = WorkflowRuntime() + + def test_wait_for_worker_ready_returns_false_when_no_is_worker_ready(self): + mock_worker = mock.MagicMock(spec=['start', 'stop', '_registry']) + del mock_worker.is_worker_ready + self.runtime._WorkflowRuntime__worker = mock_worker + self.assertFalse(self.runtime.wait_for_worker_ready(timeout=0.1)) + + def test_wait_for_worker_ready_returns_true_when_ready(self): + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = True + self.runtime._WorkflowRuntime__worker = mock_worker + self.assertTrue(self.runtime.wait_for_worker_ready(timeout=1.0)) + mock_worker.is_worker_ready.assert_called() + + def test_wait_for_worker_ready_returns_true_after_poll(self): + """Worker becomes ready on second poll.""" + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.side_effect = [False, True] + self.runtime._WorkflowRuntime__worker = mock_worker + self.assertTrue(self.runtime.wait_for_worker_ready(timeout=1.0)) + self.assertEqual(mock_worker.is_worker_ready.call_count, 2) + + def test_wait_for_worker_ready_returns_false_on_timeout(self): + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = False + self.runtime._WorkflowRuntime__worker = mock_worker + self.assertFalse(self.runtime.wait_for_worker_ready(timeout=0.2)) + + def test_start_succeeds_when_worker_ready(self): + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = True + self.runtime._WorkflowRuntime__worker = mock_worker + self.runtime.start() + mock_worker.start.assert_called_once() + mock_worker.is_worker_ready.assert_called() + + def test_start_logs_debug_when_worker_stream_ready(self): + """start() logs at debug when worker and stream are ready.""" + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = True + self.runtime._WorkflowRuntime__worker = mock_worker + with mock.patch.object(self.runtime._logger, 'debug') as mock_debug: + self.runtime.start() + mock_debug.assert_called_once() + call_args = mock_debug.call_args[0][0] + self.assertIn('ready', call_args) + self.assertIn('stream', call_args) + + def test_start_logs_exception_when_worker_start_fails(self): + """start() logs exception when worker.start() raises.""" + mock_worker = mock.MagicMock() + mock_worker.start.side_effect = RuntimeError('start failed') + self.runtime._WorkflowRuntime__worker = mock_worker + with mock.patch.object(self.runtime._logger, 'exception') as mock_exception: + with self.assertRaises(RuntimeError): + self.runtime.start() + mock_exception.assert_called_once() + self.assertIn('did not start', mock_exception.call_args[0][0]) + + def test_start_raises_when_worker_not_ready(self): + listActivities.clear() + listOrchestrators.clear() + mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + runtime = WorkflowRuntime(worker_ready_timeout=0.2) + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = False + runtime._WorkflowRuntime__worker = mock_worker + with self.assertRaises(RuntimeError) as ctx: + runtime.start() + self.assertIn('not ready', str(ctx.exception)) + + def test_start_logs_warning_when_no_is_worker_ready(self): + mock_worker = mock.MagicMock(spec=['start', 'stop', '_registry']) + del mock_worker.is_worker_ready + self.runtime._WorkflowRuntime__worker = mock_worker + self.runtime.start() + mock_worker.start.assert_called_once() + + def test_worker_ready_timeout_init(self): + listActivities.clear() + listOrchestrators.clear() + mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + rt = WorkflowRuntime(worker_ready_timeout=15.0) + self.assertEqual(rt._worker_ready_timeout, 15.0) + + def test_start_raises_when_worker_start_fails(self): + mock_worker = mock.MagicMock() + mock_worker.is_worker_ready.return_value = True + mock_worker.start.side_effect = RuntimeError('start failed') + self.runtime._WorkflowRuntime__worker = mock_worker + with self.assertRaises(RuntimeError) as ctx: + self.runtime.start() + self.assertIn('start failed', str(ctx.exception)) + mock_worker.start.assert_called_once() + + def test_start_raises_when_wait_for_worker_ready_raises(self): + mock_worker = mock.MagicMock() + mock_worker.start.return_value = None + mock_worker.is_worker_ready.side_effect = ValueError('ready check failed') + self.runtime._WorkflowRuntime__worker = mock_worker + with self.assertRaises(ValueError) as ctx: + self.runtime.start() + self.assertIn('ready check failed', str(ctx.exception)) + + def test_shutdown_raises_when_worker_stop_fails(self): + mock_worker = mock.MagicMock() + mock_worker.stop.side_effect = RuntimeError('stop failed') + self.runtime._WorkflowRuntime__worker = mock_worker + with self.assertRaises(RuntimeError) as ctx: + self.runtime.shutdown() + self.assertIn('stop failed', str(ctx.exception)) diff --git a/tests/clients/fake_dapr_server.py b/tests/clients/fake_dapr_server.py index d56cf0790..2c3d9b685 100644 --- a/tests/clients/fake_dapr_server.py +++ b/tests/clients/fake_dapr_server.py @@ -636,6 +636,20 @@ def ConverseAlpha2(self, request, context): # Create result for this input result = api_v1.ConversationResultAlpha2(choices=choices) + if hasattr(result, 'model'): + result.model = 'test-llm' + if hasattr(result, 'usage'): + try: + usage_cls = getattr(api_v1, 'ConversationResultAlpha2CompletionUsage', None) + if usage_cls is not None: + u = usage_cls( + completion_tokens=10, + prompt_tokens=5, + total_tokens=15, + ) + result.usage.CopyFrom(u) + except Exception: + pass outputs.append(result) return api_v1.ConversationResponseAlpha2( diff --git a/tests/clients/test_conversation.py b/tests/clients/test_conversation.py index 50daebc64..105c7b291 100644 --- a/tests/clients/test_conversation.py +++ b/tests/clients/test_conversation.py @@ -17,7 +17,9 @@ import json import unittest import uuid +from unittest.mock import Mock, patch +from google.protobuf.struct_pb2 import Struct from google.rpc import code_pb2, status_pb2 from dapr.aio.clients import DaprClient as AsyncDaprClient @@ -37,12 +39,16 @@ ConversationResponseAlpha2, ConversationResultAlpha2, ConversationResultAlpha2Choices, + ConversationResultAlpha2CompletionUsage, + ConversationResultAlpha2CompletionUsageCompletionTokensDetails, + ConversationResultAlpha2CompletionUsagePromptTokensDetails, ConversationResultAlpha2Message, ConversationToolCalls, ConversationToolCallsOfFunction, ConversationTools, ConversationToolsFunction, FunctionBackend, + _get_outputs_from_grpc_response, create_assistant_message, create_system_message, create_tool_message, @@ -248,6 +254,14 @@ def test_basic_conversation_alpha2(self): self.assertEqual(choice.finish_reason, 'stop') self.assertIn('Hello Alpha2!', choice.message.content) + out = response.outputs[0] + if out.model is not None: + self.assertEqual(out.model, 'test-llm') + if out.usage is not None: + self.assertGreaterEqual(out.usage.total_tokens, 15) + self.assertGreaterEqual(out.usage.prompt_tokens, 5) + self.assertGreaterEqual(out.usage.completion_tokens, 10) + def test_conversation_alpha2_with_system_message(self): """Test Alpha2 conversation with system message.""" system_message = create_system_message('You are a helpful assistant.') @@ -1107,6 +1121,186 @@ def test_empty_and_none_outputs(self): self.assertEqual(response_none.to_assistant_messages(), []) +class TestConversationResultAlpha2ModelAndUsage(unittest.TestCase): + """Tests for model and usage fields on ConversationResultAlpha2 and related types.""" + + def test_result_alpha2_has_model_and_usage_attributes(self): + """ConversationResultAlpha2 accepts and exposes model and usage.""" + msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[]) + choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg) + usage = ConversationResultAlpha2CompletionUsage( + completion_tokens=10, + prompt_tokens=5, + total_tokens=15, + ) + result = ConversationResultAlpha2( + choices=[choice], + model='test-model-1', + usage=usage, + ) + self.assertEqual(result.model, 'test-model-1') + self.assertIsNotNone(result.usage) + self.assertEqual(result.usage.completion_tokens, 10) + self.assertEqual(result.usage.prompt_tokens, 5) + self.assertEqual(result.usage.total_tokens, 15) + + def test_result_alpha2_model_and_usage_default_none(self): + """ConversationResultAlpha2 optional fields default to None when not provided. + + When the API returns a response, model and usage are set from the conversation + component. This test only checks that the dataclass defaults are None when + constructing with choices only. + """ + msg = ConversationResultAlpha2Message(content='Hi', tool_calls=[]) + choice = ConversationResultAlpha2Choices(finish_reason='stop', index=0, message=msg) + result = ConversationResultAlpha2(choices=[choice]) + self.assertIsNone(result.model) + self.assertIsNone(result.usage) + + def test_usage_completion_and_prompt_details(self): + """ConversationResultAlpha2CompletionUsage supports details.""" + completion_details = ConversationResultAlpha2CompletionUsageCompletionTokensDetails( + accepted_prediction_tokens=1, + audio_tokens=2, + reasoning_tokens=3, + rejected_prediction_tokens=0, + ) + prompt_details = ConversationResultAlpha2CompletionUsagePromptTokensDetails( + audio_tokens=0, + cached_tokens=4, + ) + usage = ConversationResultAlpha2CompletionUsage( + completion_tokens=10, + prompt_tokens=5, + total_tokens=15, + completion_tokens_details=completion_details, + prompt_tokens_details=prompt_details, + ) + self.assertEqual(usage.completion_tokens_details.accepted_prediction_tokens, 1) + self.assertEqual(usage.completion_tokens_details.audio_tokens, 2) + self.assertEqual(usage.completion_tokens_details.reasoning_tokens, 3) + self.assertEqual(usage.completion_tokens_details.rejected_prediction_tokens, 0) + self.assertEqual(usage.prompt_tokens_details.audio_tokens, 0) + self.assertEqual(usage.prompt_tokens_details.cached_tokens, 4) + self.assertEqual(usage.total_tokens, 15) + self.assertEqual(usage.completion_tokens, 10) + self.assertEqual(usage.prompt_tokens, 5) + + def test_get_outputs_from_grpc_response_populates_model_and_usage(self): + """_get_outputs_from_grpc_response sets model and usage when present on proto.""" + from unittest import mock + + # Build a mock proto response with one output that has model and usage + mock_usage = mock.Mock() + mock_usage.completion_tokens = 20 + mock_usage.prompt_tokens = 8 + mock_usage.total_tokens = 28 + mock_usage.completion_tokens_details = None + mock_usage.prompt_tokens_details = None + + mock_choice_msg = mock.Mock() + mock_choice_msg.content = 'Hello' + mock_choice_msg.tool_calls = [] + + mock_choice = mock.Mock() + mock_choice.finish_reason = 'stop' + mock_choice.index = 0 + mock_choice.message = mock_choice_msg + + mock_output = mock.Mock() + mock_output.model = 'gpt-4o-mini' + mock_output.usage = mock_usage + mock_output.choices = [mock_choice] + + mock_response = mock.Mock() + mock_response.outputs = [mock_output] + + outputs = _get_outputs_from_grpc_response(mock_response) + self.assertEqual(len(outputs), 1) + out = outputs[0] + self.assertEqual(out.model, 'gpt-4o-mini') + self.assertIsNotNone(out.usage) + self.assertEqual(out.usage.completion_tokens, 20) + self.assertEqual(out.usage.prompt_tokens, 8) + self.assertEqual(out.usage.total_tokens, 28) + self.assertEqual(len(out.choices), 1) + self.assertEqual(out.choices[0].message.content, 'Hello') + + def test_get_outputs_from_grpc_response_without_model_usage(self): + """_get_outputs_from_grpc_response leaves model and usage None when absent.""" + from unittest import mock + + mock_choice_msg = mock.Mock() + mock_choice_msg.content = 'Echo' + mock_choice_msg.tool_calls = [] + + mock_choice = mock.Mock() + mock_choice.finish_reason = 'stop' + mock_choice.index = 0 + mock_choice.message = mock_choice_msg + + mock_output = mock.Mock(spec=['choices']) + mock_output.choices = [mock_choice] + # No model or usage attributes + + mock_response = mock.Mock() + mock_response.outputs = [mock_output] + + outputs = _get_outputs_from_grpc_response(mock_response) + self.assertEqual(len(outputs), 1) + out = outputs[0] + self.assertIsNone(out.model) + self.assertIsNone(out.usage) + self.assertEqual(out.choices[0].message.content, 'Echo') + + +class ConverseAlpha2ResponseFormatTests(unittest.TestCase): + """Unit tests for converse_alpha2 response_format parameter.""" + + def test_converse_alpha2_passes_response_format_on_request(self): + """converse_alpha2 sets response_format on the gRPC request when provided.""" + user_message = create_user_message('Structured output please') + input_alpha2 = ConversationInputAlpha2(messages=[user_message]) + response_format = Struct() + response_format.update( + {'type': 'json_schema', 'json_schema': {'name': 'test', 'schema': {}}} + ) + + captured_requests = [] + mock_choice_msg = Mock() + mock_choice_msg.content = 'ok' + mock_choice_msg.tool_calls = [] + mock_choice = Mock() + mock_choice.finish_reason = 'stop' + mock_choice.index = 0 + mock_choice.message = mock_choice_msg + mock_output = Mock() + mock_output.choices = [mock_choice] + mock_response = Mock() + mock_response.outputs = [mock_output] + mock_response.context_id = '' + mock_call = Mock() + + def capture_run_rpc(rpc, request, *args, **kwargs): + captured_requests.append(request) + return (mock_response, mock_call) + + with patch('dapr.clients.health.DaprHealth.wait_for_sidecar'): + client = DaprClient('localhost:50011') + with patch.object(client.retry_policy, 'run_rpc', side_effect=capture_run_rpc): + client.converse_alpha2( + name='test-llm', + inputs=[input_alpha2], + response_format=response_format, + ) + + self.assertEqual(len(captured_requests), 1) + req = captured_requests[0] + self.assertTrue(hasattr(req, 'response_format')) + self.assertEqual(req.response_format['type'], 'json_schema') + self.assertEqual(req.response_format['json_schema']['name'], 'test') + + class ExecuteRegisteredToolSyncTests(unittest.TestCase): def tearDown(self): # Cleanup all tools we may have registered by name prefix