diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index fc54bda1..009001e5 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -791,6 +791,50 @@ async def suspend(self, instance_id: str, reason: str) -> None: if error_message: raise Exception(error_message) + async def restart(self, instance_id: str, + restart_with_new_instance_id: bool = True) -> str: + """Restart an orchestration instance with its original input. + + Parameters + ---------- + instance_id : str + The ID of the orchestration instance to restart. + restart_with_new_instance_id : bool + If True, the restarted instance will use a new instance ID. + If False, the restarted instance will reuse the original instance ID. + + Raises + ------ + Exception: + When the instance with the given ID is not found. + + Returns + ------- + str + The instance ID of the restarted orchestration. + """ + restart_with_new_instance_id_str = str(restart_with_new_instance_id).lower() + request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \ + f"restart?restartWithNewInstanceId={restart_with_new_instance_id_str}" + response = await self._post_async_request( + request_url, + None, + function_invocation_id=self._function_invocation_id) + switch_statement = { + 202: lambda: None, # instance is restarted + 410: lambda: None, # instance completed + 404: lambda: f"No instance with ID '{instance_id}' found.", + } + + has_error_message = switch_statement.get( + response[0], + lambda: f"The operation failed with an unexpected status code {response[0]}") + error_message = has_error_message() + if error_message: + raise Exception(error_message) + + return response[1] if response[1] else instance_id + async def resume(self, instance_id: str, reason: str) -> None: """Resume the specified orchestration instance. diff --git a/tests/conftest.py b/tests/conftest.py index ca65ee23..68a8f683 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,8 @@ def get_binding_string(): "resumePostUri": f"{BASE_URL}/instances/INSTANCEID/resume?reason=" "{text}&taskHub=" f"{TASK_HUB_NAME}&connection=Storage&code={AUTH_CODE}", + "restartPostUri": f"{BASE_URL}/instances/INSTANCEID/restart?taskHub=" + f"{TASK_HUB_NAME}&connection=Storage&code={AUTH_CODE}", }, "rpcBaseUrl": RPC_BASE_URL } diff --git a/tests/models/test_DurableOrchestrationClient.py b/tests/models/test_DurableOrchestrationClient.py index 7a80d461..1466587c 100644 --- a/tests/models/test_DurableOrchestrationClient.py +++ b/tests/models/test_DurableOrchestrationClient.py @@ -157,7 +157,11 @@ def test_create_check_status_response(binding_string): "resumePostUri": r"http://test_azure.net/runtime/webhooks/durabletask/instances/" r"2e2568e7-a906-43bd-8364-c81733c5891e/resume" - r"?reason={text}&taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE" + r"?reason={text}&taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE", + "restartPostUri": + r"http://test_azure.net/runtime/webhooks/durabletask/instances/" + r"2e2568e7-a906-43bd-8364-c81733c5891e/restart" + r"?taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE" } for key, _ in http_management_payload.items(): http_management_payload[key] = replace_stand_in_bits(http_management_payload[key]) @@ -742,6 +746,49 @@ async def test_post_500_resume(binding_string): await client.resume(TEST_INSTANCE_ID, raw_reason) +@pytest.mark.asyncio +async def test_restart_with_new_instance_id(binding_string): + """Test restart calls the HTTP restart endpoint with restartWithNewInstanceId=true.""" + new_instance_id = "new-instance-id-1234" + + post_mock = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}/restart?restartWithNewInstanceId=true", + response=[202, new_instance_id]) + + client = DurableOrchestrationClient(binding_string) + client._post_async_request = post_mock.post + + result = await client.restart(TEST_INSTANCE_ID) + assert result == new_instance_id + + +@pytest.mark.asyncio +async def test_restart_with_same_instance_id(binding_string): + """Test restart calls the HTTP restart endpoint with restartWithNewInstanceId=false.""" + post_mock = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}/restart?restartWithNewInstanceId=false", + response=[202, TEST_INSTANCE_ID]) + + client = DurableOrchestrationClient(binding_string) + client._post_async_request = post_mock.post + + result = await client.restart(TEST_INSTANCE_ID, restart_with_new_instance_id=False) + assert result == TEST_INSTANCE_ID + + +@pytest.mark.asyncio +async def test_restart_instance_not_found(binding_string): + """Test restart raises exception when instance is not found.""" + post_mock = MockRequest( + expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}/restart?restartWithNewInstanceId=true", + response=[404, None]) + + client = DurableOrchestrationClient(binding_string) + client._post_async_request = post_mock.post + + with pytest.raises(Exception) as ex: + await client.restart(TEST_INSTANCE_ID) + assert f"No instance with ID '{TEST_INSTANCE_ID}' found." in str(ex.value) # Tests for function_invocation_id parameter def test_client_stores_function_invocation_id(binding_string): """Test that the client stores the function_invocation_id parameter."""