Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.14.0"
__version__ = "0.14.1"

import requests

Expand Down
113 changes: 112 additions & 1 deletion cytetype/core/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,123 @@
import re
from typing import Any
from typing import Any, cast
from collections.abc import Sequence

import duckdb
import h5py
import hdf5plugin
import numpy as np
import pandas as pd
from pandas.api import types as ptypes
import scipy.sparse as sp
from anndata.abc import CSCDataset, CSRDataset

from ..config import logger


def _safe_column_dataset_name(
source_name: str,
column_index: int,
group: h5py.Group,
) -> str:
base = re.sub(r"[^A-Za-z0-9_.-]", "_", source_name).strip("_")
if not base:
base = f"column_{column_index}"
candidate = base
suffix = 1
while candidate in group:
candidate = f"{base}_{suffix}"
suffix += 1
return candidate


def _as_string_values(values: pd.Series | pd.Index | Sequence[Any]) -> np.ndarray:
series = pd.Series(values, copy=False)
return cast(np.ndarray, series.astype("string").fillna("").to_numpy(dtype=object))


def _write_var_metadata(
out_file_group: h5py.File,
n_cols: int,
var_df: pd.DataFrame,
var_names: pd.Index | Sequence[Any] | None,
) -> None:
if len(var_df) != n_cols:
raise ValueError(
f"`adata.var` row count ({len(var_df)}) does not match matrix columns ({n_cols})."
)

names_source: pd.Index | Sequence[Any] = (
var_names if var_names is not None else var_df.index
)
if len(names_source) != n_cols:
raise ValueError(
f"`var_names` length ({len(names_source)}) does not match matrix columns ({n_cols})."
)

info_group = out_file_group.create_group("info")
var_group = info_group.create_group("var")
text_dtype = h5py.string_dtype(encoding="utf-8")

var_group.create_dataset(
"var_names",
data=_as_string_values(names_source),
dtype=text_dtype,
)
var_group.create_dataset(
"index",
data=_as_string_values(var_df.index),
dtype=text_dtype,
)

columns_group = var_group.create_group("columns")
for i, col_name in enumerate(var_df.columns):
source_name = str(col_name)
dataset_name = _safe_column_dataset_name(source_name, i, columns_group)
series = var_df[col_name]

if isinstance(series.dtype, pd.CategoricalDtype):
dataset = columns_group.create_dataset(
dataset_name,
data=_as_string_values(series.astype("string")),
dtype=text_dtype,
)
elif ptypes.is_bool_dtype(series.dtype):
if series.isna().any():
bool_with_missing = series.astype("Int8").to_numpy(
dtype=np.int8,
na_value=-1,
)
dataset = columns_group.create_dataset(
dataset_name,
data=bool_with_missing,
dtype=np.int8,
)
dataset.attrs["missing_sentinel"] = -1
else:
dataset = columns_group.create_dataset(
dataset_name,
data=series.to_numpy(dtype=np.bool_),
dtype=np.bool_,
)
elif ptypes.is_numeric_dtype(series.dtype):
numeric_data = pd.to_numeric(series, errors="coerce").to_numpy()
dataset = columns_group.create_dataset(dataset_name, data=numeric_data)
else:
dataset = columns_group.create_dataset(
dataset_name,
data=_as_string_values(series),
dtype=text_dtype,
)

dataset.attrs["source_name"] = source_name
dataset.attrs["source_dtype"] = str(series.dtype)


def save_features_matrix(
out_file: str,
mat: Any,
var_df: pd.DataFrame | None = None,
var_names: pd.Index | Sequence[Any] | None = None,
min_chunk_size: int = 10_000_000,
col_batch: int | None = None,
) -> None:
Expand Down Expand Up @@ -82,6 +185,14 @@ def save_features_matrix(

group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64))

if var_df is not None:
_write_var_metadata(
out_file_group=f,
n_cols=n_cols,
var_df=var_df,
var_names=var_names,
)


def save_obs_duckdb(
out_file: str,
Expand Down
2 changes: 2 additions & 0 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def _build_and_upload_artifacts(
save_features_matrix(
out_file=vars_h5_path,
mat=self.adata.X,
var_df=self.adata.var,
var_names=self.adata.var_names,
)

logger.info("Saving obs.duckdb artifact from observation metadata...")
Expand Down
34 changes: 34 additions & 0 deletions tests/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import h5py
import anndata
from pathlib import Path

from cytetype.core.artifacts import save_features_matrix


def test_save_features_matrix_writes_var_metadata(
tmp_path: Path,
mock_adata: anndata.AnnData,
) -> None:
out_path = tmp_path / "vars.h5"

save_features_matrix(
out_file=str(out_path),
mat=mock_adata.X,
var_df=mock_adata.var,
var_names=mock_adata.var_names,
col_batch=10,
)

with h5py.File(out_path, "r") as f:
assert "vars" in f
assert "info" in f
assert "var" in f["info"]
assert "columns" in f["info/var"]
assert len(f["info/var/var_names"]) == mock_adata.n_vars
assert len(f["info/var/index"]) == mock_adata.n_vars

columns_group = f["info/var/columns"]
assert len(columns_group.keys()) == mock_adata.var.shape[1]
for dataset in columns_group.values():
assert "source_name" in dataset.attrs
assert "source_dtype" in dataset.attrs