Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 79 additions & 54 deletions packages/evaluate/src/weathergen/evaluate/io/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Standard library
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass

# Third-party
Expand Down Expand Up @@ -49,6 +50,8 @@ class DataAvailability:
List of forecast steps requested
samples:
List of samples requested
ensemle:
List of ensemble member identifiers
"""

score_availability: bool
Expand All @@ -58,33 +61,30 @@ class DataAvailability:
ensemble: list[str] | None = None


class Reader:
class Reader(ABC):
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None):
"""
Generic data reader class.

Parameters
----------
eval_cfg :
config with plotting and evaluation options for that run id
run_id :
run id of the model
eval_cfg : dict
Config with plotting and evaluation options for that run id.
run_id : str
Run identifier of the model
private_paths:
dictionary of private paths for the supported HPC
Dictionary of private paths for the supported HPC
"""
self.eval_cfg = eval_cfg
self.run_id = run_id
self.private_paths = private_paths
self.streams = eval_cfg.streams.keys()
self.streams = list(eval_cfg.streams.keys())
# TODO: propagate it to the other functions using global plotting opts
self.global_plotting_options = eval_cfg.get("global_plotting_options", {})

# If results_base_dir and model_base_dir are not provided, default paths are used
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)

self.results_base_dir = self.eval_cfg.get(
"results_base_dir", None
) # base directory where results will be stored
# Default paths if not provided
self.model_base_dir = eval_cfg.get("model_base_dir")
self.results_base_dir = eval_cfg.get("results_base_dir")

def get_stream(self, stream: str):
"""
Expand All @@ -102,22 +102,26 @@ def get_stream(self, stream: str):
"""
return self.eval_cfg.streams.get(stream, {})

@abstractmethod
def get_samples(self) -> set[int]:
"""Placeholder implementation of sample getter. Override in subclass."""
return set()
pass

@abstractmethod
def get_forecast_steps(self) -> set[int]:
"""Placeholder implementation forecast step getter. Override in subclass."""
return set()
pass

# TODO: get this from config
@abstractmethod
def get_channels(self, stream: str | None = None) -> list[str]:
"""Placeholder implementation channel names getter. Override in subclass."""
return list()
pass

@abstractmethod
def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Placeholder implementation ensemble member names getter. Override in subclass."""
return list()
pass

def is_gridded_data(self, stream: str) -> bool:
"""
Expand All @@ -126,41 +130,46 @@ def is_gridded_data(self, stream: str) -> bool:
"""
return True

@abstractmethod
def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
"""Placeholder to load pre-computed scores for a given run, stream, metric"""
return None
pass

def check_availability(
self,
stream: str,
available_data: dict | None = None,
mode: str = "",
mode: str = "evaluation",
) -> DataAvailability:
"""
Check if requested channels, forecast steps and samples are
i) available in the previously saved metric file if specified (return False otherwise)
ii) available in the source file (e.g. the Zarr file, return error otherwise)
Additionally, if channels, forecast steps or samples is None/'all', it will
i) set the variable to all available vars in source file
ii) return True only if the respective variable contains the same indeces in metric file
and source file (return False otherwise)
ii) return True only if the respective variable contains the same indices in
metric file and source file (return False otherwise)

Parameters
----------
stream :
stream : str
The stream considered.
available_data :
The available data loaded from metric file.
available_data : dict or None
Available data loaded from metric file.
mode : str
Mode string. Can be 'evaluation' or 'plotting'.

Returns
-------
DataAvailability
A dataclass containing:
- channels: list of channels or None if 'all'
- fsteps: list of forecast steps or None if 'all'
- samples: list of samples or None if 'all'
- ensemble: list of ensembleor None if 'all'
"""

# fill info for requested channels, fsteps, samples
# Fill requested info for channels, fsteps, samples, ensemble
requested_data = self._get_channels_fsteps_samples(stream, mode)

channels = requested_data.channels
Expand All @@ -174,7 +183,7 @@ def check_availability(
"ensemble": set(ensemble) if ensemble is not None else None,
}

# fill info from available metric file (if provided)
# Extract available info from metric file (if provided)
available = {
"channel": (
set(available_data["channel"].values.ravel())
Expand All @@ -193,12 +202,12 @@ def check_availability(
),
"ensemble": (
set(available_data["ens"].values.ravel())
if available_data is not None and "ens" in available_data.coords
if (available_data is not None and "ens" in available_data.coords)
else set()
),
}

# fill info from reader
# Extract actual reader data (from source)
reader_data = {
"fstep": set(int(f) for f in self.get_forecast_steps()),
"sample": set(int(s) for s in self.get_samples()),
Expand All @@ -208,24 +217,27 @@ def check_availability(

check_score = True
corrected = False

for name in ["channel", "fstep", "sample", "ensemble"]:
if requested[name] is None:
# Default to all in Zarr
requested[name] = reader_data[name]
# If file with metrics exists, must exactly match
if available_data is not None and reader_data[name] != available[name]:
_logger.info(
f"Requested all {name}s for {mode}, but previous config was a "
"strict subset. Recomputation required."
f"Requested all {name}s for {mode}, but previous config "
"was a strict subset. Recomputation required."
)
check_score = False

# Must be subset of Zarr
if not requested[name] <= reader_data[name]:
missing = requested[name] - reader_data[name]

# Special handling for ensemble mean
if name == "ensemble" and "mean" in missing:
missing.remove("mean")

if missing:
_logger.info(
f"Requested {name}(s) {missing} is unavailable. "
Expand All @@ -238,8 +250,8 @@ def check_availability(
if available_data is not None and not requested[name] <= available[name]:
missing = requested[name] - available[name]
_logger.info(
f"{name.capitalize()}(s) {missing} missing in previous evaluation."
"Recomputation required."
f"{name.capitalize()}(s) {missing} missing in previous "
"evaluation. Recomputation required."
)
check_score = False

Expand Down Expand Up @@ -278,40 +290,53 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
- fsteps: list of forecast steps or None if 'all'
- samples: list of samples or None if 'all'
"""
assert mode == "plotting" or mode == "evaluation", (
"get_channels_fsteps_samples:: Mode should be either 'plotting' or 'evaluation'"

# Helper function to process range strings like '1-3' into lists [1,2,3]
def _parse_range_list(value, name):
if isinstance(value, str) and value != "all":
assert re.match(r"^\d+-\d+$", value), (
f"String format for {name} in config must be "
f"'digit-digit' or 'all'. "
f"Got '{value}'."
)
start, end = map(int, value.split("-"))
return list(range(start, end + 1))
return value

# Normalize None vs "all"
def normalize(val):
return (
None
if (val == "all" or val is None)
else list(val)
if isinstance(val, list)
else val
)

assert mode in ("plotting", "evaluation"), (
f"Mode must be either 'plotting' or 'evaluation'. Got '{mode}' instead."
)

stream_cfg = self.get_stream(stream)
assert stream_cfg.get(mode, False), "Mode does not exist in stream config. Please add it."
assert stream_cfg.get(mode, False), (
f"Mode '{mode}' does not exist in stream config for '{stream}'. Please add it."
)

samples = stream_cfg[mode].get("sample", None)
fsteps = stream_cfg[mode].get("forecast_step", None)
channels = stream_cfg.get("channels", None)
ensemble = stream_cfg[mode].get("ensemble", None)

if ensemble == "mean":
ensemble = ["mean"]

if isinstance(fsteps, str) and fsteps != "all":
assert re.match(r"^\d+-\d+$", fsteps), (
"String format for forecast_step in config must be 'digit-digit' or 'all'"
)
fsteps = list(range(int(fsteps.split("-")[0]), int(fsteps.split("-")[1]) + 1))
if isinstance(samples, str) and samples != "all":
assert re.match(r"^\d+-\d+$", samples), (
"String format for sample in config must be 'digit-digit' or 'all'"
)
samples = list(range(int(samples.split("-")[0]), int(samples.split("-")[1]) + 1))
if isinstance(ensemble, str) and ensemble not in {"all", "mean"}:
assert re.match(r"^\d+-\d+$", ensemble), (
"String format for sample in config must be 'digit-digit' or 'all'"
)
ensemble = list(range(int(ensemble.split("-")[0]), int(ensemble.split("-")[1]) + 1))
fsteps = _parse_range_list(fsteps, "forecast_step")
samples = _parse_range_list(samples, "sample")

return DataAvailability(
score_availability=True,
channels=None if (channels == "all" or channels is None) else list(channels),
fsteps=None if (fsteps == "all" or fsteps is None) else list(fsteps),
samples=None if (samples == "all" or samples is None) else list(samples),
ensemble=None if (ensemble == "all" or ensemble is None) else list(ensemble),
channels=normalize(channels),
fsteps=normalize(fsteps),
samples=normalize(samples),
ensemble=normalize(ensemble),
)
23 changes: 12 additions & 11 deletions packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,6 @@ def get_data(
if self.is_gridded_data(stream):
vt_list = np.unique(target.valid_time.values).tolist()
valid_times_fs.append(vt_list)
else:
valid_times_fs.append(fstep)

da_tars_fs.append(target.persist())
da_preds_fs.append(pred.persist())
Expand All @@ -462,8 +460,7 @@ def get_data(
)
continue

# fsteps_final.extend(valid_times_fs)
fsteps_final.append(valid_times_fs)
fsteps_final.append(valid_times_fs if valid_times_fs else fstep)

_logger.debug(
f"Concatenating targets and predictions for stream {stream}, "
Expand All @@ -480,11 +477,15 @@ def get_data(
da_preds_fs = _force_consistent_grids(da_preds_fs)
else:
# Irregular (scatter) case. concatenate over ipoint
da_tars_fs = [xr.concat(da_tars_fs, dim="ipoint", coords="minimal")]
da_preds_fs = [xr.concat(da_preds_fs, dim="ipoint", coords="minimal")]
da_tars_fs = xr.concat(
da_tars_fs, dim="ipoint", coords="different", compat="equals"
)
da_preds_fs = xr.concat(
da_preds_fs, dim="ipoint", coords="different", compat="equals"
)

da_tars.append([da for da in da_tars_fs])
da_preds.append([da for da in da_preds_fs])
da_tars.append(da_tars_fs)
da_preds.append(da_preds_fs)

# Safer than a list
da_tars_dict, da_preds_dict = {}, {}
Expand Down Expand Up @@ -616,9 +617,9 @@ def is_gridded_data(self, stream: str) -> bool:
):
_logger.debug("Latitude and/or longitude coordinates are not regularly spaced.")
return False

_logger.debug("Latitude and longitude coordinates are regularly spaced.")
return True
else:
_logger.debug("Latitude and longitude coordinates are regularly spaced.")
return True


################### Helper functions ########################
Expand Down
Loading