diff --git a/changes/3705.bugfix.md b/changes/3705.bugfix.md new file mode 100644 index 0000000000..2abcb4ee7c --- /dev/null +++ b/changes/3705.bugfix.md @@ -0,0 +1 @@ +Fix a performance bug in morton curve generation. \ No newline at end of file diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 7f704bf2b7..beffa99cfa 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -7,7 +7,7 @@ from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum -from functools import reduce +from functools import lru_cache, reduce from types import EllipsisType from typing import ( TYPE_CHECKING, @@ -1467,16 +1467,21 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: return tuple(out) -def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - i = 0 +@lru_cache +def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + n_total = product(chunk_shape) order: list[tuple[int, ...]] = [] - while len(order) < product(chunk_shape): + i = 0 + while len(order) < n_total: m = decode_morton(i, chunk_shape) - if m not in order and all(x < y for x, y in zip(m, chunk_shape, strict=False)): + if all(x < y for x, y in zip(m, chunk_shape, strict=False)): order.append(m) i += 1 - for j in range(product(chunk_shape)): - yield order[j] + return tuple(order) + + +def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + return iter(_morton_order(tuple(chunk_shape))) def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index eae7168d49..fa2017876e 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -18,7 +18,7 @@ TransposeCodec, ) from zarr.core.buffer import default_buffer_prototype -from zarr.core.indexing import BasicSelection, morton_order_iter +from zarr.core.indexing import BasicSelection, decode_morton, morton_order_iter from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.dtype import UInt8 from zarr.errors import ZarrUserWarning @@ -171,7 +171,8 @@ def test_open(store: Store) -> None: assert a.metadata == b.metadata -def test_morton() -> None: +def test_morton_exact_order() -> None: + """Test exact morton ordering for power-of-2 shapes.""" assert list(morton_order_iter((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)] assert list(morton_order_iter((2, 2, 2))) == [ (0, 0, 0), @@ -206,21 +207,58 @@ def test_morton() -> None: @pytest.mark.parametrize( "shape", [ - [2, 2, 2], - [5, 2], - [2, 5], - [2, 9, 2], - [3, 2, 12], - [2, 5, 1], - [4, 3, 6, 2, 7], - [3, 2, 1, 6, 4, 5, 2], + (2, 2, 2), + (5, 2), + (2, 5), + (2, 9, 2), + (3, 2, 12), + (2, 5, 1), + (4, 3, 6, 2, 7), + (3, 2, 1, 6, 4, 5, 2), + (1,), + (1, 1), + (5, 1, 3), + (1, 4, 1, 2), ], ) -def test_morton2(shape: tuple[int, ...]) -> None: +def test_morton_is_permutation(shape: tuple[int, ...]) -> None: + """Test that morton_order_iter produces every valid coordinate exactly once.""" + import itertools + + from zarr.core.common import product + + order = list(morton_order_iter(shape)) + expected_len = product(shape) + # completeness: every valid coordinate is present + assert len(order) == expected_len + # no duplicates + assert len(set(order)) == expected_len + # all coordinates are within bounds + assert all(all(c < s for c, s in zip(coord, shape, strict=True)) for coord in order) + # the set of coordinates equals the full cartesian product + assert set(order) == set(itertools.product(*(range(s) for s in shape))) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 2), + (4, 4), + (2, 2, 2), + (4, 4, 4), + (2, 2, 2, 2), + ], +) +def test_morton_ordering(shape: tuple[int, ...]) -> None: + """Test that the iteration order matches consecutive decode_morton outputs. + + For power-of-2 shapes, every decode_morton output is in-bounds, + so the ordering should be exactly decode_morton(0), decode_morton(1), ... + """ + order = list(morton_order_iter(shape)) - for i, x in enumerate(order): - assert x not in order[:i] # no duplicates - assert all(x[j] < shape[j] for j in range(len(shape))) # all indices are within bounds + for i, coord in enumerate(order): + assert coord == decode_morton(i, shape) @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])