Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions imednet/core/endpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def _require_async_client(self) -> AsyncRequestorProtocol:
raise RuntimeError("Async client not configured")
return self._async_client

def _get_client(self, is_async: bool) -> RequestorProtocol | AsyncRequestorProtocol:
"""
Get the appropriate client for the execution context.

Args:
is_async: Whether an async client is required.

Returns:
The sync or async client instance.
"""
if is_async:
return self._require_async_client()
return self._client


class BaseEndpoint(EdcEndpointMixin, GenericEndpoint[T]):
"""
Expand Down
19 changes: 0 additions & 19 deletions imednet/core/endpoint/edc_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict
from urllib.parse import quote

if TYPE_CHECKING:
from imednet.core.context import Context
Expand Down Expand Up @@ -35,21 +34,3 @@ def _auto_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
if "studyKey" not in filters and self._ctx.default_study_key:
filters["studyKey"] = self._ctx.default_study_key
return filters

def _build_path(self, *segments: Any) -> str:
"""
Return an API path joined with :data:`BASE_PATH`.

Args:
*segments: URL path segments to append.

Returns:
The full API path string.
"""
parts = [self.BASE_PATH.strip("/")]
for seg in segments:
text = str(seg).strip("/")
if text:
# Encode path segments to prevent traversal and injection
parts.append(quote(text, safe=""))
return "/" + "/".join(parts)
10 changes: 6 additions & 4 deletions imednet/core/endpoint/mixins/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class GenericListEndpoint(GenericEndpoint[T], ListEndpointMixin[T]):
def _get_context(
self, is_async: bool
) -> tuple[RequestorProtocol | AsyncRequestorProtocol, type[Paginator] | type[AsyncPaginator]]:
client = self._get_client(is_async)
if is_async:
return self._require_async_client(), self.ASYNC_PAGINATOR_CLS
return self._client, self.PAGINATOR_CLS
return client, self.ASYNC_PAGINATOR_CLS
return client, self.PAGINATOR_CLS

def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]:
client, paginator = self._get_context(is_async)
Expand Down Expand Up @@ -94,10 +95,11 @@ class GenericListPathGetEndpoint(GenericListEndpoint[T], PathGetEndpointMixin[T]
"""Generic endpoint implementing ``list`` and ``get`` (via path) helpers."""

def get(self, study_key: Optional[str], item_id: Any) -> T:
return cast(T, self._get_impl_path(self._client, study_key=study_key, item_id=item_id))
client = self._get_client(is_async=False)
return cast(T, self._get_impl_path(client, study_key=study_key, item_id=item_id))

async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
client = self._require_async_client()
client = self._get_client(is_async=True)
return await cast(
Awaitable[T],
self._get_impl_path(client, study_key=study_key, item_id=item_id, is_async=True),
Expand Down
49 changes: 37 additions & 12 deletions imednet/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Endpoint for managing records (eCRF instances) in a study."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Awaitable, Dict, List, Optional, Union, cast

from imednet.constants import HEADER_EMAIL_NOTIFY
from imednet.core.endpoint.mixins import CreateEndpointMixin, EdcListGetEndpoint
from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol
from imednet.models.jobs import Job
from imednet.models.records import Record
from imednet.validation.cache import SchemaCache, validate_record_data
Expand Down Expand Up @@ -82,6 +83,35 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str]
headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower()
return headers

def _create_impl(
self,
study_key: str,
records_data: List[Dict[str, Any]],
email_notify: Union[bool, str, None] = None,
*,
schema: Optional[SchemaCache] = None,
is_async: bool = False,
) -> Job | Awaitable[Job]:
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
client = self._get_client(is_async)

if is_async:
return self._create_async(
cast(AsyncRequestorProtocol, client),
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)

return self._create_sync(
cast(RequestorProtocol, client),
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)

def create(
self,
study_key: str,
Expand All @@ -107,13 +137,9 @@ def create(
Raises:
ValueError: If email_notify contains invalid characters
"""
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
return self._create_sync(
self._client,
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
return cast(
Job,
self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=False),
)

async def async_create(
Expand Down Expand Up @@ -143,8 +169,7 @@ async def async_create(
Raises:
ValueError: If email_notify contains invalid characters
"""
client = self._require_async_client()
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
return await self._create_async(
client, path, json=records_data, headers=headers, parse_func=Job.from_json
return await cast(
Awaitable[Job],
self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=True),
)