Skip to content

Commit ff13570

Browse files
committed
Merge branch 'main' into andystaples/add-distributed-tracing
2 parents 4f40935 + ce7c524 commit ff13570

File tree

14 files changed

+1691
-138
lines changed

14 files changed

+1691
-138
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
ADDED
1111

1212
- Added `durabletask.testing` module with `InMemoryOrchestrationBackend` for testing orchestrations without a sidecar process
13+
- Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio`
14+
- Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors
15+
- Added `get_async_grpc_channel` helper for creating async gRPC channels
1316
- Improved distributed tracing support with full span coverage for orchestrations, activities, sub-orchestrations, timers, and events
1417

18+
CHANGED
19+
20+
- Refactored `TaskHubGrpcClient` to share request-building and validation logic
21+
with `AsyncTaskHubGrpcClient` via module-level helper functions
22+
1523
FIXED:
1624

1725
- Fix unbound variable in entity V1 processing

durabletask-azuremanaged/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## Unreleased
9+
10+
- Added `AsyncDurableTaskSchedulerClient` for async/await usage with `grpc.aio`
11+
- Added `DTSAsyncDefaultClientInterceptorImpl` async gRPC interceptor for DTS authentication
12+
813
## v1.3.0
914

1015
- Updates base dependency to durabletask v1.3.0

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from typing import Optional
77

88
from azure.core.credentials import TokenCredential
9+
from azure.core.credentials_async import AsyncTokenCredential
910

1011
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
12+
DTSAsyncDefaultClientInterceptorImpl,
1113
DTSDefaultClientInterceptorImpl,
1214
)
13-
from durabletask.client import TaskHubGrpcClient
15+
from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient
1416

1517

1618
# Client class used for Durable Task Scheduler (DTS)
@@ -39,3 +41,65 @@ def __init__(self, *,
3941
log_formatter=log_formatter,
4042
interceptors=interceptors,
4143
default_version=default_version)
44+
45+
46+
# Async client class used for Durable Task Scheduler (DTS)
47+
class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient):
48+
"""An async client implementation for Azure Durable Task Scheduler (DTS).
49+
50+
This class extends AsyncTaskHubGrpcClient to provide integration with Azure's
51+
Durable Task Scheduler service using async gRPC. It handles authentication via
52+
Azure credentials and configures the necessary gRPC interceptors for DTS
53+
communication.
54+
55+
Args:
56+
host_address (str): The gRPC endpoint address of the DTS service.
57+
taskhub (str): The name of the task hub. Cannot be empty.
58+
token_credential (Optional[TokenCredential]): Azure credential for authentication.
59+
If None, anonymous authentication will be used.
60+
secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS).
61+
Defaults to True.
62+
default_version (Optional[str], optional): Default version string for orchestrations.
63+
log_handler (Optional[logging.Handler], optional): Custom logging handler for client logs.
64+
log_formatter (Optional[logging.Formatter], optional): Custom log formatter for client logs.
65+
66+
Raises:
67+
ValueError: If taskhub is empty or None.
68+
69+
Example:
70+
>>> from azure.identity.aio import DefaultAzureCredential
71+
>>> from durabletask.azuremanaged import AsyncDurableTaskSchedulerClient
72+
>>>
73+
>>> credential = DefaultAzureCredential()
74+
>>> async with AsyncDurableTaskSchedulerClient(
75+
... host_address="my-dts-service.azure.com:443",
76+
... taskhub="my-task-hub",
77+
... token_credential=credential
78+
... ) as client:
79+
... instance_id = await client.schedule_new_orchestration("my_orchestrator")
80+
"""
81+
82+
def __init__(self, *,
83+
host_address: str,
84+
taskhub: str,
85+
token_credential: Optional[AsyncTokenCredential],
86+
secure_channel: bool = True,
87+
default_version: Optional[str] = None,
88+
log_handler: Optional[logging.Handler] = None,
89+
log_formatter: Optional[logging.Formatter] = None):
90+
91+
if not taskhub:
92+
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
93+
94+
interceptors = [DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub)]
95+
96+
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
97+
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
98+
super().__init__(
99+
host_address=host_address,
100+
secure_channel=secure_channel,
101+
metadata=None,
102+
log_handler=log_handler,
103+
log_formatter=log_formatter,
104+
interceptors=interceptors,
105+
default_version=default_version)

durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
from azure.core.credentials import AccessToken, TokenCredential
7+
from azure.core.credentials_async import AsyncTokenCredential
78

89
import durabletask.internal.shared as shared
910

@@ -47,3 +48,40 @@ def refresh_token(self):
4748
# Convert UNIX timestamp to timezone-aware datetime
4849
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
4950
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")
51+
52+
53+
class AsyncAccessTokenManager:
54+
"""Async version of AccessTokenManager that uses AsyncTokenCredential.
55+
56+
This avoids blocking the event loop when acquiring or refreshing tokens."""
57+
58+
_token: Optional[AccessToken]
59+
60+
def __init__(self, token_credential: Optional[AsyncTokenCredential],
61+
refresh_interval_seconds: int = 600):
62+
self._scope = "https://durabletask.io/.default"
63+
self._refresh_interval_seconds = refresh_interval_seconds
64+
self._logger = shared.get_logger("async_token_manager")
65+
66+
self._credential = token_credential
67+
self._token = None
68+
self.expiry_time = None
69+
70+
async def get_access_token(self) -> Optional[AccessToken]:
71+
if self._token is None or self.is_token_expired():
72+
await self.refresh_token()
73+
return self._token
74+
75+
def is_token_expired(self) -> bool:
76+
if self.expiry_time is None:
77+
return True
78+
return datetime.now(timezone.utc) >= (
79+
self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))
80+
81+
async def refresh_token(self):
82+
if self._credential is not None:
83+
self._token = await self._credential.get_token(self._scope)
84+
85+
# Convert UNIX timestamp to timezone-aware datetime
86+
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
87+
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
import grpc
88
from azure.core.credentials import TokenCredential
9+
from azure.core.credentials_async import AsyncTokenCredential
910

10-
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
11+
from durabletask.azuremanaged.internal.access_token_manager import (
12+
AccessTokenManager,
13+
AsyncAccessTokenManager,
14+
)
1115
from durabletask.internal.grpc_interceptor import (
16+
DefaultAsyncClientInterceptorImpl,
1217
DefaultClientInterceptorImpl,
18+
_AsyncClientCallDetails,
1319
_ClientCallDetails,
1420
)
1521

@@ -32,6 +38,7 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
3238
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
3339
super().__init__(self._metadata)
3440

41+
self._token_manager = None
3542
if token_credential is not None:
3643
self._token_credential = token_credential
3744
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
@@ -43,12 +50,72 @@ def _intercept_call(
4350
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
4451
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
4552
call details."""
46-
# Refresh the auth token if it is present and needed
47-
if self._metadata is not None:
48-
for i, (key, _) in enumerate(self._metadata):
49-
if key.lower() == "authorization": # Ensure case-insensitive comparison
50-
new_token = self._token_manager.get_access_token() # Get the new token
51-
if new_token is not None:
52-
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
53+
# Refresh the auth token if a credential was provided. The call to
54+
# get_access_token() is generally cheap, checking the expiry time and returning
55+
# the cached value without a network call when still valid.
56+
if self._token_manager is not None:
57+
access_token = self._token_manager.get_access_token()
58+
if access_token is not None:
59+
# Update the existing authorization header
60+
found = False
61+
for i, (key, _) in enumerate(self._metadata):
62+
if key.lower() == "authorization":
63+
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
64+
found = True
65+
break
66+
if not found:
67+
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
5368

5469
return super()._intercept_call(client_call_details)
70+
71+
72+
class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
73+
"""Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels.
74+
75+
This class implements async gRPC interceptors to add DTS-specific headers
76+
(task hub name, user agent, and authentication token) to all async calls."""
77+
78+
def __init__(self, token_credential: Optional[AsyncTokenCredential], taskhub_name: str):
79+
try:
80+
# Get the version of the azuremanaged package
81+
sdk_version = version('durabletask-azuremanaged')
82+
except Exception:
83+
# Fallback if version cannot be determined
84+
sdk_version = "unknown"
85+
user_agent = f"durabletask-python/{sdk_version}"
86+
self._metadata = [
87+
("taskhub", taskhub_name),
88+
("x-user-agent", user_agent)]
89+
super().__init__(self._metadata)
90+
91+
# Token acquisition is deferred to the first _intercept_call invocation
92+
# rather than happening in __init__, because get_token() on an
93+
# AsyncTokenCredential is async and cannot be awaited in a constructor.
94+
self._token_manager = None
95+
if token_credential is not None:
96+
self._token_credential = token_credential
97+
self._token_manager = AsyncAccessTokenManager(token_credential=self._token_credential)
98+
99+
async def _intercept_call(
100+
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
101+
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
102+
call details."""
103+
# Refresh the auth token if a credential was provided. The call to
104+
# get_access_token() is generally cheap, checking the expiry time and returning
105+
# the cached value without a network call when still valid.
106+
if self._token_manager is not None:
107+
access_token = await self._token_manager.get_access_token()
108+
if access_token is not None:
109+
# Update the existing authorization header, or append one if this
110+
# is the first successful token acquisition (token is lazily
111+
# fetched on the first call since async constructors aren't possible).
112+
found = False
113+
for i, (key, _) in enumerate(self._metadata):
114+
if key.lower() == "authorization":
115+
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
116+
found = True
117+
break
118+
if not found:
119+
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
120+
121+
return await super()._intercept_call(client_call_details)

0 commit comments

Comments
 (0)