Skip to content

Commit 9c688e6

Browse files
committed
add customizable auth
1 parent 26f0845 commit 9c688e6

File tree

5 files changed

+99
-33
lines changed

5 files changed

+99
-33
lines changed

src/flareio/api_client.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import requests
1010

1111
from requests.adapters import HTTPAdapter
12+
from requests.auth import AuthBase
1213
from urllib3.util import Retry
1314

1415
import typing as t
@@ -34,7 +35,7 @@ def __init__(
3435
tenant_id: t.Optional[int] = None,
3536
session: t.Optional[requests.Session] = None,
3637
api_domain: t.Optional[str] = None,
37-
_disable_auth: bool = False,
38+
_auth: AuthBase | None = None,
3839
_enable_beta_features: bool = False,
3940
) -> None:
4041
if not api_key:
@@ -52,9 +53,9 @@ def __init__(
5253
self._api_key: str = api_key
5354
self._tenant_id: t.Optional[int] = tenant_id
5455

56+
self._auth: t.Optional[AuthBase] = _auth
5557
self._api_token: t.Optional[str] = None
5658
self._api_token_exp: t.Optional[datetime] = None
57-
self._disable_auth: bool = _disable_auth
5859
self._session = session or self._create_session()
5960

6061
@classmethod
@@ -135,16 +136,24 @@ def generate_token(self) -> str:
135136

136137
return token
137138

138-
def _auth_headers(self) -> dict:
139-
if self._disable_auth:
140-
return dict()
139+
def _apply_auth(
140+
self,
141+
*,
142+
request: requests.PreparedRequest,
143+
) -> requests.PreparedRequest:
144+
if self._auth:
145+
self._auth(request)
146+
return request
147+
141148
api_token: t.Optional[str] = self._api_token
142149
if not api_token or (
143150
self._api_token_exp and self._api_token_exp < datetime.now()
144151
):
145152
api_token = self.generate_token()
146153

147-
return {"Authorization": f"Bearer {api_token}"}
154+
request.headers["Authorization"] = f"Bearer {api_token}"
155+
156+
return request
148157

149158
def _request(
150159
self,
@@ -163,19 +172,20 @@ def _request(
163172
f"Client was used to access {netloc=} at {url=}. Only the domain {self._api_domain} is supported."
164173
)
165174

166-
headers = {
167-
**(headers or {}),
168-
**self._auth_headers(),
169-
}
170-
171-
return self._session.request(
175+
request = requests.Request(
172176
method=method,
173177
url=url,
174178
params=params,
175179
json=json,
176180
headers=headers,
177181
)
178182

183+
prepared = self._session.prepare_request(request)
184+
prepared = self._apply_auth(request=prepared)
185+
resp = self._session.send(prepared)
186+
187+
return resp
188+
179189
def post(
180190
self,
181191
url: str,

src/flareio/auth.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from requests import PreparedRequest
2+
from requests.auth import AuthBase
3+
4+
5+
class _StaticHeadersAuth(AuthBase):
6+
def __init__(
7+
self,
8+
*,
9+
headers: dict[str, str],
10+
) -> None:
11+
self._headers: dict[str, str] = headers
12+
13+
def __call__(
14+
self,
15+
r: PreparedRequest,
16+
) -> PreparedRequest:
17+
r.headers.update(self._headers)
18+
return r
19+
20+
21+
class _EmptyAuth(AuthBase):
22+
def __call__(
23+
self,
24+
r: PreparedRequest,
25+
) -> PreparedRequest:
26+
return r

tests/test_api_client_auth.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import requests_mock
2+
3+
from .utils import get_test_client
4+
5+
from flareio.auth import _EmptyAuth
6+
from flareio.auth import _StaticHeadersAuth
7+
8+
9+
def test_custom_auth_empty() -> None:
10+
client = get_test_client(
11+
authenticated=False,
12+
_auth=_EmptyAuth(),
13+
)
14+
with requests_mock.Mocker() as mocker:
15+
mocker.register_uri(
16+
"POST",
17+
"https://api.flare.io/hello-post",
18+
status_code=200,
19+
)
20+
client.post("https://api.flare.io/hello-post", json={"foo": "bar"})
21+
assert not mocker.last_request.headers.get("Authorization")
22+
23+
24+
def test_custom_auth_static() -> None:
25+
client = get_test_client(
26+
authenticated=False,
27+
_auth=_StaticHeadersAuth(
28+
headers={
29+
"first-header": "first-value",
30+
"Authorization": "auth-value",
31+
}
32+
),
33+
)
34+
with requests_mock.Mocker() as mocker:
35+
mocker.register_uri(
36+
"POST",
37+
"https://api.flare.io/hello-post",
38+
status_code=200,
39+
)
40+
client.post(
41+
"https://api.flare.io/hello-post",
42+
json={"foo": "bar"},
43+
headers={"second-header": "second-value"},
44+
)
45+
assert mocker.last_request.headers["Authorization"] == "auth-value"
46+
assert mocker.last_request.headers["first-header"] == "first-value"
47+
assert mocker.last_request.headers["second-header"] == "second-value"

tests/test_api_client_endpoints.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,3 @@ def test_bad_domain() -> None:
121121
match="Client was used to access netloc='bad.com' at url='https://bad.com/hello-post'. Only the domain api.flare.io is supported.",
122122
):
123123
client.post("https://bad.com/hello-post")
124-
125-
126-
def test_disable_auth_does_not_call_generate() -> None:
127-
client = get_test_client(
128-
authenticated=False,
129-
_disable_auth=True,
130-
)
131-
with requests_mock.Mocker() as mocker:
132-
mocker.register_uri(
133-
"POST",
134-
"https://api.flare.io/hello-post",
135-
status_code=200,
136-
)
137-
client.post("https://api.flare.io/hello-post", json={"foo": "bar"})
138-
assert mocker.last_request.url == "https://api.flare.io/hello-post"
139-
assert mocker.last_request.json() == {"foo": "bar"}
140-
141-
# Authorization header should not be present when auth is disabled
142-
assert not mocker.last_request.headers.get("Authorization")

tests/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import requests_mock
22

3+
from requests.auth import AuthBase
4+
35
import typing as t
46

57
from flareio import FlareApiClient
@@ -11,14 +13,14 @@ def get_test_client(
1113
authenticated: bool = True,
1214
api_domain: t.Optional[str] = None,
1315
_enable_beta_features: bool = False,
14-
_disable_auth: bool = False,
16+
_auth: t.Optional[AuthBase] = None,
1517
) -> FlareApiClient:
1618
client = FlareApiClient(
1719
api_key="test-api-key",
1820
tenant_id=tenant_id,
1921
api_domain=api_domain,
2022
_enable_beta_features=_enable_beta_features,
21-
_disable_auth=_disable_auth,
23+
_auth=_auth,
2224
)
2325

2426
if authenticated:

0 commit comments

Comments
 (0)