Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Credential HTTP pipeline policies can now be overridden via the `headers_policy`, `logging_policy`, `http_logging_policy`, `proxy_policy`, `user_agent_policy`, and `custom_hook_policy` keyword arguments when constructing credentials. The `per_retry_policies` and `per_call_policies` are also now supported. This allows users to inject custom policies or override settings of built-in policies. ([#46072](https://github.com/Azure/azure-sdk-for-python/pull/46072))

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def __init__(self, **kwargs: Any) -> None:
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
# the credential probes only if it's part of a ChainedTokenCredential chain.
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs))
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
retry_policy = merged_kwargs.pop("retry_policy", None) or ImdsRetryPolicy(**merged_kwargs)
super().__init__(retry_policy=retry_policy, **merged_kwargs)
self._config = kwargs

if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
Expand Down
37 changes: 23 additions & 14 deletions sdk/identity/azure-identity/azure/identity/_internal/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,37 @@ def _get_config(**kwargs) -> Configuration:
:rtype: ~azure.core.configuration.Configuration
"""
config: Configuration = Configuration(**kwargs)
config.custom_hook_policy = CustomHookPolicy(**kwargs)
config.headers_policy = HeadersPolicy(**kwargs)
config.http_logging_policy = HttpLoggingPolicy(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
config.custom_hook_policy = kwargs.get("custom_hook_policy") or CustomHookPolicy(**kwargs)
config.headers_policy = kwargs.get("headers_policy") or HeadersPolicy(**kwargs)
config.http_logging_policy = kwargs.get("http_logging_policy") or HttpLoggingPolicy(**kwargs)
config.logging_policy = kwargs.get("logging_policy") or NetworkTraceLoggingPolicy(**kwargs)
config.proxy_policy = kwargs.get("proxy_policy") or ProxyPolicy(**kwargs)
config.user_agent_policy = kwargs.get("user_agent_policy") or UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
return config


def _get_policies(config, _per_retry_policies=None, **kwargs):
def _get_policies(config, **kwargs):
per_call_policies = kwargs.get("per_call_policies", [])
per_retry_policies = kwargs.get("per_retry_policies", [])

policies = [
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
config.retry_policy,
]

if _per_retry_policies:
policies.extend(_per_retry_policies)
if isinstance(per_call_policies, list):
policies.extend(per_call_policies)
else:
policies.append(per_call_policies)

policies.append(config.retry_policy)

if isinstance(per_retry_policies, list):
policies.extend(per_retry_policies)
else:
policies.append(per_retry_policies)

policies.extend(
[
Expand All @@ -63,8 +74,7 @@ def _get_policies(config, _per_retry_policies=None, **kwargs):
def build_pipeline(transport=None, policies=None, **kwargs):
if not policies:
config = _get_config(**kwargs)
retry_policy_class = kwargs.pop("retry_policy_class", None)
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else RetryPolicy(**kwargs)
config.retry_policy = kwargs.pop("retry_policy", None) or RetryPolicy(**kwargs)
policies = _get_policies(config, **kwargs)
if not transport:
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module
Expand All @@ -83,8 +93,7 @@ def build_async_pipeline(transport=None, policies=None, **kwargs):
from azure.core.pipeline.policies import AsyncRetryPolicy

config = _get_config(**kwargs)
retry_policy_class = kwargs.pop("retry_policy_class", None)
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else AsyncRetryPolicy(**kwargs)
config.retry_policy = kwargs.pop("retry_policy", None) or AsyncRetryPolicy(**kwargs)
policies = _get_policies(config, **kwargs)
if not transport:
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]:
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
if url and imds:
return AsyncManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def __init__(self, **kwargs: Any) -> None:
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
# the credential probes only if it's part of a ChainedTokenCredential chain.
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
kwargs["retry_policy_class"] = AsyncImdsRetryPolicy
self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs))
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
retry_policy = merged_kwargs.pop("retry_policy", None) or AsyncImdsRetryPolicy(**merged_kwargs)
self._client = AsyncManagedIdentityClient(_get_request, retry_policy=retry_policy, **merged_kwargs)
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
self._endpoint_available: Optional[bool] = True
else:
Expand Down
141 changes: 141 additions & 0 deletions sdk/identity/azure-identity/tests/test_pipeline_policy_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Tests for policy override support in azure-identity pipelines."""

from unittest.mock import Mock

import pytest

from azure.core.pipeline.policies import (
ContentDecodePolicy,
CustomHookPolicy,
DistributedTracingPolicy,
HeadersPolicy,
HttpLoggingPolicy,
NetworkTraceLoggingPolicy,
ProxyPolicy,
RetryPolicy,
SansIOHTTPPolicy,
UserAgentPolicy,
)

from azure.identity._internal.pipeline import (
_get_config,
_get_policies,
build_pipeline,
build_async_pipeline,
)

CONFIG_POLICIES = [
("custom_hook_policy", CustomHookPolicy),
("headers_policy", HeadersPolicy),
("http_logging_policy", HttpLoggingPolicy),
("logging_policy", NetworkTraceLoggingPolicy),
("proxy_policy", ProxyPolicy),
("user_agent_policy", UserAgentPolicy),
]


class TestGetConfigPolicyOverrides:
"""Tests that _get_config respects policy override kwargs."""

def test_default_policies_created_when_no_overrides(self):
config = _get_config()
for attr, cls in CONFIG_POLICIES:
assert isinstance(getattr(config, attr), cls)

@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
def test_single_policy_override(self, kwarg, cls):
custom = Mock(spec=cls)
config = _get_config(**{kwarg: custom})
assert getattr(config, kwarg) is custom

@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
def test_non_overridden_policies_unaffected(self, kwarg, cls):
"""Overriding one policy should not affect others."""
custom = Mock(spec=cls)
config = _get_config(**{kwarg: custom})
for other_attr, other_cls in CONFIG_POLICIES:
if other_attr == kwarg:
assert getattr(config, other_attr) is custom
else:
assert isinstance(getattr(config, other_attr), other_cls)


class TestGetPoliciesOverrides:
"""Tests for per_call_policies and per_retry_policies in _get_policies."""

@staticmethod
def _make_config():
config = _get_config()
config.retry_policy = RetryPolicy()
return config

def test_default_policy_order(self):
policies = _get_policies(self._make_config())

assert [type(p) for p in policies] == [
HeadersPolicy,
UserAgentPolicy,
ProxyPolicy,
ContentDecodePolicy,
RetryPolicy,
CustomHookPolicy,
NetworkTraceLoggingPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
]

@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
def test_per_call_policies_inserted_before_retry(self, as_list):
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
arg = custom_policies if as_list else custom_policies[0]

policies = _get_policies(self._make_config(), per_call_policies=arg)
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
for custom in custom_policies:
assert policies.index(custom) < retry_idx

@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
def test_per_retry_policies_inserted_after_retry(self, as_list):
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
arg = custom_policies if as_list else custom_policies[0]

policies = _get_policies(self._make_config(), per_retry_policies=arg)
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
for custom in custom_policies:
assert policies.index(custom) > retry_idx

def test_both_per_call_and_per_retry(self):
per_call = Mock(spec=SansIOHTTPPolicy)
per_retry = Mock(spec=SansIOHTTPPolicy)

policies = _get_policies(self._make_config(), per_call_policies=per_call, per_retry_policies=per_retry)
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
assert policies.index(per_call) < retry_idx
assert policies.index(per_retry) > retry_idx


class TestBuildPipelineOverrides:
"""Tests for policy overrides in build_pipeline and build_async_pipeline."""

@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
def test_retry_policy_override(self, builder):
custom_retry = Mock(spec=RetryPolicy)
pipeline = builder(retry_policy=custom_retry, transport=Mock())
assert custom_retry in pipeline._impl_policies

def test_default_retry_policy_when_no_override(self):
pipeline = build_pipeline(transport=Mock())
retry_policies = [p for p in pipeline._impl_policies if isinstance(p, RetryPolicy)]
assert len(retry_policies) == 1

@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
def test_policy_override_flows_through(self, builder):
"""Verify that config policy overrides reach the pipeline."""
custom_headers = Mock(spec=HeadersPolicy)
pipeline = builder(headers_policy=custom_headers, transport=Mock())
wrapped_policies = [p._policy for p in pipeline._impl_policies if hasattr(p, "_policy")]
assert custom_headers in wrapped_policies