From 800caa81e684e8820bdafdc80d1bea97b6c43437 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Wed, 18 Feb 2026 22:22:27 +0000 Subject: [PATCH] Refactor endpoint mixins to use explicit sync/async methods - Split `_list_impl` into `_list_sync` and `_list_async` in `ListEndpointMixin`. - Split `_get_impl` and `_get_impl_path` into explicit sync/async variants in `FilterGetEndpointMixin` and `PathGetEndpointMixin`. - Update `GenericEndpoint` to support strict client typing via `_require_sync_client` and `_require_async_client`. - Remove runtime type checking (`isawaitable`, `is_async` flags) from core endpoint logic. - Update unit tests to mock specific sync/async implementation methods. This change improves strict typing, reduces cognitive load, and eliminates implicit control flow in favor of explicit method dispatch. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- imednet/core/endpoint/base.py | 12 ++- imednet/core/endpoint/mixins/bases.py | 71 +++++++------- imednet/core/endpoint/mixins/get.py | 93 +++++++++++-------- imednet/core/endpoint/mixins/list.py | 48 ++++++---- tests/unit/endpoints/test_codings_endpoint.py | 2 +- tests/unit/endpoints/test_endpoints_async.py | 4 +- tests/unit/endpoints/test_forms_endpoint.py | 4 +- .../unit/endpoints/test_intervals_endpoint.py | 2 +- tests/unit/endpoints/test_queries_endpoint.py | 2 +- .../test_record_revisions_endpoint.py | 2 +- tests/unit/endpoints/test_records_endpoint.py | 4 +- tests/unit/endpoints/test_sites_endpoint.py | 2 +- .../unit/endpoints/test_subjects_endpoint.py | 2 +- tests/unit/endpoints/test_users_endpoint.py | 2 +- .../unit/endpoints/test_variables_endpoint.py | 2 +- tests/unit/endpoints/test_visits_endpoint.py | 2 +- 16 files changed, 146 insertions(+), 108 deletions(-) diff --git a/imednet/core/endpoint/base.py b/imednet/core/endpoint/base.py index 48f9a1b0..dc829a45 100644 --- a/imednet/core/endpoint/base.py +++ b/imednet/core/endpoint/base.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, Dict, TypeVar +from typing import Any, Dict, Optional, TypeVar from urllib.parse import quote from imednet.core.context import Context @@ -25,12 +25,14 @@ class GenericEndpoint(EndpointABC[T]): """ BASE_PATH = "" + _client: RequestorProtocol + _async_client: Optional[AsyncRequestorProtocol] def __init__( self, client: RequestorProtocol, ctx: Context, - async_client: AsyncRequestorProtocol | None = None, + async_client: Optional[AsyncRequestorProtocol] = None, ) -> None: self._client = client self._async_client = async_client @@ -67,6 +69,10 @@ def _build_path(self, *segments: Any) -> str: parts.append(quote(text, safe="")) return "/" + "/".join(parts) + def _require_sync_client(self) -> RequestorProtocol: + """Return the configured sync client.""" + return self._client + def _require_async_client(self) -> AsyncRequestorProtocol: """Return the configured async client or raise if missing.""" if self._async_client is None: @@ -85,7 +91,7 @@ def _get_client(self, is_async: bool) -> RequestorProtocol | AsyncRequestorProto """ if is_async: return self._require_async_client() - return self._client + return self._require_sync_client() class BaseEndpoint(EdcEndpointMixin, GenericEndpoint[T]): diff --git a/imednet/core/endpoint/mixins/bases.py b/imednet/core/endpoint/mixins/bases.py index 75425b3a..5bbbeff0 100644 --- a/imednet/core/endpoint/mixins/bases.py +++ b/imednet/core/endpoint/mixins/bases.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import Any, Awaitable, List, Optional, cast +from typing import Any, List, Optional from imednet.core.endpoint.base import GenericEndpoint from imednet.core.endpoint.edc_mixin import EdcEndpointMixin from imednet.core.paginator import AsyncPaginator, Paginator -from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol from .get import FilterGetEndpointMixin, PathGetEndpointMixin from .list import ListEndpointMixin @@ -24,24 +23,21 @@ class GenericListEndpoint(GenericEndpoint[T], ListEndpointMixin[T]): PAGINATOR_CLS: type[Paginator] = Paginator ASYNC_PAGINATOR_CLS: type[AsyncPaginator] = AsyncPaginator - def _get_context( - self, is_async: bool - ) -> tuple[RequestorProtocol | AsyncRequestorProtocol, type[Paginator] | type[AsyncPaginator]]: - client = self._get_client(is_async) - if is_async: - return client, self.ASYNC_PAGINATOR_CLS - return client, self.PAGINATOR_CLS - - def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]: - client, paginator = self._get_context(is_async) - return self._list_impl(client, paginator, **kwargs) - def list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]: - return cast(List[T], self._list_common(False, study_key=study_key, **filters)) + return self._list_sync( + self._require_sync_client(), + self.PAGINATOR_CLS, + study_key=study_key, + **filters, + ) async def async_list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]: - return await cast( - Awaitable[List[T]], self._list_common(True, study_key=study_key, **filters) + client = self._require_async_client() + return await self._list_async( + client, + self.ASYNC_PAGINATOR_CLS, + study_key=study_key, + **filters, ) @@ -60,22 +56,21 @@ class ListEndpoint(EdcListEndpoint[T]): class GenericListGetEndpoint(GenericListEndpoint[T], FilterGetEndpointMixin[T]): """Generic endpoint implementing ``list`` and ``get`` helpers.""" - def _get_common( - self, - is_async: bool, - *, - study_key: Optional[str], - item_id: Any, - ) -> T | Awaitable[T]: - client, paginator = self._get_context(is_async) - return self._get_impl(client, paginator, study_key=study_key, item_id=item_id) - def get(self, study_key: Optional[str], item_id: Any) -> T: - return cast(T, self._get_common(False, study_key=study_key, item_id=item_id)) + return self._get_sync( + self._require_sync_client(), + self.PAGINATOR_CLS, + study_key=study_key, + item_id=item_id, + ) async def async_get(self, study_key: Optional[str], item_id: Any) -> T: - return await cast( - Awaitable[T], self._get_common(True, study_key=study_key, item_id=item_id) + client = self._require_async_client() + return await self._get_async( + client, + self.ASYNC_PAGINATOR_CLS, + study_key=study_key, + item_id=item_id, ) @@ -95,14 +90,18 @@ class GenericListPathGetEndpoint(GenericListEndpoint[T], PathGetEndpointMixin[T] """Generic endpoint implementing ``list`` and ``get`` (via path) helpers.""" def get(self, study_key: Optional[str], item_id: Any) -> T: - client = self._get_client(is_async=False) - return cast(T, self._get_impl_path(client, study_key=study_key, item_id=item_id)) + return self._get_path_sync( + self._require_sync_client(), + study_key=study_key, + item_id=item_id, + ) async def async_get(self, study_key: Optional[str], item_id: Any) -> T: - client = self._get_client(is_async=True) - return await cast( - Awaitable[T], - self._get_impl_path(client, study_key=study_key, item_id=item_id, is_async=True), + client = self._require_async_client() + return await self._get_path_async( + client, + study_key=study_key, + item_id=item_id, ) diff --git a/imednet/core/endpoint/mixins/get.py b/imednet/core/endpoint/mixins/get.py index cc337534..213ba58d 100644 --- a/imednet/core/endpoint/mixins/get.py +++ b/imednet/core/endpoint/mixins/get.py @@ -1,7 +1,6 @@ from __future__ import annotations -import inspect -from typing import Any, Awaitable, Dict, Iterable, List, Optional, cast +from typing import Any, Dict, Iterable, List, Optional from imednet.core.endpoint.abc import EndpointABC from imednet.core.paginator import AsyncPaginator, Paginator @@ -15,17 +14,29 @@ class FilterGetEndpointMixin(EndpointABC[T]): # MODEL and _id_param are inherited from EndpointABC as abstract or properties - # This should be provided by ListEndpointMixin or similar implementation - def _list_impl( + # These should be provided by ListEndpointMixin or similar implementation + def _list_sync( self, - client: RequestorProtocol | AsyncRequestorProtocol, - paginator_cls: type[Paginator] | type[AsyncPaginator], + client: RequestorProtocol, + paginator_cls: type[Paginator], *, study_key: Optional[str] = None, refresh: bool = False, extra_params: Optional[Dict[str, Any]] = None, **filters: Any, - ) -> List[T] | Awaitable[List[T]]: + ) -> List[T]: + raise NotImplementedError + + async def _list_async( + self, + client: AsyncRequestorProtocol, + paginator_cls: type[AsyncPaginator], + *, + study_key: Optional[str] = None, + refresh: bool = False, + extra_params: Optional[Dict[str, Any]] = None, + **filters: Any, + ) -> List[T]: raise NotImplementedError def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id: Any) -> T: @@ -35,33 +46,41 @@ def _validate_get_result(self, items: List[T], study_key: Optional[str], item_id raise ValueError(f"{self.MODEL.__name__} {item_id} not found") return items[0] - def _get_impl( + def _get_sync( self, - client: RequestorProtocol | AsyncRequestorProtocol, - paginator_cls: type[Paginator] | type[AsyncPaginator], + client: RequestorProtocol, + paginator_cls: type[Paginator], *, study_key: Optional[str], item_id: Any, - ) -> T | Awaitable[T]: + ) -> T: filters = {self._id_param: item_id} - result = self._list_impl( + result = self._list_sync( client, paginator_cls, study_key=study_key, refresh=True, **filters, ) + return self._validate_get_result(result, study_key, item_id) - if inspect.isawaitable(result): - - async def _await() -> T: - items = await result - return self._validate_get_result(items, study_key, item_id) - - return _await() - - # Sync path - return self._validate_get_result(cast(List[T], result), study_key, item_id) + async def _get_async( + self, + client: AsyncRequestorProtocol, + paginator_cls: type[AsyncPaginator], + *, + study_key: Optional[str], + item_id: Any, + ) -> T: + filters = {self._id_param: item_id} + result = await self._list_async( + client, + paginator_cls, + study_key=study_key, + refresh=True, + **filters, + ) + return self._validate_get_result(result, study_key, item_id) class PathGetEndpointMixin(ParsingMixin[T], EndpointABC[T]): @@ -91,26 +110,24 @@ def _process_response(self, response: Any, study_key: Optional[str], item_id: An self._raise_not_found(study_key, item_id) return self._parse_item(data) - def _get_impl_path( + def _get_path_sync( self, - client: RequestorProtocol | AsyncRequestorProtocol, + client: RequestorProtocol, *, study_key: Optional[str], item_id: Any, - is_async: bool = False, - ) -> T | Awaitable[T]: + ) -> T: path = self._get_path_for_id(study_key, item_id) + response = client.get(path) + return self._process_response(response, study_key, item_id) - if is_async: - - async def _await() -> T: - # We assume client is AsyncRequestorProtocol because is_async=True - aclient = cast(AsyncRequestorProtocol, client) - response = await aclient.get(path) - return self._process_response(response, study_key, item_id) - - return _await() - - sclient = cast(RequestorProtocol, client) - response = sclient.get(path) + async def _get_path_async( + self, + client: AsyncRequestorProtocol, + *, + study_key: Optional[str], + item_id: Any, + ) -> T: + path = self._get_path_for_id(study_key, item_id) + response = await client.get(path) return self._process_response(response, study_key, item_id) diff --git a/imednet/core/endpoint/mixins/list.py b/imednet/core/endpoint/mixins/list.py index 337fa3da..9cac5074 100644 --- a/imednet/core/endpoint/mixins/list.py +++ b/imednet/core/endpoint/mixins/list.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, cast from imednet.constants import DEFAULT_PAGE_SIZE from imednet.core.endpoint.abc import EndpointABC @@ -118,17 +118,16 @@ def _prepare_list_request( cache=cache, ) - def _list_impl( + def _list_sync( self, - client: RequestorProtocol | AsyncRequestorProtocol, - paginator_cls: type[Paginator] | type[AsyncPaginator], + client: RequestorProtocol, + paginator_cls: type[Paginator], *, study_key: Optional[str] = None, refresh: bool = False, extra_params: Optional[Dict[str, Any]] = None, **filters: Any, - ) -> List[T] | Awaitable[List[T]]: - + ) -> List[T]: state = self._prepare_list_request(study_key, extra_params, filters, refresh) if state.cached_result is not None: @@ -137,17 +136,34 @@ def _list_impl( paginator = paginator_cls(client, state.path, params=state.params, page_size=self.PAGE_SIZE) parse_func = self._resolve_parse_func() - if hasattr(paginator, "__aiter__"): - return self._execute_async_list( - cast(AsyncPaginator, paginator), - parse_func, - state.study, - state.has_filters, - state.cache, - ) - return self._execute_sync_list( - cast(Paginator, paginator), + paginator, + parse_func, + state.study, + state.has_filters, + state.cache, + ) + + async def _list_async( + self, + client: AsyncRequestorProtocol, + paginator_cls: type[AsyncPaginator], + *, + study_key: Optional[str] = None, + refresh: bool = False, + extra_params: Optional[Dict[str, Any]] = None, + **filters: Any, + ) -> List[T]: + state = self._prepare_list_request(study_key, extra_params, filters, refresh) + + if state.cached_result is not None: + return state.cached_result + + paginator = paginator_cls(client, state.path, params=state.params, page_size=self.PAGE_SIZE) + parse_func = self._resolve_parse_func() + + return await self._execute_async_list( + paginator, parse_func, state.study, state.has_filters, diff --git a/tests/unit/endpoints/test_codings_endpoint.py b/tests/unit/endpoints/test_codings_endpoint.py index 4b4d32fc..d65368c6 100644 --- a/tests/unit/endpoints/test_codings_endpoint.py +++ b/tests/unit/endpoints/test_codings_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(codings.CodingsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(codings.CodingsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", "x") diff --git a/tests/unit/endpoints/test_endpoints_async.py b/tests/unit/endpoints/test_endpoints_async.py index 54cfd5eb..134a7ed4 100644 --- a/tests/unit/endpoints/test_endpoints_async.py +++ b/tests/unit/endpoints/test_endpoints_async.py @@ -166,7 +166,7 @@ async def fake_impl(self, client, paginator, *, study_key=None, **filters): called["filters"] = filters return [Record(record_id=1)] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl) rec = await ep.async_get("S1", 1) @@ -181,7 +181,7 @@ async def test_async_get_record_not_found(monkeypatch, dummy_client, context, re async def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl) with pytest.raises(ValueError): await ep.async_get("S1", 1) diff --git a/tests/unit/endpoints/test_forms_endpoint.py b/tests/unit/endpoints/test_forms_endpoint.py index d0b55156..b1da238f 100644 --- a/tests/unit/endpoints/test_forms_endpoint.py +++ b/tests/unit/endpoints/test_forms_endpoint.py @@ -34,7 +34,7 @@ def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filte called["filters"] = filters return [Form(form_id=1)] - monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl) res = ep.get("S1", 1) @@ -48,7 +48,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_intervals_endpoint.py b/tests/unit/endpoints/test_intervals_endpoint.py index 00b60906..02187e90 100644 --- a/tests/unit/endpoints/test_intervals_endpoint.py +++ b/tests/unit/endpoints/test_intervals_endpoint.py @@ -27,7 +27,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_queries_endpoint.py b/tests/unit/endpoints/test_queries_endpoint.py index 379a74e6..8873b991 100644 --- a/tests/unit/endpoints/test_queries_endpoint.py +++ b/tests/unit/endpoints/test_queries_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(queries.QueriesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(queries.QueriesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_record_revisions_endpoint.py b/tests/unit/endpoints/test_record_revisions_endpoint.py index a5be9196..bc4a2eee 100644 --- a/tests/unit/endpoints/test_record_revisions_endpoint.py +++ b/tests/unit/endpoints/test_record_revisions_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_records_endpoint.py b/tests/unit/endpoints/test_records_endpoint.py index fb1d7f8b..7f742d4b 100644 --- a/tests/unit/endpoints/test_records_endpoint.py +++ b/tests/unit/endpoints/test_records_endpoint.py @@ -35,7 +35,7 @@ def fake_impl(self, client, paginator, *, study_key=None, **filters): called["filters"] = filters return [Record(record_id=1)] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl) res = ep.get("S1", 1) @@ -49,7 +49,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_sites_endpoint.py b/tests/unit/endpoints/test_sites_endpoint.py index 386f3930..653a43b5 100644 --- a/tests/unit/endpoints/test_sites_endpoint.py +++ b/tests/unit/endpoints/test_sites_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(sites.SitesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(sites.SitesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_subjects_endpoint.py b/tests/unit/endpoints/test_subjects_endpoint.py index b2513807..a45c122b 100644 --- a/tests/unit/endpoints/test_subjects_endpoint.py +++ b/tests/unit/endpoints/test_subjects_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", "X") diff --git a/tests/unit/endpoints/test_users_endpoint.py b/tests/unit/endpoints/test_users_endpoint.py index 4c8e76e9..2dc00e22 100644 --- a/tests/unit/endpoints/test_users_endpoint.py +++ b/tests/unit/endpoints/test_users_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(users.UsersEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(users.UsersEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_variables_endpoint.py b/tests/unit/endpoints/test_variables_endpoint.py index 7b65045f..0627737f 100644 --- a/tests/unit/endpoints/test_variables_endpoint.py +++ b/tests/unit/endpoints/test_variables_endpoint.py @@ -29,7 +29,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(variables.VariablesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(variables.VariablesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_visits_endpoint.py b/tests/unit/endpoints/test_visits_endpoint.py index bf3f8ccf..9caf8703 100644 --- a/tests/unit/endpoints/test_visits_endpoint.py +++ b/tests/unit/endpoints/test_visits_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(visits.VisitsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(visits.VisitsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1)