Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
ADDED

- Added `durabletask.testing` module with `InMemoryOrchestrationBackend` for testing orchestrations without a sidecar process
- Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio`
- Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors
- Added `get_async_grpc_channel` helper for creating async gRPC channels

CHANGED

- Refactored `TaskHubGrpcClient` to share request-building and validation logic
with `AsyncTaskHubGrpcClient` via module-level helper functions

FIXED:

Expand Down
5 changes: 5 additions & 0 deletions durabletask-azuremanaged/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

- Added `AsyncDurableTaskSchedulerClient` for async/await usage with `grpc.aio`
- Added `DTSAsyncDefaultClientInterceptorImpl` async gRPC interceptor for DTS authentication

## v1.3.0

- Updates base dependency to durabletask v1.3.0
Expand Down
66 changes: 65 additions & 1 deletion durabletask-azuremanaged/durabletask/azuremanaged/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from typing import Optional

from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
DTSAsyncDefaultClientInterceptorImpl,
DTSDefaultClientInterceptorImpl,
)
from durabletask.client import TaskHubGrpcClient
from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient


# Client class used for Durable Task Scheduler (DTS)
Expand Down Expand Up @@ -39,3 +41,65 @@ def __init__(self, *,
log_formatter=log_formatter,
interceptors=interceptors,
default_version=default_version)


# Async client class used for Durable Task Scheduler (DTS)
class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient):
"""An async client implementation for Azure Durable Task Scheduler (DTS).

This class extends AsyncTaskHubGrpcClient to provide integration with Azure's
Durable Task Scheduler service using async gRPC. It handles authentication via
Azure credentials and configures the necessary gRPC interceptors for DTS
communication.

Args:
host_address (str): The gRPC endpoint address of the DTS service.
taskhub (str): The name of the task hub. Cannot be empty.
token_credential (Optional[TokenCredential]): Azure credential for authentication.
If None, anonymous authentication will be used.
secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
Defaults to True.
default_version (Optional[str], optional): Default version string for orchestrations.
log_handler (Optional[logging.Handler], optional): Custom logging handler for client logs.
log_formatter (Optional[logging.Formatter], optional): Custom log formatter for client logs.

Raises:
ValueError: If taskhub is empty or None.

Example:
>>> from azure.identity.aio import DefaultAzureCredential
>>> from durabletask.azuremanaged import AsyncDurableTaskSchedulerClient
>>>
>>> credential = DefaultAzureCredential()
>>> async with AsyncDurableTaskSchedulerClient(
... host_address="my-dts-service.azure.com:443",
... taskhub="my-task-hub",
... token_credential=credential
... ) as client:
... instance_id = await client.schedule_new_orchestration("my_orchestrator")
Comment on lines +74 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I asked Copilot about async python stuff and it is adamant that in order to do this, you need to implement __aenter__ and __aexit__. If you don't you'll get an attribute error. Is this real?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, you learn something every day. Added, and updated the tests to use async with to avoid leaking the channel.

"""

def __init__(self, *,
host_address: str,
taskhub: str,
token_credential: Optional[AsyncTokenCredential],
secure_channel: bool = True,
default_version: Optional[str] = None,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None):

if not taskhub:
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")

interceptors = [DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub)]

# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
super().__init__(
host_address=host_address,
secure_channel=secure_channel,
metadata=None,
log_handler=log_handler,
log_formatter=log_formatter,
interceptors=interceptors,
default_version=default_version)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

import durabletask.internal.shared as shared

Expand Down Expand Up @@ -47,3 +48,40 @@ def refresh_token(self):
# Convert UNIX timestamp to timezone-aware datetime
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")


class AsyncAccessTokenManager:
"""Async version of AccessTokenManager that uses AsyncTokenCredential.

This avoids blocking the event loop when acquiring or refreshing tokens."""

_token: Optional[AccessToken]

def __init__(self, token_credential: Optional[AsyncTokenCredential],
refresh_interval_seconds: int = 600):
self._scope = "https://durabletask.io/.default"
self._refresh_interval_seconds = refresh_interval_seconds
self._logger = shared.get_logger("async_token_manager")

self._credential = token_credential
self._token = None
self.expiry_time = None

async def get_access_token(self) -> Optional[AccessToken]:
if self._token is None or self.is_token_expired():
await self.refresh_token()
return self._token

def is_token_expired(self) -> bool:
if self.expiry_time is None:
return True
return datetime.now(timezone.utc) >= (
self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))

async def refresh_token(self):
if self._credential is not None:
self._token = await self._credential.get_token(self._scope)

# Convert UNIX timestamp to timezone-aware datetime
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

import grpc
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
from durabletask.azuremanaged.internal.access_token_manager import (
AccessTokenManager,
AsyncAccessTokenManager,
)
from durabletask.internal.grpc_interceptor import (
DefaultAsyncClientInterceptorImpl,
DefaultClientInterceptorImpl,
_AsyncClientCallDetails,
_ClientCallDetails,
)

Expand All @@ -32,6 +38,7 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
super().__init__(self._metadata)

self._token_manager = None
if token_credential is not None:
self._token_credential = token_credential
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
Expand All @@ -43,12 +50,72 @@ def _intercept_call(
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
call details."""
# Refresh the auth token if it is present and needed
if self._metadata is not None:
for i, (key, _) in enumerate(self._metadata):
if key.lower() == "authorization": # Ensure case-insensitive comparison
new_token = self._token_manager.get_access_token() # Get the new token
if new_token is not None:
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
# Refresh the auth token if a credential was provided. The call to
# get_access_token() is generally cheap, checking the expiry time and returning
# the cached value without a network call when still valid.
if self._token_manager is not None:
access_token = self._token_manager.get_access_token()
if access_token is not None:
# Update the existing authorization header
found = False
for i, (key, _) in enumerate(self._metadata):
if key.lower() == "authorization":
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
found = True
break
if not found:
self._metadata.append(("authorization", f"Bearer {access_token.token}"))

return super()._intercept_call(client_call_details)


class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
"""Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels.

This class implements async gRPC interceptors to add DTS-specific headers
(task hub name, user agent, and authentication token) to all async calls."""

def __init__(self, token_credential: Optional[AsyncTokenCredential], taskhub_name: str):
try:
# Get the version of the azuremanaged package
sdk_version = version('durabletask-azuremanaged')
except Exception:
# Fallback if version cannot be determined
sdk_version = "unknown"
user_agent = f"durabletask-python/{sdk_version}"
self._metadata = [
("taskhub", taskhub_name),
("x-user-agent", user_agent)]
super().__init__(self._metadata)

# Token acquisition is deferred to the first _intercept_call invocation
# rather than happening in __init__, because get_token() on an
# AsyncTokenCredential is async and cannot be awaited in a constructor.
self._token_manager = None
if token_credential is not None:
self._token_credential = token_credential
self._token_manager = AsyncAccessTokenManager(token_credential=self._token_credential)

async def _intercept_call(
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
call details."""
# Refresh the auth token if a credential was provided. The call to
# get_access_token() is generally cheap, checking the expiry time and returning
# the cached value without a network call when still valid.
if self._token_manager is not None:
access_token = await self._token_manager.get_access_token()
if access_token is not None:
# Update the existing authorization header, or append one if this
# is the first successful token acquisition (token is lazily
# fetched on the first call since async constructors aren't possible).
found = False
for i, (key, _) in enumerate(self._metadata):
if key.lower() == "authorization":
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
found = True
break
if not found:
self._metadata.append(("authorization", f"Bearer {access_token.token}"))

return await super()._intercept_call(client_call_details)
Loading