Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/bundled_models/persistence/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true -diff
4 changes: 4 additions & 0 deletions packages/bundled_models/persistence/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# pixi environments
.pixi/*
!.pixi/config.toml
report.xml
45 changes: 45 additions & 0 deletions packages/bundled_models/persistence/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Persistence Model for use with the PyEarthTools Package

**TODO: description**

## Installation

Clone the repository, then run
```shell
pip install -e .
```

## Training

No training is required for this model. It computes persistence on-the-fly using historical data loaded via the PET pipeline.

## Predictions / Inference

You can generate persistence values out of the box using the `pet predict` command line API, or by using a Jupyter Notebook as demonstrated in the tutorial gallery.

```shell
pet predict
```

and `Development/Persistence` should be visible.

If so, you can now run some inference.

```shell
pet predict --model Development/Persistence
```

When running the command, it will prompt for other required arguments.

**TODO: description of required arguments**


#### Example

```shell
pet predict --model Development/Persistence # TODO
```

## Acknowledgments

Not applicable. Heuristically developed.
3,369 changes: 3,369 additions & 0 deletions packages/bundled_models/persistence/pixi.lock

Large diffs are not rendered by default.

92 changes: 92 additions & 0 deletions packages/bundled_models/persistence/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "pyearthtools-bundled-persistence"
version = "0.6.0"
description = "Persistence Bundled Model"
readme = "README.md"
requires-python = ">=3.11, <3.14"
keywords = ["persistence", "pyearthtools", "models"]
maintainers = [
{name = "Tennessee Leeuwenburg", email = "tennessee.leeuwenburg@bom.gov.au"},
{name = "Nikeeth Ramanathan", email = "nikeeth.ramanathan@gmail.com"},
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
'pyearthtools.zoo>=0.5.0',
'pyearthtools.data>=0.5.0',
'pyearthtools.pipeline>=0.5.0',
'hydra-core',
]
[dependency-groups]
dev = [
"pytest>=8.4.2",
"ruff",
"pytest-cov",
"pytest-xdist",
]

[project.urls]
homepage = "https://pyearthtools.readthedocs.io/"
documentation = "https://pyearthtools.readthedocs.io/"
repository = "https://github.com/ACCESS-Community-Hub/PyEarthTools"

[project.entry-points."pyearthtools.zoo.model"]
Global_PERSIST = "persistence.registered_model:Persistence"

[tool.isort]
profile = "black"

[tool.black]
line-length = 120

[tool.mypy]
warn_return_any = true
warn_unused_configs = true

[[tool.mypy.overrides]]
ignore_missing_imports = true

[tool.hatch.version]
# TODO: is this the right path?
path = "src/pyearthtools/pipeline/__init__.py"

[tool.hatch.build.targets.wheel]
packages = ["src/pyearthtools/"]

[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]

[tool.pixi.pypi-dependencies]
pyearthtools-bundled-persistence = { path = ".", editable = true }

[tool.pixi.tasks]

[tool.pixi.dependencies]
python = ">=3.11,<3.14"
xarray = ">=2026.1.0,<2027"

[tool.pixi.feature.testing.dependencies]
pytest = ">=9.0.2,<10"
pytest-cov = ">=7.0.0,<8"
pytest-xdist = ">=3.8.0,<4"
ruff = ">=0.15.0,<0.16"
ipython = ">=9.10.0,<10"

[tool.pixi.feature.dask.dependencies]
dask-core = "*"
distributed = "*"
pyarrow = ">=23.0.0,<24"

[tool.pixi.environments]
dask = ["dask"]
dev = ["dask", "testing"]
15 changes: 15 additions & 0 deletions packages/bundled_models/persistence/src/persistence/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from persistence._interface import (
PersistenceMethod,
PersistenceDataChunk,
PersistenceChunker,
)

from persistence._impute import SimpleImpute
from persistence._datatypes import PetDataset

__all__ = [
"PersistenceMethod",
"PersistenceDataChunk",
"PersistenceChunker",
"SimpleImpute",
]
41 changes: 41 additions & 0 deletions packages/bundled_models/persistence/src/persistence/_daskconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from contextlib import contextmanager


# default scheduler string to set "single-threaded" mode.
_STR_DASK_SYNC_SCHEDULER = "synchronous"


@contextmanager
def _set_synchronous_dask():
"""
Wrapper to set `dask` to single-threaded mode. Note: "single-threaded" in `dask`-land
(specifically) is the same as "synchronous".

This handles the case where dask is _not_ installed. In which case it does a pass-through.

IMPORTANT: never nest this context manager or call dask.config.reset() or attempt to update any
configs inside this context. Doing so may invalidate the "synchronous" setting.

Example:
def do_stuff(...):
# I can now (optionally) fork other processes here - without confusing dask.
# IMPORTANT: I shouldn't try to reintroduce parallelism using dask here
...

with _set_synchronous_dask():
do_stuff(...)
"""
try:
# this import order is important for the "distributed" configs to be recognized
import dask
import dask.config

# NOTE: if you don't have dask.distributed, this setting may not work as intended.
# so you will have to manually deal with it in the compute level.
import dask.distributed

# set state to desired config
with dask.config.set(scheduler=_STR_DASK_SYNC_SCHEDULER):
yield
except ImportError:
yield
90 changes: 90 additions & 0 deletions packages/bundled_models/persistence/src/persistence/_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Common data array/set transformations supported by the persistence model, the main usecase is to map
a function to each data variable independently. This is a common pattern as more often than not we
wouldn't be intermixing variables in basic pre-processing steps.

TODO: this should be somewhere more common
"""

from typing import Union, Generic
from collections.abc import Callable
import xarray as xr
import numpy as np
import numpy.typing as npt

PetDataArrayLike = Union[xr.DataArray, xr.Dataset, npt.ArrayLike]


class PetDataset:
def __init__(
self,
arraylike: PetDataArrayLike,
dummy_varname="_dummyvarname", # used for xarray dataarrays and numpy arrays
dimnames: list[str] = None, # used only for numpy arrays
):
"""
Takes a PetDataArrayLike and converts it to a PetDataset which is compatible with the
`map_each_var` computation.

`dimnames` is only relevant for numpy - and only if using name-based indexing for retrieving
e.g. time dimension
"""
self.ds = PetDataset.from_arrlike(arraylike, dummy_varname, dimnames)

@staticmethod
def from_np_array(arraylike: npt.ArrayLike, dummy_varname, dimnames) -> xr.Dataset:
return PetDataset.from_xr_dataarray(
xr.DataArray(np.asarray(arraylike), dims=dimnames), dummy_varname
)

@staticmethod
def from_xr_dataarray(arraylike: xr.DataArray, dummy_varname) -> xr.Dataset:
return xr.Dataset({dummy_varname: arraylike})

@staticmethod
def from_xr_dataset(arraylike: xr.Dataset) -> xr.Dataset:
return arraylike

@staticmethod
def from_arrlike(arraylike, dummy_varname, dimnames) -> xr.Dataset:
# Order is important here, For example:
# xr.DataArray may be a npt.ArrayLike, but not the other way around. If we swap the order,
# the xr.DataArray constructor will never be reached.

msg_type_error = """
The provided data does not have a supported array type, supported array types are:
xr.DataArray, xr.Dataset and np.ndarray.
"""

if isinstance(arraylike, xr.Dataset):
return PetDataset.from_xr_dataset(arraylike)

if isinstance(arraylike, xr.DataArray):
return PetDataset.from_xr_dataarray(arraylike, dummy_varname)

if isinstance(arraylike, (np.ndarray, list, tuple)):
return PetDataset.from_np_array(arraylike, dummy_varname, dimnames)

# unsupported type
raise TypeError(msg_type_error)

def map_each_var(
self, _fn: Callable[[xr.DataArray, ...], xr.DataArray], *_fn_args, **_fn_kwargs
) -> xr.Dataset:
"""
Applies a function over each data array in the dataset. The return type will be dataset.

The return type of each function operation itself will be per variable (dataarray).

Only functions that have common structure associated to the variables in the Dataset will
work properly.

IMPORTANT: global attributes and special variables may not be preserved. This operation is
destructive and for intermediate computation purposes only.
"""
dict_res = {}

for k_var, v_da in self.ds.data_vars.items():
dict_res[k_var] = _fn(v_da, *_fn_args, **_fn_kwargs)

return xr.Dataset(dict_res)
31 changes: 31 additions & 0 deletions packages/bundled_models/persistence/src/persistence/_impute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
This module handles imputation of missing data using very simple techniques.

Only mean is currently supported.
"""

from dataclasses import dataclass
import numpy as np


@dataclass(frozen=True)
class SimpleImpute:
arr: np.ndarray

def impute_mean(self) -> np.ndarray:
"""
To keep the imputation representative of the data but yet simple we can do a simple
mean interpolation over the data slab.

NOTE: This is non-deterministic depending on the data chunking strategy.
"""
nanmask = np.isnan(self.arr)
if not nanmask.any() or nanmask.all():
# if nothing is missing or everything is missing, return the original array as-is
return self.arr
else:
# otherwise, replace missing values with the mean of the slab
# NOTE: the following flattens the array by default if axis isn't specified
fillval = np.nanmean(self.arr)
arr_imputed = np.where(nanmask, fillval, self.arr)
return arr_imputed
Loading
Loading