Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions ocf_data_sampler/load/nwp/nwp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module for opening NWP data."""

from collections.abc import Callable

import numpy as np
import xarray as xr

Expand All @@ -10,6 +12,16 @@
from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
from ocf_data_sampler.load.nwp.providers.ukv import open_ukv

_OPEN_NWP_FUNCTIONS: dict[str, Callable[..., xr.DataArray]] = {
"ukv": open_ukv,
"ecmwf": open_ifs,
"mo_global": open_ifs,
"icon-eu": open_icon_eu,
"gencast": open_gdm,
"gfs": open_gfs,
"cloudcasting": open_cloudcasting,
}


def _validate_nwp_data(data_array: xr.DataArray, provider: str) -> None:
"""Validates the structure and data types of a loaded NWP DataArray.
Expand All @@ -25,7 +37,9 @@ def _validate_nwp_data(data_array: xr.DataArray, provider: str) -> None:
ValueError: If a required coordinate is missing.
"""
if not np.issubdtype(data_array.dtype, np.number):
raise TypeError(f"NWP data for {provider} should be numeric, not {data_array.dtype}")
raise TypeError(
f"NWP data for {provider} should be numeric, not {data_array.dtype}",
)

common_expected_dtypes = {
"init_time_utc": np.datetime64,
Expand Down Expand Up @@ -90,28 +104,16 @@ def open_nwp(
"""
provider = provider.lower()

kwargs = {
"zarr_path": zarr_path,
}
if provider == "ukv":
_open_nwp = open_ukv
elif provider in ["ecmwf", "mo_global"]:
_open_nwp = open_ifs
elif provider == "icon-eu":
_open_nwp = open_icon_eu
elif provider == "gencast":
_open_nwp = open_gdm
elif provider == "gfs":
_open_nwp = open_gfs
# GFS has a public/private flag
if public:
kwargs["public"] = True
elif provider == "cloudcasting":
_open_nwp = open_cloudcasting
else:
raise ValueError(f"Unknown provider: {provider}")

data_array = _open_nwp(**kwargs)
_validate_nwp_data(data_array, provider)
if provider not in _OPEN_NWP_FUNCTIONS:
supported = ", ".join(sorted(_OPEN_NWP_FUNCTIONS.keys()))
raise ValueError(f"Unknown provider: {provider!r}. Supported: {supported}")

opener = _OPEN_NWP_FUNCTIONS[provider]

kwargs = {"zarr_path": zarr_path}
if provider == "gfs" and public:
kwargs["public"] = True

data_array = opener(**kwargs)
_validate_nwp_data(data_array, provider)
return data_array