|
6 | 6 | from typing import Mapping, NamedTuple, Optional, Union |
7 | 7 |
|
8 | 8 | import requests |
9 | | -from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo |
| 9 | +from openeo.rest.auth.oidc import ( |
| 10 | + OidcClientCredentialsAuthenticator, |
| 11 | + OidcClientInfo, |
| 12 | + OidcProviderInfo, |
| 13 | + AccessTokenResult, |
| 14 | +) |
10 | 15 | from openeo.util import str_truncate |
11 | 16 |
|
12 | 17 | _log = logging.getLogger(__name__) |
@@ -116,18 +121,25 @@ def setup_credentials(self, credentials: ClientCredentials) -> None: |
116 | 121 | client_info=client_info, requests_session=self._session |
117 | 122 | ) |
118 | 123 |
|
119 | | - def _get_access_token(self) -> str: |
| 124 | + def _get_access_token(self) -> AccessTokenResult: |
120 | 125 | """Get an access token using the configured authenticator.""" |
121 | 126 | if not self._authenticator: |
122 | 127 | raise RuntimeError("No authentication set up") |
123 | 128 | _log.debug(f"{self.__class__.__name__} getting access token") |
124 | | - tokens = self._authenticator.get_tokens() |
125 | | - return tokens.access_token |
| 129 | + access_token_response = self._authenticator.get_tokens() |
| 130 | + return access_token_response |
126 | 131 |
|
127 | 132 | def get_access_token(self) -> str: |
128 | 133 | """Get an access token using the configured authenticator.""" |
129 | 134 | if time.time() > self._cache.expires_at: |
130 | | - access_token = self._get_access_token() |
131 | | - # TODO: get expiry from access token itself? |
132 | | - self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl) |
| 135 | + access_token_response = self._get_access_token() |
| 136 | + access_token = access_token_response.access_token |
| 137 | + self._cache = _AccessTokenCache(access_token, self._get_access_token_expiry_time(access_token_response)) |
133 | 138 | return self._cache.access_token |
| 139 | + |
| 140 | + def _get_access_token_expiry_time(self, access_token_response: AccessTokenResult) -> float: |
| 141 | + if access_token_response.expires_in is None: |
| 142 | + return time.time() + self._default_ttl |
| 143 | + else: |
| 144 | + # Expire the cache entry before the entry actually expires |
| 145 | + return time.time() + access_token_response.expires_in * 0.90 |
0 commit comments