diff --git a/imednet/core/endpoint/base.py b/imednet/core/endpoint/base.py index 88d2b267..48f9a1b0 100644 --- a/imednet/core/endpoint/base.py +++ b/imednet/core/endpoint/base.py @@ -73,6 +73,20 @@ def _require_async_client(self) -> AsyncRequestorProtocol: raise RuntimeError("Async client not configured") return self._async_client + def _get_client(self, is_async: bool) -> RequestorProtocol | AsyncRequestorProtocol: + """ + Get the appropriate client for the execution context. + + Args: + is_async: Whether an async client is required. + + Returns: + The sync or async client instance. + """ + if is_async: + return self._require_async_client() + return self._client + class BaseEndpoint(EdcEndpointMixin, GenericEndpoint[T]): """ diff --git a/imednet/core/endpoint/edc_mixin.py b/imednet/core/endpoint/edc_mixin.py index b3865d2e..8f87fd58 100644 --- a/imednet/core/endpoint/edc_mixin.py +++ b/imednet/core/endpoint/edc_mixin.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict -from urllib.parse import quote if TYPE_CHECKING: from imednet.core.context import Context @@ -35,21 +34,3 @@ def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: if "studyKey" not in filters and self._ctx.default_study_key: filters["studyKey"] = self._ctx.default_study_key return filters - - def _build_path(self, *segments: Any) -> str: - """ - Return an API path joined with :data:`BASE_PATH`. - - Args: - *segments: URL path segments to append. - - Returns: - The full API path string. - """ - parts = [self.BASE_PATH.strip("/")] - for seg in segments: - text = str(seg).strip("/") - if text: - # Encode path segments to prevent traversal and injection - parts.append(quote(text, safe="")) - return "/" + "/".join(parts) diff --git a/imednet/core/endpoint/mixins/bases.py b/imednet/core/endpoint/mixins/bases.py index f321b0c2..75425b3a 100644 --- a/imednet/core/endpoint/mixins/bases.py +++ b/imednet/core/endpoint/mixins/bases.py @@ -27,9 +27,10 @@ class GenericListEndpoint(GenericEndpoint[T], ListEndpointMixin[T]): def _get_context( self, is_async: bool ) -> tuple[RequestorProtocol | AsyncRequestorProtocol, type[Paginator] | type[AsyncPaginator]]: + client = self._get_client(is_async) if is_async: - return self._require_async_client(), self.ASYNC_PAGINATOR_CLS - return self._client, self.PAGINATOR_CLS + return client, self.ASYNC_PAGINATOR_CLS + return client, self.PAGINATOR_CLS def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]: client, paginator = self._get_context(is_async) @@ -94,10 +95,11 @@ class GenericListPathGetEndpoint(GenericListEndpoint[T], PathGetEndpointMixin[T] """Generic endpoint implementing ``list`` and ``get`` (via path) helpers.""" def get(self, study_key: Optional[str], item_id: Any) -> T: - return cast(T, self._get_impl_path(self._client, study_key=study_key, item_id=item_id)) + client = self._get_client(is_async=False) + return cast(T, self._get_impl_path(client, study_key=study_key, item_id=item_id)) async def async_get(self, study_key: Optional[str], item_id: Any) -> T: - client = self._require_async_client() + client = self._get_client(is_async=True) return await cast( Awaitable[T], self._get_impl_path(client, study_key=study_key, item_id=item_id, is_async=True), diff --git a/imednet/endpoints/records.py b/imednet/endpoints/records.py index d4924067..d1c7c48e 100644 --- a/imednet/endpoints/records.py +++ b/imednet/endpoints/records.py @@ -1,9 +1,10 @@ """Endpoint for managing records (eCRF instances) in a study.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Dict, List, Optional, Union, cast from imednet.constants import HEADER_EMAIL_NOTIFY from imednet.core.endpoint.mixins import CreateEndpointMixin, EdcListGetEndpoint +from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol from imednet.models.jobs import Job from imednet.models.records import Record from imednet.validation.cache import SchemaCache, validate_record_data @@ -82,6 +83,35 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str] headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower() return headers + def _create_impl( + self, + study_key: str, + records_data: List[Dict[str, Any]], + email_notify: Union[bool, str, None] = None, + *, + schema: Optional[SchemaCache] = None, + is_async: bool = False, + ) -> Job | Awaitable[Job]: + path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) + client = self._get_client(is_async) + + if is_async: + return self._create_async( + cast(AsyncRequestorProtocol, client), + path, + json=records_data, + headers=headers, + parse_func=Job.from_json, + ) + + return self._create_sync( + cast(RequestorProtocol, client), + path, + json=records_data, + headers=headers, + parse_func=Job.from_json, + ) + def create( self, study_key: str, @@ -107,13 +137,9 @@ def create( Raises: ValueError: If email_notify contains invalid characters """ - path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) - return self._create_sync( - self._client, - path, - json=records_data, - headers=headers, - parse_func=Job.from_json, + return cast( + Job, + self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=False), ) async def async_create( @@ -143,8 +169,7 @@ async def async_create( Raises: ValueError: If email_notify contains invalid characters """ - client = self._require_async_client() - path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) - return await self._create_async( - client, path, json=records_data, headers=headers, parse_func=Job.from_json + return await cast( + Awaitable[Job], + self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=True), )