diff --git a/changes/3695.bugfix.md b/changes/3695.bugfix.md new file mode 100644 index 0000000000..a7d847e4f1 --- /dev/null +++ b/changes/3695.bugfix.md @@ -0,0 +1 @@ +Raise error when trying to encode :class:`numpy.dtypes.StringDType` with `na_object` set. \ No newline at end of file diff --git a/src/zarr/core/dtype/npy/string.py b/src/zarr/core/dtype/npy/string.py index 41d3a60078..904280a330 100644 --- a/src/zarr/core/dtype/npy/string.py +++ b/src/zarr/core/dtype/npy/string.py @@ -742,6 +742,43 @@ class VariableLengthUTF8(UTF8Base[np.dtypes.StringDType]): # type: ignore[type- dtype_cls = np.dtypes.StringDType + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create an instance of this data type from a compatible NumPy data type. + We reject NumPy StringDType instances that have the `na_object` field set, + because this is not representable by the Zarr `string` data type. + + Parameters + ---------- + dtype : TBaseDType + The native data type. + + Returns + ------- + Self + An instance of this data type. + + Raises + ------ + DataTypeValidationError + If the input is not compatible with this data type. + ValueError + If the input is `numpy.dtypes.StringDType` and has `na_object` set. + """ + if cls._check_native_dtype(dtype): + if hasattr(dtype, "na_object"): + msg = ( + f"Zarr data type resolution from {dtype} failed. " + "Attempted to resolve a zarr data type from a `numpy.dtypes.StringDType` " + "with `na_object` set, which is not supported." + ) + raise ValueError(msg) + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + def to_native_dtype(self) -> np.dtypes.StringDType: """ Create a NumPy string dtype from this VariableLengthUTF8 ZDType. diff --git a/src/zarr/core/dtype/registry.py b/src/zarr/core/dtype/registry.py index cb9ab50044..315945cf4e 100644 --- a/src/zarr/core/dtype/registry.py +++ b/src/zarr/core/dtype/registry.py @@ -161,6 +161,10 @@ def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]: raise ValueError(msg) matched: list[ZDType[TBaseDType, TBaseScalar]] = [] for val in self.contents.values(): + # DataTypeValidationError means "this dtype doesn't match me", which is + # expected and suppressed. Other exceptions (e.g. ValueError for a dtype + # that matches the type but has an invalid configuration) are propagated + # to the caller. with contextlib.suppress(DataTypeValidationError): matched.append(val.from_native_dtype(dtype)) if len(matched) == 1: diff --git a/tests/test_dtype_registry.py b/tests/test_dtype_registry.py index 58b14fe07a..b7ceb502b7 100644 --- a/tests/test_dtype_registry.py +++ b/tests/test_dtype_registry.py @@ -15,9 +15,11 @@ get_data_type_from_json, ) from zarr.core.dtype.common import unpack_dtype_json +from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.dtype import ( # type: ignore[attr-defined] Bool, FixedLengthUTF32, + VariableLengthUTF8, ZDType, data_type_registry, parse_data_type, @@ -74,6 +76,16 @@ def test_match_dtype( data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls) assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls) + @pytest.mark.skipif(not _NUMPY_SUPPORTS_VLEN_STRING, reason="requires numpy with T dtype") + @staticmethod + def test_match_dtype_string_na_object_error( + data_type_registry_fixture: DataTypeRegistry, + ) -> None: + data_type_registry_fixture.register(VariableLengthUTF8._zarr_v3_name, VariableLengthUTF8) # type: ignore[arg-type] + dtype: np.dtype[Any] = np.dtypes.StringDType(na_object=None) # type: ignore[call-arg] + with pytest.raises(ValueError, match=r"Zarr data type resolution from StringDType.*failed"): + data_type_registry_fixture.match_dtype(dtype) + @staticmethod def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None: """