diff --git a/tests/test_core.py b/tests/test_core.py index 688eca6..fa40524 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -24,7 +24,7 @@ def ds(): ds = xr.Dataset( { 'x': xr.DataArray(np.arange(4) - 2, dims='x'), - 'foo': xr.DataArray(np.ones(4, dtype='i4'), dims='x'), + 'foo': xr.DataArray(np.ones(4, dtype='i4'), dims='x', attrs=dict(units='K')), 'bar': xr.DataArray(np.arange(8, dtype=np.float32).reshape(4, 2), dims=('x', 'y')), } ) @@ -222,29 +222,50 @@ def test_dataset_empty_constructor(): def test_dataset_example(ds): ds_schema = DatasetSchema( - { - 'foo': DataArraySchema(name='foo', dtype=np.int32, dims=['x']), + data_vars={ + 'foo': DataArraySchema( + name='foo', + dtype=np.int32, + dims=['x'], + attrs=AttrsSchema(attrs=dict(units=AttrSchema(value='K'))) + ), 'bar': DataArraySchema(name='bar', dtype=np.floating, dims=['x', 'y']), - } + }, + coords={'x': DataArraySchema(name='x', dtype=np.int64, dims=['x'])}, +<<<<<<< HEAD +======= + attrs={}, +>>>>>>> d3ba8807e2b69a26a98521c8ee927a13df0f0a5a ) jsonschema.validate(ds_schema.json, ds_schema._json_schema) assert list(ds_schema.json['data_vars'].keys()) == ['foo', 'bar'] + assert list(ds_schema.json['coords']['coords'].keys()) == ['x'] ds_schema.validate(ds) - ds['foo'] = ds.foo.astype('float32') + ds2 = ds.copy() + ds2['foo'] = ds2.foo.astype('float32') with pytest.raises(SchemaError, match='dtype'): - ds_schema.validate(ds) + ds_schema.validate(ds2) - ds = ds.drop_vars('foo') + ds2 = ds2.drop_vars('foo') with pytest.raises(SchemaError, match='variable foo'): - ds_schema.validate(ds) + ds_schema.validate(ds2) + + ds3 = ds.copy() + ds3['x'] = ds3.x.astype('float32') + with pytest.raises(SchemaError, match='dtype'): + ds_schema.validate(ds3) + + ds3 = ds3.drop_vars('x') + with pytest.raises(SchemaError, match='coords has missing keys'): + ds_schema.validate(ds3) # json roundtrip rt_schema = DatasetSchema.from_json(ds_schema.json) assert isinstance(rt_schema, DatasetSchema) - rt_schema.json == ds_schema.json + assert rt_schema.json == ds_schema.json def test_checks_ds(ds): @@ -271,7 +292,7 @@ def test_dataset_with_attrs_schema(): expected_value = 'expected_value' actual_value = 'actual_value' ds = xr.Dataset(attrs={name: actual_value}) - ds_schema = DatasetSchema(attrs={name: AttrSchema(value=expected_value)}) + ds_schema = DatasetSchema(dict(attrs={name: AttrSchema(value=expected_value)})) jsonschema.validate(ds_schema.json, ds_schema._json_schema) ds_schema_2 = DatasetSchema(attrs=AttrsSchema({name: AttrSchema(value=expected_value)})) diff --git a/xarray_schema/components.py b/xarray_schema/components.py index d212ddc..e56f71f 100644 --- a/xarray_schema/components.py +++ b/xarray_schema/components.py @@ -396,13 +396,12 @@ def __init__( @classmethod def from_json(cls, obj: dict): - attrs = {} - for key, val in obj['attrs'].items(): - attrs[key] = AttrSchema(**val) + attrs = obj.pop('attrs') if 'attrs' in obj else {} + attrs = {k: AttrSchema(**v) for k, v in attrs.items()} return cls( attrs, - require_all_keys=obj['require_all_keys'], - allow_extra_keys=obj['allow_extra_keys'], + require_all_keys=obj.get('require_all_keys'), + allow_extra_keys=obj.get('allow_extra_keys'), ) def validate(self, attrs: Any) -> None: diff --git a/xarray_schema/dataarray.py b/xarray_schema/dataarray.py index 4a71e62..7f3d8d0 100644 --- a/xarray_schema/dataarray.py +++ b/xarray_schema/dataarray.py @@ -149,7 +149,7 @@ def attrs(self, value): if value is None or isinstance(value, AttrsSchema): self._attrs = value else: - self._attrs = AttrsSchema(value) + self._attrs = AttrsSchema(**value) @property def coords(self) -> Optional[CoordsSchema]: @@ -213,7 +213,7 @@ def validate(self, da: xr.DataArray) -> None: if self.chunks is not None: self.chunks.validate(da.chunks, da.dims, da.shape) - if self.attrs: + if self.attrs is not None: self.attrs.validate(da.attrs) if self.array_type is not None: diff --git a/xarray_schema/dataset.py b/xarray_schema/dataset.py index 9be35ee..8146ef1 100644 --- a/xarray_schema/dataset.py +++ b/xarray_schema/dataset.py @@ -47,9 +47,9 @@ def from_json(cls, obj: dict): k: DataArraySchema.from_json(v) for k, v in obj['data_vars'].items() } if 'coords' in obj: - kwargs['coords'] = {k: CoordsSchema.from_json(v) for k, v in obj['coords'].items()} - if 'attrs' in obj: - kwargs['attrs'] = {k: AttrsSchema.from_json(v) for k, v in obj['attrs'].items()} + kwargs['coords'] = CoordsSchema.from_json(obj['coords']) + if 'attrs' in obj and obj['attrs'] != {}: + kwargs['attrs'] = AttrsSchema.from_json(obj['attrs']) return cls(**kwargs) @@ -79,10 +79,10 @@ def validate(self, ds: xr.Dataset) -> None: else: da_schema.validate(ds.data_vars[key]) - if self.coords is not None: # pragma: no cover - raise NotImplementedError('coords schema not implemented yet') + if self.coords is not None: + self.coords.validate(ds.coords) - if self.attrs: + if self.attrs is not None: self.attrs.validate(ds.attrs) if self.checks: @@ -98,7 +98,7 @@ def attrs(self, value: Union[AttrsSchema, Dict[Hashable, Any], None]): if value is None or isinstance(value, AttrsSchema): self._attrs = value else: - self._attrs = AttrsSchema(value) + self._attrs = AttrsSchema(**value) @property def data_vars(self) -> Optional[Dict[Hashable, Optional[DataArraySchema]]]: