Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions imednet/core/endpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Any, Dict, TypeVar
from typing import Any, Dict, Optional, TypeVar
from urllib.parse import quote

from imednet.core.context import Context
Expand All @@ -25,12 +25,14 @@ class GenericEndpoint(EndpointABC[T]):
"""

BASE_PATH = ""
_client: RequestorProtocol
_async_client: Optional[AsyncRequestorProtocol]

def __init__(
self,
client: RequestorProtocol,
ctx: Context,
async_client: AsyncRequestorProtocol | None = None,
async_client: Optional[AsyncRequestorProtocol] = None,
) -> None:
self._client = client
self._async_client = async_client
Expand Down Expand Up @@ -67,6 +69,10 @@ def _build_path(self, *segments: Any) -> str:
parts.append(quote(text, safe=""))
return "/" + "/".join(parts)

def _require_sync_client(self) -> RequestorProtocol:
"""Return the configured sync client."""
return self._client

def _require_async_client(self) -> AsyncRequestorProtocol:
"""Return the configured async client or raise if missing."""
if self._async_client is None:
Expand All @@ -85,7 +91,7 @@ def _get_client(self, is_async: bool) -> RequestorProtocol | AsyncRequestorProto
"""
if is_async:
return self._require_async_client()
return self._client
return self._require_sync_client()


class BaseEndpoint(EdcEndpointMixin, GenericEndpoint[T]):
Expand Down
71 changes: 35 additions & 36 deletions imednet/core/endpoint/mixins/bases.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

from typing import Any, Awaitable, List, Optional, cast
from typing import Any, List, Optional

from imednet.core.endpoint.base import GenericEndpoint
from imednet.core.endpoint.edc_mixin import EdcEndpointMixin
from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol

from .get import FilterGetEndpointMixin, PathGetEndpointMixin
from .list import ListEndpointMixin
Expand All @@ -24,24 +23,21 @@ class GenericListEndpoint(GenericEndpoint[T], ListEndpointMixin[T]):
PAGINATOR_CLS: type[Paginator] = Paginator
ASYNC_PAGINATOR_CLS: type[AsyncPaginator] = AsyncPaginator

def _get_context(
self, is_async: bool
) -> tuple[RequestorProtocol | AsyncRequestorProtocol, type[Paginator] | type[AsyncPaginator]]:
client = self._get_client(is_async)
if is_async:
return client, self.ASYNC_PAGINATOR_CLS
return client, self.PAGINATOR_CLS

def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]:
client, paginator = self._get_context(is_async)
return self._list_impl(client, paginator, **kwargs)

def list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]:
return cast(List[T], self._list_common(False, study_key=study_key, **filters))
return self._list_sync(
self._require_sync_client(),
self.PAGINATOR_CLS,
study_key=study_key,
**filters,
)

async def async_list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]:
return await cast(
Awaitable[List[T]], self._list_common(True, study_key=study_key, **filters)
client = self._require_async_client()
return await self._list_async(
client,
self.ASYNC_PAGINATOR_CLS,
study_key=study_key,
**filters,
)


Expand All @@ -60,22 +56,21 @@ class ListEndpoint(EdcListEndpoint[T]):
class GenericListGetEndpoint(GenericListEndpoint[T], FilterGetEndpointMixin[T]):
"""Generic endpoint implementing ``list`` and ``get`` helpers."""

def _get_common(
self,
is_async: bool,
*,
study_key: Optional[str],
item_id: Any,
) -> T | Awaitable[T]:
client, paginator = self._get_context(is_async)
return self._get_impl(client, paginator, study_key=study_key, item_id=item_id)

def get(self, study_key: Optional[str], item_id: Any) -> T:
return cast(T, self._get_common(False, study_key=study_key, item_id=item_id))
return self._get_sync(
self._require_sync_client(),
self.PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)

async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
return await cast(
Awaitable[T], self._get_common(True, study_key=study_key, item_id=item_id)
client = self._require_async_client()
return await self._get_async(
client,
self.ASYNC_PAGINATOR_CLS,
study_key=study_key,
item_id=item_id,
)


Expand All @@ -95,14 +90,18 @@ class GenericListPathGetEndpoint(GenericListEndpoint[T], PathGetEndpointMixin[T]
"""Generic endpoint implementing ``list`` and ``get`` (via path) helpers."""

def get(self, study_key: Optional[str], item_id: Any) -> T:
client = self._get_client(is_async=False)
return cast(T, self._get_impl_path(client, study_key=study_key, item_id=item_id))
return self._get_path_sync(
self._require_sync_client(),
study_key=study_key,
item_id=item_id,
)

async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
client = self._get_client(is_async=True)
return await cast(
Awaitable[T],
self._get_impl_path(client, study_key=study_key, item_id=item_id, is_async=True),
client = self._require_async_client()
return await self._get_path_async(
client,
study_key=study_key,
item_id=item_id,
)


Expand Down
93 changes: 55 additions & 38 deletions imednet/core/endpoint/mixins/get.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import inspect
from typing import Any, Awaitable, Dict, Iterable, List, Optional, cast
from typing import Any, Dict, Iterable, List, Optional

from imednet.core.endpoint.abc import EndpointABC
from imednet.core.paginator import AsyncPaginator, Paginator
Expand All @@ -15,17 +14,29 @@ class FilterGetEndpointMixin(EndpointABC[T]):

# MODEL and _id_param are inherited from EndpointABC as abstract or properties

# This should be provided by ListEndpointMixin or similar implementation
def _list_impl(
# These should be provided by ListEndpointMixin or similar implementation
def _list_sync(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
paginator_cls: type[Paginator] | type[AsyncPaginator],
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T] | Awaitable[List[T]]:
) -> List[T]:
raise NotImplementedError

async def _list_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T]:
raise NotImplementedError

def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id: Any) -> T:
Expand All @@ -35,33 +46,41 @@ def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id
raise ValueError(f"{self.MODEL.__name__} {item_id} not found")
return items[0]

def _get_impl(
def _get_sync(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
paginator_cls: type[Paginator] | type[AsyncPaginator],
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T | Awaitable[T]:
) -> T:
filters = {self._id_param: item_id}
result = self._list_impl(
result = self._list_sync(
client,
paginator_cls,
study_key=study_key,
refresh=True,
**filters,
)
return self._validate_get_result(result, study_key, item_id)

if inspect.isawaitable(result):

async def _await() -> T:
items = await result
return self._validate_get_result(items, study_key, item_id)

return _await()

# Sync path
return self._validate_get_result(cast(List[T], result), study_key, item_id)
async def _get_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T:
filters = {self._id_param: item_id}
result = await self._list_async(
client,
paginator_cls,
study_key=study_key,
refresh=True,
**filters,
)
return self._validate_get_result(result, study_key, item_id)


class PathGetEndpointMixin(ParsingMixin[T], EndpointABC[T]):
Expand Down Expand Up @@ -91,26 +110,24 @@ def _process_response(self, response: Any, study_key: Optional[str], item_id: An
self._raise_not_found(study_key, item_id)
return self._parse_item(data)

def _get_impl_path(
def _get_path_sync(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
client: RequestorProtocol,
*,
study_key: Optional[str],
item_id: Any,
is_async: bool = False,
) -> T | Awaitable[T]:
) -> T:
path = self._get_path_for_id(study_key, item_id)
response = client.get(path)
return self._process_response(response, study_key, item_id)

if is_async:

async def _await() -> T:
# We assume client is AsyncRequestorProtocol because is_async=True
aclient = cast(AsyncRequestorProtocol, client)
response = await aclient.get(path)
return self._process_response(response, study_key, item_id)

return _await()

sclient = cast(RequestorProtocol, client)
response = sclient.get(path)
async def _get_path_async(
self,
client: AsyncRequestorProtocol,
*,
study_key: Optional[str],
item_id: Any,
) -> T:
path = self._get_path_for_id(study_key, item_id)
response = await client.get(path)
return self._process_response(response, study_key, item_id)
48 changes: 32 additions & 16 deletions imednet/core/endpoint/mixins/list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, cast

from imednet.constants import DEFAULT_PAGE_SIZE
from imednet.core.endpoint.abc import EndpointABC
Expand Down Expand Up @@ -118,17 +118,16 @@ def _prepare_list_request(
cache=cache,
)

def _list_impl(
def _list_sync(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
paginator_cls: type[Paginator] | type[AsyncPaginator],
client: RequestorProtocol,
paginator_cls: type[Paginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T] | Awaitable[List[T]]:

) -> List[T]:
state = self._prepare_list_request(study_key, extra_params, filters, refresh)

if state.cached_result is not None:
Expand All @@ -137,17 +136,34 @@ def _list_impl(
paginator = paginator_cls(client, state.path, params=state.params, page_size=self.PAGE_SIZE)
parse_func = self._resolve_parse_func()

if hasattr(paginator, "__aiter__"):
return self._execute_async_list(
cast(AsyncPaginator, paginator),
parse_func,
state.study,
state.has_filters,
state.cache,
)

return self._execute_sync_list(
cast(Paginator, paginator),
paginator,
parse_func,
state.study,
state.has_filters,
state.cache,
)

async def _list_async(
self,
client: AsyncRequestorProtocol,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T]:
state = self._prepare_list_request(study_key, extra_params, filters, refresh)

if state.cached_result is not None:
return state.cached_result

paginator = paginator_cls(client, state.path, params=state.params, page_size=self.PAGE_SIZE)
parse_func = self._resolve_parse_func()

return await self._execute_async_list(
paginator,
parse_func,
state.study,
state.has_filters,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_codings_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(codings.CodingsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(codings.CodingsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", "x")
4 changes: 2 additions & 2 deletions tests/unit/endpoints/test_endpoints_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def fake_impl(self, client, paginator, *, study_key=None, **filters):
called["filters"] = filters
return [Record(record_id=1)]

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl)

rec = await ep.async_get("S1", 1)

Expand All @@ -181,7 +181,7 @@ async def test_async_get_record_not_found(monkeypatch, dummy_client, context, re
async def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl)

with pytest.raises(ValueError):
await ep.async_get("S1", 1)
Expand Down
Loading