diff --git a/changes/3706.misc.md b/changes/3706.misc.md new file mode 100644 index 0000000000..70a0e44c58 --- /dev/null +++ b/changes/3706.misc.md @@ -0,0 +1 @@ +Allow NumPy ints as input when declaring a shape. \ No newline at end of file diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index d38949657e..275d062eba 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -21,6 +21,7 @@ overload, ) +import numpy as np from typing_extensions import ReadOnly from zarr.core.config import config as zarr_config @@ -37,7 +38,7 @@ ZMETADATA_V2_JSON = ".zmetadata" BytesLike = bytes | bytearray | memoryview -ShapeLike = Iterable[int] | int +ShapeLike = Iterable[int | np.integer[Any]] | int | np.integer[Any] # For backwards compatibility ChunkCoords = tuple[int, ...] ZarrFormat = Literal[2, 3] @@ -185,23 +186,28 @@ def parse_named_configuration( def parse_shapelike(data: ShapeLike) -> tuple[int, ...]: - if isinstance(data, int): + """ + Parse a shape-like input into an explicit shape. + """ + if isinstance(data, int | np.integer): if data < 0: raise ValueError(f"Expected a non-negative integer. Got {data} instead") - return (data,) + return (int(data),) try: data_tuple = tuple(data) except TypeError as e: msg = f"Expected an integer or an iterable of integers. Got {data} instead." raise TypeError(msg) from e - if not all(isinstance(v, int) for v in data_tuple): + if not all(isinstance(v, int | np.integer) for v in data_tuple): msg = f"Expected an iterable of integers. Got {data} instead." raise TypeError(msg) if not all(v > -1 for v in data_tuple): msg = f"Expected all values to be non-negative. Got {data} instead." raise ValueError(msg) - return data_tuple + + # cast NumPy scalars to plain python ints + return tuple(int(x) for x in data_tuple) def parse_fill_value(data: Any) -> Any: diff --git a/tests/test_common.py b/tests/test_common.py index 0944c3375a..0dedde1d6b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from typing import TYPE_CHECKING, get_args import numpy as np @@ -15,7 +16,6 @@ from zarr.core.config import parse_indexing_order if TYPE_CHECKING: - from collections.abc import Iterable from typing import Any, Literal @@ -115,9 +115,15 @@ def test_parse_shapelike_invalid_iterable_values(data: Any) -> None: parse_shapelike(data) -@pytest.mark.parametrize("data", [range(10), [0, 1, 2, 3], (3, 4, 5), ()]) -def test_parse_shapelike_valid(data: Iterable[int]) -> None: - assert parse_shapelike(data) == tuple(data) +@pytest.mark.parametrize( + "data", [range(10), [0, 1, 2, np.uint64(3)], (3, 4, 5), (), 1, np.uint8(1)] +) +def test_parse_shapelike_valid(data: Iterable[int] | int) -> None: + if isinstance(data, Iterable): + expected = tuple(data) + else: + expected = (data,) + assert parse_shapelike(data) == expected # todo: more dtypes