Skip to content

Commit 9fc2fa9

Browse files
authored
CI-589 - Add client credentials auth method (#190)
* initial client creds code * client creds implementation * remove cognito endpoint * some fixes * comments
1 parent aa4a824 commit 9fc2fa9

2 files changed

Lines changed: 124 additions & 3 deletions

File tree

cirro/auth/__init__.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,25 @@
22
from typing import Optional
33

44
from cirro.auth.access_token import AccessTokenAuth
5+
from cirro.auth.base import AuthInfo
6+
from cirro.auth.client_creds import ClientCredentialsAuth
57
from cirro.auth.device_code import DeviceCodeAuth
68

79
__all__ = [
810
'get_auth_info_from_config',
9-
"DeviceCodeAuth",
10-
"AccessTokenAuth",
11+
'AuthInfo',
12+
'DeviceCodeAuth',
13+
'AccessTokenAuth',
14+
'ClientCredentialsAuth'
1115
]
1216

1317
from cirro.config import AppConfig
1418

1519

1620
def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO] = None):
21+
"""
22+
Generates the AuthInfo object from the user's saved configuration
23+
"""
1724
user_config = app_config.user_config
1825
if not user_config or not user_config.auth_method:
1926
return DeviceCodeAuth(region=app_config.region,
@@ -22,7 +29,8 @@ def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO]
2229
auth_io=auth_io)
2330

2431
auth_methods = [
25-
DeviceCodeAuth
32+
DeviceCodeAuth,
33+
ClientCredentialsAuth
2634
]
2735
matched_auth_method = next((m for m in auth_methods if m.__name__ == user_config.auth_method), None)
2836
if not matched_auth_method:
@@ -40,3 +48,12 @@ def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO]
4048
auth_endpoint=app_config.auth_endpoint,
4149
enable_cache=auth_config.get('enable_cache') == 'True',
4250
auth_io=auth_io)
51+
52+
if matched_auth_method == ClientCredentialsAuth:
53+
return ClientCredentialsAuth(
54+
auth_config.get('client_id'),
55+
auth_config.get('client_secret'),
56+
auth_endpoint=app_config.auth_endpoint
57+
)
58+
59+
return None

cirro/auth/client_creds.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import base64
2+
import logging
3+
import threading
4+
import time
5+
from typing import Optional
6+
7+
import jwt
8+
import requests
9+
from cirro_api_client import RefreshableTokenAuth
10+
from cirro_api_client.cirro_auth import AuthMethod
11+
12+
from cirro.auth.base import AuthInfo
13+
from cirro.auth.oauth_models import OAuthTokenResponse
14+
15+
logger = logging.getLogger()
16+
17+
18+
class ClientCredentialsAuth(AuthInfo):
19+
"""
20+
Authenticates to Cirro with OAuth client credentials
21+
22+
Args:
23+
client_id (str): Client ID
24+
client_secret (str): Client Secret
25+
auth_endpoint (str): Auth Endpoint
26+
27+
```python
28+
import os
29+
from cirro import CirroApi
30+
from cirro.auth.client_creds import ClientCredentialsAuth
31+
from cirro.config import AppConfig
32+
33+
client_id = os.getenv('CIRRO_CLIENT_ID')
34+
client_secret = os.getenv('CIRRO_CLIENT_SECRET')
35+
36+
config = AppConfig(base_url="app.cirro.bio")
37+
auth_info = ClientCredentialsAuth(client_id, client_secret, auth_endpoint=config.auth_endpoint)
38+
cirro = CirroApi(auth_info=auth_info)
39+
```
40+
"""
41+
42+
def __init__(
43+
self,
44+
client_id: str,
45+
client_secret: str,
46+
auth_endpoint: str
47+
):
48+
self._client_id = client_id
49+
self._client_secret = client_secret
50+
self._auth_endpoint = auth_endpoint
51+
self._token_info: Optional[OAuthTokenResponse] = None
52+
self._token_expiry = None
53+
self._username = None
54+
self._get_token_lock = threading.Lock()
55+
56+
def get_current_user(self) -> str:
57+
return self._username
58+
59+
def get_auth_method(self) -> AuthMethod:
60+
return RefreshableTokenAuth(token_getter=lambda: self._get_token()["access_token"])
61+
62+
def _get_token(self) -> OAuthTokenResponse:
63+
with self._get_token_lock:
64+
# Refresh access token if expired
65+
if not self._token_expiry or time.time() > self._token_expiry:
66+
self._refresh_token()
67+
68+
return self._token_info
69+
70+
def _refresh_token(self):
71+
logger.debug("Refreshing token")
72+
basic_auth = base64.b64encode(
73+
f"{self._client_id}:{self._client_secret}".encode()
74+
).decode()
75+
76+
headers = {
77+
"Authorization": f"Basic {basic_auth}",
78+
"Content-Type": "application/x-www-form-urlencoded",
79+
}
80+
81+
data = {
82+
"grant_type": "client_credentials",
83+
}
84+
85+
response = requests.post(
86+
f"{self._auth_endpoint}/token",
87+
headers=headers,
88+
data=data,
89+
)
90+
token_info: OAuthTokenResponse = response.json()
91+
92+
self._token_info = token_info
93+
94+
if "access_token" not in token_info:
95+
raise RuntimeError(f"Error authenticating {token_info}")
96+
97+
self._update_token_metadata()
98+
99+
def _update_token_metadata(self):
100+
decoded_access_token = jwt.decode(self._token_info["access_token"],
101+
options={"verify_signature": False})
102+
expires_in = self._token_info.get("expires_in", 3600)
103+
self._token_expiry = time.time() + expires_in - 30
104+
self._username = decoded_access_token["appUsername"]

0 commit comments

Comments
 (0)