66
77import grpc
88from azure .core .credentials import TokenCredential
9+ from azure .core .credentials_async import AsyncTokenCredential
910
10- from durabletask .azuremanaged .internal .access_token_manager import AccessTokenManager
11+ from durabletask .azuremanaged .internal .access_token_manager import (
12+ AccessTokenManager ,
13+ AsyncAccessTokenManager ,
14+ )
1115from durabletask .internal .grpc_interceptor import (
16+ DefaultAsyncClientInterceptorImpl ,
1217 DefaultClientInterceptorImpl ,
18+ _AsyncClientCallDetails ,
1319 _ClientCallDetails ,
1420)
1521
@@ -32,6 +38,7 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
3238 ("x-user-agent" , user_agent )] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
3339 super ().__init__ (self ._metadata )
3440
41+ self ._token_manager = None
3542 if token_credential is not None :
3643 self ._token_credential = token_credential
3744 self ._token_manager = AccessTokenManager (token_credential = self ._token_credential )
@@ -43,12 +50,72 @@ def _intercept_call(
4350 self , client_call_details : _ClientCallDetails ) -> grpc .ClientCallDetails :
4451 """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
4552 call details."""
46- # Refresh the auth token if it is present and needed
47- if self ._metadata is not None :
48- for i , (key , _ ) in enumerate (self ._metadata ):
49- if key .lower () == "authorization" : # Ensure case-insensitive comparison
50- new_token = self ._token_manager .get_access_token () # Get the new token
51- if new_token is not None :
52- self ._metadata [i ] = ("authorization" , f"Bearer { new_token .token } " ) # Update the token
53+ # Refresh the auth token if a credential was provided. The call to
54+ # get_access_token() is generally cheap, checking the expiry time and returning
55+ # the cached value without a network call when still valid.
56+ if self ._token_manager is not None :
57+ access_token = self ._token_manager .get_access_token ()
58+ if access_token is not None :
59+ # Update the existing authorization header
60+ found = False
61+ for i , (key , _ ) in enumerate (self ._metadata ):
62+ if key .lower () == "authorization" :
63+ self ._metadata [i ] = ("authorization" , f"Bearer { access_token .token } " )
64+ found = True
65+ break
66+ if not found :
67+ self ._metadata .append (("authorization" , f"Bearer { access_token .token } " ))
5368
5469 return super ()._intercept_call (client_call_details )
70+
71+
72+ class DTSAsyncDefaultClientInterceptorImpl (DefaultAsyncClientInterceptorImpl ):
73+ """Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels.
74+
75+ This class implements async gRPC interceptors to add DTS-specific headers
76+ (task hub name, user agent, and authentication token) to all async calls."""
77+
78+ def __init__ (self , token_credential : Optional [AsyncTokenCredential ], taskhub_name : str ):
79+ try :
80+ # Get the version of the azuremanaged package
81+ sdk_version = version ('durabletask-azuremanaged' )
82+ except Exception :
83+ # Fallback if version cannot be determined
84+ sdk_version = "unknown"
85+ user_agent = f"durabletask-python/{ sdk_version } "
86+ self ._metadata = [
87+ ("taskhub" , taskhub_name ),
88+ ("x-user-agent" , user_agent )]
89+ super ().__init__ (self ._metadata )
90+
91+ # Token acquisition is deferred to the first _intercept_call invocation
92+ # rather than happening in __init__, because get_token() on an
93+ # AsyncTokenCredential is async and cannot be awaited in a constructor.
94+ self ._token_manager = None
95+ if token_credential is not None :
96+ self ._token_credential = token_credential
97+ self ._token_manager = AsyncAccessTokenManager (token_credential = self ._token_credential )
98+
99+ async def _intercept_call (
100+ self , client_call_details : _AsyncClientCallDetails ) -> grpc .aio .ClientCallDetails :
101+ """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
102+ call details."""
103+ # Refresh the auth token if a credential was provided. The call to
104+ # get_access_token() is generally cheap, checking the expiry time and returning
105+ # the cached value without a network call when still valid.
106+ if self ._token_manager is not None :
107+ access_token = await self ._token_manager .get_access_token ()
108+ if access_token is not None :
109+ # Update the existing authorization header, or append one if this
110+ # is the first successful token acquisition (token is lazily
111+ # fetched on the first call since async constructors aren't possible).
112+ found = False
113+ for i , (key , _ ) in enumerate (self ._metadata ):
114+ if key .lower () == "authorization" :
115+ self ._metadata [i ] = ("authorization" , f"Bearer { access_token .token } " )
116+ found = True
117+ break
118+ if not found :
119+ self ._metadata .append (("authorization" , f"Bearer { access_token .token } " ))
120+
121+ return await super ()._intercept_call (client_call_details )
0 commit comments