Skip to content

Commit e2c2885

Browse files
committed
fix: do not exceed lifetime of access token in cache
1 parent 08fc806 commit e2c2885

1 file changed

Lines changed: 19 additions & 7 deletions

File tree

openeo_driver/util/auth.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from typing import Mapping, NamedTuple, Optional, Union
77

88
import requests
9-
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo
9+
from openeo.rest.auth.oidc import (
10+
OidcClientCredentialsAuthenticator,
11+
OidcClientInfo,
12+
OidcProviderInfo,
13+
AccessTokenResult,
14+
)
1015
from openeo.util import str_truncate
1116

1217
_log = logging.getLogger(__name__)
@@ -116,18 +121,25 @@ def setup_credentials(self, credentials: ClientCredentials) -> None:
116121
client_info=client_info, requests_session=self._session
117122
)
118123

119-
def _get_access_token(self) -> str:
124+
def _get_access_token(self) -> AccessTokenResult:
120125
"""Get an access token using the configured authenticator."""
121126
if not self._authenticator:
122127
raise RuntimeError("No authentication set up")
123128
_log.debug(f"{self.__class__.__name__} getting access token")
124-
tokens = self._authenticator.get_tokens()
125-
return tokens.access_token
129+
access_token_response = self._authenticator.get_tokens()
130+
return access_token_response
126131

127132
def get_access_token(self) -> str:
128133
"""Get an access token using the configured authenticator."""
129134
if time.time() > self._cache.expires_at:
130-
access_token = self._get_access_token()
131-
# TODO: get expiry from access token itself?
132-
self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl)
135+
access_token_response = self._get_access_token()
136+
access_token = access_token_response.access_token
137+
self._cache = _AccessTokenCache(access_token, self._get_access_token_expiry_time(access_token_response))
133138
return self._cache.access_token
139+
140+
def _get_access_token_expiry_time(self, access_token_response: AccessTokenResult) -> float:
141+
if access_token_response.expires_in is None:
142+
return time.time() + self._default_ttl
143+
else:
144+
# Expire the cache entry before the entry actually expires
145+
return time.time() + access_token_response.expires_in * 0.90

0 commit comments

Comments
 (0)