diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 686480a47a00..ac557fd7c153 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Credential HTTP pipeline policies can now be overridden via the `headers_policy`, `logging_policy`, `http_logging_policy`, `proxy_policy`, `user_agent_policy`, and `custom_hook_policy` keyword arguments when constructing credentials. The `per_retry_policies` and `per_call_policies` are also now supported. This allows users to inject custom policies or override settings of built-in policies. ([#46072](https://github.com/Azure/azure-sdk-for-python/pull/46072)) + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index ece23fbc4419..3ce5e0026a8f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -84,7 +84,9 @@ def __init__(self, **kwargs: Any) -> None: # probes for the IMDS endpoint before attempting to get a token. If None (the default), # the credential probes only if it's part of a ChainedTokenCredential chain. self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None) - super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs)) + merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs) + retry_policy = merged_kwargs.pop("retry_policy", None) or ImdsRetryPolicy(**merged_kwargs) + super().__init__(retry_policy=retry_policy, **merged_kwargs) self._config = kwargs if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/pipeline.py b/sdk/identity/azure-identity/azure/identity/_internal/pipeline.py index 3dea7bd66423..c1231c38ee4d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/pipeline.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/pipeline.py @@ -27,26 +27,37 @@ def _get_config(**kwargs) -> Configuration: :rtype: ~azure.core.configuration.Configuration """ config: Configuration = Configuration(**kwargs) - config.custom_hook_policy = CustomHookPolicy(**kwargs) - config.headers_policy = HeadersPolicy(**kwargs) - config.http_logging_policy = HttpLoggingPolicy(**kwargs) - config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) - config.proxy_policy = ProxyPolicy(**kwargs) - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) + config.custom_hook_policy = kwargs.get("custom_hook_policy") or CustomHookPolicy(**kwargs) + config.headers_policy = kwargs.get("headers_policy") or HeadersPolicy(**kwargs) + config.http_logging_policy = kwargs.get("http_logging_policy") or HttpLoggingPolicy(**kwargs) + config.logging_policy = kwargs.get("logging_policy") or NetworkTraceLoggingPolicy(**kwargs) + config.proxy_policy = kwargs.get("proxy_policy") or ProxyPolicy(**kwargs) + config.user_agent_policy = kwargs.get("user_agent_policy") or UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) return config -def _get_policies(config, _per_retry_policies=None, **kwargs): +def _get_policies(config, **kwargs): + per_call_policies = kwargs.get("per_call_policies", []) + per_retry_policies = kwargs.get("per_retry_policies", []) + policies = [ config.headers_policy, config.user_agent_policy, config.proxy_policy, ContentDecodePolicy(**kwargs), - config.retry_policy, ] - if _per_retry_policies: - policies.extend(_per_retry_policies) + if isinstance(per_call_policies, list): + policies.extend(per_call_policies) + else: + policies.append(per_call_policies) + + policies.append(config.retry_policy) + + if isinstance(per_retry_policies, list): + policies.extend(per_retry_policies) + else: + policies.append(per_retry_policies) policies.extend( [ @@ -63,8 +74,7 @@ def _get_policies(config, _per_retry_policies=None, **kwargs): def build_pipeline(transport=None, policies=None, **kwargs): if not policies: config = _get_config(**kwargs) - retry_policy_class = kwargs.pop("retry_policy_class", None) - config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else RetryPolicy(**kwargs) + config.retry_policy = kwargs.pop("retry_policy", None) or RetryPolicy(**kwargs) policies = _get_policies(config, **kwargs) if not transport: from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module @@ -83,8 +93,7 @@ def build_async_pipeline(transport=None, policies=None, **kwargs): from azure.core.pipeline.policies import AsyncRetryPolicy config = _get_config(**kwargs) - retry_policy_class = kwargs.pop("retry_policy_class", None) - config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else AsyncRetryPolicy(**kwargs) + config.retry_policy = kwargs.pop("retry_policy", None) or AsyncRetryPolicy(**kwargs) policies = _get_policies(config, **kwargs) if not transport: from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index d9c5690d03d1..c452061ef2a1 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -20,7 +20,7 @@ def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) if url and imds: return AsyncManagedIdentityClient( - _per_retry_policies=[ArcChallengeAuthPolicy()], + per_retry_policies=[ArcChallengeAuthPolicy()], request_factory=functools.partial(_get_request, url), **kwargs, ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 0fb758c3026c..d66188502031 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -49,8 +49,9 @@ def __init__(self, **kwargs: Any) -> None: # probes for the IMDS endpoint before attempting to get a token. If None (the default), # the credential probes only if it's part of a ChainedTokenCredential chain. self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None) - kwargs["retry_policy_class"] = AsyncImdsRetryPolicy - self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs)) + merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs) + retry_policy = merged_kwargs.pop("retry_policy", None) or AsyncImdsRetryPolicy(**merged_kwargs) + self._client = AsyncManagedIdentityClient(_get_request, retry_policy=retry_policy, **merged_kwargs) if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ: self._endpoint_available: Optional[bool] = True else: diff --git a/sdk/identity/azure-identity/tests/test_pipeline_policy_overrides.py b/sdk/identity/azure-identity/tests/test_pipeline_policy_overrides.py new file mode 100644 index 000000000000..b6e3ec6451ef --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_pipeline_policy_overrides.py @@ -0,0 +1,141 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Tests for policy override support in azure-identity pipelines.""" + +from unittest.mock import Mock + +import pytest + +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + CustomHookPolicy, + DistributedTracingPolicy, + HeadersPolicy, + HttpLoggingPolicy, + NetworkTraceLoggingPolicy, + ProxyPolicy, + RetryPolicy, + SansIOHTTPPolicy, + UserAgentPolicy, +) + +from azure.identity._internal.pipeline import ( + _get_config, + _get_policies, + build_pipeline, + build_async_pipeline, +) + +CONFIG_POLICIES = [ + ("custom_hook_policy", CustomHookPolicy), + ("headers_policy", HeadersPolicy), + ("http_logging_policy", HttpLoggingPolicy), + ("logging_policy", NetworkTraceLoggingPolicy), + ("proxy_policy", ProxyPolicy), + ("user_agent_policy", UserAgentPolicy), +] + + +class TestGetConfigPolicyOverrides: + """Tests that _get_config respects policy override kwargs.""" + + def test_default_policies_created_when_no_overrides(self): + config = _get_config() + for attr, cls in CONFIG_POLICIES: + assert isinstance(getattr(config, attr), cls) + + @pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES) + def test_single_policy_override(self, kwarg, cls): + custom = Mock(spec=cls) + config = _get_config(**{kwarg: custom}) + assert getattr(config, kwarg) is custom + + @pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES) + def test_non_overridden_policies_unaffected(self, kwarg, cls): + """Overriding one policy should not affect others.""" + custom = Mock(spec=cls) + config = _get_config(**{kwarg: custom}) + for other_attr, other_cls in CONFIG_POLICIES: + if other_attr == kwarg: + assert getattr(config, other_attr) is custom + else: + assert isinstance(getattr(config, other_attr), other_cls) + + +class TestGetPoliciesOverrides: + """Tests for per_call_policies and per_retry_policies in _get_policies.""" + + @staticmethod + def _make_config(): + config = _get_config() + config.retry_policy = RetryPolicy() + return config + + def test_default_policy_order(self): + policies = _get_policies(self._make_config()) + + assert [type(p) for p in policies] == [ + HeadersPolicy, + UserAgentPolicy, + ProxyPolicy, + ContentDecodePolicy, + RetryPolicy, + CustomHookPolicy, + NetworkTraceLoggingPolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + ] + + @pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"]) + def test_per_call_policies_inserted_before_retry(self, as_list): + custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)] + arg = custom_policies if as_list else custom_policies[0] + + policies = _get_policies(self._make_config(), per_call_policies=arg) + retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy)) + for custom in custom_policies: + assert policies.index(custom) < retry_idx + + @pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"]) + def test_per_retry_policies_inserted_after_retry(self, as_list): + custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)] + arg = custom_policies if as_list else custom_policies[0] + + policies = _get_policies(self._make_config(), per_retry_policies=arg) + retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy)) + for custom in custom_policies: + assert policies.index(custom) > retry_idx + + def test_both_per_call_and_per_retry(self): + per_call = Mock(spec=SansIOHTTPPolicy) + per_retry = Mock(spec=SansIOHTTPPolicy) + + policies = _get_policies(self._make_config(), per_call_policies=per_call, per_retry_policies=per_retry) + retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy)) + assert policies.index(per_call) < retry_idx + assert policies.index(per_retry) > retry_idx + + +class TestBuildPipelineOverrides: + """Tests for policy overrides in build_pipeline and build_async_pipeline.""" + + @pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline]) + def test_retry_policy_override(self, builder): + custom_retry = Mock(spec=RetryPolicy) + pipeline = builder(retry_policy=custom_retry, transport=Mock()) + assert custom_retry in pipeline._impl_policies + + def test_default_retry_policy_when_no_override(self): + pipeline = build_pipeline(transport=Mock()) + retry_policies = [p for p in pipeline._impl_policies if isinstance(p, RetryPolicy)] + assert len(retry_policies) == 1 + + @pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline]) + def test_policy_override_flows_through(self, builder): + """Verify that config policy overrides reach the pipeline.""" + custom_headers = Mock(spec=HeadersPolicy) + pipeline = builder(headers_policy=custom_headers, transport=Mock()) + wrapped_policies = [p._policy for p in pipeline._impl_policies if hasattr(p, "_policy")] + assert custom_headers in wrapped_policies