From e9922f7cc6391097a179a2a08b55e3a686096849 Mon Sep 17 00:00:00 2001 From: Joseph Marotta Date: Mon, 23 Mar 2026 23:07:38 -0400 Subject: [PATCH 1/2] refactor(google-auth): refactor async AuthorizedSession to remove inheritance from aiohttp.ClientSession --- .../auth/transport/_aiohttp_requests.py | 179 +++++++++++------- .../transport/test_aiohttp_requests.py | 28 +++ 2 files changed, 136 insertions(+), 71 deletions(-) diff --git a/packages/google-auth/google/auth/transport/_aiohttp_requests.py b/packages/google-auth/google/auth/transport/_aiohttp_requests.py index e8321965e0db..0709a9b1942f 100644 --- a/packages/google-auth/google/auth/transport/_aiohttp_requests.py +++ b/packages/google-auth/google/auth/transport/_aiohttp_requests.py @@ -204,7 +204,7 @@ async def __call__( raise new_exc from caught_exc -class AuthorizedSession(aiohttp.ClientSession): +class AuthorizedSession: """This is an async implementation of the Authorized Session class. We utilize an aiohttp transport instance, and the interface mirrors the google.auth.transport.requests Authorized Session class, except for the change in the transport used in the async use case. @@ -253,18 +253,30 @@ def __init__( auto_decompress=False, **kwargs, ): - super(AuthorizedSession, self).__init__(**kwargs) + self._session = aiohttp.ClientSession(auto_decompress=auto_decompress, **kwargs) self.credentials = credentials self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts self._refresh_timeout = refresh_timeout self._is_mtls = False - self._auth_request = auth_request - self._auth_request_session = None self._loop = asyncio.get_event_loop() self._refresh_lock = asyncio.Lock() self._auto_decompress = auto_decompress + # Create a new aiohttp.ClientSession and Request if one isn't provided + if auth_request is None: + self._auth_request_session = aiohttp.ClientSession( + auto_decompress=auto_decompress, + trust_env=kwargs.get("trust_env", False), + ) + auth_request = Request(self._auth_request_session) + else: + self._auth_request_session = None + + # Request instance used by internal methods (for example, + # credentials.refresh). + self._auth_request = auth_request + async def request( self, method, @@ -273,7 +285,6 @@ async def request( headers=None, max_allowed_time=None, timeout=_DEFAULT_TIMEOUT, - auto_decompress=False, **kwargs, ): """Implementation of Authorized Session aiohttp request. @@ -312,19 +323,50 @@ async def request( if type(headers[key]) is bytes: headers[key] = headers[key].decode("utf-8") - async with aiohttp.ClientSession( - auto_decompress=self._auto_decompress, - trust_env=kwargs.get("trust_env", False), - ) as self._auth_request_session: - auth_request = Request(self._auth_request_session) - self._auth_request = auth_request + # Use a kwarg for this instead of an attribute to maintain + # thread-safety. + _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) + # Make a copy of the headers. They will be modified by the credentials + # and we want to pass the original headers if we recurse. + request_headers = headers.copy() if headers is not None else {} + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + remaining_time = max_allowed_time + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + await self.credentials.before_request( + auth_request, method, url, request_headers + ) - # Use a kwarg for this instead of an attribute to maintain - # thread-safety. - _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) - # Make a copy of the headers. They will be modified by the credentials - # and we want to pass the original headers if we recurse. - request_headers = headers.copy() if headers is not None else {} + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + response = await self._session.request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs, + ) + + remaining_time = guard.remaining_timeout + + if ( + response.status in self._refresh_status_codes + and _credential_refresh_attempt < self._max_refresh_attempts + ): + requests._LOGGER.info( + "Refreshing credentials due to a %s response. Attempt %s/%s.", + response.status, + _credential_refresh_attempt + 1, + self._max_refresh_attempts, + ) # Do not apply the timeout unconditionally in order to not override the # _auth_request's default timeout. @@ -334,63 +376,58 @@ async def request( else functools.partial(self._auth_request, timeout=timeout) ) - remaining_time = max_allowed_time - - with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: - await self.credentials.before_request( - auth_request, method, url, request_headers - ) - with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: - response = await super(AuthorizedSession, self).request( - method, - url, - data=data, - headers=request_headers, - timeout=timeout, - **kwargs, - ) + async with self._refresh_lock: + await self._loop.run_in_executor( + None, self.credentials.refresh, auth_request + ) remaining_time = guard.remaining_timeout - if ( - response.status in self._refresh_status_codes - and _credential_refresh_attempt < self._max_refresh_attempts - ): - requests._LOGGER.info( - "Refreshing credentials due to a %s response. Attempt %s/%s.", - response.status, - _credential_refresh_attempt + 1, - self._max_refresh_attempts, - ) - - # Do not apply the timeout unconditionally in order to not override the - # _auth_request's default timeout. - auth_request = ( - self._auth_request - if timeout is None - else functools.partial(self._auth_request, timeout=timeout) - ) - - with requests.TimeoutGuard( - remaining_time, asyncio.TimeoutError - ) as guard: - async with self._refresh_lock: - await self._loop.run_in_executor( - None, self.credentials.refresh, auth_request - ) - - remaining_time = guard.remaining_timeout - - return await self.request( - method, - url, - data=data, - headers=headers, - max_allowed_time=remaining_time, - timeout=timeout, - _credential_refresh_attempt=_credential_refresh_attempt + 1, - **kwargs, - ) + return await self.request( + method, + url, + data=data, + headers=headers, + max_allowed_time=remaining_time, + timeout=timeout, + _credential_refresh_attempt=_credential_refresh_attempt + 1, + **kwargs, + ) return response + + async def get(self, url, **kwargs): + return await self.request("GET", url, **kwargs) # pragma: NO COVER + + async def post(self, url, **kwargs): + return await self.request("POST", url, **kwargs) # pragma: NO COVER + + async def patch(self, url, **kwargs): + return await self.request("PATCH", url, **kwargs) # pragma: NO COVER + + async def put(self, url, **kwargs): + return await self.request("PUT", url, **kwargs) # pragma: NO COVER + + async def delete(self, url, **kwargs): + return await self.request("DELETE", url, **kwargs) # pragma: NO COVER + + async def close(self): + if self._auth_request_session is not None: + await self._auth_request_session.close() + await self._session.close() + + def __getattr__(self, name): + """ + Pass through all other methods to the underlying aiohttp.ClientSession object. + """ + return getattr(self._session, name) + + async def __aenter__(self): + await self._session.__aenter__() + return self + + async def __aexit__(self, *exc): + if self._auth_request_session is not None: + await self._auth_request_session.close() + return await self._session.__aexit__(*exc) diff --git a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py index d6a24da2e302..dd861c8d2004 100644 --- a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py +++ b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py @@ -163,6 +163,34 @@ async def test_constructor_with_auth_request(self): assert authed_session._auth_request == auth_request + @pytest.mark.asyncio + async def test_context_manager_closes_session(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + auth_request = aiohttp_requests.Request(http) + + async with aiohttp_requests.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) as session: + pass + + assert session.closed + + @pytest.mark.asyncio + async def test_explicit_close_session(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + auth_request = aiohttp_requests.Request(http) + + session = aiohttp_requests.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) + + await session.close() + assert session.closed + @pytest.mark.asyncio async def test_request(self): with aioresponses() as mocked: From cf41c43f3d7196dc54ad2fa9189f383548009892 Mon Sep 17 00:00:00 2001 From: Joseph Marotta Date: Mon, 23 Mar 2026 23:32:37 -0400 Subject: [PATCH 2/2] remove redundancy in __aexit__ --- .../google-auth/google/auth/transport/_aiohttp_requests.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/google-auth/google/auth/transport/_aiohttp_requests.py b/packages/google-auth/google/auth/transport/_aiohttp_requests.py index 0709a9b1942f..e14d0eecf568 100644 --- a/packages/google-auth/google/auth/transport/_aiohttp_requests.py +++ b/packages/google-auth/google/auth/transport/_aiohttp_requests.py @@ -428,6 +428,4 @@ async def __aenter__(self): return self async def __aexit__(self, *exc): - if self._auth_request_session is not None: - await self._auth_request_session.close() - return await self._session.__aexit__(*exc) + await self.close()