diff --git a/changes/3717.misc.md b/changes/3717.misc.md new file mode 100644 index 0000000000..5fed76b2b7 --- /dev/null +++ b/changes/3717.misc.md @@ -0,0 +1 @@ +Add benchmarks for Morton order computation with non-power-of-2 and near-miss shard shapes, covering both pure computation and end-to-end read/write performance. diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 454f7e2290..aaddc58e34 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1512,54 +1512,47 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: out.flags.writeable = False return out - # Optimization: Remove singleton dimensions to enable magic number usage - # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. - singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1) - if singleton_dims: - squeezed_shape = tuple(s for s in chunk_shape if s != 1) - if squeezed_shape: - # Compute Morton order on squeezed shape, then expand singleton dims (always 0) - squeezed_order = np.asarray(_morton_order(squeezed_shape)) - out = np.zeros((n_total, n_dims), dtype=np.intp) - squeezed_col = 0 - for full_col in range(n_dims): - if chunk_shape[full_col] != 1: - out[:, full_col] = squeezed_order[:, squeezed_col] - squeezed_col += 1 - else: - # All dimensions are singletons, just return the single point - out = np.zeros((1, n_dims), dtype=np.intp) - out.flags.writeable = False - return out - - # Find the largest power-of-2 hypercube that fits within chunk_shape. - # Within this hypercube, Morton codes are guaranteed to be in bounds. - min_dim = min(chunk_shape) - if min_dim >= 1: - power = min_dim.bit_length() - 1 # floor(log2(min_dim)) - hypercube_size = 1 << power # 2^power - n_hypercube = hypercube_size**n_dims + # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span + # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number + # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits + # is the size of this hypercube. + total_bits = sum((c - 1).bit_length() for c in chunk_shape) + n_z = 1 << total_bits if total_bits > 0 else 1 + + # Decode all Morton codes in the ceiling hypercube, then filter to valid coords. + # This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33): + # n_z=262144, n_total=35937), consider the argsort strategy below. + if n_z <= 4 * n_total: + # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds. + # Works well when the overgeneration ratio n_z/n_total is small (≤4). + z_values = np.arange(n_z, dtype=np.intp) + all_coords = decode_morton_vectorized(z_values, chunk_shape) + shape_arr = np.array(chunk_shape, dtype=np.intp) + valid_mask = np.all(all_coords < shape_arr, axis=1) + order = all_coords[valid_mask] else: - n_hypercube = 0 + # Argsort strategy: enumerate all n_total valid coordinates directly, + # encode each to a Morton code, then sort by code. Avoids the 8x or + # larger overgeneration penalty for near-miss shapes like (33,33,33). + # Cost: O(n_total * bits) encode + O(n_total log n_total) sort, + # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling. + grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij") + all_coords = np.stack([g.ravel() for g in grids], axis=1) + + # Encode all coordinates to Morton codes (vectorized). + bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape) + max_coord_bits = max(bits_per_dim) + z_codes = np.zeros(n_total, dtype=np.intp) + output_bit = 0 + for coord_bit in range(max_coord_bits): + for dim in range(n_dims): + if coord_bit < bits_per_dim[dim]: + z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit + output_bit += 1 + + sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable") + order = all_coords[sort_idx] - # Within the hypercube, no bounds checking needed - use vectorized decoding - if n_hypercube > 0: - z_values = np.arange(n_hypercube, dtype=np.intp) - order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape) - else: - order = np.empty((0, n_dims), dtype=np.intp) - - # For remaining elements outside the hypercube, bounds checking is needed - remaining: list[tuple[int, ...]] = [] - i = n_hypercube - while len(order) + len(remaining) < n_total: - m = decode_morton(i, chunk_shape) - if all(x < y for x, y in zip(m, chunk_shape, strict=False)): - remaining.append(m) - i += 1 - - if remaining: - order = np.vstack([order, np.array(remaining, dtype=np.intp)]) order.flags.writeable = False return order diff --git a/tests/benchmarks/test_indexing.py b/tests/benchmarks/test_indexing.py index d30d731f0f..385a85b5b5 100644 --- a/tests/benchmarks/test_indexing.py +++ b/tests/benchmarks/test_indexing.py @@ -106,7 +106,10 @@ def read_with_cache_clear() -> None: # Benchmark with larger chunks_per_shard to make Morton order impact more visible large_morton_shards = ( - (32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard + (32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard (power-of-2) + (30,) * 3, # With 1x1x1 chunks: 30x30x30 = 27000 chunks per shard (non-power-of-2) + (33,) + * 3, # With 1x1x1 chunks: 33x33x33 = 35937 chunks per shard (near-miss: just above power-of-2) ) @@ -197,9 +200,13 @@ def read_with_cache_clear() -> None: # Benchmark for morton_order_iter directly (no I/O) morton_iter_shapes = ( - (8, 8, 8), # 512 elements - (16, 16, 16), # 4096 elements - (32, 32, 32), # 32768 elements + (8, 8, 8), # 512 elements (power-of-2) + (10, 10, 10), # 1000 elements (non-power-of-2) + (16, 16, 16), # 4096 elements (power-of-2) + (20, 20, 20), # 8000 elements (non-power-of-2) + (32, 32, 32), # 32768 elements (power-of-2) + (30, 30, 30), # 27000 elements (non-power-of-2) + (33, 33, 33), # 35937 elements (near-miss: just above power-of-2, n_z=262144) )