Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
601 changes: 542 additions & 59 deletions src/zarr/core/array.py

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions src/zarr/core/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Composable, lazy coordinate transforms for zarr array indexing.

This package implements TensorStore-inspired index transforms. The core idea:
every indexing operation (slicing, fancy indexing, etc.) produces a coordinate
mapping from user space to storage space. These mappings compose lazily — no
I/O until you explicitly read or write.

Key types:

- ``IndexDomain`` — a rectangular region of integer coordinates
- ``IndexTransform`` — maps input coordinates to storage coordinates
- ``ConstantMap``, ``DimensionMap``, ``ArrayMap`` — the three ways a single
output dimension can depend on the input (see ``output_map.py``)
- ``compose`` — chain two transforms into one
"""

from zarr.core.transforms.composition import compose
from zarr.core.transforms.domain import IndexDomain
from zarr.core.transforms.json import (
IndexDomainJSON,
IndexTransformJSON,
OutputIndexMapJSON,
index_domain_from_json,
index_domain_to_json,
index_transform_from_json,
index_transform_to_json,
)
from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap
from zarr.core.transforms.transform import IndexTransform

__all__ = [
"ArrayMap",
"ConstantMap",
"DimensionMap",
"IndexDomain",
"IndexDomainJSON",
"IndexTransform",
"IndexTransformJSON",
"OutputIndexMap",
"OutputIndexMapJSON",
"compose",
"index_domain_from_json",
"index_domain_to_json",
"index_transform_from_json",
"index_transform_to_json",
]
207 changes: 207 additions & 0 deletions src/zarr/core/transforms/chunk_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""Chunk resolution — mapping transforms to chunk-level I/O.

Given an ``IndexTransform`` (which coordinates a user wants to access) and a
``ChunkGrid`` (how storage is divided into chunks), chunk resolution answers:

For each chunk, which storage coordinates does this transform touch,
and where do those values land in the output buffer?

The algorithm is:

1. **Enumerate candidate chunks** — determine which chunks could possibly
be touched by the transform's output coordinate ranges.

2. **Intersect** — for each candidate chunk, call
``transform.intersect(chunk_domain)`` to restrict the transform to
coordinates within that chunk. If the intersection is empty, skip it.

3. **Translate** — shift the restricted transform to chunk-local coordinates
via ``transform.translate(-chunk_origin)``.

4. **Yield** — produce ``(chunk_coords, local_transform, surviving_indices)``
triples that the codec pipeline consumes.

``sub_transform_to_selections`` bridges from the transform representation
back to the raw ``(chunk_selection, out_selection, drop_axes)`` tuples that
the current codec pipeline expects. This bridge will go away when the codec
pipeline accepts transforms natively.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from zarr.core.transforms.domain import IndexDomain
from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap
from zarr.core.transforms.transform import IndexTransform

if TYPE_CHECKING:
from collections.abc import Iterator

from zarr.core.chunk_grids import ChunkGrid

ChunkTransformResult = tuple[
tuple[int, ...],
IndexTransform,
np.ndarray[Any, np.dtype[np.intp]] | None,
]


def iter_chunk_transforms(
transform: IndexTransform,
chunk_grid: ChunkGrid,
) -> Iterator[ChunkTransformResult]:
"""Resolve a composed IndexTransform against a ChunkGrid.

Yields ``(chunk_coords, sub_transform, out_indices)`` triples:

- ``chunk_coords``: which chunk to access.
- ``sub_transform``: maps output buffer coords to chunk-local coords.
- ``out_indices``: for vectorized/array indexing, the output scatter
indices (integer array). ``None`` for basic/slice indexing.
"""
dim_grids = chunk_grid._dimensions

# Enumerate all possible chunks via cartesian product of per-dim chunk ranges
# For each candidate chunk, intersect the transform with the chunk domain.
# The transform.intersect method handles both orthogonal and vectorized cases.
chunk_ranges: list[range] = []
for out_dim, m in enumerate(transform.output):
dg = dim_grids[out_dim]
if isinstance(m, ConstantMap):
# Single chunk
c = dg.index_to_chunk(m.offset)
chunk_ranges.append(range(c, c + 1))
elif isinstance(m, DimensionMap):
d = m.input_dimension
dim_lo = transform.domain.inclusive_min[d]
dim_hi = transform.domain.exclusive_max[d]
if dim_lo >= dim_hi:
return # empty domain
if m.stride > 0:
s_min = m.offset + m.stride * dim_lo
s_max = m.offset + m.stride * (dim_hi - 1)
else:
s_min = m.offset + m.stride * (dim_hi - 1)
s_max = m.offset + m.stride * dim_lo
first = dg.index_to_chunk(s_min)
last = dg.index_to_chunk(s_max)
chunk_ranges.append(range(first, last + 1))
elif isinstance(m, ArrayMap):
storage = m.offset + m.stride * m.index_array
flat = storage.ravel().astype(np.intp)
chunk_ids = dg.indices_to_chunks(flat)
first = int(chunk_ids.min())
last = int(chunk_ids.max())
chunk_ranges.append(range(first, last + 1))

import itertools

for chunk_coords_tuple in itertools.product(*chunk_ranges):
chunk_coords = tuple(int(c) for c in chunk_coords_tuple)

# Build the chunk domain in storage space
chunk_min: list[int] = []
chunk_max: list[int] = []
chunk_shift: list[int] = []
for out_dim, c in enumerate(chunk_coords):
dg = dim_grids[out_dim]
c_start = dg.chunk_offset(c)
c_size = dg.chunk_size(c)
chunk_min.append(c_start)
chunk_max.append(c_start + c_size)
chunk_shift.append(-c_start)

chunk_domain = IndexDomain(
inclusive_min=tuple(chunk_min),
exclusive_max=tuple(chunk_max),
)

# Intersect transform with chunk domain
result = transform.intersect(chunk_domain)
if result is None:
continue

restricted, surviving = result

# Translate to chunk-local coordinates
local = restricted.translate(tuple(chunk_shift))

yield (chunk_coords, local, surviving)


def sub_transform_to_selections(
sub_transform: IndexTransform,
out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None,
) -> tuple[
tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
tuple[int, ...],
]:
"""Convert a chunk-local sub-transform to raw selections for the codec pipeline.

Parameters
----------
sub_transform
A chunk-local IndexTransform (output maps already translated to
chunk-local coordinates).
out_indices
For vectorized indexing: the output scatter indices for this chunk.
None for orthogonal/basic indexing.

Returns
-------
tuple
``(chunk_selection, out_selection, drop_axes)``
"""
chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []
drop_axes: list[int] = []

for m in sub_transform.output:
if isinstance(m, ConstantMap):
chunk_sel.append(m.offset)
elif isinstance(m, DimensionMap):
dim_lo = sub_transform.domain.inclusive_min[m.input_dimension]
dim_hi = sub_transform.domain.exclusive_max[m.input_dimension]
start = m.offset + m.stride * dim_lo
stop = m.offset + m.stride * dim_hi
if m.stride < 0:
start, stop = stop + 1, start + 1
chunk_sel.append(slice(start, stop, m.stride))
elif isinstance(m, ArrayMap):
if m.offset == 0 and m.stride == 1:
chunk_sel.append(m.index_array)
else:
storage_coords = m.offset + m.stride * m.index_array
chunk_sel.append(storage_coords.astype(np.intp))

# Build out_sel: one entry per non-dropped output dim.
out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []

# Vectorized: multiple correlated ArrayMaps share one scatter index
is_vectorized = (
out_indices is not None
and sum(1 for m in sub_transform.output if isinstance(m, ArrayMap)) >= 2
)

if is_vectorized:
assert out_indices is not None
out_sel.append(out_indices)
else:
for m in sub_transform.output:
if isinstance(m, ConstantMap):
continue
if isinstance(m, DimensionMap):
lo = sub_transform.domain.inclusive_min[m.input_dimension]
hi = sub_transform.domain.exclusive_max[m.input_dimension]
out_sel.append(slice(lo, hi))
elif isinstance(m, ArrayMap):
if out_indices is not None:
# Orthogonal ArrayMap: out_indices has the surviving positions
out_sel.append(out_indices)
else:
out_sel.append(slice(0, len(m.index_array)))

return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes)
113 changes: 113 additions & 0 deletions src/zarr/core/transforms/composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import numpy as np

from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap
from zarr.core.transforms.transform import IndexTransform


def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform:
"""Compose two IndexTransforms.

``outer`` maps user coords (rank m) to intermediate coords (rank n).
``inner`` maps intermediate coords (rank n) to storage coords (rank p).
The result maps user coords (rank m) to storage coords (rank p).

Precondition: ``outer.output_rank == inner.domain.ndim``.
"""
if outer.output_rank != inner.domain.ndim:
raise ValueError(
f"outer output rank ({outer.output_rank}) must match inner input rank "
f"({inner.domain.ndim})"
)

result_output = [_compose_single(outer, inner_map) for inner_map in inner.output]

return IndexTransform(domain=outer.domain, output=tuple(result_output))


def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap:
"""Compose a single inner output map with the full outer transform."""
if isinstance(inner_map, ConstantMap):
return ConstantMap(offset=inner_map.offset)

if isinstance(inner_map, DimensionMap):
return _compose_dimension(outer, inner_map)

if isinstance(inner_map, ArrayMap):
return _compose_array(outer, inner_map)

raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover


def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap:
"""Compose when inner is a DimensionMap.

storage = offset_i + stride_i * intermediate[dim_i]
where intermediate[dim_i] = outer.output[dim_i](user_input)
"""
dim_i = inner_map.input_dimension
offset_i = inner_map.offset
stride_i = inner_map.stride
outer_map = outer.output[dim_i]

if isinstance(outer_map, ConstantMap):
return ConstantMap(offset=offset_i + stride_i * outer_map.offset)

if isinstance(outer_map, DimensionMap):
return DimensionMap(
input_dimension=outer_map.input_dimension,
offset=offset_i + stride_i * outer_map.offset,
stride=stride_i * outer_map.stride,
)

if isinstance(outer_map, ArrayMap):
return ArrayMap(
index_array=outer_map.index_array,
offset=offset_i + stride_i * outer_map.offset,
stride=stride_i * outer_map.stride,
)

raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover


def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap:
"""Compose when inner is an ArrayMap.

storage = offset_i + stride_i * arr_i[intermediate]
We need to evaluate arr_i at the intermediate coordinates produced by outer.
"""
arr_i = inner_map.index_array
offset_i = inner_map.offset
stride_i = inner_map.stride

# Check if all outer outputs are constant
all_constant = all(isinstance(m, ConstantMap) for m in outer.output)

if all_constant:
# Evaluate arr_i at the single constant point
idx = tuple(m.offset for m in outer.output if isinstance(m, ConstantMap))
value = int(arr_i[idx])
return ConstantMap(offset=offset_i + stride_i * value)

# For 1D inner array with a single outer output (simple case)
if arr_i.ndim == 1 and len(outer.output) == 1:
outer_map = outer.output[0]

if isinstance(outer_map, DimensionMap):
dim_size = outer.domain.shape[outer_map.input_dimension]
user_indices = np.arange(dim_size, dtype=np.intp)
intermediate_vals = outer_map.offset + outer_map.stride * user_indices
new_arr = arr_i[intermediate_vals]
return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i)

if isinstance(outer_map, ArrayMap):
intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array
new_arr = arr_i[intermediate_vals]
return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i)

# General multi-dim case: not yet implemented
raise NotImplementedError(
"Composing a multi-dimensional inner array map with non-constant outer maps "
"is not yet supported."
)
Loading
Loading