Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions imednet/core/endpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,6 @@ def _require_async_client(self) -> AsyncRequestorProtocol:
raise RuntimeError("Async client not configured")
return self._async_client

def _get_client(self, is_async: bool) -> RequestorProtocol | AsyncRequestorProtocol:
"""
Get the appropriate client for the execution context.

Args:
is_async: Whether an async client is required.

Returns:
The sync or async client instance.
"""
if is_async:
return self._require_async_client()
return self._require_sync_client()


class BaseEndpoint(EdcEndpointMixin, GenericEndpoint[T]):
"""
Expand Down
54 changes: 17 additions & 37 deletions imednet/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Endpoint for managing records (eCRF instances) in a study."""

from typing import Any, Awaitable, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union

from imednet.constants import HEADER_EMAIL_NOTIFY
from imednet.core.endpoint.mixins import CreateEndpointMixin, EdcListGetEndpoint
from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol
from imednet.models.jobs import Job
from imednet.models.records import Record
from imednet.validation.cache import SchemaCache, validate_record_data
Expand Down Expand Up @@ -83,35 +82,6 @@ def _build_headers(self, email_notify: Union[bool, str, None]) -> Dict[str, str]
headers[HEADER_EMAIL_NOTIFY] = str(email_notify).lower()
return headers

def _create_impl(
self,
study_key: str,
records_data: List[Dict[str, Any]],
email_notify: Union[bool, str, None] = None,
*,
schema: Optional[SchemaCache] = None,
is_async: bool = False,
) -> Job | Awaitable[Job]:
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
client = self._get_client(is_async)

if is_async:
return self._create_async(
cast(AsyncRequestorProtocol, client),
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)

return self._create_sync(
cast(RequestorProtocol, client),
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)

def create(
self,
study_key: str,
Expand All @@ -137,9 +107,14 @@ def create(
Raises:
ValueError: If email_notify contains invalid characters
"""
return cast(
Job,
self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=False),
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
client = self._require_sync_client()
return self._create_sync(
client,
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)

async def async_create(
Expand Down Expand Up @@ -169,7 +144,12 @@ async def async_create(
Raises:
ValueError: If email_notify contains invalid characters
"""
return await cast(
Awaitable[Job],
self._create_impl(study_key, records_data, email_notify, schema=schema, is_async=True),
path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema)
client = self._require_async_client()
return await self._create_async(
client,
path,
json=records_data,
headers=headers,
parse_func=Job.from_json,
)