From faa5ef502a364809a32656c6942d8d501091610c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 14 Apr 2026 14:27:49 +0200 Subject: [PATCH 1/2] feat: add IndexTransform library for composable, lazy coordinate mappings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new `src/zarr/core/transforms/` package implementing TensorStore-inspired index transforms. The core idea: every indexing operation (slicing, fancy indexing, etc.) produces a coordinate mapping from user space to storage space. These mappings compose lazily — no I/O until explicitly resolved. Key types: - `IndexDomain` — rectangular region in N-dimensional integer space - `ConstantMap`, `DimensionMap`, `ArrayMap` — three representations of a set of storage coordinates (singleton, arithmetic progression, explicit enumeration) - `IndexTransform` — pairs an input domain with output maps (one per storage dim) - `compose(outer, inner)` — chain two transforms Key operations on IndexTransform: - `__getitem__`, `.oindex[]`, `.vindex[]` — indexing produces new transforms - `.intersect(domain)` — restrict to coordinates within a region (chunk resolution) - `.translate(shift)` — shift coordinates (make chunk-local) The transform library is standalone with no dependency on Array. Includes comprehensive test suite (143 tests covering all types, operations, composition, chunk resolution, and edge cases). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/zarr/core/array.py | 601 +++++++++-- src/zarr/core/transforms/__init__.py | 30 + src/zarr/core/transforms/chunk_resolution.py | 207 ++++ src/zarr/core/transforms/composition.py | 113 +++ src/zarr/core/transforms/domain.py | 178 ++++ src/zarr/core/transforms/output_map.py | 83 ++ src/zarr/core/transforms/transform.py | 932 ++++++++++++++++++ tests/test_array.py | 3 +- tests/test_lazy_indexing.py | 164 +++ tests/test_transforms/__init__.py | 0 .../test_transforms/test_chunk_resolution.py | 178 ++++ tests/test_transforms/test_composition.py | 166 ++++ tests/test_transforms/test_domain.py | 202 ++++ tests/test_transforms/test_output_map.py | 56 ++ tests/test_transforms/test_transform.py | 516 ++++++++++ 15 files changed, 3369 insertions(+), 60 deletions(-) create mode 100644 src/zarr/core/transforms/__init__.py create mode 100644 src/zarr/core/transforms/chunk_resolution.py create mode 100644 src/zarr/core/transforms/composition.py create mode 100644 src/zarr/core/transforms/domain.py create mode 100644 src/zarr/core/transforms/output_map.py create mode 100644 src/zarr/core/transforms/transform.py create mode 100644 tests/test_lazy_indexing.py create mode 100644 tests/test_transforms/__init__.py create mode 100644 tests/test_transforms/test_chunk_resolution.py create mode 100644 tests/test_transforms/test_composition.py create mode 100644 tests/test_transforms/test_domain.py create mode 100644 tests/test_transforms/test_output_map.py create mode 100644 tests/test_transforms/test_transform.py diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 4736805b9d..1d087a7945 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -96,12 +96,18 @@ VIndex, _iter_grid, _iter_regions, + boundscheck_indices, check_fields, check_no_multi_fields, + ensure_tuple, + is_basic_selection, + is_coordinate_selection, is_pure_fancy_indexing, is_pure_orthogonal_indexing, is_scalar, pop_fields, + replace_lists, + wraparound_indices, ) from zarr.core.metadata import ( ArrayMetadata, @@ -126,6 +132,13 @@ resolve_chunks, ) from zarr.core.sync import sync +from zarr.core.transforms.chunk_resolution import iter_chunk_transforms, sub_transform_to_selections +from zarr.core.transforms.output_map import ArrayMap +from zarr.core.transforms.transform import ( + IndexTransform, + _normalize_negative_indices, + selection_to_transform, +) from zarr.errors import ( ArrayNotFoundError, ChunkNotFoundError, @@ -329,6 +342,7 @@ class AsyncArray[T_ArrayMetadata: (ArrayV2Metadata, ArrayV3Metadata)]: store_path: StorePath codec_pipeline: CodecPipeline = field(init=False) _chunk_grid: ChunkGrid = field(init=False) + _transform: IndexTransform = field(init=False) config: ArrayConfig @overload @@ -365,6 +379,7 @@ def __init__( "codec_pipeline", create_codec_pipeline(metadata=metadata_parsed, store=store_path.store), ) + object.__setattr__(self, "_transform", IndexTransform.from_shape(metadata_parsed.shape)) # this overload defines the function signature when zarr_format is 2 @overload @@ -1040,6 +1055,17 @@ async def example(): _metadata_dict = cast("ArrayMetadataJSON_V3", metadata_dict) return cls(store_path=store_path, metadata=_metadata_dict) + def _with_transform(self, transform: IndexTransform) -> AsyncArray[T_ArrayMetadata]: + """Return a new AsyncArray sharing storage but with a different transform.""" + new = object.__new__(type(self)) + object.__setattr__(new, "metadata", self.metadata) + object.__setattr__(new, "store_path", self.store_path) + object.__setattr__(new, "config", self.config) + object.__setattr__(new, "_chunk_grid", self._chunk_grid) + object.__setattr__(new, "codec_pipeline", self.codec_pipeline) + object.__setattr__(new, "_transform", transform) + return new + @property def store(self) -> Store: return self.store_path.store @@ -1058,7 +1084,7 @@ def ndim(self) -> int: int The number of dimensions in the Array. """ - return len(self.metadata.shape) + return len(self.shape) @property def shape(self) -> tuple[int, ...]: @@ -1069,6 +1095,11 @@ def shape(self) -> tuple[int, ...]: tuple The shape of the Array. """ + return self._transform.domain.shape + + @property + def storage_shape(self) -> tuple[int, ...]: + """The shape of the underlying storage array (ignoring any view transform).""" return self.metadata.shape @property @@ -1828,6 +1859,40 @@ async def _set_selection( fields=fields, ) + async def _get_selection_t( + self, + transform: IndexTransform, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, + ) -> NDArrayLikeOrScalar: + return await _get_selection_via_transform( + self.store_path, + self.metadata, + self.config, + transform, + self.codec_pipeline, + prototype=prototype, + out=out, + ) + + async def _set_selection_t( + self, + transform: IndexTransform, + value: npt.ArrayLike, + *, + prototype: BufferPrototype, + ) -> None: + return await _set_selection_via_transform( + self.store_path, + self.metadata, + self.config, + transform, + value, + self.codec_pipeline, + prototype=prototype, + ) + async def setitem( self, selection: BasicSelection, @@ -2086,6 +2151,11 @@ def _chunk_grid(self) -> ChunkGrid: """The chunk grid for this array, bound to the array's shape.""" return self.async_array._chunk_grid + def _with_transform(self, transform: IndexTransform) -> Array[T_ArrayMetadata]: + """Return a new Array sharing storage but with a different transform.""" + new_async = self._async_array._with_transform(transform) + return type(self)(new_async) + @classmethod @deprecated("Use zarr.create_array instead.", category=ZarrDeprecationWarning) def create( @@ -3225,14 +3295,19 @@ def get_basic_selection( if prototype is None: prototype = default_buffer_prototype() - return sync( - self.async_array._get_selection( - BasicIndexer(selection, self.shape, self._chunk_grid), - out=out, - fields=fields, - prototype=prototype, + if fields is not None: + # Fall back to legacy path for structured dtype field selection + return sync( + self.async_array._get_selection( + BasicIndexer(selection, self.shape, self._chunk_grid), + out=out, + fields=fields, + prototype=prototype, + ) ) - ) + selection = _normalize_negative_indices(selection, self.shape) + transform = selection_to_transform(selection, self._async_array._transform, "basic") + return sync(self._async_array._get_selection_t(transform, out=out, prototype=prototype)) def set_basic_selection( self, @@ -3334,8 +3409,16 @@ def set_basic_selection( """ if prototype is None: prototype = default_buffer_prototype() - indexer = BasicIndexer(selection, self.shape, self._chunk_grid) - sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + if fields is not None: + # Fall back to legacy path for structured dtype field selection + indexer = BasicIndexer(selection, self.shape, self._chunk_grid) + sync( + self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) + ) + return + selection = _normalize_negative_indices(selection, self.shape) + transform = selection_to_transform(selection, self._async_array._transform, "basic") + sync(self._async_array._set_selection_t(transform, value, prototype=prototype)) def get_orthogonal_selection( self, @@ -3462,12 +3545,17 @@ def get_orthogonal_selection( """ if prototype is None: prototype = default_buffer_prototype() - indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid) - return sync( - self.async_array._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + if fields is not None or not is_basic_selection(selection): + # Fall back to legacy path for structured dtypes or advanced selections + indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid) + return sync( + self.async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) ) - ) + selection = _normalize_negative_indices(selection, self.shape) + transform = selection_to_transform(selection, self._async_array._transform, "basic") + return sync(self._async_array._get_selection_t(transform, out=out, prototype=prototype)) def set_orthogonal_selection( self, @@ -3581,10 +3669,16 @@ def set_orthogonal_selection( """ if prototype is None: prototype = default_buffer_prototype() - indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid) - return sync( - self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) - ) + if fields is not None or not is_basic_selection(selection): + # Fall back to legacy path for structured dtypes or advanced selections + indexer = OrthogonalIndexer(selection, self.shape, self._chunk_grid) + sync( + self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) + ) + return + selection = _normalize_negative_indices(selection, self.shape) + transform = selection_to_transform(selection, self._async_array._transform, "basic") + sync(self._async_array._set_selection_t(transform, value, prototype=prototype)) def get_mask_selection( self, @@ -3669,12 +3763,28 @@ def get_mask_selection( if prototype is None: prototype = default_buffer_prototype() - indexer = MaskIndexer(mask, self.shape, self._chunk_grid) - return sync( - self.async_array._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + if fields is not None: + indexer = MaskIndexer(mask, self.shape, self._chunk_grid) + return sync( + self.async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) ) - ) + # Unwrap if VIndex passed a tuple + if isinstance(mask, tuple) and len(mask) == 1: + mask = mask[0] + # Validate mask + mask_arr = np.asarray(mask) + if mask_arr.dtype != np.bool_: + raise IndexError("invalid mask selection; expected Boolean array") + if mask_arr.shape != self.shape: + raise IndexError( + f"invalid mask selection; expected Boolean array with shape {self.shape}, " + f"got {mask_arr.shape}" + ) + selection = (mask_arr,) + transform = selection_to_transform(selection, self._async_array._transform, "vectorized") + return sync(self._async_array._get_selection_t(transform, out=out, prototype=prototype)) def set_mask_selection( self, @@ -3759,8 +3869,25 @@ def set_mask_selection( """ if prototype is None: prototype = default_buffer_prototype() - indexer = MaskIndexer(mask, self.shape, self._chunk_grid) - sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + if fields is not None: + indexer = MaskIndexer(mask, self.shape, self._chunk_grid) + sync( + self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) + ) + return + if isinstance(mask, tuple) and len(mask) == 1: + mask = mask[0] + mask_arr = np.asarray(mask) + if mask_arr.dtype != np.bool_: + raise IndexError("invalid mask selection; expected Boolean array") + if mask_arr.shape != self.shape: + raise IndexError( + f"invalid mask selection; expected Boolean array with shape {self.shape}, " + f"got {mask_arr.shape}" + ) + selection = (mask_arr,) + transform = selection_to_transform(selection, self._async_array._transform, "vectorized") + sync(self._async_array._set_selection_t(transform, value, prototype=prototype)) def get_coordinate_selection( self, @@ -3847,16 +3974,45 @@ def get_coordinate_selection( """ if prototype is None: prototype = default_buffer_prototype() - indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid) - out_array = sync( - self.async_array._get_selection( - indexer=indexer, out=out, fields=fields, prototype=prototype + if fields is not None: + indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid) + out_array = sync( + self.async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) ) + if hasattr(out_array, "shape"): + out_array = np.array(out_array).reshape(indexer.sel_shape) + return out_array + # Validate and normalize as coordinate selection + sel_normalized = ensure_tuple(selection) + sel_normalized = tuple( + np.asarray([s], dtype=np.intp) if isinstance(s, (int, np.integer)) else s + for s in sel_normalized ) - - if hasattr(out_array, "shape"): - # restore shape - out_array = np.array(out_array).reshape(indexer.sel_shape) + sel_normalized = replace_lists(sel_normalized) + if not is_coordinate_selection(sel_normalized, self.shape): + raise IndexError( + "invalid coordinate selection; expected one integer " + "(coordinate) array per dimension of the target array, " + f"got {selection!r}" + ) + # Handle wraparound and bounds checking + for dim_sel, dim_len in zip(sel_normalized, self.shape, strict=True): + wraparound_indices(dim_sel, dim_len) + boundscheck_indices(dim_sel, dim_len) + transform = selection_to_transform( + sel_normalized, self._async_array._transform, "vectorized" + ) + out_array = sync( + self._async_array._get_selection_t(transform, out=out, prototype=prototype) + ) + # Reshape to the broadcast shape of the coordinate arrays + sel_tuple = sel_normalized + sel_arrays = [np.asarray(s) for s in sel_tuple] + sel_shape = np.broadcast_shapes(*(s.shape for s in sel_arrays)) + if hasattr(out_array, "shape") and sel_shape != (): + out_array = np.array(out_array).reshape(sel_shape) return out_array def set_coordinate_selection( @@ -3939,30 +4095,46 @@ def set_coordinate_selection( """ if prototype is None: prototype = default_buffer_prototype() - # setup indexer - indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid) - - # handle value - need ndarray-like flatten value - if not is_scalar(value, self.dtype): - try: - from numcodecs.compat import ensure_ndarray_like - - value = ensure_ndarray_like(value) # TODO replace with agnostic - except TypeError: - # Handle types like `list` or `tuple` - value = np.array(value) # TODO replace with agnostic - if hasattr(value, "shape") and len(value.shape) > 1: - value = np.array(value).reshape(-1) - - if not is_scalar(value, self.dtype) and ( - isinstance(value, NDArrayLike) and indexer.shape != value.shape - ): - raise ValueError( - f"Attempting to set a selection of {indexer.sel_shape[0]} " - f"elements with an array of {value.shape[0]} elements." + # Normalize empty fields list to None + if not fields: + fields = None + if fields is not None: + indexer = CoordinateIndexer(selection, self.shape, self._chunk_grid) + if not is_scalar(value, self.dtype): + try: + from numcodecs.compat import ensure_ndarray_like + + value = ensure_ndarray_like(value) + except TypeError: + value = np.array(value) + if hasattr(value, "shape") and len(value.shape) > 1: + value = np.array(value).reshape(-1) + sync( + self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype) ) - - sync(self.async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) + return + sel_normalized = ensure_tuple(selection) + sel_normalized = tuple( + np.asarray([s], dtype=np.intp) if isinstance(s, (int, np.integer)) else s + for s in sel_normalized + ) + sel_normalized = replace_lists(sel_normalized) + if not is_coordinate_selection(sel_normalized, self.shape): + raise IndexError( + "invalid coordinate selection; expected one integer " + "(coordinate) array per dimension of the target array, " + f"got {selection!r}" + ) + for dim_sel, dim_len in zip(sel_normalized, self.shape, strict=True): + wraparound_indices(dim_sel, dim_len) + boundscheck_indices(dim_sel, dim_len) + transform = selection_to_transform( + sel_normalized, self._async_array._transform, "vectorized" + ) + # Flatten value for coordinate selection + if not is_scalar(value, self.dtype) and hasattr(value, "shape") and len(value.shape) > 1: + value = np.asarray(value).reshape(-1) + sync(self._async_array._set_selection_t(transform, value, prototype=prototype)) def get_block_selection( self, @@ -4192,6 +4364,18 @@ def blocks(self) -> BlockIndex: examples.""" return BlockIndex(self) + @property + def z(self) -> _LazyIndexAccessor: + """Lazy indexing accessor. Returns a new Array with composed transform, no I/O.""" + return _LazyIndexAccessor(self) + + def resolve(self, prototype: BufferPrototype | None = None) -> NDArrayLikeOrScalar: + """Read and return the data for this array view. + + Equivalent to ``self[...]`` but more explicit for lazy views. + """ + return self[...] + def resize(self, new_shape: ShapeLike) -> None: """ Change the shape of the array by growing or shrinking one or more @@ -4295,7 +4479,11 @@ def update_attributes(self, new_attributes: dict[str, JSON]) -> Self: return type(self)(new_array) def __repr__(self) -> str: - return f"" + t = self._async_array._transform + return ( + f"" + ) @property def info(self) -> Any: @@ -4405,6 +4593,65 @@ async def _shards_initialized( type SerializerLike = dict[str, JSON] | ArrayBytesCodec | Literal["auto"] +class _LazyOIndex: + """Lazy orthogonal indexing via ``array.z.oindex[...]``.""" + + __slots__ = ("_array",) + + def __init__(self, array: Array[Any]) -> None: + self._array = array + + def __getitem__(self, selection: Any) -> Array[Any]: + new_t = selection_to_transform(selection, self._array._async_array._transform, "orthogonal") + return self._array._with_transform(new_t) + + def __setitem__(self, selection: Any, value: npt.ArrayLike) -> None: + new_t = selection_to_transform(selection, self._array._async_array._transform, "orthogonal") + self._array._with_transform(new_t)[...] = value + + +class _LazyVIndex: + """Lazy vectorized indexing via ``array.z.vindex[...]``.""" + + __slots__ = ("_array",) + + def __init__(self, array: Array[Any]) -> None: + self._array = array + + def __getitem__(self, selection: Any) -> Array[Any]: + new_t = selection_to_transform(selection, self._array._async_array._transform, "vectorized") + return self._array._with_transform(new_t) + + def __setitem__(self, selection: Any, value: npt.ArrayLike) -> None: + new_t = selection_to_transform(selection, self._array._async_array._transform, "vectorized") + self._array._with_transform(new_t)[...] = value + + +class _LazyIndexAccessor: + """Provides lazy indexing via ``array.z[...]``.""" + + __slots__ = ("_array",) + + def __init__(self, array: Array[Any]) -> None: + self._array = array + + def __getitem__(self, selection: Selection) -> Array[Any]: + new_t = selection_to_transform(selection, self._array._async_array._transform, "basic") + return self._array._with_transform(new_t) + + def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None: + new_t = selection_to_transform(selection, self._array._async_array._transform, "basic") + self._array._with_transform(new_t)[...] = value + + @property + def oindex(self) -> _LazyOIndex: + return _LazyOIndex(self._array) + + @property + def vindex(self) -> _LazyVIndex: + return _LazyVIndex(self._array) + + class ShardsConfigParam(TypedDict): shape: tuple[int, ...] index_location: ShardingCodecIndexLocation | None @@ -5778,6 +6025,241 @@ def _get_chunk_spec( ) +def _is_complete_chunk( + sub_transform: IndexTransform, chunk_grid: ChunkGrid, chunk_coords: tuple[int, ...] +) -> bool: + """Check if a sub-transform covers an entire chunk.""" + from zarr.core.transforms.output_map import ConstantMap, DimensionMap + + spec = chunk_grid[chunk_coords] + if spec is None: + return False + for out_dim, m in enumerate(sub_transform.output): + if isinstance(m, ConstantMap): + # A ConstantMap means a single element is selected along this output dimension, + # so the write does not cover the full chunk along this dimension. + chunk_dim_size = spec.shape[out_dim] + if chunk_dim_size > 1: + return False + continue # chunk dim size is 1, so selecting the single element is complete + if isinstance(m, DimensionMap): + chunk_dim_size = spec.shape[out_dim] + # Compute actual storage range: storage = offset + stride * input_coord + dim_lo = sub_transform.domain.inclusive_min[m.input_dimension] + dim_hi = sub_transform.domain.exclusive_max[m.input_dimension] + if m.stride == 1: + storage_start = m.offset + dim_lo + storage_stop = m.offset + dim_hi + if storage_start != 0 or storage_stop != chunk_dim_size: + return False + else: + return False # strided access is never a complete chunk + else: + return False # ArrayMap is never a "complete chunk" + return True + + +async def _get_selection_via_transform( + store_path: StorePath, + metadata: ArrayMetadata, + config: ArrayConfig, + transform: IndexTransform, + codec_pipeline: CodecPipeline, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, +) -> NDArrayLikeOrScalar: + """Read data using an IndexTransform.""" + chunk_grid = ChunkGrid.from_metadata(metadata) + + # Get dtype (same logic as existing _get_selection) + if metadata.zarr_format == 2: + zdtype = metadata.dtype + order = metadata.order + else: + zdtype = metadata.data_type + order = config.order + dtype = zdtype.to_native_dtype() + + out_shape = transform.domain.shape + + # When the transform has ArrayMap outputs, chunk resolution produces + # flat scatter indices (out_indices). The output buffer must be 1D + # during the read, then reshaped to out_shape afterwards. + # For vectorized indexing (all outputs are ArrayMaps), chunk resolution + # produces flat scatter indices. The buffer must be 1D during the read. + # For orthogonal indexing (mixed ArrayMap + DimensionMap), the buffer + # stays multi-dimensional — each dim gets its own out_sel entry. + needs_flat_buffer = all(isinstance(m, ArrayMap) for m in transform.output) + buffer_shape = (product(out_shape),) if needs_flat_buffer else out_shape + + # Setup output buffer + if out is not None: + if not isinstance(out, NDBuffer): + raise TypeError(f"out argument needs to be an NDBuffer. Got {type(out)!r}") + if out.shape != out_shape: + raise ValueError( + f"shape of out argument doesn't match. Expected {out_shape}, got {out.shape}" + ) + out_buffer = out + else: + out_buffer = prototype.nd_buffer.empty(shape=buffer_shape, dtype=dtype, order=order) + + if product(out_shape) > 0: + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # Build batch_info using transforms + batch_info = [] + drop_axes: tuple[int, ...] = () + for chunk_coords, sub_transform, out_indices in iter_chunk_transforms( + transform, chunk_grid + ): + chunk_sel, out_sel, da = sub_transform_to_selections(sub_transform, out_indices) + drop_axes = da # same for all chunks + is_complete = _is_complete_chunk(sub_transform, chunk_grid, chunk_coords) + batch_info.append( + ( + store_path / metadata.encode_chunk_key(chunk_coords), + _get_chunk_spec(metadata, chunk_grid, chunk_coords, _config, prototype), + chunk_sel, + out_sel, + is_complete, + ) + ) + + results = await codec_pipeline.read(batch_info, out_buffer, drop_axes=drop_axes) + + # Handle read_missing_chunks + if _config.read_missing_chunks is False: + missing_info = [] + for i, result in enumerate(results): + if result["status"] == "missing": + coords_path = batch_info[i][0] + missing_info.append(f" chunk at '{coords_path}'") + if missing_info: + chunks_str = "\n".join(missing_info) + raise ChunkNotFoundError( + f"{len(missing_info)} chunk(s) not found in store '{store_path}'.\n" + f"Set the 'array.read_missing_chunks' config to True to fill " + f"missing chunks with the fill value.\n" + f"Missing chunks:\n{chunks_str}" + ) + + # Return scalar for 0-d results + if out_shape == (): + return out_buffer.as_scalar() + out_result = out_buffer.as_ndarray_like() + # Reshape if we flattened for array indexing + if needs_flat_buffer and hasattr(out_result, "reshape"): + out_result = np.array(out_result).reshape(out_shape) + return out_result + + +async def _set_selection_via_transform( + store_path: StorePath, + metadata: ArrayMetadata, + config: ArrayConfig, + transform: IndexTransform, + value: npt.ArrayLike, + codec_pipeline: CodecPipeline, + *, + prototype: BufferPrototype, +) -> None: + """Write data using an IndexTransform.""" + chunk_grid = ChunkGrid.from_metadata(metadata) + + # Get dtype from metadata + if metadata.zarr_format == 2: + zdtype = metadata.dtype + else: + zdtype = metadata.data_type + dtype = zdtype.to_native_dtype() + + # check value shape + if np.isscalar(value): + array_like = prototype.buffer.create_zero_length().as_array_like() + if isinstance(array_like, np._typing._SupportsArrayFunc): + array_like_ = cast("np._typing._SupportsArrayFunc", array_like) + value = np.asanyarray(value, dtype=dtype, like=array_like_) + else: + if not hasattr(value, "shape"): + value = np.asarray(value, dtype) + if not hasattr(value, "dtype") or value.dtype.name != dtype.name: + if hasattr(value, "astype"): + value = value.astype(dtype=dtype, order="A") + else: + value = np.array(value, dtype=dtype, order="A") + value = cast("NDArrayLike", value) + + # Validate value shape against selection shape + sel_shape = transform.domain.shape + needs_flat_buffer = all(isinstance(m, ArrayMap) for m in transform.output) + if hasattr(value, "shape") and value.shape != () and value.shape != sel_shape: + if needs_flat_buffer: + # For ArrayMap (coordinate/vindex), values are flattened so check total size + sel_size = product(sel_shape) + val_size = product(value.shape) + if val_size != sel_size and val_size != 1: + raise ValueError( + f"Attempting to set a selection with a value of incompatible shape. " + f"The selection has shape {sel_shape}, but the value has shape {value.shape}." + ) + else: + # Check if value is broadcastable to sel_shape + try: + np.broadcast_shapes(value.shape, sel_shape) + except ValueError: + raise ValueError( + f"Attempting to set a selection with a value of incompatible shape. " + f"The selection has shape {sel_shape}, but the value has shape {value.shape}." + ) from None + + # When the transform has ArrayMap outputs, chunk resolution produces + # flat scatter indices (out_indices). The value buffer must be 1D + # during the write, matching the flat index layout. + if ( + needs_flat_buffer + and hasattr(value, "reshape") + and not np.isscalar(value) + and np.ndim(value) > 0 + ): + value = np.asarray(value).reshape(-1) + + # Convert to NDBuffer + value_buffer = prototype.nd_buffer.from_ndarray_like(value) + + # Determine memory order + if metadata.zarr_format == 2: + order = metadata.order + else: + order = config.order + + _config = config + if metadata.zarr_format == 2: + _config = replace(_config, order=order) + + # Build batch_info using transforms + batch_info = [] + drop_axes: tuple[int, ...] = () + for chunk_coords, sub_transform, out_indices in iter_chunk_transforms(transform, chunk_grid): + chunk_sel, out_sel, da = sub_transform_to_selections(sub_transform, out_indices) + drop_axes = da # same for all chunks + is_complete = _is_complete_chunk(sub_transform, chunk_grid, chunk_coords) + batch_info.append( + ( + store_path / metadata.encode_chunk_key(chunk_coords), + _get_chunk_spec(metadata, chunk_grid, chunk_coords, _config, prototype), + chunk_sel, + out_sel, + is_complete, + ) + ) + + await codec_pipeline.write(batch_info, value_buffer, drop_axes=drop_axes) + + async def _get_selection( store_path: StorePath, metadata: ArrayMetadata, @@ -6315,9 +6797,10 @@ async def _delete_key(key: str) -> None: # Write new metadata await save_metadata(array.store_path, new_metadata) - # Update metadata and chunk_grid (in place) + # Update metadata, chunk_grid, and transform (in place) object.__setattr__(array, "metadata", new_metadata) object.__setattr__(array, "_chunk_grid", new_chunk_grid) + object.__setattr__(array, "_transform", IndexTransform.from_shape(new_shape)) async def _append( diff --git a/src/zarr/core/transforms/__init__.py b/src/zarr/core/transforms/__init__.py new file mode 100644 index 0000000000..530dd39cea --- /dev/null +++ b/src/zarr/core/transforms/__init__.py @@ -0,0 +1,30 @@ +"""Composable, lazy coordinate transforms for zarr array indexing. + +This package implements TensorStore-inspired index transforms. The core idea: +every indexing operation (slicing, fancy indexing, etc.) produces a coordinate +mapping from user space to storage space. These mappings compose lazily — no +I/O until you explicitly read or write. + +Key types: + +- ``IndexDomain`` — a rectangular region of integer coordinates +- ``IndexTransform`` — maps input coordinates to storage coordinates +- ``ConstantMap``, ``DimensionMap``, ``ArrayMap`` — the three ways a single + output dimension can depend on the input (see ``output_map.py``) +- ``compose`` — chain two transforms into one +""" + +from zarr.core.transforms.composition import compose +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap +from zarr.core.transforms.transform import IndexTransform + +__all__ = [ + "ArrayMap", + "ConstantMap", + "DimensionMap", + "IndexDomain", + "IndexTransform", + "OutputIndexMap", + "compose", +] diff --git a/src/zarr/core/transforms/chunk_resolution.py b/src/zarr/core/transforms/chunk_resolution.py new file mode 100644 index 0000000000..db066a2525 --- /dev/null +++ b/src/zarr/core/transforms/chunk_resolution.py @@ -0,0 +1,207 @@ +"""Chunk resolution — mapping transforms to chunk-level I/O. + +Given an ``IndexTransform`` (which coordinates a user wants to access) and a +``ChunkGrid`` (how storage is divided into chunks), chunk resolution answers: + + For each chunk, which storage coordinates does this transform touch, + and where do those values land in the output buffer? + +The algorithm is: + +1. **Enumerate candidate chunks** — determine which chunks could possibly + be touched by the transform's output coordinate ranges. + +2. **Intersect** — for each candidate chunk, call + ``transform.intersect(chunk_domain)`` to restrict the transform to + coordinates within that chunk. If the intersection is empty, skip it. + +3. **Translate** — shift the restricted transform to chunk-local coordinates + via ``transform.translate(-chunk_origin)``. + +4. **Yield** — produce ``(chunk_coords, local_transform, surviving_indices)`` + triples that the codec pipeline consumes. + +``sub_transform_to_selections`` bridges from the transform representation +back to the raw ``(chunk_selection, out_selection, drop_axes)`` tuples that +the current codec pipeline expects. This bridge will go away when the codec +pipeline accepts transforms natively. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core.transforms.transform import IndexTransform + +if TYPE_CHECKING: + from collections.abc import Iterator + + from zarr.core.chunk_grids import ChunkGrid + +ChunkTransformResult = tuple[ + tuple[int, ...], + IndexTransform, + np.ndarray[Any, np.dtype[np.intp]] | None, +] + + +def iter_chunk_transforms( + transform: IndexTransform, + chunk_grid: ChunkGrid, +) -> Iterator[ChunkTransformResult]: + """Resolve a composed IndexTransform against a ChunkGrid. + + Yields ``(chunk_coords, sub_transform, out_indices)`` triples: + + - ``chunk_coords``: which chunk to access. + - ``sub_transform``: maps output buffer coords to chunk-local coords. + - ``out_indices``: for vectorized/array indexing, the output scatter + indices (integer array). ``None`` for basic/slice indexing. + """ + dim_grids = chunk_grid._dimensions + + # Enumerate all possible chunks via cartesian product of per-dim chunk ranges + # For each candidate chunk, intersect the transform with the chunk domain. + # The transform.intersect method handles both orthogonal and vectorized cases. + chunk_ranges: list[range] = [] + for out_dim, m in enumerate(transform.output): + dg = dim_grids[out_dim] + if isinstance(m, ConstantMap): + # Single chunk + c = dg.index_to_chunk(m.offset) + chunk_ranges.append(range(c, c + 1)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + dim_lo = transform.domain.inclusive_min[d] + dim_hi = transform.domain.exclusive_max[d] + if dim_lo >= dim_hi: + return # empty domain + if m.stride > 0: + s_min = m.offset + m.stride * dim_lo + s_max = m.offset + m.stride * (dim_hi - 1) + else: + s_min = m.offset + m.stride * (dim_hi - 1) + s_max = m.offset + m.stride * dim_lo + first = dg.index_to_chunk(s_min) + last = dg.index_to_chunk(s_max) + chunk_ranges.append(range(first, last + 1)) + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + flat = storage.ravel().astype(np.intp) + chunk_ids = dg.indices_to_chunks(flat) + first = int(chunk_ids.min()) + last = int(chunk_ids.max()) + chunk_ranges.append(range(first, last + 1)) + + import itertools + + for chunk_coords_tuple in itertools.product(*chunk_ranges): + chunk_coords = tuple(int(c) for c in chunk_coords_tuple) + + # Build the chunk domain in storage space + chunk_min: list[int] = [] + chunk_max: list[int] = [] + chunk_shift: list[int] = [] + for out_dim, c in enumerate(chunk_coords): + dg = dim_grids[out_dim] + c_start = dg.chunk_offset(c) + c_size = dg.chunk_size(c) + chunk_min.append(c_start) + chunk_max.append(c_start + c_size) + chunk_shift.append(-c_start) + + chunk_domain = IndexDomain( + inclusive_min=tuple(chunk_min), + exclusive_max=tuple(chunk_max), + ) + + # Intersect transform with chunk domain + result = transform.intersect(chunk_domain) + if result is None: + continue + + restricted, surviving = result + + # Translate to chunk-local coordinates + local = restricted.translate(tuple(chunk_shift)) + + yield (chunk_coords, local, surviving) + + +def sub_transform_to_selections( + sub_transform: IndexTransform, + out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None, +) -> tuple[ + tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[int, ...], +]: + """Convert a chunk-local sub-transform to raw selections for the codec pipeline. + + Parameters + ---------- + sub_transform + A chunk-local IndexTransform (output maps already translated to + chunk-local coordinates). + out_indices + For vectorized indexing: the output scatter indices for this chunk. + None for orthogonal/basic indexing. + + Returns + ------- + tuple + ``(chunk_selection, out_selection, drop_axes)`` + """ + chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + drop_axes: list[int] = [] + + for m in sub_transform.output: + if isinstance(m, ConstantMap): + chunk_sel.append(m.offset) + elif isinstance(m, DimensionMap): + dim_lo = sub_transform.domain.inclusive_min[m.input_dimension] + dim_hi = sub_transform.domain.exclusive_max[m.input_dimension] + start = m.offset + m.stride * dim_lo + stop = m.offset + m.stride * dim_hi + if m.stride < 0: + start, stop = stop + 1, start + 1 + chunk_sel.append(slice(start, stop, m.stride)) + elif isinstance(m, ArrayMap): + if m.offset == 0 and m.stride == 1: + chunk_sel.append(m.index_array) + else: + storage_coords = m.offset + m.stride * m.index_array + chunk_sel.append(storage_coords.astype(np.intp)) + + # Build out_sel: one entry per non-dropped output dim. + out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + + # Vectorized: multiple correlated ArrayMaps share one scatter index + is_vectorized = ( + out_indices is not None + and sum(1 for m in sub_transform.output if isinstance(m, ArrayMap)) >= 2 + ) + + if is_vectorized: + assert out_indices is not None + out_sel.append(out_indices) + else: + for m in sub_transform.output: + if isinstance(m, ConstantMap): + continue + if isinstance(m, DimensionMap): + lo = sub_transform.domain.inclusive_min[m.input_dimension] + hi = sub_transform.domain.exclusive_max[m.input_dimension] + out_sel.append(slice(lo, hi)) + elif isinstance(m, ArrayMap): + if out_indices is not None: + # Orthogonal ArrayMap: out_indices has the surviving positions + out_sel.append(out_indices) + else: + out_sel.append(slice(0, len(m.index_array))) + + return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes) diff --git a/src/zarr/core/transforms/composition.py b/src/zarr/core/transforms/composition.py new file mode 100644 index 0000000000..9d07bd3324 --- /dev/null +++ b/src/zarr/core/transforms/composition.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap +from zarr.core.transforms.transform import IndexTransform + + +def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform: + """Compose two IndexTransforms. + + ``outer`` maps user coords (rank m) to intermediate coords (rank n). + ``inner`` maps intermediate coords (rank n) to storage coords (rank p). + The result maps user coords (rank m) to storage coords (rank p). + + Precondition: ``outer.output_rank == inner.domain.ndim``. + """ + if outer.output_rank != inner.domain.ndim: + raise ValueError( + f"outer output rank ({outer.output_rank}) must match inner input rank " + f"({inner.domain.ndim})" + ) + + result_output = [_compose_single(outer, inner_map) for inner_map in inner.output] + + return IndexTransform(domain=outer.domain, output=tuple(result_output)) + + +def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap: + """Compose a single inner output map with the full outer transform.""" + if isinstance(inner_map, ConstantMap): + return ConstantMap(offset=inner_map.offset) + + if isinstance(inner_map, DimensionMap): + return _compose_dimension(outer, inner_map) + + if isinstance(inner_map, ArrayMap): + return _compose_array(outer, inner_map) + + raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover + + +def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap: + """Compose when inner is a DimensionMap. + + storage = offset_i + stride_i * intermediate[dim_i] + where intermediate[dim_i] = outer.output[dim_i](user_input) + """ + dim_i = inner_map.input_dimension + offset_i = inner_map.offset + stride_i = inner_map.stride + outer_map = outer.output[dim_i] + + if isinstance(outer_map, ConstantMap): + return ConstantMap(offset=offset_i + stride_i * outer_map.offset) + + if isinstance(outer_map, DimensionMap): + return DimensionMap( + input_dimension=outer_map.input_dimension, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + if isinstance(outer_map, ArrayMap): + return ArrayMap( + index_array=outer_map.index_array, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover + + +def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap: + """Compose when inner is an ArrayMap. + + storage = offset_i + stride_i * arr_i[intermediate] + We need to evaluate arr_i at the intermediate coordinates produced by outer. + """ + arr_i = inner_map.index_array + offset_i = inner_map.offset + stride_i = inner_map.stride + + # Check if all outer outputs are constant + all_constant = all(isinstance(m, ConstantMap) for m in outer.output) + + if all_constant: + # Evaluate arr_i at the single constant point + idx = tuple(m.offset for m in outer.output if isinstance(m, ConstantMap)) + value = int(arr_i[idx]) + return ConstantMap(offset=offset_i + stride_i * value) + + # For 1D inner array with a single outer output (simple case) + if arr_i.ndim == 1 and len(outer.output) == 1: + outer_map = outer.output[0] + + if isinstance(outer_map, DimensionMap): + dim_size = outer.domain.shape[outer_map.input_dimension] + user_indices = np.arange(dim_size, dtype=np.intp) + intermediate_vals = outer_map.offset + outer_map.stride * user_indices + new_arr = arr_i[intermediate_vals] + return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) + + if isinstance(outer_map, ArrayMap): + intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array + new_arr = arr_i[intermediate_vals] + return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i) + + # General multi-dim case: not yet implemented + raise NotImplementedError( + "Composing a multi-dimensional inner array map with non-constant outer maps " + "is not yet supported." + ) diff --git a/src/zarr/core/transforms/domain.py b/src/zarr/core/transforms/domain.py new file mode 100644 index 0000000000..90bcc08ace --- /dev/null +++ b/src/zarr/core/transforms/domain.py @@ -0,0 +1,178 @@ +"""Index domains — rectangular regions in N-dimensional integer space. + +An ``IndexDomain`` represents the set of valid coordinates for an array or +array view. It is the cartesian product of per-dimension integer ranges:: + + IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) + # represents {(i, j) : 2 <= i < 10, 5 <= j < 20} + +Unlike NumPy, domains can have **non-zero origins**. After slicing +``arr[5:10]``, the result has origin 5 and shape 5 — coordinates 5 through +9 are valid. This follows the TensorStore convention. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class IndexDomain: + """A rectangular region in N-dimensional index space. + + The valid coordinates are the integers in + ``[inclusive_min[d], exclusive_max[d])`` for each dimension ``d``. + """ + + inclusive_min: tuple[int, ...] + exclusive_max: tuple[int, ...] + labels: tuple[str, ...] | None = None + + def __post_init__(self) -> None: + if len(self.inclusive_min) != len(self.exclusive_max): + raise ValueError( + f"inclusive_min and exclusive_max must have the same length. " + f"Got {len(self.inclusive_min)} and {len(self.exclusive_max)}." + ) + for i, (lo, hi) in enumerate(zip(self.inclusive_min, self.exclusive_max, strict=True)): + if lo > hi: + raise ValueError( + f"inclusive_min must be <= exclusive_max for all dimensions. " + f"Dimension {i}: {lo} > {hi}" + ) + if self.labels is not None and len(self.labels) != len(self.inclusive_min): + raise ValueError( + f"labels must have the same length as dimensions. " + f"Got {len(self.labels)} labels for {len(self.inclusive_min)} dimensions." + ) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexDomain: + """Create a domain with origin at zero.""" + return cls( + inclusive_min=(0,) * len(shape), + exclusive_max=shape, + ) + + @property + def ndim(self) -> int: + return len(self.inclusive_min) + + @property + def origin(self) -> tuple[int, ...]: + return self.inclusive_min + + @property + def shape(self) -> tuple[int, ...]: + return tuple(hi - lo for lo, hi in zip(self.inclusive_min, self.exclusive_max, strict=True)) + + def contains(self, index: tuple[int, ...]) -> bool: + if len(index) != self.ndim: + return False + return all( + lo <= idx < hi + for lo, hi, idx in zip(self.inclusive_min, self.exclusive_max, index, strict=True) + ) + + def contains_domain(self, other: IndexDomain) -> bool: + if other.ndim != self.ndim: + return False + return all( + self_lo <= other_lo and other_hi <= self_hi + for self_lo, self_hi, other_lo, other_hi in zip( + self.inclusive_min, + self.exclusive_max, + other.inclusive_min, + other.exclusive_max, + strict=True, + ) + ) + + def intersect(self, other: IndexDomain) -> IndexDomain | None: + if other.ndim != self.ndim: + raise ValueError( + f"Cannot intersect domains with different ranks: {self.ndim} vs {other.ndim}" + ) + new_min = tuple( + max(a, b) for a, b in zip(self.inclusive_min, other.inclusive_min, strict=True) + ) + new_max = tuple( + min(a, b) for a, b in zip(self.exclusive_max, other.exclusive_max, strict=True) + ) + if any(lo >= hi for lo, hi in zip(new_min, new_max, strict=True)): + return None + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def translate(self, offset: tuple[int, ...]) -> IndexDomain: + if len(offset) != self.ndim: + raise ValueError( + f"Offset must have same length as domain dimensions. " + f"Domain has {self.ndim} dimensions, offset has {len(offset)}." + ) + new_min = tuple(lo + off for lo, off in zip(self.inclusive_min, offset, strict=True)) + new_max = tuple(hi + off for hi, off in zip(self.exclusive_max, offset, strict=True)) + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def narrow(self, selection: Any) -> IndexDomain: + """Apply a basic selection and return a narrowed domain. + Indices are absolute coordinates. Integer indices produce length-1 extent. + Strided slices are not supported — use IndexTransform for strides. + """ + normalized = _normalize_selection(selection, self.ndim) + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + for dim_idx, (sel, dim_lo, dim_hi) in enumerate( + zip(normalized, self.inclusive_min, self.exclusive_max, strict=True) + ): + if isinstance(sel, int): + if sel < dim_lo or sel >= dim_hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {dim_idx} " + f"with domain [{dim_lo}, {dim_hi})" + ) + new_inclusive_min.append(sel) + new_exclusive_max.append(sel + 1) + else: + start, stop, step = sel.start, sel.stop, sel.step + if step is not None and step != 1: + raise IndexError( + "IndexDomain.narrow only supports step=1 slices. " + f"Got step={step}. Use IndexTransform for strided access." + ) + abs_start = dim_lo if start is None else start + abs_stop = dim_hi if stop is None else stop + abs_start = max(abs_start, dim_lo) + abs_stop = min(abs_stop, dim_hi) + abs_stop = max(abs_stop, abs_start) + new_inclusive_min.append(abs_start) + new_exclusive_max.append(abs_stop) + return IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + +def _normalize_selection(selection: Any, ndim: int) -> tuple[int | slice, ...]: + """Normalize a basic selection to a tuple of ints/slices with length ndim.""" + if not isinstance(selection, tuple): + selection = (selection,) + result: list[int | slice] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - (len(selection) - 1) + result.extend([slice(None)] * num_missing) + else: + result.append(sel) + while len(result) < ndim: + result.append(slice(None)) + if len(result) > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, " + f"but {len(result)} were indexed" + ) + return tuple(result) diff --git a/src/zarr/core/transforms/output_map.py b/src/zarr/core/transforms/output_map.py new file mode 100644 index 0000000000..5e17a0ae82 --- /dev/null +++ b/src/zarr/core/transforms/output_map.py @@ -0,0 +1,83 @@ +"""Output index maps — three representations of a set of integer coordinates. + +An output index map describes, for one dimension of storage, which coordinates +an array access will touch. Conceptually it is a **set of integers**. Three +representations cover the cases that arise in practice: + +- ``ConstantMap(offset=5)`` — a singleton set: ``{5}`` +- ``DimensionMap(input_dimension=0, offset=3, stride=2)`` over input ``[0, 5)`` + — an arithmetic progression: ``{3, 5, 7, 9, 11}`` +- ``ArrayMap(index_array=[1, 5, 9])`` — an explicit enumeration: ``{1, 5, 9}`` + +Every output map supports two set-theoretic operations (defined on +``IndexTransform``, which provides the input domain context these maps lack): + +- **intersect** — restrict to coordinates within a range (e.g., a chunk). + ``{3, 5, 7, 9, 11} ∩ [4, 8) = {5, 7}`` +- **translate** — shift every coordinate by a constant (e.g., make chunk-local). + ``{5, 7} - 4 = {1, 3}`` + +These two operations are the foundation of chunk resolution: for each chunk, +intersect the map with the chunk's range, then translate to chunk-local +coordinates. + +The three types exist because they trade off generality for efficiency: + +- ``ConstantMap``: O(1) storage, O(1) intersection +- ``DimensionMap``: O(1) storage, O(1) intersection (analytical) +- ``ArrayMap``: O(n) storage, O(n) intersection (must scan the array) + +Collapsing everything to ``ArrayMap`` would be correct but wasteful — a +billion-element slice would materialize a billion coordinates just to group +them by chunk, when ``DimensionMap`` does it with three integers. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + + +@dataclass(frozen=True, slots=True) +class ConstantMap: + """A singleton set: one storage coordinate. + + Represents ``{offset}``. Arises from integer indexing (e.g., ``arr[5]`` + fixes one dimension to coordinate 5). + """ + + offset: int = 0 + + +@dataclass(frozen=True, slots=True) +class DimensionMap: + """An arithmetic progression of storage coordinates. + + Represents ``{offset + stride * i : i in input_range}``, where the input + range comes from the enclosing ``IndexTransform``'s domain. Arises from + slice indexing (e.g., ``arr[2:10:3]`` gives offset=2, stride=3). + """ + + input_dimension: int + offset: int = 0 + stride: int = 1 + + +@dataclass(frozen=True, slots=True) +class ArrayMap: + """An explicit enumeration of storage coordinates. + + Represents ``{offset + stride * index_array[i] : i in input_range}``. + Arises from fancy indexing (e.g., ``arr[[1, 5, 9]]`` or boolean masks). + """ + + index_array: npt.NDArray[np.intp] + offset: int = 0 + stride: int = 1 + + +OutputIndexMap = ConstantMap | DimensionMap | ArrayMap diff --git a/src/zarr/core/transforms/transform.py b/src/zarr/core/transforms/transform.py new file mode 100644 index 0000000000..ee2f3ce4c2 --- /dev/null +++ b/src/zarr/core/transforms/transform.py @@ -0,0 +1,932 @@ +"""Index transforms — composable, lazy coordinate mappings. + +An ``IndexTransform`` pairs an **input domain** (the coordinates a user sees) +with a tuple of **output maps** (the storage coordinates those inputs map to). +One output map per storage dimension. See ``output_map.py`` for the three +output map types. + +Key operations: + +- **Indexing** (``transform[2:8]``, ``.oindex[idx]``, ``.vindex[idx]``) — + produces a new transform with a narrower input domain and adjusted output + maps. No I/O occurs. This is how lazy slicing works. + +- **intersect(output_domain)** — restrict to storage coordinates within a + region. This is chunk resolution: "which of my coordinates fall in this + chunk?" + +- **translate(shift)** — shift all output coordinates. This makes coordinates + chunk-local: "express my coordinates relative to the chunk origin." + +- **compose(outer, inner)** — chain two transforms. See ``composition.py``. + +The transform is the atomic unit that connects user-facing indexing to +chunk-level I/O. Every ``Array`` holds a transform (identity by default). +``Array.z[...]`` composes a new transform lazily. Reading resolves the +transform against the chunk grid via intersect + translate. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Literal + +import numpy as np + +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap + + +@dataclass(frozen=True, slots=True) +class IndexTransform: + """A composable mapping from input coordinates to storage coordinates. + + An ``IndexTransform`` has: + + - ``domain``: an ``IndexDomain`` describing the valid input coordinates + (the user-facing shape, possibly with non-zero origin). + - ``output``: a tuple of output maps (one per storage dimension), each + describing which storage coordinates the inputs touch. + + For a freshly opened array, the transform is the identity: input + coordinate ``i`` maps to storage coordinate ``i``. Indexing operations + compose new transforms without I/O. + """ + + domain: IndexDomain + output: tuple[OutputIndexMap, ...] + + def __post_init__(self) -> None: + for i, m in enumerate(self.output): + if isinstance(m, DimensionMap): + if m.input_dimension < 0 or m.input_dimension >= self.domain.ndim: + raise ValueError( + f"output[{i}].input_dimension = {m.input_dimension} " + f"is out of range for input rank {self.domain.ndim}" + ) + elif isinstance(m, ArrayMap) and m.index_array.ndim > self.domain.ndim: + raise ValueError( + f"output[{i}].index_array has {m.index_array.ndim} dims " + f"but input domain has {self.domain.ndim} dims" + ) + + @property + def input_rank(self) -> int: + return self.domain.ndim + + @property + def output_rank(self) -> int: + return len(self.output) + + @classmethod + def identity(cls, domain: IndexDomain) -> IndexTransform: + output = tuple(DimensionMap(input_dimension=i) for i in range(domain.ndim)) + return cls(domain=domain, output=output) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexTransform: + return cls.identity(IndexDomain.from_shape(shape)) + + @property + def selection_repr(self) -> str: + """Compact domain string, e.g. ``'{ [2, 8), [0, 10) }'``. + + Follows TensorStore's IndexDomain notation: each dimension shown + as ``[inclusive_min, exclusive_max)`` with stride annotation if not 1. + Constant (integer-indexed) dimensions show as a single value. + Array-indexed dimensions show the set of selected coordinates. + """ + parts: list[str] = [] + for m in self.output: + if isinstance(m, ConstantMap): + parts.append(str(m.offset)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + lo = self.domain.inclusive_min[d] + hi = self.domain.exclusive_max[d] + start = m.offset + m.stride * lo + stop = m.offset + m.stride * hi + if m.stride == 1: + parts.append(f"[{start}, {stop})") + else: + parts.append(f"[{start}, {stop}) step {m.stride}") + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + n = len(storage) + if n <= 5: + vals = ", ".join(str(int(v)) for v in storage.ravel()) + parts.append("{" + vals + "}") + else: + parts.append("{" + f"array({n})" + "}") + return "{ " + ", ".join(parts) + " }" + + def __repr__(self) -> str: + maps: list[str] = [] + for i, m in enumerate(self.output): + if isinstance(m, ConstantMap): + maps.append(f"out[{i}] = {m.offset}") + elif isinstance(m, DimensionMap): + maps.append(f"out[{i}] = {m.offset} + {m.stride} * in[{m.input_dimension}]") + elif isinstance(m, ArrayMap): + maps.append(f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[in]") + maps_str = ", ".join(maps) + return f"IndexTransform(domain={self.domain}, {maps_str})" + + def intersect( + self, output_domain: IndexDomain + ) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Restrict this transform to storage coordinates within output_domain. + + Returns ``(restricted_transform, surviving_indices)`` or None if empty. + + ``surviving_indices`` is an integer array of which input positions + survived the intersection (for ArrayMap dimensions), or None if all + positions survived (ConstantMap/DimensionMap only). + """ + return _intersect(self, output_domain) + + def translate(self, shift: tuple[int, ...]) -> IndexTransform: + """Shift all output coordinates by ``shift``.""" + if len(shift) != self.output_rank: + raise ValueError(f"shift must have length {self.output_rank}, got {len(shift)}") + new_output: list[OutputIndexMap] = [] + for m, s in zip(self.output, shift, strict=True): + if isinstance(m, ConstantMap): + new_output.append(ConstantMap(offset=m.offset + s)) + elif isinstance(m, DimensionMap): + new_output.append( + DimensionMap( + input_dimension=m.input_dimension, + offset=m.offset + s, + stride=m.stride, + ) + ) + elif isinstance(m, ArrayMap): + new_output.append( + ArrayMap( + index_array=m.index_array, + offset=m.offset + s, + stride=m.stride, + ) + ) + return IndexTransform(domain=self.domain, output=tuple(new_output)) + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_basic_indexing(self, selection) + + @property + def oindex(self) -> _OIndexHelper: + return _OIndexHelper(self) + + @property + def vindex(self) -> _VIndexHelper: + return _VIndexHelper(self) + + +def _intersect( + transform: IndexTransform, output_domain: IndexDomain +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a transform with an output domain (e.g., a chunk's bounds). + + For each output dimension, restrict to storage coordinates within + [output_domain.inclusive_min[d], output_domain.exclusive_max[d]). + + For orthogonal transforms (ConstantMap, DimensionMap, independent ArrayMaps), + each dimension is intersected independently and the input domain is narrowed. + + For vectorized transforms (correlated ArrayMaps), all array dimensions + must be checked simultaneously — a point survives only if ALL its + coordinates fall within the output domain. + + Returns None if the intersection is empty. + """ + if output_domain.ndim != transform.output_rank: + raise ValueError( + f"output_domain rank ({output_domain.ndim}) != " + f"transform output rank ({transform.output_rank})" + ) + + # Check if we have correlated ArrayMaps (vectorized) + array_dims = [i for i, m in enumerate(transform.output) if isinstance(m, ArrayMap)] + if len(array_dims) >= 2: + return _intersect_vectorized(transform, output_domain, array_dims) + + # Orthogonal: intersect each output dimension independently + new_min = list(transform.domain.inclusive_min) + new_max = list(transform.domain.exclusive_max) + new_output: list[OutputIndexMap] = [] + surviving_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None + + for out_dim, m in enumerate(transform.output): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + + if isinstance(m, ConstantMap): + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + + elif isinstance(m, DimensionMap): + d = m.input_dimension + input_lo = new_min[d] + input_hi = new_max[d] + if input_lo >= input_hi: + return None + + # Find input range that produces storage coords in [lo, hi) + if m.stride > 0: + new_input_lo = max(input_lo, math.ceil((lo - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((hi - m.offset) / m.stride)) + elif m.stride < 0: + new_input_lo = max(input_lo, math.ceil((hi - 1 - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((lo - 1 - m.offset) / m.stride)) + else: + if lo <= m.offset < hi: + new_input_lo, new_input_hi = input_lo, input_hi + else: + return None + + if new_input_lo >= new_input_hi: + return None + + new_min[d] = new_input_lo + new_max[d] = new_input_hi + new_output.append(m) + + elif isinstance(m, ArrayMap): + storage = m.offset + m.stride * m.index_array + mask = (storage >= lo) & (storage < hi) + if not np.any(mask): + return None + surviving_indices = np.nonzero(mask.ravel())[0].astype(np.intp) + filtered = m.index_array.ravel()[surviving_indices] + new_output.append( + ArrayMap( + index_array=filtered, + offset=m.offset, + stride=m.stride, + ) + ) + + new_domain = IndexDomain( + inclusive_min=tuple(new_min), + exclusive_max=tuple(new_max), + ) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving_indices) + + +def _intersect_vectorized( + transform: IndexTransform, + output_domain: IndexDomain, + array_dims: list[int], +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a vectorized transform with an output domain. + + All ArrayMap outputs are correlated — a point survives only if ALL its + storage coordinates fall within the output domain. + """ + # Compute storage coords per array dim and check bounds simultaneously + n_points: int | None = None + masks: list[np.ndarray[Any, np.dtype[np.bool_]]] = [] + + for out_dim in array_dims: + m = transform.output[out_dim] + assert isinstance(m, ArrayMap) + storage = m.offset + m.stride * m.index_array + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + masks.append((storage >= lo) & (storage < hi)) + if n_points is None: + n_points = storage.size + + # A point survives only if it's in-bounds on ALL array dims + combined_mask = masks[0] + for mask in masks[1:]: + combined_mask = combined_mask & mask + + if not np.any(combined_mask): + return None + + surviving = np.nonzero(combined_mask.ravel())[0].astype(np.intp) + + # Build new output maps + new_output: list[OutputIndexMap] = [] + for out_dim, m in enumerate(transform.output): + if isinstance(m, ArrayMap): + filtered = m.index_array.ravel()[surviving] + new_output.append( + ArrayMap( + index_array=filtered, + offset=m.offset, + stride=m.stride, + ) + ) + elif isinstance(m, ConstantMap): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + elif isinstance(m, DimensionMap): + new_output.append(m) + + new_domain = IndexDomain.from_shape((len(surviving),)) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving) + + +def _normalize_basic_selection(selection: Any, ndim: int) -> tuple[int | slice | None, ...]: + """Normalize a selection to a tuple of int, slice, or None (newaxis), + expanding ellipsis and padding with slice(None) as needed. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Count non-newaxis, non-ellipsis entries to determine how many real dims are addressed + n_newaxis = sum(1 for s in selection if s is None) + has_ellipsis = any(s is Ellipsis for s in selection) + n_real = len(selection) - n_newaxis - (1 if has_ellipsis else 0) + + if n_real > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, but {n_real} were indexed" + ) + + result: list[int | slice | None] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, (int, np.integer)): + result.append(int(sel)) + elif isinstance(sel, slice) or sel is None: + result.append(sel) + else: + raise IndexError(f"unsupported selection type for basic indexing: {type(sel)!r}") + + # Pad remaining dimensions with slice(None) + while sum(1 for s in result if s is not None) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _reindex_array( + arr: np.ndarray[Any, np.dtype[np.intp]], + normalized: tuple[int | slice | None, ...], + domain: IndexDomain, +) -> np.ndarray[Any, np.dtype[np.intp]]: + """Apply basic indexing operations to an ArrayMap's index_array. + + The array's axes correspond to the transform's input dimensions (0-indexed + over the domain shape). When input dimensions are dropped (int), sliced, + or inserted (newaxis), the array must be updated accordingly. + """ + # Build a numpy indexing tuple: one entry per old input dimension + idx: list[Any] = [] + old_dim = 0 + newaxis_positions: list[int] = [] + result_axis = 0 + + for sel in normalized: + if sel is None: + newaxis_positions.append(result_axis) + result_axis += 1 + elif isinstance(sel, int): + if old_dim < arr.ndim: + # Convert absolute domain coordinate to 0-based array index + array_idx = sel - domain.inclusive_min[old_dim] + idx.append(array_idx) + old_dim += 1 + elif isinstance(sel, slice): + if old_dim < arr.ndim: + dim_size = domain.shape[old_dim] + # sel.indices gives 0-based start/stop/step for the array axis + start, stop, step = sel.indices(dim_size) + idx.append(slice(start, stop, step)) + old_dim += 1 + result_axis += 1 + + result = arr[tuple(idx)] if idx else arr + + for pos in newaxis_positions: + result = np.expand_dims(result, axis=pos) + + return np.asarray(result, dtype=np.intp) + + +def _reindex_array_oindex( + arr: np.ndarray[Any, np.dtype[np.intp]], + normalized: tuple[Any, ...] | list[Any], + domain: IndexDomain, +) -> np.ndarray[Any, np.dtype[np.intp]]: + """Apply oindex/vindex selection to an existing ArrayMap's index_array. + + Each old input dimension gets either an array (fancy index that axis) + or a slice applied to the corresponding array axis. + """ + idx: list[Any] = [] + for old_dim, sel in enumerate(normalized): + if old_dim >= arr.ndim: + break + if isinstance(sel, np.ndarray): + idx.append(sel) + elif isinstance(sel, slice): + dim_size = domain.shape[old_dim] + start, stop, step = sel.indices(dim_size) + idx.append(slice(start, stop, step)) + else: + idx.append(slice(None)) + + result = arr[tuple(idx)] if idx else arr + return np.asarray(result, dtype=np.intp) + + +def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply basic indexing (int, slice, ellipsis, newaxis) to an IndexTransform.""" + normalized = _normalize_basic_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + old_dim = 0 + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + dropped_dims: set[int] = set() + + # Per old-dim: the slice parameters (for computing new output maps) + dim_slice_params: dict[int, tuple[int, int, int]] = {} # old_dim -> (start, stop, step) + dim_int_val: dict[int, int] = {} # old_dim -> integer index value + + for sel in normalized: + if sel is None: + # newaxis: add a size-1 dimension + new_inclusive_min.append(0) + new_exclusive_max.append(1) + new_dim_idx += 1 + elif isinstance(sel, int): + # Integer index: drop this input dimension. + # Negative indices are literal coordinates (TensorStore convention), + # NOT "from the end" like NumPy. The Array layer handles conversion. + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + idx = sel + if idx < lo or idx >= hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {old_dim} with domain [{lo}, {hi})" + ) + dropped_dims.add(old_dim) + dim_int_val[old_dim] = idx + old_dim += 1 + elif isinstance(sel, slice): + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + + # Resolve slice relative to the current domain (origin-based) + start, stop, step = sel.indices(dim_size) + # start, stop, step are now relative to a 0-based range of size dim_size + + if step <= 0: + raise IndexError("slice step must be positive") + + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + + # Absolute start in the original domain coordinates + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + old_dim += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Now update output maps + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dropped_dims: + # Integer index: this output becomes constant + new_offset = m.offset + m.stride * dim_int_val[d] + new_output.append(ConstantMap(offset=new_offset)) + elif d in old_to_new_dim: + # Slice: update offset and stride + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError(f"unexpected: dimension {d} not handled") + elif isinstance(m, ArrayMap): + new_arr = _reindex_array(m.index_array, normalized, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _OIndexHelper: + """Helper that provides orthogonal (outer) indexing via ``transform.oindex[...]``.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_oindex(self._transform, selection) + + +def _normalize_oindex_selection( + selection: Any, ndim: int +) -> tuple[np.ndarray[Any, np.dtype[np.intp]] | slice, ...]: + """Normalize an oindex selection: arrays, slices, booleans, integers.""" + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis + has_ellipsis = any(s is Ellipsis for s in selection) + n_ellipsis = 1 if has_ellipsis else 0 + n_real = len(selection) - n_ellipsis + + result: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + # Boolean array -> integer indices + (indices,) = np.nonzero(sel) + result.append(indices.astype(np.intp)) + elif isinstance(sel, np.ndarray): + result.append(sel.astype(np.intp)) + elif isinstance(sel, slice): + result.append(sel) + elif isinstance(sel, (int, np.integer)): + # Convert integer scalars to 1-element arrays for orthogonal indexing + result.append(np.array([int(sel)], dtype=np.intp)) + elif isinstance(sel, (list, tuple)): + result.append(np.asarray(sel, dtype=np.intp)) + else: + result.append(sel) + + # Pad with slice(None) + while len(result) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply orthogonal indexing to an IndexTransform. + + Each index array is applied independently per dimension (outer product). + """ + normalized = _normalize_oindex_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + + # Info per old dim + dim_array: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + dim_slice_params: dict[int, tuple[int, int, int]] = {} + + for old_dim, sel in enumerate(normalized): + if isinstance(sel, np.ndarray): + dim_array[old_dim] = sel + new_inclusive_min.append(0) + new_exclusive_max.append(len(sel)) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + elif isinstance(sel, slice): + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dim_array: + new_output.append( + ArrayMap( + index_array=dim_array[d], + offset=m.offset, + stride=m.stride, + ) + ) + elif d in dim_slice_params: + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError(f"unexpected: dimension {d} not handled") + elif isinstance(m, ArrayMap): + new_arr = _reindex_array_oindex(m.index_array, normalized, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _VIndexHelper: + """Helper that provides vectorized (fancy) indexing via ``transform.vindex[...]``.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_vindex(self._transform, selection) + + +def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply vectorized indexing to an IndexTransform. + + All array indices are broadcast together. Broadcast dimensions are prepended, + followed by non-array (slice) dimensions. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis and count consumed dimensions + # Boolean arrays with ndim > 1 consume ndim dims + n_consumed = 0 + for s in selection: + if s is Ellipsis: + continue + if isinstance(s, np.ndarray) and s.dtype == np.bool_ and s.ndim > 1: + n_consumed += s.ndim + else: + n_consumed += 1 + ndim = transform.domain.ndim + + expanded: list[Any] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_consumed + expanded.extend([slice(None)] * num_missing) + else: + expanded.append(sel) + # Count dimensions already consumed by expanded entries + n_expanded_dims = 0 + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_ and sel.ndim > 1: + n_expanded_dims += sel.ndim + else: + n_expanded_dims += 1 + while n_expanded_dims < ndim: + expanded.append(slice(None)) + n_expanded_dims += 1 + + # Convert booleans, lists, ints to integer arrays + processed: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + indices_tuple = np.nonzero(sel) + processed.extend(indices.astype(np.intp) for indices in indices_tuple) + elif isinstance(sel, np.ndarray): + processed.append(sel.astype(np.intp)) + elif isinstance(sel, (list, tuple)): + processed.append(np.asarray(sel, dtype=np.intp)) + elif isinstance(sel, (int, np.integer)): + processed.append(np.array([int(sel)], dtype=np.intp)) + else: + processed.append(sel) + + # Separate array dims and slice dims + array_dims: list[int] = [] + slice_dims: list[int] = [] + arrays: list[np.ndarray[Any, np.dtype[np.intp]]] = [] + + for i, sel in enumerate(processed): + if isinstance(sel, np.ndarray): + array_dims.append(i) + arrays.append(sel) + else: + slice_dims.append(i) + + # Broadcast all arrays together + broadcast_arrays: list[np.ndarray[Any, np.dtype[np.intp]]] + if arrays: + broadcast_arrays = list(np.broadcast_arrays(*arrays)) + broadcast_shape = broadcast_arrays[0].shape + else: + broadcast_arrays = [] + broadcast_shape = () + + # Build new domain: broadcast dims first, then slice dims + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + + # Broadcast dimensions + for s in broadcast_shape: + new_inclusive_min.append(0) + new_exclusive_max.append(s) + + # Slice dimensions + slice_dim_params: dict[int, tuple[int, int, int]] = {} + for old_dim in slice_dims: + sel = processed[old_dim] + assert isinstance(sel, slice) + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + slice_dim_params[old_dim] = (abs_start, stop, step) + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Build output maps + array_dim_to_broadcast: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + for i, d in enumerate(array_dims): + array_dim_to_broadcast[d] = broadcast_arrays[i] + + # New dim index for slice dims starts after broadcast dims + n_broadcast_dims = len(broadcast_shape) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in array_dim_to_broadcast: + new_output.append( + ArrayMap( + index_array=array_dim_to_broadcast[d], + offset=m.offset, + stride=m.stride, + ) + ) + else: + # Slice dim + abs_start, _, step = slice_dim_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = n_broadcast_dims + slice_dims.index(d) + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + elif isinstance(m, ArrayMap): + new_arr = _reindex_array_oindex(m.index_array, processed, transform.domain) + new_output.append(ArrayMap(index_array=new_arr, offset=m.offset, stride=m.stride)) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +def _normalize_negative_indices(selection: Any, shape: tuple[int, ...]) -> Any: + """Convert negative indices to positive ones using the array shape. + + Only normalizes integer and array-like index components; leaves + slices, Ellipsis, None, etc. untouched. + """ + if not isinstance(selection, tuple): + selection_tuple: tuple[Any, ...] = (selection,) + else: + selection_tuple = selection + + # Count real dimensions (non-None, non-Ellipsis) to map each entry to a shape dim + has_ellipsis = any(s is Ellipsis for s in selection_tuple) + n_non_newaxis = sum(1 for s in selection_tuple if s is not None and s is not Ellipsis) + n_ellipsis_dims = len(shape) - n_non_newaxis + (1 if has_ellipsis else 0) + + result: list[Any] = [] + dim = 0 + + for sel in selection_tuple: + if sel is Ellipsis: + result.append(sel) + dim += max(0, n_ellipsis_dims) + elif sel is None: + result.append(sel) + elif isinstance(sel, (int, np.integer)) and not isinstance(sel, bool): + idx = int(sel) + if idx < 0 and dim < len(shape): + idx = idx + shape[dim] + result.append(idx) + dim += 1 + elif isinstance(sel, np.ndarray) and sel.dtype != np.bool_: + arr = sel.copy() + if dim < len(shape): + arr = np.where(arr < 0, arr + shape[dim], arr) + result.append(arr) + dim += 1 + elif isinstance(sel, list): + # Convert lists to arrays with negative index normalization + arr = np.asarray(sel, dtype=np.intp) + if dim < len(shape): + arr = np.where(arr < 0, arr + shape[dim], arr) + result.append(arr) + dim += 1 + else: + # slice, bool array, or anything else: pass through + result.append(sel) + if sel is not None and sel is not Ellipsis: + dim += 1 + + if not isinstance(selection, tuple) and len(result) == 1: + return result[0] + return tuple(result) + + +def _validate_array_selection(selection: Any, shape: tuple[int, ...], mode: str) -> None: + """Validate array-based selections (orthogonal, vectorized). + + Rejects types that are not valid for coordinate/vectorized indexing. + Does not check bounds — the transform operations handle that. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for sel in items: + if sel is Ellipsis or isinstance(sel, (int, np.integer, slice)): + continue + if isinstance(sel, (list, np.ndarray)): + continue + raise IndexError(f"unsupported selection type for {mode} indexing: {type(sel)!r}") + + +def _validate_basic_selection(selection: Any) -> None: + """Validate that a selection only contains basic indexing types (int, slice, Ellipsis). + + Rejects None (newaxis), arrays, lists, floats, strings, etc. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for s in items: + if s is Ellipsis or isinstance(s, (int, np.integer, slice)): + continue + raise IndexError(f"unsupported selection type for basic indexing: {type(s)!r}") + + +def selection_to_transform( + selection: Any, + transform: IndexTransform, + mode: Literal["basic", "orthogonal", "vectorized"], +) -> IndexTransform: + """Convert a user selection into a composed IndexTransform. + + Negative indices are treated as literal coordinates (TensorStore convention). + The caller (Array layer) is responsible for converting numpy-style negative + indices before calling this function. + """ + if mode == "basic": + _validate_basic_selection(selection) + return transform[selection] + elif mode == "orthogonal": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.oindex[selection] + elif mode == "vectorized": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.vindex[selection] + else: + raise ValueError(f"Unknown mode: {mode!r}") diff --git a/tests/test_array.py b/tests/test_array.py index f7f564f30e..396c09d28a 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1956,7 +1956,8 @@ def test_array_repr(store: Store) -> None: shape = (2, 3, 4) dtype = "uint8" arr = zarr.create_array(store, shape=shape, dtype=dtype) - assert str(arr) == f"" + domain = "{ [0, 2), [0, 3), [0, 4) }" + assert str(arr) == f"" class UnknownObjectDtype(UTF8Base[np.dtypes.ObjectDType]): diff --git a/tests/test_lazy_indexing.py b/tests/test_lazy_indexing.py new file mode 100644 index 0000000000..f27a242109 --- /dev/null +++ b/tests/test_lazy_indexing.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +import zarr +from zarr.storage import MemoryStore + + +@pytest.fixture +def arr() -> zarr.Array[Any]: + """Create a 2D array with known data.""" + store = MemoryStore() + a = zarr.create(shape=(20, 30), chunks=(5, 10), dtype="i4", store=store) + data = np.arange(600, dtype="i4").reshape(20, 30) + a[...] = data + return a + + +@pytest.fixture +def data() -> np.ndarray[Any, Any]: + return np.arange(600, dtype="i4").reshape(20, 30) + + +class TestEagerRead: + def test_basic_slice(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[2:8, 5:15] + np.testing.assert_array_equal(result, data[2:8, 5:15]) + + def test_basic_int(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[3] + np.testing.assert_array_equal(result, data[3]) + + def test_basic_int_scalar(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[3, 5] + assert result == data[3, 5] + + def test_ellipsis(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[...] + np.testing.assert_array_equal(result, data) + + def test_strided_slice(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[::2, ::3] + np.testing.assert_array_equal(result, data[::2, ::3]) + + def test_oindex(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + idx = np.array([1, 5, 10], dtype=np.intp) + result = arr.oindex[idx, :] + np.testing.assert_array_equal(result, data[idx, :]) + + def test_vindex(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + idx0 = np.array([1, 5, 10], dtype=np.intp) + idx1 = np.array([2, 8, 15], dtype=np.intp) + result = arr.vindex[idx0, idx1] + np.testing.assert_array_equal(result, data[idx0, idx1]) + + def test_slice_across_chunks(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + """Slice that spans multiple chunks.""" + result = arr[3:17, 8:22] + np.testing.assert_array_equal(result, data[3:17, 8:22]) + + def test_single_element(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[0:1, 0:1] + np.testing.assert_array_equal(result, data[0:1, 0:1]) + + def test_full_read(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + result = arr[:] + np.testing.assert_array_equal(result, data) + + +class TestEagerWrite: + def test_write_slice(self, arr: zarr.Array[Any]) -> None: + arr[2:5, 10:20] = np.ones((3, 10), dtype="i4") * 99 + result = arr[2:5, 10:20] + np.testing.assert_array_equal(result, np.ones((3, 10), dtype="i4") * 99) + + def test_write_scalar(self, arr: zarr.Array[Any]) -> None: + arr[0, 0] = 42 + assert arr[0, 0] == 42 + + def test_roundtrip(self, arr: zarr.Array[Any]) -> None: + new_data = np.random.randint(0, 100, size=(20, 30), dtype="i4") + arr[...] = new_data + np.testing.assert_array_equal(arr[...], new_data) + + def test_write_across_chunks(self, arr: zarr.Array[Any]) -> None: + """Write spanning multiple chunks.""" + val = np.ones((14, 14), dtype="i4") * 77 + arr[3:17, 8:22] = val + result = arr[3:17, 8:22] + np.testing.assert_array_equal(result, val) + + +class TestLazyRead: + def test_lazy_shape(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + v = arr.z[2:8, 5:15] + assert isinstance(v, zarr.Array) + assert v.shape == (6, 10) + + def test_lazy_resolve(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + v = arr.z[2:8, 5:15] + result = v[...] + np.testing.assert_array_equal(result, data[2:8, 5:15]) + + def test_lazy_np_asarray(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + v = arr.z[2:8] + result = np.asarray(v) + np.testing.assert_array_equal(result, data[2:8]) + + def test_lazy_composition(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + v = arr.z[2:12].z[3:8] + assert v.shape == (5, 30) + result = v[...] + np.testing.assert_array_equal(result, data[5:10]) + + def test_lazy_oindex(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + idx = np.array([1, 5, 10], dtype=np.intp) + v = arr.z.oindex[idx, :] + assert isinstance(v, zarr.Array) + assert v.shape == (3, 30) + result = v[...] + np.testing.assert_array_equal(result, data[idx, :]) + + def test_lazy_vindex(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + idx0 = np.array([1, 5, 10], dtype=np.intp) + idx1 = np.array([2, 8, 15], dtype=np.intp) + v = arr.z.vindex[idx0, idx1] + assert isinstance(v, zarr.Array) + assert v.shape == (3,) + result = v[...] + np.testing.assert_array_equal(result, data[idx0, idx1]) + + def test_lazy_resolve_method(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + v = arr.z[2:8] + result = v.resolve() + np.testing.assert_array_equal(result, data[2:8]) + + def test_lazy_across_chunks(self, arr: zarr.Array[Any], data: np.ndarray[Any, Any]) -> None: + """Lazy slice spanning multiple chunks resolves correctly.""" + v = arr.z[3:17, 8:22] + result = v[...] + np.testing.assert_array_equal(result, data[3:17, 8:22]) + + +class TestLazyWrite: + def test_lazy_write(self, arr: zarr.Array[Any]) -> None: + arr.z[2:5, 10:20] = np.ones((3, 10), dtype="i4") * 99 + result = arr[2:5, 10:20] + np.testing.assert_array_equal(result, np.ones((3, 10), dtype="i4") * 99) + + def test_lazy_oindex_write(self, arr: zarr.Array[Any]) -> None: + idx = np.array([0, 5, 10], dtype=np.intp) + arr.z.oindex[idx, :] = np.zeros((3, 30), dtype="i4") + result = arr.oindex[idx, :] + np.testing.assert_array_equal(result, np.zeros((3, 30), dtype="i4")) + + def test_lazy_vindex_write(self, arr: zarr.Array[Any]) -> None: + idx0 = np.array([0, 5, 10], dtype=np.intp) + idx1 = np.array([0, 5, 10], dtype=np.intp) + arr.z.vindex[idx0, idx1] = np.array([77, 88, 99], dtype="i4") + result = arr.vindex[idx0, idx1] + np.testing.assert_array_equal(result, np.array([77, 88, 99], dtype="i4")) diff --git a/tests/test_transforms/__init__.py b/tests/test_transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py new file mode 100644 index 0000000000..63246a9caf --- /dev/null +++ b/tests/test_transforms/test_chunk_resolution.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core.chunk_grids import ChunkGrid, FixedDimension +from zarr.core.transforms.chunk_resolution import iter_chunk_transforms, sub_transform_to_selections +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core.transforms.transform import IndexTransform + + +class TestChunkResolutionIdentity: + def test_single_chunk(self) -> None: + """Array fits in one chunk.""" + t = IndexTransform.from_shape((10,)) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=10),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + coords, sub_t, _ = results[0] + assert coords == (0,) + assert sub_t.domain.shape == (10,) + + def test_multiple_chunks_1d(self) -> None: + """1D array spanning 3 chunks.""" + t = IndexTransform.from_shape((30,)) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 3 + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + def test_multiple_chunks_2d(self) -> None: + """2D array spanning 2x3 chunks.""" + t = IndexTransform.from_shape((20, 30)) + grid = ChunkGrid( + dimensions=( + FixedDimension(size=10, extent=20), + FixedDimension(size=10, extent=30), + ) + ) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 6 + coords_list = [r[0] for r in results] + assert (0, 0) in coords_list + assert (1, 2) in coords_list + + +class TestChunkResolutionSliced: + def test_slice_within_chunk(self) -> None: + """Slice that falls within a single chunk.""" + t = IndexTransform.from_shape((100,))[5:8] + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + coords, sub_t, _ = results[0] + assert coords == (0,) + assert isinstance(sub_t.output[0], DimensionMap) + assert sub_t.output[0].offset == 5 + + def test_slice_across_chunks(self) -> None: + """Slice that spans two chunks.""" + t = IndexTransform.from_shape((100,))[8:15] + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=100),)) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 2 + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + + +class TestChunkResolutionConstant: + def test_integer_index(self) -> None: + """Integer index produces constant map — single chunk per constant dim.""" + t = IndexTransform.from_shape((100, 100))[25, :] + grid = ChunkGrid( + dimensions=( + FixedDimension(size=10, extent=100), + FixedDimension(size=10, extent=100), + ) + ) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 10 + for coords, _, _ in results: + assert coords[0] == 2 + + +class TestChunkResolutionArray: + def test_array_index(self) -> None: + """Array index map — chunks determined by array values.""" + idx = np.array([5, 15, 25], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=idx),), + ) + grid = ChunkGrid(dimensions=(FixedDimension(size=10, extent=30),)) + results = list(iter_chunk_transforms(t, grid)) + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + +class TestSubTransformToSelections: + def test_constant_map(self) -> None: + """ConstantMap produces int selection + drop axis.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (5,) + assert out_sel == () + assert drop_axes == () + + def test_dimension_map_stride_1(self) -> None: + """DimensionMap with stride=1 produces contiguous slice.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=3, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (slice(3, 13, 1),) + assert out_sel == (slice(0, 10),) + assert drop_axes == () + + def test_dimension_map_strided(self) -> None: + """DimensionMap with stride>1 produces strided slice.""" + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=2, stride=3),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel == (slice(2, 17, 3),) + assert out_sel == (slice(0, 5),) + assert drop_axes == () + + def test_array_map(self) -> None: + """ArrayMap produces integer array selection.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=0, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], arr) + # Without chunk_mask, out_sel falls back to domain-based slices + assert out_sel == (slice(0, 3),) + assert drop_axes == () + + def test_array_map_with_offset_stride(self) -> None: + """ArrayMap with offset and stride computes storage coords.""" + arr = np.array([0, 1, 2], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=10, stride=5),), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], np.array([10, 15, 20])) + assert drop_axes == () + + def test_mixed_maps_2d(self) -> None: + """Mix of ConstantMap and DimensionMap.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert chunk_sel[0] == 5 + assert chunk_sel[1] == slice(0, 10, 1) + # drop_axes is empty — integer in chunk_sel naturally drops the dim via numpy + assert drop_axes == () diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py new file mode 100644 index 0000000000..e96616933d --- /dev/null +++ b/tests/test_transforms/test_composition.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from zarr.core.transforms.composition import compose +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core.transforms.transform import IndexTransform + + +class TestComposeConstantInner: + """Inner = constant. Result is always constant.""" + + def test_constant_inner_any_outer(self) -> None: + outer = IndexTransform.from_shape((5,)) + inner = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=42),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +class TestComposeDimensionInner: + """Inner = DimensionMap.""" + + def test_dimension_inner_constant_outer(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 25 + + def test_dimension_inner_dimension_outer(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 25 + assert result.output[0].stride == 6 + assert result.output[0].input_dimension == 0 + + def test_dimension_inner_array_outer(self) -> None: + arr = np.array([0, 2, 4], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, offset=5, stride=2),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 25 + assert result.output[0].stride == 6 + np.testing.assert_array_equal(result.output[0].index_array, arr) + + +class TestComposeArrayInner: + """Inner = ArrayMap.""" + + def test_array_inner_constant_outer(self) -> None: + inner_arr = np.array([10, 20, 30], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=1),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 20 + + def test_array_inner_array_outer(self) -> None: + outer_arr = np.array([0, 2, 1], dtype=np.intp) + inner_arr = np.array([10, 20, 30], dtype=np.intp) + outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=outer_arr, offset=0, stride=1),), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=inner_arr, offset=0, stride=1),), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ArrayMap) + expected = np.array([10, 30, 20], dtype=np.intp) + np.testing.assert_array_equal(result.output[0].index_array, expected) + + +class TestComposeMultiDim: + def test_2d_identity_compose(self) -> None: + a = IndexTransform.from_shape((10, 20)) + b = IndexTransform.from_shape((10, 20)) + result = compose(a, b) + assert result.domain.shape == (10, 20) + for i in range(2): + m = result.output[i] + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + def test_mixed_map_types(self) -> None: + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10, 10)), + output=( + DimensionMap(input_dimension=0, offset=2, stride=3), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 17 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + assert result.output[1].offset == 0 + assert result.output[1].stride == 1 + + def test_rank_mismatch_raises(self) -> None: + outer = IndexTransform.from_shape((10,)) + inner = IndexTransform.from_shape((10, 20)) + with pytest.raises(ValueError, match="rank"): + compose(outer, inner) + + +class TestComposeChain: + def test_three_transforms(self) -> None: + a = IndexTransform.from_shape((100,)) + b = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=1),), + ) + c = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + bc = compose(b, c) + abc = compose(a, bc) + assert isinstance(abc.output[0], DimensionMap) + assert abc.output[0].offset == 25 + assert abc.output[0].stride == 2 diff --git a/tests/test_transforms/test_domain.py b/tests/test_transforms/test_domain.py new file mode 100644 index 0000000000..5a222a548c --- /dev/null +++ b/tests/test_transforms/test_domain.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import pytest + +from zarr.core.transforms.domain import IndexDomain + + +class TestIndexDomainConstruction: + def test_from_shape(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.inclusive_min == (0, 0) + assert d.exclusive_max == (10, 20) + assert d.ndim == 2 + assert d.origin == (0, 0) + assert d.shape == (10, 20) + + def test_from_shape_0d(self) -> None: + d = IndexDomain.from_shape(()) + assert d.ndim == 0 + assert d.shape == () + + def test_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(5, 10), exclusive_max=(15, 30)) + assert d.origin == (5, 10) + assert d.shape == (10, 20) + assert d.ndim == 2 + + def test_validation_mismatched_lengths(self) -> None: + with pytest.raises(ValueError, match="same length"): + IndexDomain(inclusive_min=(0,), exclusive_max=(10, 20)) + + def test_validation_min_greater_than_max(self) -> None: + with pytest.raises(ValueError, match="inclusive_min must be <="): + IndexDomain(inclusive_min=(10,), exclusive_max=(5,)) + + def test_empty_domain(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(5,)) + assert d.shape == (0,) + + def test_labels(self) -> None: + d = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 20), labels=("x", "y")) + assert d.labels == ("x", "y") + + def test_labels_none(self) -> None: + d = IndexDomain.from_shape((10,)) + assert d.labels is None + + +class TestIndexDomainContains: + def test_contains_inside(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((0, 0)) is True + assert d.contains((9, 19)) is True + assert d.contains((5, 10)) is True + + def test_contains_outside(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((10, 0)) is False + assert d.contains((-1, 0)) is False + assert d.contains((0, 20)) is False + + def test_contains_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + assert d.contains((5,)) is True + assert d.contains((9,)) is True + assert d.contains((4,)) is False + assert d.contains((10,)) is False + + def test_contains_wrong_ndim(self) -> None: + d = IndexDomain.from_shape((10, 20)) + assert d.contains((5,)) is False + + def test_contains_domain_inside(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(8, 15)) + assert outer.contains_domain(inner) is True + + def test_contains_domain_outside(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain(inclusive_min=(2, 3), exclusive_max=(11, 15)) + assert outer.contains_domain(inner) is False + + def test_contains_domain_wrong_ndim(self) -> None: + outer = IndexDomain.from_shape((10, 20)) + inner = IndexDomain.from_shape((5,)) + assert outer.contains_domain(inner) is False + + +class TestIndexDomainIntersect: + def test_overlapping(self) -> None: + a = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 10)) + b = IndexDomain(inclusive_min=(5, 5), exclusive_max=(15, 15)) + result = a.intersect(b) + assert result is not None + assert result.inclusive_min == (5, 5) + assert result.exclusive_max == (10, 10) + + def test_disjoint(self) -> None: + a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) + b = IndexDomain(inclusive_min=(10,), exclusive_max=(15,)) + assert a.intersect(b) is None + + def test_touching_boundary(self) -> None: + a = IndexDomain(inclusive_min=(0,), exclusive_max=(5,)) + b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + assert a.intersect(b) is None + + def test_contained(self) -> None: + a = IndexDomain.from_shape((20,)) + b = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + result = a.intersect(b) + assert result is not None + assert result.inclusive_min == (5,) + assert result.exclusive_max == (10,) + + def test_wrong_ndim(self) -> None: + a = IndexDomain.from_shape((10,)) + b = IndexDomain.from_shape((10, 20)) + with pytest.raises(ValueError, match="different ranks"): + a.intersect(b) + + +class TestIndexDomainTranslate: + def test_translate_positive(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.translate((5, 10)) + assert result.inclusive_min == (5, 10) + assert result.exclusive_max == (15, 30) + + def test_translate_negative(self) -> None: + d = IndexDomain(inclusive_min=(10, 20), exclusive_max=(30, 40)) + result = d.translate((-10, -20)) + assert result.inclusive_min == (0, 0) + assert result.exclusive_max == (20, 20) + + def test_translate_wrong_length(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(ValueError, match="same length"): + d.translate((1, 2)) + + +class TestIndexDomainNarrow: + def test_narrow_slice(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.narrow((slice(2, 8), slice(5, 15))) + assert result.inclusive_min == (2, 5) + assert result.exclusive_max == (8, 15) + + def test_narrow_int(self) -> None: + d = IndexDomain.from_shape((10, 20)) + result = d.narrow((3, slice(None))) + assert result.inclusive_min == (3, 0) + assert result.exclusive_max == (4, 20) + + def test_narrow_ellipsis(self) -> None: + d = IndexDomain.from_shape((10, 20, 30)) + result = d.narrow((slice(1, 5), ...)) + assert result.inclusive_min == (1, 0, 0) + assert result.exclusive_max == (5, 20, 30) + + def test_narrow_slice_none(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow((slice(None),)) + assert result == d + + def test_narrow_non_zero_origin(self) -> None: + d = IndexDomain(inclusive_min=(10,), exclusive_max=(20,)) + result = d.narrow((slice(12, 18),)) + assert result.inclusive_min == (12,) + assert result.exclusive_max == (18,) + + def test_narrow_int_out_of_bounds(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="out of bounds"): + d.narrow((10,)) + + def test_narrow_int_below_origin(self) -> None: + d = IndexDomain(inclusive_min=(5,), exclusive_max=(10,)) + with pytest.raises(IndexError, match="out of bounds"): + d.narrow((4,)) + + def test_narrow_clamps_to_domain(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow((slice(-5, 100),)) + assert result.inclusive_min == (0,) + assert result.exclusive_max == (10,) + + def test_narrow_bare_slice(self) -> None: + d = IndexDomain.from_shape((10,)) + result = d.narrow(slice(2, 8)) + assert result.inclusive_min == (2,) + assert result.exclusive_max == (8,) + + def test_narrow_too_many_indices(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="too many indices"): + d.narrow((1, 2)) + + def test_narrow_step_not_one(self) -> None: + d = IndexDomain.from_shape((10,)) + with pytest.raises(IndexError, match="step=1"): + d.narrow((slice(0, 10, 2),)) diff --git a/tests/test_transforms/test_output_map.py b/tests/test_transforms/test_output_map.py new file mode 100644 index 0000000000..57dbebaf44 --- /dev/null +++ b/tests/test_transforms/test_output_map.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap + + +class TestConstantMap: + def test_construction(self) -> None: + m = ConstantMap(offset=42) + assert m.offset == 42 + + def test_default_offset(self) -> None: + m = ConstantMap() + assert m.offset == 0 + + def test_frozen(self) -> None: + m = ConstantMap(offset=5) + assert isinstance(m, ConstantMap) + + +class TestDimensionMap: + def test_construction(self) -> None: + m = DimensionMap(input_dimension=3, offset=5, stride=2) + assert m.input_dimension == 3 + assert m.offset == 5 + assert m.stride == 2 + + def test_defaults(self) -> None: + m = DimensionMap(input_dimension=0) + assert m.offset == 0 + assert m.stride == 1 + + def test_frozen(self) -> None: + m = DimensionMap(input_dimension=0) + assert isinstance(m, DimensionMap) + + +class TestArrayMap: + def test_construction(self) -> None: + arr = np.array([1, 3, 5], dtype=np.intp) + m = ArrayMap(index_array=arr, offset=10, stride=2) + assert m.offset == 10 + assert m.stride == 2 + np.testing.assert_array_equal(m.index_array, arr) + + def test_defaults(self) -> None: + arr = np.array([0, 1], dtype=np.intp) + m = ArrayMap(index_array=arr) + assert m.offset == 0 + assert m.stride == 1 + + def test_frozen(self) -> None: + arr = np.array([0], dtype=np.intp) + m = ArrayMap(index_array=arr) + assert isinstance(m, ArrayMap) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py new file mode 100644 index 0000000000..f531714045 --- /dev/null +++ b/tests/test_transforms/test_transform.py @@ -0,0 +1,516 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core.transforms.transform import IndexTransform, selection_to_transform + + +class TestIndexTransformConstruction: + def test_from_shape(self) -> None: + t = IndexTransform.from_shape((10, 20)) + assert t.input_rank == 2 + assert t.output_rank == 2 + assert t.domain.shape == (10, 20) + assert t.domain.origin == (0, 0) + for i, m in enumerate(t.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + def test_identity(self) -> None: + domain = IndexDomain(inclusive_min=(5,), exclusive_max=(15,)) + t = IndexTransform.identity(domain) + assert t.input_rank == 1 + assert t.output_rank == 1 + assert t.domain == domain + assert isinstance(t.output[0], DimensionMap) + assert t.output[0].input_dimension == 0 + + def test_from_shape_0d(self) -> None: + t = IndexTransform.from_shape(()) + assert t.input_rank == 0 + assert t.output_rank == 0 + assert t.domain.shape == () + + def test_custom_output_maps(self) -> None: + domain = IndexDomain.from_shape((10,)) + maps = (ConstantMap(offset=42), DimensionMap(input_dimension=0, offset=5, stride=2)) + t = IndexTransform(domain=domain, output=maps) + assert t.input_rank == 1 + assert t.output_rank == 2 + + def test_validation_input_dimension_out_of_range(self) -> None: + domain = IndexDomain.from_shape((10,)) + maps = (DimensionMap(input_dimension=5),) + with pytest.raises(ValueError, match="input_dimension"): + IndexTransform(domain=domain, output=maps) + + +class TestIndexTransformBasicIndexing: + def test_slice_identity(self) -> None: + """slice(None) on identity transform is a no-op.""" + t = IndexTransform.from_shape((10, 20)) + result = t[slice(None), slice(None)] + assert result.domain.shape == (10, 20) + assert result.input_rank == 2 + assert result.output_rank == 2 + + def test_slice_narrows(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[2:8, 5:15] + assert result.domain.shape == (6, 10) + assert result.domain.origin == (0, 0) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 2 + assert result.output[0].stride == 1 + assert result.output[0].input_dimension == 0 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 5 + assert result.output[1].input_dimension == 1 + + def test_strided_slice(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t[::2] + assert result.domain.shape == (5,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 0 + assert result.output[0].stride == 2 + + def test_strided_slice_with_start(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t[1:9:3] + # indices: 1, 4, 7 -> 3 elements + assert result.domain.shape == (3,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 1 + assert result.output[0].stride == 3 + + def test_int_drops_dimension(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[3] + assert result.input_rank == 1 + assert result.output_rank == 2 + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 3 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + + def test_int_middle_dimension(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + result = t[:, 5, :] + assert result.input_rank == 2 + assert result.output_rank == 3 + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].input_dimension == 0 + assert isinstance(result.output[1], ConstantMap) + assert result.output[1].offset == 5 + assert isinstance(result.output[2], DimensionMap) + assert result.output[2].input_dimension == 1 + + def test_ellipsis(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + result = t[2:8, ...] + assert result.input_rank == 3 + assert result.domain.shape == (6, 20, 30) + + def test_newaxis(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[np.newaxis, :, :] + assert result.input_rank == 3 + assert result.domain.shape == (1, 10, 20) + assert result.output_rank == 2 + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].input_dimension == 1 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 2 + + def test_int_out_of_bounds(self) -> None: + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError): + t[10] + + def test_negative_int_is_literal(self) -> None: + """Negative indices are literal coordinates (TensorStore convention), + not 'from the end' like NumPy.""" + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError): + t[-1] # -1 is out of bounds for domain [0, 10) + + def test_negative_int_valid_with_negative_origin(self) -> None: + """Negative index is valid if the domain includes negative coordinates.""" + domain = IndexDomain(inclusive_min=(-5,), exclusive_max=(5,)) + t = IndexTransform.identity(domain) + result = t[-3] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == -3 + + def test_composition_of_slices(self) -> None: + """Slicing a sliced transform should compose offsets.""" + t = IndexTransform.from_shape((100,)) + result = t[10:50][5:20] + assert result.domain.shape == (15,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 15 + assert result.output[0].stride == 1 + + def test_composition_of_strides(self) -> None: + t = IndexTransform.from_shape((100,)) + result = t[::2][::3] + # t[::2] -> shape (50,), offset=0, stride=2 + # [::3] -> shape ceil(50/3)=17, offset=0, stride=2*3=6 + assert result.domain.shape == (17,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].stride == 6 + + def test_bare_int(self) -> None: + """Non-tuple selection.""" + t = IndexTransform.from_shape((10, 20)) + result = t[3] + assert result.input_rank == 1 + + def test_bare_slice(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t[2:8] + assert result.domain.shape == (6, 20) + + +class TestBasicIndexingOnArrayMaps: + """When a transform already has ArrayMap outputs, basic indexing must + apply the corresponding operation to the index_array's axes.""" + + def test_int_on_array_map_drops_axis(self) -> None: + """Integer index on a dimension referenced by an ArrayMap should + index into the array on that axis.""" + arr = np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp) + # 2D input domain (3, 2), one ArrayMap output + t = IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=(ArrayMap(index_array=arr),), + ) + # Index with int on dim 0 -> pick row 1 -> arr[1, :] = [30, 40] + result = t[1] + assert result.input_rank == 1 + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([30, 40])) + + def test_slice_on_array_map(self) -> None: + """Slice on a dimension referenced by an ArrayMap should slice the array.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[1:4] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30, 40])) + + def test_strided_slice_on_array_map(self) -> None: + """Strided slice on ArrayMap should stride the array.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[::2] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 30, 50])) + + def test_newaxis_on_array_map(self) -> None: + """Newaxis should insert an axis in the index_array.""" + arr = np.array([10, 20, 30], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + result = t[np.newaxis, :] + assert result.input_rank == 2 + assert result.domain.shape == (1, 3) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].index_array.shape == (1, 3) + np.testing.assert_array_equal(result.output[0].index_array, np.array([[10, 20, 30]])) + + def test_int_drops_one_of_two_array_dims(self) -> None: + """2D array map, int on dim 0, slice on dim 1.""" + arr = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((2, 3)), + output=(ArrayMap(index_array=arr),), + ) + result = t[0, 1:3] + assert result.input_rank == 1 + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + # arr[0, 1:3] = [20, 30] + np.testing.assert_array_equal(result.output[0].index_array, np.array([20, 30])) + + +class TestIndexTransformOindex: + def test_oindex_int_array(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.oindex[idx, :] + assert result.input_rank == 2 + assert result.domain.shape == (3, 20) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx) + assert result.output[0].offset == 0 + assert result.output[0].stride == 1 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 1 + + def test_oindex_bool_array(self) -> None: + t = IndexTransform.from_shape((5,)) + mask = np.array([True, False, True, False, True]) + result = t.oindex[mask] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal( + result.output[0].index_array, np.array([0, 2, 4], dtype=np.intp) + ) + + def test_oindex_mixed(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([2, 4], dtype=np.intp) + result = t.oindex[idx, 5:15] + assert result.input_rank == 2 + assert result.domain.shape == (2, 10) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 5 + + def test_oindex_multiple_arrays(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + idx0 = np.array([1, 3], dtype=np.intp) + idx1 = np.array([5, 10, 15], dtype=np.intp) + result = t.oindex[idx0, :, idx1] + assert result.input_rank == 3 + assert result.domain.shape == (2, 20, 3) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + assert isinstance(result.output[2], ArrayMap) + + +class TestIndexTransformVindex: + def test_vindex_single_array(self) -> None: + t = IndexTransform.from_shape((10,)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.vindex[idx] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx) + + def test_vindex_broadcast(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([[1, 2], [3, 4]], dtype=np.intp) + idx1 = np.array([[10, 11], [12, 13]], dtype=np.intp) + result = t.vindex[idx0, idx1] + assert result.input_rank == 2 + assert result.domain.shape == (2, 2) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, idx0) + np.testing.assert_array_equal(result.output[1].index_array, idx1) + + def test_vindex_with_slice(self) -> None: + t = IndexTransform.from_shape((10, 20, 30)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = t.vindex[idx, :, :] + assert result.input_rank == 3 + assert result.domain.shape == (3, 20, 30) + assert isinstance(result.output[0], ArrayMap) + + def test_vindex_bool_mask(self) -> None: + t = IndexTransform.from_shape((5,)) + mask = np.array([True, False, True, False, True]) + result = t.vindex[mask] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + + def test_vindex_broadcast_different_shapes(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([1, 2, 3], dtype=np.intp) + idx1 = np.array([[10], [11]], dtype=np.intp) + result = t.vindex[idx0, idx1] + assert result.input_rank == 2 + assert result.domain.shape == (2, 3) + + +class TestSelectionToTransform: + def test_basic_slice(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform((slice(2, 8), slice(5, 15)), t, "basic") + assert result.domain.shape == (6, 10) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 2 + + def test_basic_int(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform((3, slice(None)), t, "basic") + assert result.input_rank == 1 + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 3 + + def test_basic_ellipsis(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = selection_to_transform(Ellipsis, t, "basic") + assert result.domain.shape == (10, 20) + + def test_orthogonal(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx = np.array([1, 3, 5], dtype=np.intp) + result = selection_to_transform((idx, slice(None)), t, "orthogonal") + assert result.domain.shape == (3, 20) + assert isinstance(result.output[0], ArrayMap) + + def test_vectorized(self) -> None: + t = IndexTransform.from_shape((10, 20)) + idx0 = np.array([1, 3], dtype=np.intp) + idx1 = np.array([5, 7], dtype=np.intp) + result = selection_to_transform((idx0, idx1), t, "vectorized") + assert result.domain.shape == (2,) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + def test_composition_with_non_identity(self) -> None: + """Indexing a sliced transform composes offsets.""" + t = IndexTransform.from_shape((100,))[10:50] + result = selection_to_transform(slice(5, 20), t, "basic") + assert result.domain.shape == (15,) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == 15 + + +class TestIndexTransformIntersect: + def test_constant_inside(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(0,), exclusive_max=(10,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert surviving is None + + def test_constant_outside(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) + assert result is None + + def test_dimension_partial(self) -> None: + """DimensionMap over [0,10) intersected with [5,15) narrows input to [5,10).""" + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))) + assert result is not None + restricted, surviving = result + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + assert surviving is None + + def test_dimension_no_overlap(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(20,), exclusive_max=(30,))) + assert result is None + + def test_dimension_strided(self) -> None: + """stride=2, offset=1 over [0,5): storage 1,3,5,7,9. Chunk [4,8).""" + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=2),), + ) + result = t.intersect(IndexDomain(inclusive_min=(4,), exclusive_max=(8,))) + assert result is not None + restricted, _surviving = result + # input 2->5, input 3->7. Both in [4,8). + assert restricted.domain.inclusive_min == (2,) + assert restricted.domain.exclusive_max == (4,) + + def test_array_partial(self) -> None: + arr = np.array([3, 8, 15, 22], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(ArrayMap(index_array=arr),), + ) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([8, 15])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([1, 2])) + + def test_array_none_inside(self) -> None: + arr = np.array([1, 2, 3], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + def test_2d_mixed(self) -> None: + """2D: ConstantMap on dim 0, DimensionMap on dim 1.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 5), exclusive_max=(10, 15)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert isinstance(restricted.output[1], DimensionMap) + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + + +class TestIndexTransformTranslate: + def test_translate_constant(self) -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.translate((-5,)) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 0 + + def test_translate_dimension(self) -> None: + t = IndexTransform.from_shape((10,)) + result = t.translate((-3,)) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == -3 + assert result.output[0].stride == 1 + + def test_translate_array(self) -> None: + arr = np.array([5, 10], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((2,)), + output=(ArrayMap(index_array=arr, offset=3),), + ) + result = t.translate((-3,)) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 0 + np.testing.assert_array_equal(result.output[0].index_array, arr) + + def test_translate_2d(self) -> None: + t = IndexTransform.from_shape((10, 20)) + result = t.translate((-5, -10)) + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == -5 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == -10 From 79cd9c80a0b743f350f0683fa7bd75597342314e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 14 Apr 2026 15:08:22 +0200 Subject: [PATCH 2/2] feat: add JSON serialization for IndexTransform (TensorStore-compatible) Add TypedDict definitions and conversion functions for serializing IndexDomain, OutputIndexMap, and IndexTransform to/from JSON. The JSON format follows TensorStore's conventions for interoperability: - IndexDomain: input_inclusive_min, input_exclusive_max, input_labels - OutputIndexMap: offset + optional stride/input_dimension/index_array - IndexTransform: domain fields + output array TypedDicts: IndexDomainJSON, OutputIndexMapJSON, IndexTransformJSON Functions: index_domain_to_json, index_domain_from_json, index_transform_to_json, index_transform_from_json Co-Authored-By: Claude Opus 4.6 (1M context) --- src/zarr/core/transforms/__init__.py | 16 +++ src/zarr/core/transforms/json.py | 163 ++++++++++++++++++++++ tests/test_transforms/test_json.py | 199 +++++++++++++++++++++++++++ 3 files changed, 378 insertions(+) create mode 100644 src/zarr/core/transforms/json.py create mode 100644 tests/test_transforms/test_json.py diff --git a/src/zarr/core/transforms/__init__.py b/src/zarr/core/transforms/__init__.py index 530dd39cea..ec98e02d4d 100644 --- a/src/zarr/core/transforms/__init__.py +++ b/src/zarr/core/transforms/__init__.py @@ -16,6 +16,15 @@ from zarr.core.transforms.composition import compose from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.json import ( + IndexDomainJSON, + IndexTransformJSON, + OutputIndexMapJSON, + index_domain_from_json, + index_domain_to_json, + index_transform_from_json, + index_transform_to_json, +) from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap from zarr.core.transforms.transform import IndexTransform @@ -24,7 +33,14 @@ "ConstantMap", "DimensionMap", "IndexDomain", + "IndexDomainJSON", "IndexTransform", + "IndexTransformJSON", "OutputIndexMap", + "OutputIndexMapJSON", "compose", + "index_domain_from_json", + "index_domain_to_json", + "index_transform_from_json", + "index_transform_to_json", ] diff --git a/src/zarr/core/transforms/json.py b/src/zarr/core/transforms/json.py new file mode 100644 index 0000000000..59a81cc1e8 --- /dev/null +++ b/src/zarr/core/transforms/json.py @@ -0,0 +1,163 @@ +"""JSON serialization for index transforms. + +Defines TypedDict types matching TensorStore's JSON representation of +IndexTransform and IndexDomain, plus conversion functions. + +The JSON format follows TensorStore's conventions for interoperability:: + + { + "input_inclusive_min": [0, 0], + "input_exclusive_max": [100, 200], + "input_labels": ["x", "y"], + "output": [ + {"offset": 5}, + {"offset": 10, "stride": 2, "input_dimension": 1}, + {"offset": 0, "stride": 1, "index_array": [[1, 2, 0]]} + ] + } +""" + +from __future__ import annotations + +from typing import Required, TypedDict + +import numpy as np + +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap +from zarr.core.transforms.transform import IndexTransform + +# --------------------------------------------------------------------------- +# TypedDict definitions (JSON shapes) +# --------------------------------------------------------------------------- + + +class IndexDomainJSON(TypedDict, total=False): + """JSON representation of an IndexDomain.""" + + input_inclusive_min: Required[list[int]] + input_exclusive_max: Required[list[int]] + input_labels: list[str] + + +class OutputIndexMapJSON(TypedDict, total=False): + """JSON representation of a single output index map. + + Exactly one of three forms: + - ``{"offset": 5}`` — constant + - ``{"offset": 0, "stride": 1, "input_dimension": 0}`` — dimension + - ``{"offset": 0, "stride": 1, "index_array": [...]}`` — array + """ + + offset: int + stride: int + input_dimension: int + index_array: list[int] | list[list[int]] + + +class IndexTransformJSON(TypedDict, total=False): + """JSON representation of an IndexTransform.""" + + input_inclusive_min: Required[list[int]] + input_exclusive_max: Required[list[int]] + input_labels: list[str] + output: Required[list[OutputIndexMapJSON]] + + +# --------------------------------------------------------------------------- +# IndexDomain serialization +# --------------------------------------------------------------------------- + + +def index_domain_to_json(domain: IndexDomain) -> IndexDomainJSON: + """Convert an IndexDomain to its JSON representation.""" + result: IndexDomainJSON = { + "input_inclusive_min": list(domain.inclusive_min), + "input_exclusive_max": list(domain.exclusive_max), + } + if domain.labels is not None: + result["input_labels"] = list(domain.labels) + return result + + +def index_domain_from_json(data: IndexDomainJSON) -> IndexDomain: + """Construct an IndexDomain from its JSON representation.""" + return IndexDomain( + inclusive_min=tuple(data["input_inclusive_min"]), + exclusive_max=tuple(data["input_exclusive_max"]), + labels=tuple(data["input_labels"]) if "input_labels" in data else None, + ) + + +# --------------------------------------------------------------------------- +# OutputIndexMap serialization +# --------------------------------------------------------------------------- + + +def output_index_map_to_json(m: OutputIndexMap) -> OutputIndexMapJSON: + """Convert an output index map to its JSON representation.""" + if isinstance(m, ConstantMap): + result: OutputIndexMapJSON = {"offset": m.offset} + return result + + if isinstance(m, DimensionMap): + result = {"offset": m.offset, "input_dimension": m.input_dimension} + if m.stride != 1: + result["stride"] = m.stride + return result + + if isinstance(m, ArrayMap): + result = {"offset": m.offset, "index_array": m.index_array.tolist()} + if m.stride != 1: + result["stride"] = m.stride + return result + + raise TypeError(f"Unknown output map type: {type(m)}") + + +def output_index_map_from_json(data: OutputIndexMapJSON) -> OutputIndexMap: + """Construct an output index map from its JSON representation.""" + if "index_array" in data: + return ArrayMap( + index_array=np.asarray(data["index_array"], dtype=np.intp), + offset=data.get("offset", 0), + stride=data.get("stride", 1), + ) + + if "input_dimension" in data: + return DimensionMap( + input_dimension=data["input_dimension"], + offset=data.get("offset", 0), + stride=data.get("stride", 1), + ) + + # Constant map: only offset present + return ConstantMap(offset=data.get("offset", 0)) + + +# --------------------------------------------------------------------------- +# IndexTransform serialization +# --------------------------------------------------------------------------- + + +def index_transform_to_json(transform: IndexTransform) -> IndexTransformJSON: + """Convert an IndexTransform to its JSON representation.""" + result: IndexTransformJSON = { + "input_inclusive_min": list(transform.domain.inclusive_min), + "input_exclusive_max": list(transform.domain.exclusive_max), + "output": [output_index_map_to_json(m) for m in transform.output], + } + if transform.domain.labels is not None: + result["input_labels"] = list(transform.domain.labels) + return result + + +def index_transform_from_json(data: IndexTransformJSON) -> IndexTransform: + """Construct an IndexTransform from its JSON representation.""" + domain = IndexDomain( + inclusive_min=tuple(data["input_inclusive_min"]), + exclusive_max=tuple(data["input_exclusive_max"]), + labels=tuple(data["input_labels"]) if "input_labels" in data else None, + ) + output = tuple(output_index_map_from_json(m) for m in data["output"]) + return IndexTransform(domain=domain, output=output) diff --git a/tests/test_transforms/test_json.py b/tests/test_transforms/test_json.py new file mode 100644 index 0000000000..db582a9561 --- /dev/null +++ b/tests/test_transforms/test_json.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core.transforms.domain import IndexDomain +from zarr.core.transforms.json import ( + IndexTransformJSON, + index_domain_from_json, + index_domain_to_json, + index_transform_from_json, + index_transform_to_json, + output_index_map_from_json, + output_index_map_to_json, +) +from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core.transforms.transform import IndexTransform + + +class TestIndexDomainJSON: + def test_roundtrip(self) -> None: + domain = IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) + json = index_domain_to_json(domain) + assert json == {"input_inclusive_min": [2, 5], "input_exclusive_max": [10, 20]} + restored = index_domain_from_json(json) + assert restored == domain + + def test_with_labels(self) -> None: + domain = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 20), labels=("x", "y")) + json = index_domain_to_json(domain) + assert json["input_labels"] == ["x", "y"] + restored = index_domain_from_json(json) + assert restored.labels == ("x", "y") + + def test_without_labels(self) -> None: + domain = IndexDomain.from_shape((5,)) + json = index_domain_to_json(domain) + assert "input_labels" not in json + restored = index_domain_from_json(json) + assert restored.labels is None + + def test_zero_origin(self) -> None: + domain = IndexDomain.from_shape((10, 20, 30)) + json = index_domain_to_json(domain) + assert json == { + "input_inclusive_min": [0, 0, 0], + "input_exclusive_max": [10, 20, 30], + } + assert index_domain_from_json(json) == domain + + +class TestOutputIndexMapJSON: + def test_constant(self) -> None: + m = ConstantMap(offset=42) + json = output_index_map_to_json(m) + assert json == {"offset": 42} + restored = output_index_map_from_json(json) + assert isinstance(restored, ConstantMap) + assert restored.offset == 42 + + def test_constant_zero(self) -> None: + m = ConstantMap(offset=0) + json = output_index_map_to_json(m) + assert json == {"offset": 0} + restored = output_index_map_from_json(json) + assert isinstance(restored, ConstantMap) + assert restored.offset == 0 + + def test_dimension(self) -> None: + m = DimensionMap(input_dimension=1, offset=10, stride=3) + json = output_index_map_to_json(m) + assert json == {"offset": 10, "stride": 3, "input_dimension": 1} + restored = output_index_map_from_json(json) + assert isinstance(restored, DimensionMap) + assert restored.input_dimension == 1 + assert restored.offset == 10 + assert restored.stride == 3 + + def test_dimension_stride_1_omitted(self) -> None: + """stride=1 is the default and should be omitted from JSON.""" + m = DimensionMap(input_dimension=0) + json = output_index_map_to_json(m) + assert "stride" not in json + assert json == {"offset": 0, "input_dimension": 0} + restored = output_index_map_from_json(json) + assert isinstance(restored, DimensionMap) + assert restored.stride == 1 + + def test_array(self) -> None: + arr = np.array([1, 5, 9], dtype=np.intp) + m = ArrayMap(index_array=arr, offset=2, stride=3) + json = output_index_map_to_json(m) + assert json == {"offset": 2, "stride": 3, "index_array": [1, 5, 9]} + restored = output_index_map_from_json(json) + assert isinstance(restored, ArrayMap) + np.testing.assert_array_equal(restored.index_array, arr) + assert restored.offset == 2 + assert restored.stride == 3 + + def test_array_stride_1_omitted(self) -> None: + arr = np.array([0, 1, 2], dtype=np.intp) + m = ArrayMap(index_array=arr) + json = output_index_map_to_json(m) + assert "stride" not in json + restored = output_index_map_from_json(json) + assert isinstance(restored, ArrayMap) + assert restored.stride == 1 + + def test_array_2d(self) -> None: + arr = np.array([[1, 2], [3, 4]], dtype=np.intp) + m = ArrayMap(index_array=arr) + json = output_index_map_to_json(m) + assert json["index_array"] == [[1, 2], [3, 4]] + restored = output_index_map_from_json(json) + assert isinstance(restored, ArrayMap) + np.testing.assert_array_equal(restored.index_array, arr) + + +class TestIndexTransformJSON: + def test_identity(self) -> None: + t = IndexTransform.from_shape((10, 20)) + json = index_transform_to_json(t) + assert json == { + "input_inclusive_min": [0, 0], + "input_exclusive_max": [10, 20], + "output": [ + {"offset": 0, "input_dimension": 0}, + {"offset": 0, "input_dimension": 1}, + ], + } + restored = index_transform_from_json(json) + assert restored.domain == t.domain + assert len(restored.output) == 2 + for orig, rest in zip(t.output, restored.output, strict=True): + assert type(orig) is type(rest) + + def test_sliced(self) -> None: + t = IndexTransform.from_shape((100,))[10:50:2] + json = index_transform_to_json(t) + restored = index_transform_from_json(json) + assert restored.domain.shape == t.domain.shape + assert isinstance(restored.output[0], DimensionMap) + orig = t.output[0] + assert isinstance(orig, DimensionMap) + assert restored.output[0].offset == orig.offset + assert restored.output[0].stride == orig.stride + + def test_with_constant(self) -> None: + t = IndexTransform.from_shape((10, 20))[3] + json = index_transform_to_json(t) + restored = index_transform_from_json(json) + assert isinstance(restored.output[0], ConstantMap) + assert restored.output[0].offset == 3 + assert isinstance(restored.output[1], DimensionMap) + + def test_with_array(self) -> None: + idx = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform.from_shape((10, 20)).oindex[idx, :] + json = index_transform_to_json(t) + restored = index_transform_from_json(json) + assert isinstance(restored.output[0], ArrayMap) + np.testing.assert_array_equal(restored.output[0].index_array, idx) + assert isinstance(restored.output[1], DimensionMap) + + def test_with_labels(self) -> None: + domain = IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 20), labels=("x", "y")) + t = IndexTransform.identity(domain) + json = index_transform_to_json(t) + assert json["input_labels"] == ["x", "y"] + restored = index_transform_from_json(json) + assert restored.domain.labels == ("x", "y") + + def test_tensorstore_compatible_format(self) -> None: + """Verify the JSON matches TensorStore's format exactly.""" + json: IndexTransformJSON = { + "input_inclusive_min": [0, 0, 0], + "input_exclusive_max": [100, 200, 3], + "input_labels": ["x", "y", "channel"], + "output": [ + {"offset": 5}, + {"offset": 10, "stride": 2, "input_dimension": 1}, + {"offset": 0, "stride": 1, "index_array": [1, 2, 0]}, + ], + } + t = index_transform_from_json(json) + assert t.domain.shape == (100, 200, 3) + assert t.domain.labels == ("x", "y", "channel") + assert isinstance(t.output[0], ConstantMap) + assert t.output[0].offset == 5 + assert isinstance(t.output[1], DimensionMap) + assert t.output[1].offset == 10 + assert t.output[1].stride == 2 + assert t.output[1].input_dimension == 1 + assert isinstance(t.output[2], ArrayMap) + np.testing.assert_array_equal(t.output[2].index_array, [1, 2, 0]) + + # Roundtrip + json_rt = index_transform_to_json(t) + t_rt = index_transform_from_json(json_rt) + assert t_rt.domain == t.domain