-
Notifications
You must be signed in to change notification settings - Fork 349
draft: aio mtls support #1954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
draft: aio mtls support #1954
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Helper functions for mTLS in asyncio. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import contextlib | ||
| import logging | ||
| import os | ||
| import ssl | ||
| import tempfile | ||
| from typing import Optional | ||
|
|
||
| from google.auth import exceptions | ||
|
|
||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def _create_temp_file(content: bytes): | ||
| """Creates a temporary file with the given content. | ||
|
|
||
| Args: | ||
| content (bytes): The content to write to the file. | ||
|
|
||
| Yields: | ||
| str: The path to the temporary file. | ||
| """ | ||
| # Create a temporary file that is readable only by the owner. | ||
| fd, path = tempfile.mkstemp() | ||
| try: | ||
| with os.fdopen(fd, "wb") as f: | ||
| f.write(content) | ||
| yield path | ||
| finally: | ||
| # Securely delete the file after use. | ||
| if os.path.exists(path): | ||
| os.remove(path) | ||
|
|
||
|
|
||
| def make_client_cert_ssl_context( | ||
| cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None | ||
| ) -> ssl.SSLContext: | ||
| """Creates an SSLContext with the given client certificate and key. | ||
|
|
||
| This function writes the certificate and key to temporary files so that | ||
| ssl.create_default_context can load them, as the ssl module requires | ||
| file paths for client certificates. | ||
|
|
||
| Args: | ||
| cert_bytes (bytes): The client certificate content in PEM format. | ||
| key_bytes (bytes): The client private key content in PEM format. | ||
| passphrase (Optional[bytes]): The passphrase for the private key, if any. | ||
|
|
||
| Returns: | ||
| ssl.SSLContext: The configured SSL context with client certificate. | ||
|
|
||
| Raises: | ||
| google.auth.exceptions.TransportError: If there is an error loading the certificate. | ||
| """ | ||
| try: | ||
| context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) | ||
|
|
||
| # Write cert and key to temp files because ssl.load_cert_chain requires paths | ||
| with _create_temp_file(cert_bytes) as cert_path: | ||
| with _create_temp_file(key_bytes) as key_path: | ||
| context.load_cert_chain( | ||
| certfile=cert_path, | ||
| keyfile=key_path, | ||
| password=passphrase | ||
| ) | ||
| return context | ||
| except (ssl.SSLError, OSError) as exc: | ||
| raise exceptions.TransportError( | ||
| "Failed to load client certificate and key for mTLS." | ||
| ) from exc | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,15 +22,30 @@ | |
| from google.auth.aio import transport | ||
| from google.auth.aio.credentials import Credentials | ||
| from google.auth.exceptions import TimeoutError | ||
| import google.auth.transport._mtls_helper | ||
| from google.auth.aio.transport import mtls | ||
|
|
||
| try: | ||
| from google.auth.aio.transport.aiohttp import Request as AiohttpRequest | ||
| import aiohttp | ||
|
|
||
| AIOHTTP_INSTALLED = True | ||
| except ImportError: # pragma: NO COVER | ||
| AIOHTTP_INSTALLED = False | ||
|
|
||
|
|
||
|
|
||
|
|
||
| async def _run_in_executor(func, *args): | ||
| """Run a blocking function in an executor.""" | ||
| try: | ||
| return await asyncio.to_thread(func, *args) | ||
| except AttributeError: | ||
| # Fallback for Python < 3.9 | ||
| loop = asyncio.get_running_loop() | ||
| return await loop.run_in_executor(None, func, *args) | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def timeout_guard(timeout): | ||
| """ | ||
|
|
@@ -124,12 +139,73 @@ def __init__( | |
| _auth_request = auth_request | ||
| if not _auth_request and AIOHTTP_INSTALLED: | ||
| _auth_request = AiohttpRequest() | ||
| self._is_mtls = False | ||
| self._cached_cert = None | ||
| if _auth_request is None: | ||
| raise exceptions.TransportError( | ||
| "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value." | ||
| ) | ||
| self._auth_request = _auth_request | ||
|
|
||
| async def configure_mtls_channel(self, client_cert_callback=None): | ||
| """Configure the client certificate and key for SSL connection. | ||
|
|
||
| The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is | ||
| explicitly set to `true`. In this case if client certificate and key are | ||
| successfully obtained (from the given client_cert_callback or from application | ||
| default SSL credentials), the underlying transport will be reconfigured | ||
| to use mTLS. | ||
|
|
||
| Args: | ||
| client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): | ||
| The optional callback returns the client certificate and private | ||
| key bytes both in PEM format. | ||
| If the callback is None, application default SSL credentials | ||
| will be used. | ||
|
|
||
| Raises: | ||
| google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel | ||
| creation failed for any reason. | ||
| """ | ||
| # Run the blocking check in an executor | ||
| use_client_cert = await _run_in_executor( | ||
| google.auth.transport._mtls_helper.check_use_client_cert | ||
| ) | ||
| if not use_client_cert: | ||
| self._is_mtls = False | ||
| return | ||
|
|
||
| try: | ||
| ( | ||
| self._is_mtls, | ||
| cert, | ||
| key, | ||
| ) = await _run_in_executor( | ||
| google.auth.transport._mtls_helper.get_client_cert_and_key, | ||
| client_cert_callback, | ||
| ) | ||
|
|
||
| if self._is_mtls: | ||
| self._cached_cert = cert | ||
| ssl_context = await _run_in_executor( | ||
| mtls.make_client_cert_ssl_context, cert, key | ||
| ) | ||
|
|
||
| # Re-create the auth request with the new SSL context | ||
| if isinstance(self._auth_request, AiohttpRequest): | ||
| connector = aiohttp.TCPConnector(ssl=ssl_context) | ||
| new_session = aiohttp.ClientSession(connector=connector) | ||
| await self._auth_request.close() | ||
| self._auth_request = AiohttpRequest(session=new_session) | ||
|
Comment on lines
+195
to
+199
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a potential issue here. If Consider raising an exception for unsupported transport types to make the behavior explicit. if isinstance(self._auth_request, AiohttpRequest):
connector = aiohttp.TCPConnector(ssl=ssl_context)
new_session = aiohttp.ClientSession(connector=connector)
await self._auth_request.close()
self._auth_request = AiohttpRequest(session=new_session)
else:
raise exceptions.TransportError("mTLS is only supported for aiohttp transport.") |
||
|
|
||
| except ( | ||
| exceptions.ClientCertError, | ||
| ImportError, | ||
| OSError, | ||
| ) as caught_exc: | ||
| new_exc = exceptions.MutualTLSChannelError(caught_exc) | ||
| raise new_exc from caught_exc | ||
|
|
||
| async def request( | ||
| self, | ||
| method: str, | ||
|
|
@@ -174,6 +250,8 @@ async def request( | |
| retries = _exponential_backoff.AsyncExponentialBackoff( | ||
| total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS | ||
| ) | ||
| if headers is None: | ||
| headers = {} | ||
| async with timeout_guard(max_allowed_time) as with_timeout: | ||
| await with_timeout( | ||
| # Note: before_request will attempt to refresh credentials if expired. | ||
|
|
@@ -261,6 +339,11 @@ async def delete( | |
| "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs | ||
| ) | ||
|
|
||
| @property | ||
| def is_mtls(self): | ||
| """Indicates if the created SSL channel is mutual TLS.""" | ||
| return self._is_mtls | ||
|
|
||
| async def close(self) -> None: | ||
| """ | ||
| Close the underlying auth request session. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
|
|
||
| import asyncio | ||
| import os | ||
| import logging | ||
| import google.auth | ||
| import google.auth.transport.requests | ||
| from google.auth.aio.transport.sessions import AsyncAuthorizedSession | ||
| from google.auth.aio.credentials import Credentials as AioCredentials | ||
|
|
||
| # Configure logging to see detailed mTLS info if available | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
| class SyncToAsyncCredentialsAdapter(AioCredentials): | ||
| """ | ||
| Adapts synchronous google.oauth2.credentials.Credentials to | ||
| google.auth.aio.credentials.Credentials. | ||
|
|
||
| This allows using standard ADC (User Credentials) with AsyncAuthorizedSession. | ||
| It uses a thread executor to perform blocking refresh operations. | ||
| """ | ||
| def __init__(self, sync_creds): | ||
| super().__init__() | ||
| self._sync_creds = sync_creds | ||
|
|
||
| async def refresh(self, request): | ||
| # We ignore the async `request` passed here and use a new sync Request | ||
| # because the underlying credentials are synchronous. | ||
| sync_request = google.auth.transport.requests.Request() | ||
| await asyncio.to_thread(self._sync_creds.refresh, sync_request) | ||
| self.token = self._sync_creds.token | ||
| self.expiry = self._sync_creds.expiry | ||
|
|
||
| async def before_request(self, request, method, url, headers): | ||
| sync_request = google.auth.transport.requests.Request() | ||
| # Offload the blocking refresh/apply check to a thread | ||
| await asyncio.to_thread( | ||
| self._sync_creds.before_request, sync_request, method, url, headers | ||
| ) | ||
| # after before_request, token might be refreshed | ||
| self.token = self._sync_creds.token | ||
| self.expiry = self._sync_creds.expiry | ||
|
|
||
| async def main(): | ||
| # 1. Get default credentials and project | ||
| print("Loading default credentials...") | ||
| sync_credentials, project_id = google.auth.default( | ||
| scopes=["https://www.googleapis.com/auth/pubsub"] | ||
| ) | ||
|
|
||
| print(f"Using credentials for project: {project_id}") | ||
| print(f"Credential type: {type(sync_credentials)}") | ||
|
|
||
| if not project_id: | ||
| print("Error: Could not determine project ID from environment.") | ||
| print("Please set GOOGLE_CLOUD_PROJECT or have it in your ADC.") | ||
| return | ||
|
|
||
| # 2. Adapt credentials | ||
| async_credentials = SyncToAsyncCredentialsAdapter(sync_credentials) | ||
|
|
||
| # 3. Create the AsyncAuthorizedSession | ||
| session = AsyncAuthorizedSession(async_credentials) | ||
|
|
||
| try: | ||
| # 4. Enable mTLS | ||
| # To actually force mTLS, ensure GOOGLE_API_USE_CLIENT_CERTIFICATE=true is in env | ||
| print("Configuring mTLS channel...") | ||
| await session.configure_mtls_channel() | ||
|
|
||
| print(f"mTLS enabled: {session.is_mtls}") | ||
|
|
||
| # 5. Make a request to Pub/Sub API | ||
| url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics" | ||
| print(f"Making request to: {url}") | ||
|
|
||
| response = await session.get(url) | ||
|
|
||
| print(f"Response Status: {response.status_code}") | ||
| if response.status_code == 200: | ||
| response_data = await response.json() | ||
| print("Success! Topics found.") | ||
| # print("Response Body (first 200 chars):", str(response_data)[:200]) | ||
| else: | ||
| print("Request failed.") | ||
| print(await response.text()) | ||
|
|
||
| finally: | ||
| await session.close() | ||
|
|
||
| if __name__ == "__main__": | ||
| # Ensure SSL cert/key env vars are set if you want to test actual mTLS, | ||
| # otherwise it might fallback to regular TLS if the check returns False. | ||
| if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE") != "true": | ||
| print("WARNING: GOOGLE_API_USE_CLIENT_CERTIFICATE is not set to 'true'.") | ||
| print("mTLS might not be attempted. Run with: export GOOGLE_API_USE_CLIENT_CERTIFICATE=true") | ||
|
|
||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
|
|
||
| import asyncio | ||
| import os | ||
| import aiohttp | ||
| from google.auth.aio.credentials import StaticCredentials | ||
| from google.auth.aio.transport.sessions import AsyncAuthorizedSession | ||
|
|
||
| async def main(): | ||
| # 1. Obtain a token from ADC | ||
| # We use sync credentials to get the token, then pass it to StaticCredentials | ||
| # which is compatible with AsyncAuthorizedSession. | ||
| from google.auth.transport.requests import Request | ||
| import google.auth | ||
|
|
||
| print("Loading default credentials...") | ||
| creds, project_id = google.auth.default( | ||
| scopes=["https://www.googleapis.com/auth/cloud-platform"] | ||
| ) | ||
|
|
||
| # Refresh to ensure we have a valid token | ||
| print("Refreshing credentials to get access token...") | ||
| creds.refresh(Request()) | ||
| token = creds.token | ||
|
|
||
| print(f"Using token from ADC: {token[:10]}...") | ||
|
|
||
| # 2. Create StaticCredentials | ||
| # These credentials are immutable and will not be refreshed by the Async session. | ||
| # Since we just refreshed them, they should be valid for ~1 hour. | ||
| async_creds = StaticCredentials(token=token) | ||
|
|
||
| # 3. Create the AsyncAuthorizedSession | ||
| session = AsyncAuthorizedSession(async_creds) | ||
|
|
||
| try: | ||
| # 4. Make a request to Pub/Sub API (REST) | ||
| # Note: GAPIC libraries (google-cloud-pubsub) generally do not support | ||
| # google.auth.aio credentials yet. We use REST to verify the Async Session. | ||
| url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics" | ||
| print(f"Making request to: {url}") | ||
|
|
||
| response = await session.get(url) | ||
| print(f"Response Status: {response.status_code}") | ||
|
|
||
| if response.status_code == 200: | ||
| import json | ||
| body_bytes = await response.read() | ||
| data = json.loads(body_bytes) | ||
| print("Topics found (count):", len(data.get("topics", []))) | ||
|
|
||
| else: | ||
| print("Request failed.") | ||
| print((await response.read()).decode("utf-8")) | ||
|
|
||
| finally: | ||
| await session.close() | ||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
asyncioandloggingmodules are imported but not used in this file. It's good practice to remove unused imports to keep the code clean.