Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.15.0"
__version__ = "0.16.0"

import requests

Expand Down
75 changes: 66 additions & 9 deletions cytetype/api/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import math
import time
import threading
from pathlib import Path
from typing import Any
from concurrent.futures import ThreadPoolExecutor

from .transport import HTTPTransport
from .progress import ProgressDisplay
Expand All @@ -21,6 +24,7 @@ def _upload_file(
file_kind: UploadFileKind,
file_path: str,
timeout: float | tuple[float, float] = (30.0, 3600.0),
max_workers: int = 4,
) -> UploadResponse:
path_obj = Path(file_path)
if not path_obj.is_file():
Expand All @@ -34,33 +38,86 @@ def _upload_file(
)

transport = HTTPTransport(base_url, auth_token)
with path_obj.open("rb") as f:
_, response = transport.post_binary(
f"upload/{file_kind}",
data=f,

# Step 1 – Initiate chunked upload
_, init_data = transport.post_empty(f"upload/{file_kind}/initiate", timeout=timeout)
upload_id: str = init_data["upload_id"]
chunk_size: int = init_data["chunk_size_bytes"]

server_max = init_data.get("max_size_bytes")
if server_max is not None and size_bytes > server_max:
raise ValueError(
f"{file_kind} exceeds server upload limit: "
f"{size_bytes} bytes > {server_max} bytes"
)

n_chunks = math.ceil(size_bytes / chunk_size) if size_bytes > 0 else 0

# Step 2 – Upload chunks in parallel.
# Each worker thread gets its own HTTPTransport (and thus its own
# requests.Session / connection pool) for thread safety.
# Memory is bounded to ~max_workers × chunk_size because each thread
# reads its chunk on demand via seek+read.
_tls = threading.local()

def _upload_chunk(chunk_idx: int) -> None:
if not hasattr(_tls, "transport"):
_tls.transport = HTTPTransport(base_url, auth_token)
offset = chunk_idx * chunk_size
read_size = min(chunk_size, size_bytes - offset)
with path_obj.open("rb") as f:
f.seek(offset)
chunk_data = f.read(read_size)
_tls.transport.put_binary(
f"upload/{upload_id}/chunk/{chunk_idx}",
data=chunk_data,
timeout=timeout,
)
return UploadResponse(**response)

if n_chunks > 0:
effective_workers = min(max_workers, n_chunks)
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
list(pool.map(_upload_chunk, range(n_chunks)))

# Step 3 – Complete upload (returns same UploadResponse shape as before)
_, complete_data = transport.post_empty(
f"upload/{upload_id}/complete", timeout=timeout
)
return UploadResponse(**complete_data)


def upload_obs_duckdb(
base_url: str,
auth_token: str | None,
file_path: str,
timeout: float | tuple[float, float] = (30.0, 3600.0),
max_workers: int = 4,
) -> UploadResponse:
"""Upload obs duckdb file and return upload metadata."""
return _upload_file(base_url, auth_token, "obs_duckdb", file_path, timeout=timeout)
return _upload_file(
base_url,
auth_token,
"obs_duckdb",
file_path,
timeout=timeout,
max_workers=max_workers,
)


def upload_vars_h5(
base_url: str,
auth_token: str | None,
file_path: str,
timeout: float | tuple[float, float] = (30.0, 3600.0),
max_workers: int = 4,
) -> UploadResponse:
"""Upload vars h5 file and return upload metadata."""
return _upload_file(base_url, auth_token, "vars_h5", file_path, timeout=timeout)
return _upload_file(
base_url,
auth_token,
"vars_h5",
file_path,
timeout=timeout,
max_workers=max_workers,
)


def submit_annotation_job(
Expand Down
30 changes: 27 additions & 3 deletions cytetype/api/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,41 @@ def post(
self._handle_request_error(e)
raise # For type checker

def post_binary(
def post_empty(
self,
endpoint: str,
timeout: float | tuple[float, float] = 30.0,
) -> tuple[int, dict[str, Any]]:
"""Make POST request with no body."""
url = f"{self.base_url}/{endpoint.lstrip('/')}"

try:
response = self.session.post(
url,
headers=self._build_headers(),
timeout=timeout,
)

if not response.ok:
self._parse_error(response)

return response.status_code, response.json()

except requests.RequestException as e:
self._handle_request_error(e)
raise # For type checker

def put_binary(
self,
endpoint: str,
data: bytes | BinaryIO,
timeout: float | tuple[float, float] = (30.0, 3600.0),
) -> tuple[int, dict[str, Any]]:
"""Make POST request with raw binary body (application/octet-stream)."""
"""Make PUT request with raw binary body (application/octet-stream)."""
url = f"{self.base_url}/{endpoint.lstrip('/')}"

try:
response = self.session.post(
response = self.session.put(
url,
data=data,
headers=self._build_headers(content_type="application/octet-stream"),
Expand Down
7 changes: 7 additions & 0 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _build_and_upload_artifacts(
vars_h5_path: str,
obs_duckdb_path: str,
upload_timeout_seconds: int,
upload_max_workers: int = 4,
) -> dict[str, str]:
"""Build local artifacts and upload them before annotate."""
logger.info("Saving vars.h5 artifact from normalized counts...")
Expand All @@ -220,6 +221,7 @@ def _build_and_upload_artifacts(
self.auth_token,
obs_duckdb_path,
timeout=(30.0, float(upload_timeout_seconds)),
max_workers=upload_max_workers,
)
if obs_upload.file_kind != "obs_duckdb":
raise ValueError(
Expand All @@ -232,6 +234,7 @@ def _build_and_upload_artifacts(
self.auth_token,
vars_h5_path,
timeout=(30.0, float(upload_timeout_seconds)),
max_workers=upload_max_workers,
)
if vars_upload.file_kind != "vars_h5":
raise ValueError(
Expand Down Expand Up @@ -267,6 +270,7 @@ def run(
vars_h5_path: str = "vars.h5",
obs_duckdb_path: str = "obs.duckdb",
upload_timeout_seconds: int = 3600,
upload_max_workers: int = 4,
cleanup_artifacts: bool = False,
require_artifacts: bool = True,
show_progress: bool = True,
Expand Down Expand Up @@ -309,6 +313,8 @@ def run(
Defaults to "obs.duckdb".
upload_timeout_seconds (int, optional): Socket read timeout used for each artifact upload.
Defaults to 3600.
upload_max_workers (int, optional): Number of parallel threads used to upload file
chunks. Each worker holds one chunk in memory (~100 MB). Defaults to 4.
cleanup_artifacts (bool, optional): Whether to delete generated artifact files after run
completes or fails. Defaults to False.
require_artifacts (bool, optional): Whether to raise an error if artifact building or
Expand Down Expand Up @@ -371,6 +377,7 @@ def run(
vars_h5_path=vars_h5_path,
obs_duckdb_path=obs_duckdb_path,
upload_timeout_seconds=upload_timeout_seconds,
upload_max_workers=upload_max_workers,
)
payload["uploaded_files"] = uploaded_file_refs
except Exception as exc:
Expand Down