Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions openeo_driver/util/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from typing import Mapping, NamedTuple, Optional, Union

import requests
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo
from openeo.rest.auth.oidc import (
OidcClientCredentialsAuthenticator,
OidcClientInfo,
OidcProviderInfo,
AccessTokenResult,
)
from openeo.util import str_truncate

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -116,18 +121,25 @@ def setup_credentials(self, credentials: ClientCredentials) -> None:
client_info=client_info, requests_session=self._session
)

def _get_access_token(self) -> str:
def _get_access_token(self) -> AccessTokenResult:
"""Get an access token using the configured authenticator."""
if not self._authenticator:
raise RuntimeError("No authentication set up")
_log.debug(f"{self.__class__.__name__} getting access token")
tokens = self._authenticator.get_tokens()
return tokens.access_token
access_token_response = self._authenticator.get_tokens()
return access_token_response

def get_access_token(self) -> str:
"""Get an access token using the configured authenticator."""
if time.time() > self._cache.expires_at:
access_token = self._get_access_token()
# TODO: get expiry from access token itself?
self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl)
access_token_response = self._get_access_token()
access_token = access_token_response.access_token
self._cache = _AccessTokenCache(access_token, self._get_access_token_expiry_time(access_token_response))
return self._cache.access_token

def _get_access_token_expiry_time(self, access_token_response: AccessTokenResult) -> float:
if access_token_response.expires_in is None:
return time.time() + self._default_ttl
else:
# Expire the cache entry before the entry actually expires
return time.time() + access_token_response.expires_in * 0.90
71 changes: 61 additions & 10 deletions tests/util/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import re
import time
from typing import Optional

import pytest
import time_machine
from openeo.rest.auth.testing import OidcMock

from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper
Expand Down Expand Up @@ -67,27 +70,75 @@ def credentials(self) -> ClientCredentials:
return ClientCredentials(oidc_issuer="https://oidc.test", client_id="client123", client_secret="s3cr3t")

@pytest.fixture
def oidc_mock(self, requests_mock, credentials) -> OidcMock:
def access_token_expires_in(self) -> Optional[int]:
"""By default we let access tokens of the mock expire in 1 hour"""
return 3600

@pytest.fixture
def local_cache_ttl(self) -> int:
"""By default we let the local cache expire in 30 minutes"""
return 1800

@pytest.fixture
def oidc_mock(self, requests_mock, credentials, access_token_expires_in) -> OidcMock:
oidc_mock = OidcMock(
requests_mock=requests_mock,
oidc_issuer=credentials.oidc_issuer,
expected_grant_type="client_credentials",
expected_client_id=credentials.client_id,
expected_fields={"client_secret": credentials.client_secret, "scope": "openid"},
access_token_expires_in=access_token_expires_in,
)
return oidc_mock

def test_basic(self, credentials, oidc_mock: OidcMock):
helper = ClientCredentialsAccessTokenHelper(credentials=credentials)
def test_basic(self, credentials, oidc_mock: OidcMock, local_cache_ttl):
helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl)
assert helper.get_access_token() == oidc_mock.state["access_token"]

def test_caching(self, credentials, oidc_mock: OidcMock):
helper = ClientCredentialsAccessTokenHelper(credentials=credentials)
assert oidc_mock.mocks["token_endpoint"].call_count == 0
assert helper.get_access_token() == oidc_mock.state["access_token"]
assert oidc_mock.mocks["token_endpoint"].call_count == 1
assert helper.get_access_token() == oidc_mock.state["access_token"]
assert oidc_mock.mocks["token_endpoint"].call_count == 1
@pytest.mark.parametrize(
["desc", "local_cache_ttl", "access_token_expires_in", "no_cache_at_30m", "no_cache_at_50m"],
[
("Long caching", 3600, 3600, False, False),
("No/very short caching", 0, 0, True, True),
("Local cache expires after 30m but before 50m, no server expiry", 40 * 60, None, False, True),
("Server cache shortest and cause expiry after 30m but before 50m", 3600, 40 * 60, False, True),
("Local cache expires after 30m but access token does not", 40 * 60, 7200, False, False),
],
)
def test_caching(
self,
credentials,
oidc_mock: OidcMock,
local_cache_ttl,
access_token_expires_in,
no_cache_at_30m,
no_cache_at_50m,
desc: str,
):
"""
Test caching by requesting an access token at start time and at the 30 and 50 minute mark.
"""
now = time.time()
helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl)

expected_chache_misses = 0
with time_machine.travel(now):
assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses
assert helper.get_access_token() == oidc_mock.state["access_token"]
expected_chache_misses += 1 # First request is always a miss
assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses

with time_machine.travel(now + 30 * 60):
assert helper.get_access_token() == oidc_mock.state["access_token"]
if no_cache_at_30m:
expected_chache_misses += 1
assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses

with time_machine.travel(now + 50 * 60):
assert helper.get_access_token() == oidc_mock.state["access_token"]
if no_cache_at_50m:
expected_chache_misses += 1
assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses

@pytest.mark.skip(reason="Logging was removed for eu-cdse/openeo-cdse-infra#476")
def test_secret_logging(self, credentials, oidc_mock: OidcMock, caplog):
Expand Down