Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Changed
:class:`imod.mf6.Evapotranspiration` are now optional variables. If provided,
now require ``"segment"`` dimension when ``proportion_depth`` and
``proportion_rate``.
- :meth:`imod.msw.GridData.generate_index_array` is now deprecated, use
:meth:`imod.msw.GridData.generate_isactive_svat_arrays` instead.


[1.0.0] - 2025-11-11
Expand Down
1 change: 1 addition & 0 deletions docs/api/msw.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Grid packages
GridData.from_imod5_data
GridData.get_regrid_methods
GridData.generate_index_array
GridData.generate_isactive_svat_arrays
GridData.write
Infiltration
Infiltration.regrid_like
Expand Down
54 changes: 33 additions & 21 deletions imod/msw/grid_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import xarray as xr

Expand Down Expand Up @@ -83,15 +85,14 @@ def __init__(

self._pkgcheck()

def generate_index_array(self) -> np.ndarray:
def _generate_isactive_array(self) -> xr.DataArray:
"""
Generate index array to be used on other packages.
Generate a 1D array of active cells to be used on other packages.

Returns
-------
np.ndarray
Index array and svat grid.
The index array is a 1D array with the index of the active cells.
A 1D array with which cells are active.
"""
area = self.dataset["area"]
active = self.dataset["active"]
Expand All @@ -101,36 +102,47 @@ def generate_index_array(self) -> np.ndarray:
# https://github.com/dask/dask/issues/11753
isactive.load()

index = isactive.values.ravel()

return index
return isactive

def generate_index_svat_array(self) -> tuple[np.ndarray, xr.DataArray]:
def generate_isactive_svat_arrays(self) -> tuple[np.ndarray, xr.DataArray]:
"""
Generate index array and svat grid to be used on other packages.

Returns
-------
tuple[np.ndarray, xr.DataArray]
Index array and svat grid.
The index array is a 1D array with the index of the active cells.
isactive array and svat grid.
The isactive array is a 1D array with which cells are active.
The svat grid is a 2D array with the SVAT numbers for each cell.
"""
index = self.generate_index_array()
isactive = self._generate_isactive_array()
isactive_1d = isactive.values.ravel()

area = self.dataset["area"]
active = self.dataset["active"]
isactive = area.where(active).notnull()
svat = xr.full_like(isactive, fill_value=0, dtype=np.int64).rename("svat")
svat.data[isactive.data] = np.arange(1, isactive_1d.sum() + 1)

svat = xr.full_like(area, fill_value=0, dtype=np.int64).rename("svat")
# Load into memory to avoid dask issue
# https://github.com/dask/dask/issues/11753
isactive.load()
svat.load()
return isactive_1d, svat

svat.data[isactive.data] = np.arange(1, index.sum() + 1)
def generate_index_array(self) -> tuple[np.ndarray, xr.DataArray]:
"""
This method is kept for backward compatibility, but will be removed in
future versions and will thus throw a deprecation warning. Use
:meth:`imod.msw.GridData.generate_isactive_svat_arrays` instead.

Generate index array and svat grid to be used on other packages.

return index, svat
Returns
-------
tuple[np.ndarray, xr.DataArray]
isactive array and svat grid.
The isactive array is a 1D array with which cells are active.
The svat grid is a 2D array with the SVAT numbers for each cell.
"""
warnings.warn(
"Method 'generate_index_array' is deprecated and will be removed in the future, use 'generate_isactive_svat_arrays' instead.",
DeprecationWarning,
)
return self.generate_isactive_svat_arrays()

def _pkgcheck(self):
super()._pkgcheck()
Expand Down
10 changes: 6 additions & 4 deletions imod/msw/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def write(
# Get index and svat
grid_key = self.get_pkgkey(GridData)
grid_pkg = cast(GridData, self[grid_key])
index, svat = grid_pkg.generate_index_svat_array()
index, svat = grid_pkg.generate_isactive_svat_arrays()

# write package contents
for pkgname in self:
Expand Down Expand Up @@ -565,11 +565,13 @@ def split(

for submodel_name, submodel in partitioned_submodels.items():
partition_info = submodel_to_partition[submodel_name]
sliced_grid_pkg = clip_by_grid(grid_pkg, partition_info.active_domain)
sliced_index = sliced_grid_pkg.generate_index_array()
sliced_grid_pkg = cast(
GridData, clip_by_grid(grid_pkg, partition_info.active_domain)
)
sliced_isactive = sliced_grid_pkg._generate_isactive_array().values

# Add package to model if it has data in the active domain.
if bool(sliced_index.any()):
if bool(sliced_isactive.any()):
is_in_active_domain[submodel_name] = True
submodel[grid_key] = sliced_grid_pkg
else:
Expand Down
10 changes: 5 additions & 5 deletions imod/tests/test_msw/test_grid_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_write(
xr.full_like(like, True, dtype=bool),
)

index, svat = grid_data.generate_index_svat_array()
index, svat = grid_data.generate_isactive_svat_arrays()

with tempfile.TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
Expand Down Expand Up @@ -293,12 +293,12 @@ def case_grid_data_two_subunits__dask(


@parametrize_with_cases("grid_data_dict", cases=".", has_tag="two_subunit")
def test_generate_index_svat_array(
def test_generate_isactive_svat_arrays(
grid_data_dict: dict[str, xr.DataArray], coords_two_subunit: dict
):
grid_data = GridData(**grid_data_dict)

index, svat = grid_data.generate_index_svat_array()
index, svat = grid_data.generate_isactive_svat_arrays()

index_expected = [
False,
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_generate_index_svat_array(
def test_simple_model(fixed_format_parser, grid_data_dict: dict[str, xr.DataArray]):
grid_data = GridData(**grid_data_dict)

index, svat = grid_data.generate_index_svat_array()
index, svat = grid_data.generate_isactive_svat_arrays()

with tempfile.TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
Expand Down Expand Up @@ -388,7 +388,7 @@ def test_simple_model_1_subunit(
):
grid_data = GridData(**grid_data_dict)

index, svat = grid_data.generate_index_svat_array()
index, svat = grid_data.generate_isactive_svat_arrays()

with tempfile.TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
Expand Down
Loading