From 5de2580a77be2bdb5d769491fdfef3379e476931 Mon Sep 17 00:00:00 2001 From: Sai Sunder Srinivasan Date: Mon, 2 Feb 2026 19:44:39 +0000 Subject: [PATCH] draft: aio mtls support --- google/auth/aio/transport/mtls.py | 88 ++++++++++++++++ google/auth/aio/transport/sessions.py | 83 +++++++++++++++ samples/verify_async_mtls.py | 97 ++++++++++++++++++ samples/verify_async_static_creds.py | 59 +++++++++++ samples/verify_sync_mtls.py | 48 +++++++++ samples/verify_vertex_async.py | 47 +++++++++ tests/aio/transport/test_mtls.py | 69 +++++++++++++ tests/aio/transport/test_sessions_mtls.py | 119 ++++++++++++++++++++++ 8 files changed, 610 insertions(+) create mode 100644 google/auth/aio/transport/mtls.py create mode 100644 samples/verify_async_mtls.py create mode 100644 samples/verify_async_static_creds.py create mode 100644 samples/verify_sync_mtls.py create mode 100644 samples/verify_vertex_async.py create mode 100644 tests/aio/transport/test_mtls.py create mode 100644 tests/aio/transport/test_sessions_mtls.py diff --git a/google/auth/aio/transport/mtls.py b/google/auth/aio/transport/mtls.py new file mode 100644 index 000000000..f43f92e0b --- /dev/null +++ b/google/auth/aio/transport/mtls.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for mTLS in asyncio. +""" + +import asyncio +import contextlib +import logging +import os +import ssl +import tempfile +from typing import Optional + +from google.auth import exceptions + + + +@contextlib.contextmanager +def _create_temp_file(content: bytes): + """Creates a temporary file with the given content. + + Args: + content (bytes): The content to write to the file. + + Yields: + str: The path to the temporary file. + """ + # Create a temporary file that is readable only by the owner. + fd, path = tempfile.mkstemp() + try: + with os.fdopen(fd, "wb") as f: + f.write(content) + yield path + finally: + # Securely delete the file after use. + if os.path.exists(path): + os.remove(path) + + +def make_client_cert_ssl_context( + cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None +) -> ssl.SSLContext: + """Creates an SSLContext with the given client certificate and key. + + This function writes the certificate and key to temporary files so that + ssl.create_default_context can load them, as the ssl module requires + file paths for client certificates. + + Args: + cert_bytes (bytes): The client certificate content in PEM format. + key_bytes (bytes): The client private key content in PEM format. + passphrase (Optional[bytes]): The passphrase for the private key, if any. + + Returns: + ssl.SSLContext: The configured SSL context with client certificate. + + Raises: + google.auth.exceptions.TransportError: If there is an error loading the certificate. + """ + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + + # Write cert and key to temp files because ssl.load_cert_chain requires paths + with _create_temp_file(cert_bytes) as cert_path: + with _create_temp_file(key_bytes) as key_path: + context.load_cert_chain( + certfile=cert_path, + keyfile=key_path, + password=passphrase + ) + return context + except (ssl.SSLError, OSError) as exc: + raise exceptions.TransportError( + "Failed to load client certificate and key for mTLS." + ) from exc diff --git a/google/auth/aio/transport/sessions.py b/google/auth/aio/transport/sessions.py index 8045911cb..529770114 100644 --- a/google/auth/aio/transport/sessions.py +++ b/google/auth/aio/transport/sessions.py @@ -22,15 +22,30 @@ from google.auth.aio import transport from google.auth.aio.credentials import Credentials from google.auth.exceptions import TimeoutError +import google.auth.transport._mtls_helper +from google.auth.aio.transport import mtls try: from google.auth.aio.transport.aiohttp import Request as AiohttpRequest + import aiohttp AIOHTTP_INSTALLED = True except ImportError: # pragma: NO COVER AIOHTTP_INSTALLED = False + + +async def _run_in_executor(func, *args): + """Run a blocking function in an executor.""" + try: + return await asyncio.to_thread(func, *args) + except AttributeError: + # Fallback for Python < 3.9 + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, func, *args) + + @asynccontextmanager async def timeout_guard(timeout): """ @@ -124,12 +139,73 @@ def __init__( _auth_request = auth_request if not _auth_request and AIOHTTP_INSTALLED: _auth_request = AiohttpRequest() + self._is_mtls = False + self._cached_cert = None if _auth_request is None: raise exceptions.TransportError( "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." ) self._auth_request = _auth_request + async def configure_mtls_channel(self, client_cert_callback=None): + """Configure the client certificate and key for SSL connection. + + The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is + explicitly set to `true`. In this case if client certificate and key are + successfully obtained (from the given client_cert_callback or from application + default SSL credentials), the underlying transport will be reconfigured + to use mTLS. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): + The optional callback returns the client certificate and private + key bytes both in PEM format. + If the callback is None, application default SSL credentials + will be used. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel + creation failed for any reason. + """ + # Run the blocking check in an executor + use_client_cert = await _run_in_executor( + google.auth.transport._mtls_helper.check_use_client_cert + ) + if not use_client_cert: + self._is_mtls = False + return + + try: + ( + self._is_mtls, + cert, + key, + ) = await _run_in_executor( + google.auth.transport._mtls_helper.get_client_cert_and_key, + client_cert_callback, + ) + + if self._is_mtls: + self._cached_cert = cert + ssl_context = await _run_in_executor( + mtls.make_client_cert_ssl_context, cert, key + ) + + # Re-create the auth request with the new SSL context + if isinstance(self._auth_request, AiohttpRequest): + connector = aiohttp.TCPConnector(ssl=ssl_context) + new_session = aiohttp.ClientSession(connector=connector) + await self._auth_request.close() + self._auth_request = AiohttpRequest(session=new_session) + + except ( + exceptions.ClientCertError, + ImportError, + OSError, + ) as caught_exc: + new_exc = exceptions.MutualTLSChannelError(caught_exc) + raise new_exc from caught_exc + async def request( self, method: str, @@ -174,6 +250,8 @@ async def request( retries = _exponential_backoff.AsyncExponentialBackoff( total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS ) + if headers is None: + headers = {} async with timeout_guard(max_allowed_time) as with_timeout: await with_timeout( # Note: before_request will attempt to refresh credentials if expired. @@ -261,6 +339,11 @@ async def delete( "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs ) + @property + def is_mtls(self): + """Indicates if the created SSL channel is mutual TLS.""" + return self._is_mtls + async def close(self) -> None: """ Close the underlying auth request session. diff --git a/samples/verify_async_mtls.py b/samples/verify_async_mtls.py new file mode 100644 index 000000000..4afa81066 --- /dev/null +++ b/samples/verify_async_mtls.py @@ -0,0 +1,97 @@ + +import asyncio +import os +import logging +import google.auth +import google.auth.transport.requests +from google.auth.aio.transport.sessions import AsyncAuthorizedSession +from google.auth.aio.credentials import Credentials as AioCredentials + +# Configure logging to see detailed mTLS info if available +logging.basicConfig(level=logging.INFO) + +class SyncToAsyncCredentialsAdapter(AioCredentials): + """ + Adapts synchronous google.oauth2.credentials.Credentials to + google.auth.aio.credentials.Credentials. + + This allows using standard ADC (User Credentials) with AsyncAuthorizedSession. + It uses a thread executor to perform blocking refresh operations. + """ + def __init__(self, sync_creds): + super().__init__() + self._sync_creds = sync_creds + + async def refresh(self, request): + # We ignore the async `request` passed here and use a new sync Request + # because the underlying credentials are synchronous. + sync_request = google.auth.transport.requests.Request() + await asyncio.to_thread(self._sync_creds.refresh, sync_request) + self.token = self._sync_creds.token + self.expiry = self._sync_creds.expiry + + async def before_request(self, request, method, url, headers): + sync_request = google.auth.transport.requests.Request() + # Offload the blocking refresh/apply check to a thread + await asyncio.to_thread( + self._sync_creds.before_request, sync_request, method, url, headers + ) + # after before_request, token might be refreshed + self.token = self._sync_creds.token + self.expiry = self._sync_creds.expiry + +async def main(): + # 1. Get default credentials and project + print("Loading default credentials...") + sync_credentials, project_id = google.auth.default( + scopes=["https://www.googleapis.com/auth/pubsub"] + ) + + print(f"Using credentials for project: {project_id}") + print(f"Credential type: {type(sync_credentials)}") + + if not project_id: + print("Error: Could not determine project ID from environment.") + print("Please set GOOGLE_CLOUD_PROJECT or have it in your ADC.") + return + + # 2. Adapt credentials + async_credentials = SyncToAsyncCredentialsAdapter(sync_credentials) + + # 3. Create the AsyncAuthorizedSession + session = AsyncAuthorizedSession(async_credentials) + + try: + # 4. Enable mTLS + # To actually force mTLS, ensure GOOGLE_API_USE_CLIENT_CERTIFICATE=true is in env + print("Configuring mTLS channel...") + await session.configure_mtls_channel() + + print(f"mTLS enabled: {session.is_mtls}") + + # 5. Make a request to Pub/Sub API + url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics" + print(f"Making request to: {url}") + + response = await session.get(url) + + print(f"Response Status: {response.status_code}") + if response.status_code == 200: + response_data = await response.json() + print("Success! Topics found.") + # print("Response Body (first 200 chars):", str(response_data)[:200]) + else: + print("Request failed.") + print(await response.text()) + + finally: + await session.close() + +if __name__ == "__main__": + # Ensure SSL cert/key env vars are set if you want to test actual mTLS, + # otherwise it might fallback to regular TLS if the check returns False. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE") != "true": + print("WARNING: GOOGLE_API_USE_CLIENT_CERTIFICATE is not set to 'true'.") + print("mTLS might not be attempted. Run with: export GOOGLE_API_USE_CLIENT_CERTIFICATE=true") + + asyncio.run(main()) diff --git a/samples/verify_async_static_creds.py b/samples/verify_async_static_creds.py new file mode 100644 index 000000000..be4663737 --- /dev/null +++ b/samples/verify_async_static_creds.py @@ -0,0 +1,59 @@ + +import asyncio +import os +import aiohttp +from google.auth.aio.credentials import StaticCredentials +from google.auth.aio.transport.sessions import AsyncAuthorizedSession + +async def main(): + # 1. Obtain a token from ADC + # We use sync credentials to get the token, then pass it to StaticCredentials + # which is compatible with AsyncAuthorizedSession. + from google.auth.transport.requests import Request + import google.auth + + print("Loading default credentials...") + creds, project_id = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + # Refresh to ensure we have a valid token + print("Refreshing credentials to get access token...") + creds.refresh(Request()) + token = creds.token + + print(f"Using token from ADC: {token[:10]}...") + + # 2. Create StaticCredentials + # These credentials are immutable and will not be refreshed by the Async session. + # Since we just refreshed them, they should be valid for ~1 hour. + async_creds = StaticCredentials(token=token) + + # 3. Create the AsyncAuthorizedSession + session = AsyncAuthorizedSession(async_creds) + + try: + # 4. Make a request to Pub/Sub API (REST) + # Note: GAPIC libraries (google-cloud-pubsub) generally do not support + # google.auth.aio credentials yet. We use REST to verify the Async Session. + url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics" + print(f"Making request to: {url}") + + response = await session.get(url) + print(f"Response Status: {response.status_code}") + + if response.status_code == 200: + import json + body_bytes = await response.read() + data = json.loads(body_bytes) + print("Topics found (count):", len(data.get("topics", []))) + + else: + print("Request failed.") + print((await response.read()).decode("utf-8")) + + finally: + await session.close() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/samples/verify_sync_mtls.py b/samples/verify_sync_mtls.py new file mode 100644 index 000000000..31004fa60 --- /dev/null +++ b/samples/verify_sync_mtls.py @@ -0,0 +1,48 @@ + +import os +import google.auth +import google.auth.transport.requests +from google.auth.transport.requests import AuthorizedSession + +def main(): + # 1. Get default credentials and project + print("Loading default credentials...") + credentials, project_id = google.auth.default( + scopes=["https://www.googleapis.com/auth/pubsub"] + ) + + print(f"Using credentials for project: {project_id}") + print(f"Credential type: {type(credentials)}") + + if not project_id: + print("Error: Could not determine project ID from environment.") + return + + # 2. Create the AuthorizedSession + session = AuthorizedSession(credentials) + + # 3. Enable mTLS + # To actually force mTLS, ensure GOOGLE_API_USE_CLIENT_CERTIFICATE=true is in env + print("Configuring mTLS channel...") + session.configure_mtls_channel() + + print(f"mTLS enabled: {session.is_mtls}") + + # 4. Make a request to Pub/Sub API + url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics" + print(f"Making request to: {url}") + + response = session.get(url) + + print(f"Response Status: {response.status_code}") + if response.status_code == 200: + print("Success! Topics found.") + else: + print("Request failed.") + print(response.text) + +if __name__ == "__main__": + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE") != "true": + print("WARNING: GOOGLE_API_USE_CLIENT_CERTIFICATE is not set to 'true'.") + + main() diff --git a/samples/verify_vertex_async.py b/samples/verify_vertex_async.py new file mode 100644 index 000000000..dcd010a8a --- /dev/null +++ b/samples/verify_vertex_async.py @@ -0,0 +1,47 @@ + +import asyncio +import os +import vertexai +from google.auth.transport.requests import Request +import google.auth +from google.auth.aio.credentials import StaticCredentials +from vertexai.preview.generative_models import GenerativeModel +from google.cloud import aiplatform + +async def main(): + print("Loading default credentials...") + creds, project_id = google.auth.default() + + # Refresh to ensure we have a valid token + print("Refreshing credentials to get access token...") + creds.refresh(Request()) + token = creds.token + print(f"Using token from ADC: {token[:10]}...") + + # Create StaticCredentials for Async + async_creds = StaticCredentials(token=token) + + # Initialize Vertex AI with REST transport + print(f"Initializing Vertex AI for project: {project_id}") + vertexai.init(project=project_id, location="us-central1", api_transport="rest") + + # Inject our Async Credentials + # This uses the hidden API mentioned by the user to set the async REST credentials + print("Injecting async credentials into AI Platform initializer...") + aiplatform.initializer._set_async_rest_credentials(credentials=async_creds) + + # Generate Content + print("Generating content...") + model = GenerativeModel("gemini-2.5-flash") + + try: + response = await model.generate_content_async("Tell me a one sentence joke.") + print("\nResponse from Gemini:") + print(response.text) + except Exception as e: + print(f"\nError generating content: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/aio/transport/test_mtls.py b/tests/aio/transport/test_mtls.py new file mode 100644 index 000000000..859f63662 --- /dev/null +++ b/tests/aio/transport/test_mtls.py @@ -0,0 +1,69 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import ssl +import unittest +from unittest import mock +import asyncio +import json + +from google.auth.aio.transport import mtls +from google.auth import exceptions +from google.auth import environment_vars + +class TestMtls(unittest.TestCase): + + @mock.patch("ssl.create_default_context") + def test_make_client_cert_ssl_context(self, create_default_context): + mock_context = mock.Mock() + create_default_context.return_value = mock_context + + cert_bytes = b"cert_content" + key_bytes = b"key_content" + passphrase = b"passphrase" + + context = mtls.make_client_cert_ssl_context(cert_bytes, key_bytes, passphrase) + + # Verify context creation + create_default_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + assert context == mock_context + + # Verify load_cert_chain was called with some file paths + mock_context.load_cert_chain.assert_called_once() + call_args = mock_context.load_cert_chain.call_args + cert_path = call_args[1]['certfile'] + key_path = call_args[1]['keyfile'] + password = call_args[1]['password'] + + # Verify passphrase passed correctly + self.assertEqual(password, passphrase) + + # Verify temporary files provided but deleted after call returns + self.assertFalse(os.path.exists(cert_path)) + self.assertFalse(os.path.exists(key_path)) + + @mock.patch("ssl.create_default_context") + def test_make_client_cert_ssl_context_error(self, create_default_context): + mock_context = mock.Mock() + create_default_context.return_value = mock_context + + # Simulate SSL error + mock_context.load_cert_chain.side_effect = ssl.SSLError("oops") + + with self.assertRaises(exceptions.TransportError): + mtls.make_client_cert_ssl_context(b"cert", b"key") + +if __name__ == '__main__': + unittest.main() diff --git a/tests/aio/transport/test_sessions_mtls.py b/tests/aio/transport/test_sessions_mtls.py new file mode 100644 index 000000000..18442775d --- /dev/null +++ b/tests/aio/transport/test_sessions_mtls.py @@ -0,0 +1,119 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock +import sys +import asyncio + +# Create a mock for aiohttp +mock_aiohttp = mock.Mock() +# Ensure it looks like a package +mock_aiohttp.__name__ = "aiohttp" + +# Mock aiohttp in sys.modules so imports work +with mock.patch.dict("sys.modules", {"aiohttp": mock_aiohttp}): + from google.auth.aio.transport import sessions + from google.auth.aio.transport import mtls + import aiohttp + +class TestSessionsMtls(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # Reset mocks between tests + mock_aiohttp.reset_mock() + + @mock.patch("google.auth.aio.transport.sessions._run_in_executor") + async def test_configure_mtls_channel( + self, _run_in_executor + ): + # Setup mocks + mock_ssl_context = mock.Mock() + + # Side effect to return different values for different calls + async def side_effect(func, *args): + if func.__name__ == 'check_use_client_cert': + return True + if func.__name__ == 'get_client_cert_and_key': + return (True, b"cert", b"key") + if func.__name__ == 'make_client_cert_ssl_context': + return mock_ssl_context + return None + + _run_in_executor.side_effect = side_effect + + # Setup session + creds = mock.Mock(spec=sessions.Credentials) + # Mock AiohttpRequest to satisfy the isinstance check + mock_auth_request = mock.Mock(spec=sessions.AiohttpRequest) + # Mock close coroutine + mock_auth_request.close = mock.AsyncMock() + + session = sessions.AsyncAuthorizedSession(creds, auth_request=mock_auth_request) + + # Call method (now async) + await session.configure_mtls_channel() + + # Verify interactions + self.assertEqual(_run_in_executor.call_count, 3) + + # Verify aiohttp interactions via the captured mock_aiohttp + mock_aiohttp.TCPConnector.assert_called_once_with(ssl=mock_ssl_context) + mock_aiohttp.ClientSession.assert_called_once() + + # Verify the session's auth_request was updated + self.assertTrue(isinstance(session._auth_request, sessions.AiohttpRequest)) + # Verify it's a new instance (different from the mock we passed) + self.assertIsNot(session._auth_request, mock_auth_request) + # Verify old request was closed + mock_auth_request.close.assert_awaited_once() + + # Verify is_mtls property + self.assertTrue(session.is_mtls) + + # Verify the chain of objects + mock_connector = mock_aiohttp.TCPConnector.return_value + mock_client_session = mock_aiohttp.ClientSession.return_value + + # Ensure ClientSession was initialized with the correct connector + mock_aiohttp.ClientSession.assert_called_with(connector=mock_connector) + + # Ensure the session's auth_request is holding the new client session + self.assertEqual(session._auth_request._session, mock_client_session) + + @mock.patch("google.auth.aio.transport.sessions._run_in_executor") + async def test_configure_mtls_channel_disabled(self, _run_in_executor): + # Configure helper to return False for check_use_client_cert + async def side_effect(func, *args): + if func.__name__ == 'check_use_client_cert': + return False + return None + _run_in_executor.side_effect = side_effect + + creds = mock.Mock(spec=sessions.Credentials) + session = sessions.AsyncAuthorizedSession(creds) + original_request = session._auth_request + + # Verify initial state + self.assertFalse(session.is_mtls) + + await session.configure_mtls_channel() + + # Should not have changed + self.assertIs(session._auth_request, original_request) + # Verify is_mtls property remains False + self.assertFalse(session.is_mtls) + +if __name__ == '__main__': + unittest.main()