Skip to content

Commit eed8524

Browse files
committed
[Identity] Allow policy override
Similar to SDK clients, this allows credential pipelines to be customized at the policy level. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 952fb30 commit eed8524

6 files changed

Lines changed: 173 additions & 18 deletions

File tree

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Credential HTTP pipeline policies can now be overridden via the `headers_policy`, `logging_policy`, `http_logging_policy`, `proxy_policy`, `user_agent_policy`, and `custom_hook_policy` keyword arguments when constructing credentials. The `per_retry_policies` and `per_call_policies` are also now supported. This allows users to inject custom policies or override settings of built-in policies.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/identity/azure-identity/azure/identity/_credentials/imds.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def __init__(self, **kwargs: Any) -> None:
8484
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
8585
# the credential probes only if it's part of a ChainedTokenCredential chain.
8686
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
87-
super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs))
87+
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
88+
retry_policy = merged_kwargs.pop("retry_policy", None) or ImdsRetryPolicy(**merged_kwargs)
89+
super().__init__(retry_policy=retry_policy, **merged_kwargs)
8890
self._config = kwargs
8991

9092
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:

sdk/identity/azure-identity/azure/identity/_internal/pipeline.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,37 @@ def _get_config(**kwargs) -> Configuration:
2727
:rtype: ~azure.core.configuration.Configuration
2828
"""
2929
config: Configuration = Configuration(**kwargs)
30-
config.custom_hook_policy = CustomHookPolicy(**kwargs)
31-
config.headers_policy = HeadersPolicy(**kwargs)
32-
config.http_logging_policy = HttpLoggingPolicy(**kwargs)
33-
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
34-
config.proxy_policy = ProxyPolicy(**kwargs)
35-
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
30+
config.custom_hook_policy = kwargs.get("custom_hook_policy") or CustomHookPolicy(**kwargs)
31+
config.headers_policy = kwargs.get("headers_policy") or HeadersPolicy(**kwargs)
32+
config.http_logging_policy = kwargs.get("http_logging_policy") or HttpLoggingPolicy(**kwargs)
33+
config.logging_policy = kwargs.get("logging_policy") or NetworkTraceLoggingPolicy(**kwargs)
34+
config.proxy_policy = kwargs.get("proxy_policy") or ProxyPolicy(**kwargs)
35+
config.user_agent_policy = kwargs.get("user_agent_policy") or UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
3636
return config
3737

3838

39-
def _get_policies(config, _per_retry_policies=None, **kwargs):
39+
def _get_policies(config, **kwargs):
40+
per_call_policies = kwargs.get("per_call_policies", [])
41+
per_retry_policies = kwargs.get("per_retry_policies", [])
42+
4043
policies = [
4144
config.headers_policy,
4245
config.user_agent_policy,
4346
config.proxy_policy,
4447
ContentDecodePolicy(**kwargs),
45-
config.retry_policy,
4648
]
4749

48-
if _per_retry_policies:
49-
policies.extend(_per_retry_policies)
50+
if isinstance(per_call_policies, list):
51+
policies.extend(per_call_policies)
52+
else:
53+
policies.append(per_call_policies)
54+
55+
policies.append(config.retry_policy)
56+
57+
if isinstance(per_retry_policies, list):
58+
policies.extend(per_retry_policies)
59+
else:
60+
policies.append(per_retry_policies)
5061

5162
policies.extend(
5263
[
@@ -63,8 +74,7 @@ def _get_policies(config, _per_retry_policies=None, **kwargs):
6374
def build_pipeline(transport=None, policies=None, **kwargs):
6475
if not policies:
6576
config = _get_config(**kwargs)
66-
retry_policy_class = kwargs.pop("retry_policy_class", None)
67-
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else RetryPolicy(**kwargs)
77+
config.retry_policy = kwargs.pop("retry_policy", None) or RetryPolicy(**kwargs)
6878
policies = _get_policies(config, **kwargs)
6979
if not transport:
7080
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module
@@ -83,8 +93,7 @@ def build_async_pipeline(transport=None, policies=None, **kwargs):
8393
from azure.core.pipeline.policies import AsyncRetryPolicy
8494

8595
config = _get_config(**kwargs)
86-
retry_policy_class = kwargs.pop("retry_policy_class", None)
87-
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else AsyncRetryPolicy(**kwargs)
96+
config.retry_policy = kwargs.pop("retry_policy", None) or AsyncRetryPolicy(**kwargs)
8897
policies = _get_policies(config, **kwargs)
8998
if not transport:
9099
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module

sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]:
2020
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
2121
if url and imds:
2222
return AsyncManagedIdentityClient(
23-
_per_retry_policies=[ArcChallengeAuthPolicy()],
23+
per_retry_policies=[ArcChallengeAuthPolicy()],
2424
request_factory=functools.partial(_get_request, url),
2525
**kwargs,
2626
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def __init__(self, **kwargs: Any) -> None:
4949
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
5050
# the credential probes only if it's part of a ChainedTokenCredential chain.
5151
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
52-
kwargs["retry_policy_class"] = AsyncImdsRetryPolicy
53-
self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs))
52+
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
53+
retry_policy = merged_kwargs.pop("retry_policy", None) or AsyncImdsRetryPolicy(**merged_kwargs)
54+
self._client = AsyncManagedIdentityClient(_get_request, retry_policy=retry_policy, **merged_kwargs)
5455
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
5556
self._endpoint_available: Optional[bool] = True
5657
else:
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
"""Tests for policy override support in azure-identity pipelines."""
6+
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from azure.core.pipeline.policies import (
12+
ContentDecodePolicy,
13+
CustomHookPolicy,
14+
DistributedTracingPolicy,
15+
HeadersPolicy,
16+
HttpLoggingPolicy,
17+
NetworkTraceLoggingPolicy,
18+
ProxyPolicy,
19+
RetryPolicy,
20+
SansIOHTTPPolicy,
21+
UserAgentPolicy,
22+
)
23+
24+
from azure.identity._internal.pipeline import (
25+
_get_config,
26+
_get_policies,
27+
build_pipeline,
28+
build_async_pipeline,
29+
)
30+
31+
CONFIG_POLICIES = [
32+
("custom_hook_policy", CustomHookPolicy),
33+
("headers_policy", HeadersPolicy),
34+
("http_logging_policy", HttpLoggingPolicy),
35+
("logging_policy", NetworkTraceLoggingPolicy),
36+
("proxy_policy", ProxyPolicy),
37+
("user_agent_policy", UserAgentPolicy),
38+
]
39+
40+
41+
class TestGetConfigPolicyOverrides:
42+
"""Tests that _get_config respects policy override kwargs."""
43+
44+
def test_default_policies_created_when_no_overrides(self):
45+
config = _get_config()
46+
for attr, cls in CONFIG_POLICIES:
47+
assert isinstance(getattr(config, attr), cls)
48+
49+
@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
50+
def test_single_policy_override(self, kwarg, cls):
51+
custom = Mock(spec=cls)
52+
config = _get_config(**{kwarg: custom})
53+
assert getattr(config, kwarg) is custom
54+
55+
@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
56+
def test_non_overridden_policies_unaffected(self, kwarg, cls):
57+
"""Overriding one policy should not affect others."""
58+
custom = Mock(spec=cls)
59+
config = _get_config(**{kwarg: custom})
60+
for other_attr, other_cls in CONFIG_POLICIES:
61+
if other_attr == kwarg:
62+
assert getattr(config, other_attr) is custom
63+
else:
64+
assert isinstance(getattr(config, other_attr), other_cls)
65+
66+
67+
class TestGetPoliciesOverrides:
68+
"""Tests for per_call_policies and per_retry_policies in _get_policies."""
69+
70+
@staticmethod
71+
def _make_config():
72+
config = _get_config()
73+
config.retry_policy = RetryPolicy()
74+
return config
75+
76+
def test_default_policy_order(self):
77+
policies = _get_policies(self._make_config())
78+
79+
assert [type(p) for p in policies] == [
80+
HeadersPolicy,
81+
UserAgentPolicy,
82+
ProxyPolicy,
83+
ContentDecodePolicy,
84+
RetryPolicy,
85+
CustomHookPolicy,
86+
NetworkTraceLoggingPolicy,
87+
DistributedTracingPolicy,
88+
HttpLoggingPolicy,
89+
]
90+
91+
@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
92+
def test_per_call_policies_inserted_before_retry(self, as_list):
93+
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
94+
arg = custom_policies if as_list else custom_policies[0]
95+
96+
policies = _get_policies(self._make_config(), per_call_policies=arg)
97+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
98+
for custom in custom_policies:
99+
assert policies.index(custom) < retry_idx
100+
101+
@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
102+
def test_per_retry_policies_inserted_after_retry(self, as_list):
103+
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
104+
arg = custom_policies if as_list else custom_policies[0]
105+
106+
policies = _get_policies(self._make_config(), per_retry_policies=arg)
107+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
108+
for custom in custom_policies:
109+
assert policies.index(custom) > retry_idx
110+
111+
def test_both_per_call_and_per_retry(self):
112+
per_call = Mock(spec=SansIOHTTPPolicy)
113+
per_retry = Mock(spec=SansIOHTTPPolicy)
114+
115+
policies = _get_policies(self._make_config(), per_call_policies=per_call, per_retry_policies=per_retry)
116+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
117+
assert policies.index(per_call) < retry_idx
118+
assert policies.index(per_retry) > retry_idx
119+
120+
121+
class TestBuildPipelineOverrides:
122+
"""Tests for policy overrides in build_pipeline and build_async_pipeline."""
123+
124+
@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
125+
def test_retry_policy_override(self, builder):
126+
custom_retry = Mock(spec=RetryPolicy)
127+
pipeline = builder(retry_policy=custom_retry, transport=Mock())
128+
assert custom_retry in pipeline._impl_policies
129+
130+
def test_default_retry_policy_when_no_override(self):
131+
pipeline = build_pipeline(transport=Mock())
132+
retry_policies = [p for p in pipeline._impl_policies if isinstance(p, RetryPolicy)]
133+
assert len(retry_policies) == 1
134+
135+
@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
136+
def test_policy_override_flows_through(self, builder):
137+
"""Verify that config policy overrides reach the pipeline."""
138+
custom_headers = Mock(spec=HeadersPolicy)
139+
pipeline = builder(headers_policy=custom_headers, transport=Mock())
140+
wrapped_policies = [p._policy for p in pipeline._impl_policies if hasattr(p, "_policy")]
141+
assert custom_headers in wrapped_policies

0 commit comments

Comments
 (0)