diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 2b8ad3e..2caccb7 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.14.0" +__version__ = "0.14.1" import requests diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index a70b757..13fd3ba 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -1,20 +1,123 @@ import re -from typing import Any +from typing import Any, cast +from collections.abc import Sequence import duckdb import h5py import hdf5plugin import numpy as np import pandas as pd +from pandas.api import types as ptypes import scipy.sparse as sp from anndata.abc import CSCDataset, CSRDataset from ..config import logger +def _safe_column_dataset_name( + source_name: str, + column_index: int, + group: h5py.Group, +) -> str: + base = re.sub(r"[^A-Za-z0-9_.-]", "_", source_name).strip("_") + if not base: + base = f"column_{column_index}" + candidate = base + suffix = 1 + while candidate in group: + candidate = f"{base}_{suffix}" + suffix += 1 + return candidate + + +def _as_string_values(values: pd.Series | pd.Index | Sequence[Any]) -> np.ndarray: + series = pd.Series(values, copy=False) + return cast(np.ndarray, series.astype("string").fillna("").to_numpy(dtype=object)) + + +def _write_var_metadata( + out_file_group: h5py.File, + n_cols: int, + var_df: pd.DataFrame, + var_names: pd.Index | Sequence[Any] | None, +) -> None: + if len(var_df) != n_cols: + raise ValueError( + f"`adata.var` row count ({len(var_df)}) does not match matrix columns ({n_cols})." + ) + + names_source: pd.Index | Sequence[Any] = ( + var_names if var_names is not None else var_df.index + ) + if len(names_source) != n_cols: + raise ValueError( + f"`var_names` length ({len(names_source)}) does not match matrix columns ({n_cols})." + ) + + info_group = out_file_group.create_group("info") + var_group = info_group.create_group("var") + text_dtype = h5py.string_dtype(encoding="utf-8") + + var_group.create_dataset( + "var_names", + data=_as_string_values(names_source), + dtype=text_dtype, + ) + var_group.create_dataset( + "index", + data=_as_string_values(var_df.index), + dtype=text_dtype, + ) + + columns_group = var_group.create_group("columns") + for i, col_name in enumerate(var_df.columns): + source_name = str(col_name) + dataset_name = _safe_column_dataset_name(source_name, i, columns_group) + series = var_df[col_name] + + if isinstance(series.dtype, pd.CategoricalDtype): + dataset = columns_group.create_dataset( + dataset_name, + data=_as_string_values(series.astype("string")), + dtype=text_dtype, + ) + elif ptypes.is_bool_dtype(series.dtype): + if series.isna().any(): + bool_with_missing = series.astype("Int8").to_numpy( + dtype=np.int8, + na_value=-1, + ) + dataset = columns_group.create_dataset( + dataset_name, + data=bool_with_missing, + dtype=np.int8, + ) + dataset.attrs["missing_sentinel"] = -1 + else: + dataset = columns_group.create_dataset( + dataset_name, + data=series.to_numpy(dtype=np.bool_), + dtype=np.bool_, + ) + elif ptypes.is_numeric_dtype(series.dtype): + numeric_data = pd.to_numeric(series, errors="coerce").to_numpy() + dataset = columns_group.create_dataset(dataset_name, data=numeric_data) + else: + dataset = columns_group.create_dataset( + dataset_name, + data=_as_string_values(series), + dtype=text_dtype, + ) + + dataset.attrs["source_name"] = source_name + dataset.attrs["source_dtype"] = str(series.dtype) + + def save_features_matrix( out_file: str, mat: Any, + var_df: pd.DataFrame | None = None, + var_names: pd.Index | Sequence[Any] | None = None, min_chunk_size: int = 10_000_000, col_batch: int | None = None, ) -> None: @@ -82,6 +185,14 @@ def save_features_matrix( group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) + if var_df is not None: + _write_var_metadata( + out_file_group=f, + n_cols=n_cols, + var_df=var_df, + var_names=var_names, + ) + def save_obs_duckdb( out_file: str, diff --git a/cytetype/main.py b/cytetype/main.py index af606f3..982b8b1 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -204,6 +204,8 @@ def _build_and_upload_artifacts( save_features_matrix( out_file=vars_h5_path, mat=self.adata.X, + var_df=self.adata.var, + var_names=self.adata.var_names, ) logger.info("Saving obs.duckdb artifact from observation metadata...") diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 0000000..a041a62 --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,34 @@ +import h5py +import anndata +from pathlib import Path + +from cytetype.core.artifacts import save_features_matrix + + +def test_save_features_matrix_writes_var_metadata( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + out_path = tmp_path / "vars.h5" + + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + var_df=mock_adata.var, + var_names=mock_adata.var_names, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "vars" in f + assert "info" in f + assert "var" in f["info"] + assert "columns" in f["info/var"] + assert len(f["info/var/var_names"]) == mock_adata.n_vars + assert len(f["info/var/index"]) == mock_adata.n_vars + + columns_group = f["info/var/columns"] + assert len(columns_group.keys()) == mock_adata.var.shape[1] + for dataset in columns_group.values(): + assert "source_name" in dataset.attrs + assert "source_dtype" in dataset.attrs