Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.16.0"
__version__ = "0.16.1"

import requests

Expand Down
64 changes: 55 additions & 9 deletions cytetype/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .transport import HTTPTransport
from .progress import ProgressDisplay
from .exceptions import JobFailedError, TimeoutError, APIError
from .exceptions import JobFailedError, TimeoutError, APIError, NetworkError
from .schemas import UploadResponse, UploadFileKind
from ..config import logger

Expand All @@ -17,6 +17,9 @@
"vars_h5": 10 * 1024 * 1024 * 1024, # 10GB
}

_CHUNK_RETRY_DELAYS = (1, 5, 20)
_RETRYABLE_API_ERROR_CODES = frozenset({"INTERNAL_ERROR", "HTTP_ERROR"})


def _upload_file(
base_url: str,
Expand Down Expand Up @@ -59,6 +62,8 @@ def _upload_file(
# Memory is bounded to ~max_workers × chunk_size because each thread
# reads its chunk on demand via seek+read.
_tls = threading.local()
_progress_lock = threading.Lock()
_chunks_done = [0]

def _upload_chunk(chunk_idx: int) -> None:
if not hasattr(_tls, "transport"):
Expand All @@ -68,16 +73,57 @@ def _upload_chunk(chunk_idx: int) -> None:
with path_obj.open("rb") as f:
f.seek(offset)
chunk_data = f.read(read_size)
_tls.transport.put_binary(
f"upload/{upload_id}/chunk/{chunk_idx}",
data=chunk_data,
timeout=timeout,
)

last_exc: Exception | None = None
for attempt in range(1 + len(_CHUNK_RETRY_DELAYS)):
try:
_tls.transport.put_binary(
f"upload/{upload_id}/chunk/{chunk_idx}",
data=chunk_data,
timeout=timeout,
)
with _progress_lock:
_chunks_done[0] += 1
done = _chunks_done[0]
pct = 100 * done / n_chunks
print(
f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)",
end="",
flush=True,
)
return
except (NetworkError, TimeoutError) as exc:
last_exc = exc
except APIError as exc:
if exc.error_code in _RETRYABLE_API_ERROR_CODES:
last_exc = exc
else:
raise

if attempt < len(_CHUNK_RETRY_DELAYS):
delay = _CHUNK_RETRY_DELAYS[attempt]
logger.warning(
"Chunk %d/%d upload failed (attempt %d/%d), retrying in %ds: %s",
chunk_idx + 1,
n_chunks,
attempt + 1,
1 + len(_CHUNK_RETRY_DELAYS),
delay,
last_exc,
)
time.sleep(delay)

raise last_exc # type: ignore[misc]

if n_chunks > 0:
effective_workers = min(max_workers, n_chunks)
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
list(pool.map(_upload_chunk, range(n_chunks)))
try:
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
list(pool.map(_upload_chunk, range(n_chunks)))
print(f"\r \033[92m✓\033[0m Uploaded {n_chunks}/{n_chunks} chunks (100%)")
except BaseException:
print() # ensure newline on failure
raise

# Step 3 – Complete upload (returns same UploadResponse shape as before)
_, complete_data = transport.post_empty(
Expand Down Expand Up @@ -293,7 +339,7 @@ def wait_for_completion(
if job_status == "completed":
if progress:
progress.finalize(cluster_status)
logger.info(f"Job {job_id} completed successfully.")
logger.success(f"Job {job_id} completed successfully.")
return fetch_job_results(base_url, auth_token, job_id)

elif job_status == "failed":
Expand Down
20 changes: 18 additions & 2 deletions cytetype/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
from loguru import logger
from __future__ import annotations

import sys
from typing import TYPE_CHECKING

from loguru import logger

if TYPE_CHECKING:
from loguru import Record

logger.remove()


def _log_format(record: Record) -> str:
if record["level"].name == "WARNING":
return "⚠️ {message}\n"
if record["level"].name == "SUCCESS":
return "\033[92m✓\033[0m {message}\n"
return "{message}\n"


logger.add(
sys.stdout,
level="INFO",
format="{message}",
format=_log_format,
)
12 changes: 6 additions & 6 deletions cytetype/core/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def store_annotations(
_check_unannotated_clusters(result_data, clusters)

# Log success
logger.success(
f"Annotations successfully added to `adata.obs['{results_prefix}_annotation_{group_key}']`\n"
f"Ontology terms added to `adata.obs['{results_prefix}_cellOntologyTerm_{group_key}']`\n"
f"Ontology term IDs added to `adata.obs['{results_prefix}_ontologyTermID_{group_key}']`\n"
f"Cell states added to `adata.obs['{results_prefix}_cellState_{group_key}']`\n"
f"Full results added to `adata.uns['{results_prefix}_results']`."
logger.info(
f" Annotation labels → adata.obs['{results_prefix}_annotation_{group_key}']\n"
f" Cell Ontology terms adata.obs['{results_prefix}_cellOntologyTerm_{group_key}']\n"
f"Cell Ontology term IDs adata.obs['{results_prefix}_ontologyTermID_{group_key}']\n"
f" Cell states adata.obs['{results_prefix}_cellState_{group_key}']\n"
f" Full results adata.uns['{results_prefix}_results']"
)


Expand Down
125 changes: 73 additions & 52 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,52 +199,69 @@ def _build_and_upload_artifacts(
obs_duckdb_path: str,
upload_timeout_seconds: int,
upload_max_workers: int = 4,
) -> dict[str, str]:
"""Build local artifacts and upload them before annotate."""
logger.info("Saving vars.h5 artifact from normalized counts...")
save_features_matrix(
out_file=vars_h5_path,
mat=self.adata.X,
var_df=self.adata.var,
var_names=self.adata.var_names,
)
) -> tuple[dict[str, str], list[tuple[str, Exception]]]:
"""Build and upload each artifact as an independent unit.

logger.info("Saving obs.duckdb artifact from observation metadata...")
save_obs_duckdb_file(
out_file=obs_duckdb_path,
obs_df=self.adata.obs,
)
Returns (uploaded_ids, errors) so the caller can decide whether
partial success is acceptable.
"""
uploaded: dict[str, str] = {}
errors: list[tuple[str, Exception]] = []
timeout = (30.0, float(upload_timeout_seconds))

logger.info("Uploading obs.duckdb artifact...")
obs_upload = upload_obs_duckdb_file(
self.api_url,
self.auth_token,
obs_duckdb_path,
timeout=(30.0, float(upload_timeout_seconds)),
max_workers=upload_max_workers,
)
if obs_upload.file_kind != "obs_duckdb":
raise ValueError(
f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}"
# --- vars.h5 (save then upload) ---
try:
logger.info("Saving vars.h5 artifact from normalized counts...")
save_features_matrix(
out_file=vars_h5_path,
mat=self.adata.X,
var_df=self.adata.var,
var_names=self.adata.var_names,
)
logger.info("Uploading vars.h5 artifact...")
vars_upload = upload_vars_h5_file(
self.api_url,
self.auth_token,
vars_h5_path,
timeout=timeout,
max_workers=upload_max_workers,
)
if vars_upload.file_kind != "vars_h5":
raise ValueError(
f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}"
)
uploaded["vars_h5"] = vars_upload.upload_id
except Exception as exc:
logger.warning(f"vars.h5 artifact failed: {exc}")
errors.append(("vars_h5", exc))

logger.info("Uploading vars.h5 artifact...")
vars_upload = upload_vars_h5_file(
self.api_url,
self.auth_token,
vars_h5_path,
timeout=(30.0, float(upload_timeout_seconds)),
max_workers=upload_max_workers,
)
if vars_upload.file_kind != "vars_h5":
raise ValueError(
f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}"
print()

# --- obs.duckdb (save then upload) ---
try:
logger.info("Saving obs.duckdb artifact from observation metadata...")
save_obs_duckdb_file(
out_file=obs_duckdb_path,
obs_df=self.adata.obs,
)
logger.info("Uploading obs.duckdb artifact...")
obs_upload = upload_obs_duckdb_file(
self.api_url,
self.auth_token,
obs_duckdb_path,
timeout=timeout,
max_workers=upload_max_workers,
)
if obs_upload.file_kind != "obs_duckdb":
raise ValueError(
f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}"
)
uploaded["obs_duckdb"] = obs_upload.upload_id
except Exception as exc:
logger.warning(f"obs.duckdb artifact failed: {exc}")
errors.append(("obs_duckdb", exc))

return {
"obs_duckdb": obs_upload.upload_id,
"vars_h5": vars_upload.upload_id,
}
return uploaded, errors

@staticmethod
def _cleanup_artifact_files(paths: list[str]) -> None:
Expand Down Expand Up @@ -372,26 +389,29 @@ def run(

artifact_paths = [vars_h5_path, obs_duckdb_path]
try:
try:
uploaded_file_refs = self._build_and_upload_artifacts(
vars_h5_path=vars_h5_path,
obs_duckdb_path=obs_duckdb_path,
upload_timeout_seconds=upload_timeout_seconds,
upload_max_workers=upload_max_workers,
)
uploaded_file_refs, artifact_errors = self._build_and_upload_artifacts(
vars_h5_path=vars_h5_path,
obs_duckdb_path=obs_duckdb_path,
upload_timeout_seconds=upload_timeout_seconds,
upload_max_workers=upload_max_workers,
)
if uploaded_file_refs:
payload["uploaded_files"] = uploaded_file_refs
except Exception as exc:

if artifact_errors:
failed_names = ", ".join(name for name, _ in artifact_errors)
if require_artifacts:
logger.error(
"Artifact build/upload failed. "
f"Artifact build/upload failed for: {failed_names}. "
"Rerun with `require_artifacts=False` to skip this error.\n"
"Please report the error below in a new issue at "
"https://github.com/NygenAnalytics/CyteType\n"
f"({type(exc).__name__}: {exc})"
f"({type(artifact_errors[0][1]).__name__}: {str(artifact_errors[0][1]).strip()})"
)
raise
raise artifact_errors[0][1]
logger.warning(
"Artifact build/upload failed. Continuing without artifacts. "
f"Artifact build/upload failed for: {failed_names}. "
"Continuing without those artifacts. "
"Set `require_artifacts=True` to see the full traceback."
)

Expand All @@ -400,6 +420,7 @@ def run(
save_query_to_file(payload["input_data"], query_filename)

# Submit job and store details
print()
job_id = submit_annotation_job(self.api_url, self.auth_token, payload)
store_job_details(self.adata, job_id, self.api_url, results_prefix)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_cytetype_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def test_cytetype_run_artifact_failure_continues_when_not_required(
mock_api_response: dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test run() proceeds without uploaded_files when require_artifacts=False."""
"""Test run() proceeds with partial uploads when require_artifacts=False.

vars.h5 save fails but obs.duckdb still succeeds independently.
"""
mock_submit.return_value = "job_no_artifacts"
mock_wait.return_value = mock_api_response

Expand All @@ -197,9 +200,9 @@ def test_cytetype_run_artifact_failure_continues_when_not_required(
assert result is not None
assert mock_submit.called

# Payload must not contain uploaded_files
# obs.duckdb should have succeeded independently
payload = mock_submit.call_args.args[2]
assert "uploaded_files" not in payload
assert payload["uploaded_files"] == {"obs_duckdb": "obs_upload_123"}


@patch("cytetype.main.wait_for_completion")
Expand Down