Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 106 additions & 71 deletions packages/google-auth/google/auth/transport/_aiohttp_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def __call__(
raise new_exc from caught_exc


class AuthorizedSession(aiohttp.ClientSession):
class AuthorizedSession:
"""This is an async implementation of the Authorized Session class. We utilize an
aiohttp transport instance, and the interface mirrors the google.auth.transport.requests
Authorized Session class, except for the change in the transport used in the async use case.
Expand Down Expand Up @@ -253,18 +253,30 @@ def __init__(
auto_decompress=False,
**kwargs,
):
super(AuthorizedSession, self).__init__(**kwargs)
self._session = aiohttp.ClientSession(auto_decompress=auto_decompress, **kwargs)
self.credentials = credentials
self._refresh_status_codes = refresh_status_codes
self._max_refresh_attempts = max_refresh_attempts
self._refresh_timeout = refresh_timeout
self._is_mtls = False
self._auth_request = auth_request
self._auth_request_session = None
self._loop = asyncio.get_event_loop()
self._refresh_lock = asyncio.Lock()
self._auto_decompress = auto_decompress

# Create a new aiohttp.ClientSession and Request if one isn't provided
if auth_request is None:
self._auth_request_session = aiohttp.ClientSession(
auto_decompress=auto_decompress,
trust_env=kwargs.get("trust_env", False),
)
auth_request = Request(self._auth_request_session)
else:
self._auth_request_session = None

# Request instance used by internal methods (for example,
# credentials.refresh).
self._auth_request = auth_request

async def request(
self,
method,
Expand All @@ -273,7 +285,6 @@ async def request(
headers=None,
max_allowed_time=None,
timeout=_DEFAULT_TIMEOUT,
auto_decompress=False,
**kwargs,
):
"""Implementation of Authorized Session aiohttp request.
Expand Down Expand Up @@ -312,19 +323,50 @@ async def request(
if type(headers[key]) is bytes:
headers[key] = headers[key].decode("utf-8")

async with aiohttp.ClientSession(
auto_decompress=self._auto_decompress,
trust_env=kwargs.get("trust_env", False),
) as self._auth_request_session:
auth_request = Request(self._auth_request_session)
self._auth_request = auth_request
# Use a kwarg for this instead of an attribute to maintain
# thread-safety.
_credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0)
# Make a copy of the headers. They will be modified by the credentials
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)

remaining_time = max_allowed_time

with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
await self.credentials.before_request(
auth_request, method, url, request_headers
)

# Use a kwarg for this instead of an attribute to maintain
# thread-safety.
_credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0)
# Make a copy of the headers. They will be modified by the credentials
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}
with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
response = await self._session.request(
method,
url,
data=data,
headers=request_headers,
timeout=timeout,
**kwargs,
)

remaining_time = guard.remaining_timeout

if (
response.status in self._refresh_status_codes
and _credential_refresh_attempt < self._max_refresh_attempts
):
requests._LOGGER.info(
"Refreshing credentials due to a %s response. Attempt %s/%s.",
response.status,
_credential_refresh_attempt + 1,
self._max_refresh_attempts,
)

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
Expand All @@ -334,63 +376,56 @@ async def request(
else functools.partial(self._auth_request, timeout=timeout)
)

remaining_time = max_allowed_time

with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
await self.credentials.before_request(
auth_request, method, url, request_headers
)

with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
response = await super(AuthorizedSession, self).request(
method,
url,
data=data,
headers=request_headers,
timeout=timeout,
**kwargs,
)
async with self._refresh_lock:
await self._loop.run_in_executor(
None, self.credentials.refresh, auth_request
)

remaining_time = guard.remaining_timeout

if (
response.status in self._refresh_status_codes
and _credential_refresh_attempt < self._max_refresh_attempts
):
requests._LOGGER.info(
"Refreshing credentials due to a %s response. Attempt %s/%s.",
response.status,
_credential_refresh_attempt + 1,
self._max_refresh_attempts,
)

# Do not apply the timeout unconditionally in order to not override the
# _auth_request's default timeout.
auth_request = (
self._auth_request
if timeout is None
else functools.partial(self._auth_request, timeout=timeout)
)

with requests.TimeoutGuard(
remaining_time, asyncio.TimeoutError
) as guard:
async with self._refresh_lock:
await self._loop.run_in_executor(
None, self.credentials.refresh, auth_request
)

remaining_time = guard.remaining_timeout

return await self.request(
method,
url,
data=data,
headers=headers,
max_allowed_time=remaining_time,
timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs,
)
return await self.request(
method,
url,
data=data,
headers=headers,
max_allowed_time=remaining_time,
timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs,
)

return response

async def get(self, url, **kwargs):
return await self.request("GET", url, **kwargs) # pragma: NO COVER

async def post(self, url, **kwargs):
return await self.request("POST", url, **kwargs) # pragma: NO COVER

async def patch(self, url, **kwargs):
return await self.request("PATCH", url, **kwargs) # pragma: NO COVER

async def put(self, url, **kwargs):
return await self.request("PUT", url, **kwargs) # pragma: NO COVER

async def delete(self, url, **kwargs):
return await self.request("DELETE", url, **kwargs) # pragma: NO COVER

async def close(self):
if self._auth_request_session is not None:
await self._auth_request_session.close()
await self._session.close()

def __getattr__(self, name):
"""
Pass through all other methods to the underlying aiohttp.ClientSession object.
"""
return getattr(self._session, name)

async def __aenter__(self):
await self._session.__aenter__()
return self

async def __aexit__(self, *exc):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,34 @@ async def test_constructor_with_auth_request(self):

assert authed_session._auth_request == auth_request

@pytest.mark.asyncio
async def test_context_manager_closes_session(self):
http = mock.create_autospec(
aiohttp.ClientSession, instance=True, _auto_decompress=False
)
auth_request = aiohttp_requests.Request(http)

async with aiohttp_requests.AuthorizedSession(
mock.sentinel.credentials, auth_request=auth_request
) as session:
pass

assert session.closed

@pytest.mark.asyncio
async def test_explicit_close_session(self):
http = mock.create_autospec(
aiohttp.ClientSession, instance=True, _auto_decompress=False
)
auth_request = aiohttp_requests.Request(http)

session = aiohttp_requests.AuthorizedSession(
mock.sentinel.credentials, auth_request=auth_request
)

await session.close()
assert session.closed

@pytest.mark.asyncio
async def test_request(self):
with aioresponses() as mocked:
Expand Down
Loading