Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions google/auth/aio/transport/mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helper functions for mTLS in asyncio.
"""

import asyncio
import contextlib
import logging
Comment on lines +19 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The asyncio and logging modules are imported but not used in this file. It's good practice to remove unused imports to keep the code clean.

Suggested change
import asyncio
import contextlib
import logging
import contextlib

import os
import ssl
import tempfile
from typing import Optional

from google.auth import exceptions



@contextlib.contextmanager
def _create_temp_file(content: bytes):
"""Creates a temporary file with the given content.

Args:
content (bytes): The content to write to the file.

Yields:
str: The path to the temporary file.
"""
# Create a temporary file that is readable only by the owner.
fd, path = tempfile.mkstemp()
try:
with os.fdopen(fd, "wb") as f:
f.write(content)
yield path
finally:
# Securely delete the file after use.
if os.path.exists(path):
os.remove(path)


def make_client_cert_ssl_context(
cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None
) -> ssl.SSLContext:
"""Creates an SSLContext with the given client certificate and key.

This function writes the certificate and key to temporary files so that
ssl.create_default_context can load them, as the ssl module requires
file paths for client certificates.

Args:
cert_bytes (bytes): The client certificate content in PEM format.
key_bytes (bytes): The client private key content in PEM format.
passphrase (Optional[bytes]): The passphrase for the private key, if any.

Returns:
ssl.SSLContext: The configured SSL context with client certificate.

Raises:
google.auth.exceptions.TransportError: If there is an error loading the certificate.
"""
try:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)

# Write cert and key to temp files because ssl.load_cert_chain requires paths
with _create_temp_file(cert_bytes) as cert_path:
with _create_temp_file(key_bytes) as key_path:
context.load_cert_chain(
certfile=cert_path,
keyfile=key_path,
password=passphrase
)
return context
except (ssl.SSLError, OSError) as exc:
raise exceptions.TransportError(
"Failed to load client certificate and key for mTLS."
) from exc
83 changes: 83 additions & 0 deletions google/auth/aio/transport/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,30 @@
from google.auth.aio import transport
from google.auth.aio.credentials import Credentials
from google.auth.exceptions import TimeoutError
import google.auth.transport._mtls_helper
from google.auth.aio.transport import mtls

try:
from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
import aiohttp

AIOHTTP_INSTALLED = True
except ImportError: # pragma: NO COVER
AIOHTTP_INSTALLED = False




async def _run_in_executor(func, *args):
"""Run a blocking function in an executor."""
try:
return await asyncio.to_thread(func, *args)
except AttributeError:
# Fallback for Python < 3.9
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args)


@asynccontextmanager
async def timeout_guard(timeout):
"""
Expand Down Expand Up @@ -124,12 +139,73 @@ def __init__(
_auth_request = auth_request
if not _auth_request and AIOHTTP_INSTALLED:
_auth_request = AiohttpRequest()
self._is_mtls = False
self._cached_cert = None
if _auth_request is None:
raise exceptions.TransportError(
"`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
)
self._auth_request = _auth_request

async def configure_mtls_channel(self, client_cert_callback=None):
"""Configure the client certificate and key for SSL connection.

The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is
explicitly set to `true`. In this case if client certificate and key are
successfully obtained (from the given client_cert_callback or from application
default SSL credentials), the underlying transport will be reconfigured
to use mTLS.

Args:
client_cert_callback (Optional[Callable[[], (bytes, bytes)]]):
The optional callback returns the client certificate and private
key bytes both in PEM format.
If the callback is None, application default SSL credentials
will be used.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
creation failed for any reason.
"""
# Run the blocking check in an executor
use_client_cert = await _run_in_executor(
google.auth.transport._mtls_helper.check_use_client_cert
)
if not use_client_cert:
self._is_mtls = False
return

try:
(
self._is_mtls,
cert,
key,
) = await _run_in_executor(
google.auth.transport._mtls_helper.get_client_cert_and_key,
client_cert_callback,
)

if self._is_mtls:
self._cached_cert = cert
ssl_context = await _run_in_executor(
mtls.make_client_cert_ssl_context, cert, key
)

# Re-create the auth request with the new SSL context
if isinstance(self._auth_request, AiohttpRequest):
connector = aiohttp.TCPConnector(ssl=ssl_context)
new_session = aiohttp.ClientSession(connector=connector)
await self._auth_request.close()
self._auth_request = AiohttpRequest(session=new_session)
Comment on lines +195 to +199
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential issue here. If self._is_mtls is true but self._auth_request is not an instance of AiohttpRequest, the transport will not be reconfigured for mTLS. However, self.is_mtls will still report True. This could be misleading and is a potential security risk if the user believes mTLS is enabled when it is not.

Consider raising an exception for unsupported transport types to make the behavior explicit.

                if isinstance(self._auth_request, AiohttpRequest):
                    connector = aiohttp.TCPConnector(ssl=ssl_context)
                    new_session = aiohttp.ClientSession(connector=connector)
                    await self._auth_request.close()
                    self._auth_request = AiohttpRequest(session=new_session)
                else:
                    raise exceptions.TransportError("mTLS is only supported for aiohttp transport.")


except (
exceptions.ClientCertError,
ImportError,
OSError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc

async def request(
self,
method: str,
Expand Down Expand Up @@ -174,6 +250,8 @@ async def request(
retries = _exponential_backoff.AsyncExponentialBackoff(
total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
)
if headers is None:
headers = {}
async with timeout_guard(max_allowed_time) as with_timeout:
await with_timeout(
# Note: before_request will attempt to refresh credentials if expired.
Expand Down Expand Up @@ -261,6 +339,11 @@ async def delete(
"DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
)

@property
def is_mtls(self):
"""Indicates if the created SSL channel is mutual TLS."""
return self._is_mtls

async def close(self) -> None:
"""
Close the underlying auth request session.
Expand Down
97 changes: 97 additions & 0 deletions samples/verify_async_mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@

import asyncio
import os
import logging
import google.auth
import google.auth.transport.requests
from google.auth.aio.transport.sessions import AsyncAuthorizedSession
from google.auth.aio.credentials import Credentials as AioCredentials

# Configure logging to see detailed mTLS info if available
logging.basicConfig(level=logging.INFO)

class SyncToAsyncCredentialsAdapter(AioCredentials):
"""
Adapts synchronous google.oauth2.credentials.Credentials to
google.auth.aio.credentials.Credentials.

This allows using standard ADC (User Credentials) with AsyncAuthorizedSession.
It uses a thread executor to perform blocking refresh operations.
"""
def __init__(self, sync_creds):
super().__init__()
self._sync_creds = sync_creds

async def refresh(self, request):
# We ignore the async `request` passed here and use a new sync Request
# because the underlying credentials are synchronous.
sync_request = google.auth.transport.requests.Request()
await asyncio.to_thread(self._sync_creds.refresh, sync_request)
self.token = self._sync_creds.token
self.expiry = self._sync_creds.expiry

async def before_request(self, request, method, url, headers):
sync_request = google.auth.transport.requests.Request()
# Offload the blocking refresh/apply check to a thread
await asyncio.to_thread(
self._sync_creds.before_request, sync_request, method, url, headers
)
# after before_request, token might be refreshed
self.token = self._sync_creds.token
self.expiry = self._sync_creds.expiry

async def main():
# 1. Get default credentials and project
print("Loading default credentials...")
sync_credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/pubsub"]
)

print(f"Using credentials for project: {project_id}")
print(f"Credential type: {type(sync_credentials)}")

if not project_id:
print("Error: Could not determine project ID from environment.")
print("Please set GOOGLE_CLOUD_PROJECT or have it in your ADC.")
return

# 2. Adapt credentials
async_credentials = SyncToAsyncCredentialsAdapter(sync_credentials)

# 3. Create the AsyncAuthorizedSession
session = AsyncAuthorizedSession(async_credentials)

try:
# 4. Enable mTLS
# To actually force mTLS, ensure GOOGLE_API_USE_CLIENT_CERTIFICATE=true is in env
print("Configuring mTLS channel...")
await session.configure_mtls_channel()

print(f"mTLS enabled: {session.is_mtls}")

# 5. Make a request to Pub/Sub API
url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics"
print(f"Making request to: {url}")

response = await session.get(url)

print(f"Response Status: {response.status_code}")
if response.status_code == 200:
response_data = await response.json()
print("Success! Topics found.")
# print("Response Body (first 200 chars):", str(response_data)[:200])
else:
print("Request failed.")
print(await response.text())

finally:
await session.close()

if __name__ == "__main__":
# Ensure SSL cert/key env vars are set if you want to test actual mTLS,
# otherwise it might fallback to regular TLS if the check returns False.
if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE") != "true":
print("WARNING: GOOGLE_API_USE_CLIENT_CERTIFICATE is not set to 'true'.")
print("mTLS might not be attempted. Run with: export GOOGLE_API_USE_CLIENT_CERTIFICATE=true")

asyncio.run(main())
59 changes: 59 additions & 0 deletions samples/verify_async_static_creds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

import asyncio
import os
import aiohttp
from google.auth.aio.credentials import StaticCredentials
from google.auth.aio.transport.sessions import AsyncAuthorizedSession

async def main():
# 1. Obtain a token from ADC
# We use sync credentials to get the token, then pass it to StaticCredentials
# which is compatible with AsyncAuthorizedSession.
from google.auth.transport.requests import Request
import google.auth

print("Loading default credentials...")
creds, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)

# Refresh to ensure we have a valid token
print("Refreshing credentials to get access token...")
creds.refresh(Request())
token = creds.token

print(f"Using token from ADC: {token[:10]}...")

# 2. Create StaticCredentials
# These credentials are immutable and will not be refreshed by the Async session.
# Since we just refreshed them, they should be valid for ~1 hour.
async_creds = StaticCredentials(token=token)

# 3. Create the AsyncAuthorizedSession
session = AsyncAuthorizedSession(async_creds)

try:
# 4. Make a request to Pub/Sub API (REST)
# Note: GAPIC libraries (google-cloud-pubsub) generally do not support
# google.auth.aio credentials yet. We use REST to verify the Async Session.
url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics"
print(f"Making request to: {url}")

response = await session.get(url)
print(f"Response Status: {response.status_code}")

if response.status_code == 200:
import json
body_bytes = await response.read()
data = json.loads(body_bytes)
print("Topics found (count):", len(data.get("topics", [])))

else:
print("Request failed.")
print((await response.read()).decode("utf-8"))

finally:
await session.close()

if __name__ == "__main__":
asyncio.run(main())
Loading
Loading