From 7efa3677b2b5a212d820fa61f2b1a28148a5339f Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 5 Feb 2026 08:05:54 +0100 Subject: [PATCH 01/20] Add/improve/correct typehints --- src/weathergen/datasets/masking.py | 4 +-- .../datasets/multi_stream_data_sampler.py | 27 ++++++++++------ src/weathergen/datasets/tokenizer_masking.py | 8 +++-- src/weathergen/datasets/tokenizer_utils.py | 32 ++++++++++++------- src/weathergen/model/model.py | 2 +- src/weathergen/train/loss_calculator.py | 2 +- .../train/target_and_aux_module_base.py | 11 +++++-- src/weathergen/train/trainer.py | 13 +++++--- src/weathergen/utils/validation_io.py | 16 +++++++++- 9 files changed, 77 insertions(+), 38 deletions(-) diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f84111541..fb8615670 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -259,7 +259,7 @@ def build_samples_for_stream( num_cells: int, stage_cfg: dict, stream_cfg: dict, - ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: + ) -> tuple[MaskData, MaskData, np.typing.NDArray[np.int32]]: """ Construct teacher/student keep masks for a stream. SampleMetaData is currently just a dict with the masking params used. @@ -355,7 +355,7 @@ def build_samples_for_stream( source_target_mapping += [target_idx] i_source += 1 - source_target_mapping = np.array(source_target_mapping, dtype=np.int32) + source_target_mapping: NDArray[np.int32] = np.array(source_target_mapping, dtype=np.int32) return (target_masks, source_masks, source_target_mapping) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 179333b8e..556f6942f 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -9,6 +9,7 @@ import logging import pathlib +import typing import numpy as np import torch @@ -24,7 +25,7 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs -from weathergen.datasets.masking import Masker +from weathergen.datasets.masking import MaskData, Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_masking import TokenizerMasking from weathergen.datasets.utils import ( @@ -521,7 +522,9 @@ def _build_stream_data( return stream_data - def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, stream_ds): + def _get_data_windows( + self, base_idx, num_forecast_steps, num_steps_input_max, stream_ds + ) -> tuple[list[IOReaderData], list[IOReaderData]]: """ Collect all data needed for current stream to potentially amortize costs by generating multiple samples @@ -578,7 +581,7 @@ def _get_source_target_masks(self, training_mode): Generate source and target masks for all streams """ - masks = {} + masks: dict[str, tuple[MaskData, MaskData, np.NDArray[np.int32]]] = {} for stream_info in self.streams: # Build source and target sample masks masks[stream_info["name"]] = self.tokenizer.build_samples_for_stream( @@ -587,7 +590,7 @@ def _get_source_target_masks(self, training_mode): self.mode_cfg, stream_info, ) - # identical for all streams + # shape identical for all streams: #target/source config x num_samples num_target_samples = len(masks[stream_info["name"]][0]) num_source_samples = len(masks[stream_info["name"]][1]) @@ -619,12 +622,15 @@ def _get_batch(self, idx: int, num_forecast_steps: int): mode = self.mode_cfg.get("training_mode") source_cfgs = self.mode_cfg.get("model_input") - target_cfgs = self.mode_cfg.get("target_input", {}) + target_cfgs: typing.Mapping = self.mode_cfg.get("target_input", {}) # get/coordinate masks masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) - source_select, target_select = [], [] + # contain string flags + source_select: list[str] = [] + target_select: list[str] = [] + if "masking" in mode: source_select += ["network_input", "target_coords"] target_select += ["target_values"] @@ -649,6 +655,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for stream_info, (stream_name, stream_ds) in zip( self.streams, self.streams_datasets.items(), strict=True ): + stream_info: Config (target_masks, source_masks, source_to_target) = masks_streams[stream_name] # max number of input steps @@ -672,7 +679,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for sidx, source_mask in enumerate(source_masks.masks): # Map each source to its target - tidx = source_to_target[sidx].item() + tidx: int = source_to_target[sidx].item() sdata = self._build_stream_data( source_select, tidx, @@ -715,14 +722,14 @@ def _get_batch(self, idx: int, num_forecast_steps: int): ] batch.add_target_stream(tidx, student_indices, stream_name, sdata, target_metadata) - source_in_steps = input_steps.max().item() + source_input_steps: int = input_steps.max().item() target_in_steps = np.array([tc.get("num_steps_input", 1) for _, tc in target_cfgs.items()]) - target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() + target_in_steps: int = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) return batch - def __iter__(self) -> ModelBatch: + def __iter__(self) -> typing.Generator[ModelBatch, typing.Any, typing.Any]: """ Return one batch of data diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 6dfe71c89..375143c21 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -11,8 +11,8 @@ import numpy as np import torch +from weathergen.common.config import Config from weathergen.common.io import IOReaderData -from weathergen.datasets.batch import SampleMetaData from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer from weathergen.datasets.tokenizer_utils import ( @@ -53,7 +53,9 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng - def get_tokens_windows(self, stream_info, data, pad_tokens): + def get_tokens_windows( + self, stream_info: Config, data: list[IOReaderData], pad_tokens: bool + ) -> list[tuple[list[list[torch.Tensor | None]], list[list[int]]]]: """ Tokenize data (to amortize over the different views that are generated) @@ -83,7 +85,7 @@ def build_samples_for_stream( num_cells: int, stage_cfg: dict, stream_cfg: dict, - ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: + ): """ Create masks for samples """ diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 3f8f11b9d..1de41ae5e 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -1,3 +1,5 @@ +import typing + import numpy as np import pandas as pd import torch @@ -5,6 +7,7 @@ from torch import Tensor from weathergen.common.io import IOReaderData +from weathergen.datasets.data_reader_base import NPDT64 from weathergen.datasets.utils import ( locs_to_cell_coords_ctrs, locs_to_ctr_coords, @@ -117,7 +120,7 @@ def hpy_cell_splits(coords: torch.tensor, hl: int): def hpy_splits( coords: torch.Tensor, hl: int, token_size: int, pad_tokens: bool -) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor]: +) -> tuple[list[list[torch.Tensor | None]], list[list[int]]]: """Compute healpix cell for each data point and splitting information per cell; when the token_size is exceeded then splitting based on lat is used; tokens can be padded @@ -168,10 +171,10 @@ def hpy_splits( def tokenize_space( - rdata, - token_size, - hl, - pad_tokens=True, + rdata: IOReaderData, + token_size: int, + hl: int, + pad_tokens: bool = True, ): """Process one window into tokens""" @@ -182,7 +185,7 @@ def tokenize_space( def tokenize_spacetime( - rdata, + rdata: IOReaderData, token_size, hl, pad_tokens=True, @@ -192,11 +195,12 @@ def tokenize_spacetime( """ num_healpix_cells = 12 * 4**hl - idxs_cells = [[] for _ in range(num_healpix_cells)] - idxs_cells_lens = [[] for _ in range(num_healpix_cells)] + idxs_cells: list[list[Tensor | None]] = [[] for _ in range(num_healpix_cells)] + idxs_cells_lens: list[list[int]] = [[] for _ in range(num_healpix_cells)] - t_unique = np.unique(rdata.datetimes) + t_unique: typing.iterable[NPDT64] = np.unique(rdata.datetimes) for _, t in enumerate(t_unique): + t: NPDT64 # data for current time step mask = t == rdata.datetimes rdata_cur = IOReaderData( @@ -204,9 +208,12 @@ def tokenize_spacetime( ) idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens) - # collect data for all time steps - idxs_cells = [t + tc for t, tc in zip(idxs_cells, idxs_cur, strict=True)] - idxs_cells_lens = [t + tc_l for t, tc_l in zip(idxs_cells_lens, idxs_cur_lens, strict=True)] + # append tokens/n_tokens for current time step for each healpix cell to existing tokens + for cell_steps_tokens, cell_steps_tokens_cur, cell_steps_lens, cell_steps_lens_cur in zip( + idxs_cells, idxs_cur, idxs_cells_lens, idxs_cur_lens, strict=True + ): + cell_steps_tokens.extend(cell_steps_tokens_cur) + cell_steps_lens.extend(cell_steps_lens_cur) return idxs_cells, idxs_cells_lens @@ -382,6 +389,7 @@ def return_empty(rdata, idxs_cells_lens): else: coords_local = torch.tensor([]) + # geoinfos information is empedded into coords_local here return data, datetimes, coords, coords_local, masked_points_per_cell diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 273c28838..e4271bcc3 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -50,7 +50,7 @@ class ModelOutput: Representation of model output """ - physical: list[dict[StreamName, torch.Tensor]] + physical: list[dict[StreamName, tuple[torch.Tensor]]] latent: list[dict[str, torch.Tensor | LatentState]] def __init__(self, len_output: int) -> None: diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2f81940a3..5c9a38409 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -80,7 +80,7 @@ def __init__( def compute_loss( self, preds: ModelOutput, - targets_and_aux: TargetAuxOutput, + targets_and_aux: dict[str, TargetAuxOutput], metadata: dict, ): losses_all = defaultdict(dict) diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index bb39d1b17..b4733a983 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -13,7 +13,7 @@ import torch -from weathergen.model.engines import LatentState +from weathergen.datasets.batch import SampleMetaData type StreamName = str @@ -26,8 +26,13 @@ class TargetAuxOutput: output_idxs: list[int] - physical: list[dict[StreamName, torch.Tensor]] - latent: list[dict[str, torch.Tensor | LatentState]] + physical: list[dict[StreamName, dict[str, torch.Tensor]]] + latent: list[ + dict[ + StreamName, + dict[str, list[torch.Tensor] | list[dict[str, SampleMetaData]] | list[bool]], + ] + ] aux_outputs: dict[str, torch.Tensor] def __init__(self, len_target: int, output_idxs: list) -> None: diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index bfa289c7b..888117fd8 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -11,6 +11,7 @@ import copy import logging import time +from collections.abc import Iterable import numpy as np import torch @@ -22,6 +23,7 @@ import weathergen.common.config as config from weathergen.common.config import Config +from weathergen.datasets.batch import ModelBatch from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( @@ -32,6 +34,7 @@ from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( TRAIN, @@ -75,8 +78,8 @@ def __init__(self, train_log_freq: Config): self.model_params = None self.optimizer: torch.optim.Optimizer | None = None self.t_start: float = 0 - self.target_and_aux_calculators = None - self.target_and_aux_calculators_val = None + self.target_and_aux_calculators: dict[str, TargetAndAuxModuleBase] = None + self.target_and_aux_calculators_val: dict[str, TargetAndAuxModuleBase] = None self.validate_with_ema_cfg = None self.validate_with_ema: bool = False self.batch_size_per_gpu = -1 @@ -168,7 +171,7 @@ def get_target_aux_calculators(self, mode_cfg): # get target_aux calculators for different loss terms # del self.cf.training_config.losses["student-teacher"]["loss_fcts"]["JEPA"] # del mode_cfg.losses["student-teacher"]["loss_fcts"]["JEPA"] - target_and_aux_calculators = {} + target_and_aux_calculators: dict[str, TargetAndAuxModuleBase] = {} for loss_name, loss_cfg in mode_cfg.losses.items(): target_and_aux_calculators[loss_name] = get_target_aux_calculator( self.cf, loss_cfg, self.dataset, self.model, self.device, batch_size @@ -549,7 +552,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): cf = self.cf self.model.eval() - dataset_val_iter = iter(self.data_loader_validation) + dataset_val_iter: Iterable[ModelBatch] = iter(self.data_loader_validation) num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size @@ -580,7 +583,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.get_source_samples(), ) - targets_and_auxs = {} + targets_and_auxs: dict[str, TargetAuxOutput] = {} for loss_name, target_aux in self.target_and_aux_calculators_val.items(): target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) targets_and_auxs[loss_name] = target_aux.compute( diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 0e09fd38d..7150954bf 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import logging +import typing import numpy as np import torch @@ -15,13 +16,26 @@ import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer +from weathergen.datasets.batch import ModelBatch from weathergen.datasets.data_reader_base import TimeWindowHandler +from weathergen.model.model import ModelOutput +from weathergen.train.target_and_aux_module_base import TargetAuxOutput _logger = logging.getLogger(__name__) def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out + cf: config.Config, # only streams const + val_cfg: config.Config, # const + batch_size: int, # const + mini_epoch: int, # get filename + batch_idx: int, # calculate sample_start + dn_data: typing.Callable, # const + batch: ModelBatch, + # contains physical/latent predictions => can ignore latent? + model_output: ModelOutput, + # contains physical/latent targets => can ignore latent? + target_aux_out: dict[str, TargetAuxOutput], ): """ Interface for writing model output From f756e5bab690c8877d4253a7f00f2d01c25563e5 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 5 Feb 2026 08:12:11 +0100 Subject: [PATCH 02/20] remove class attributes from Sample and correctly initialize --- src/weathergen/datasets/batch.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index d106feb08..7380b11f9 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -27,12 +27,13 @@ class SampleMetaData: class Sample: - # keys: stream name, values: SampleMetaData - meta_info: dict[str | SampleMetaData] - # data for all streams - # keys: stream_name, values: StreamData - streams_data: dict[str, StreamData | None] + def __init__(self, streams: dict) -> None: + self.meta_info: dict[str, SampleMetaData] = {} + self.streams_data: dict[str, StreamData | None] = {} + + for stream_info in streams: + self.streams_data[stream_info["name"]] = None def pin_memory(self): """Pin all tensors in this Sample to CPU pinned memory""" @@ -53,13 +54,6 @@ def pin_memory(self): return self - def __init__(self, streams: dict) -> None: - self.meta_info = {} - - self.streams_data = {} - for stream_info in streams: - self.streams_data[stream_info["name"]] = None - def to_device(self, device) -> None: for key in self.meta_info.keys(): self.meta_info[key].mask = ( From 05a0c0c39beefcd5c497cff9b2f728a0031f4b29 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Fri, 27 Feb 2026 11:23:16 +0100 Subject: [PATCH 03/20] refactor batch.py: constructor, typehints docstrings. --- src/weathergen/datasets/batch.py | 81 +++++++++++++++++++------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 7380b11f9..55d150f87 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -5,6 +5,7 @@ - Model data (StreamData objects containing tensors) - View metadata (spatial masks, strategies, relationships) """ +from __future__ import annotations # allow forward references in typehints import copy from dataclasses import dataclass @@ -189,40 +190,52 @@ def pin_memory(self): class ModelBatch: """ Container for all data and metadata for one training batch. - """ - - # source samples (for model) - source_samples: BatchSamples - - # target samples (for TargetAuxCalculator) - target_samples: BatchSamples - - # index of corresponding target (for source samples) or source (for target samples) - # these are in 1-to-1 corresponding for classical training modes (e.g. MTM, forecasting) but - # can be more complex for strategies like student-teacher training - source2target_matching_idxs: np.typing.NDArray[np.int32] - target2source_matching_idxs: np.typing.NDArray[np.int32] - # indices of valid outputs - output_idxs: list[int] - - # device of the tensors in the batch - device: str | torch.device + The data a instance contains is associated with one particular inital sampling window. + From this initial data multiple forecast windows are derived by offsetting the initial window. + Multiple samples for one forecast window can be generated + by sampling differently masked versions of the data. + Note that `output_offset` is the difference in forecast steps + between corresponding source and target samples. + Source and target samples are in 1-to-1 correspondance for classical training modes + (e.g. MTM, forecasting), but can be more complex for strategies like student-teacher training. + This relationship is expressed in `source2target_matching_idxs` + and `target2source_matching_idxs`. + + + Attributes: + source_samples: sources for model + target_samples: targets for TargetAuxCalculator + source2target_matching_idxs: index of corresponding target indices. + target2source_matching_idxs: index of corresponding source indices. + output_idxs: Forecast step indices (including offset) for this batch. The number of + forecast steps is constant per batch, this amortized through data parallel training. + device: device of the tensors in this batch. + """ def __init__( self, streams: dict, num_source_samples: int, num_target_samples: int, - output_offset, - output_steps, + output_offset: int, + output_steps: int, ) -> None: - """ """ + """ + Initialize new ModelBatch. + + Args: + streams: global StreamConfig. + num_source_samples: Number of differently masked source samples for one input window. + num_target_samples: Number of differently masked target samples for one input window. + output_offset: forecast offset for this batch. + output_steps: number of forecast steps for this batch. + """ # define forecast indices - self.output_offset = output_offset - self.output_steps = output_steps - self.output_idxs = list(range(output_offset, output_steps)) + self.output_offset: BatchSamples = output_offset + self.output_steps: BatchSamples = output_steps + self.output_idxs: list[int] = list(range(output_offset, output_steps)) self.source_samples = BatchSamples( streams, num_source_samples, output_steps, self.output_idxs @@ -231,10 +244,15 @@ def __init__( streams, num_target_samples, output_steps, self.output_idxs ) - self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) - self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] + self.source2target_matching_idxs: np.typing.NDArray[np.int32] = np.full( + num_source_samples, -1, dtype=np.int32 + ) + self.target2source_matching_idxs: np.typing.NDArray[np.int32] = [ + [] for _ in range(num_target_samples) + ] + self.device: torch.device | None = None - def pin_memory(self): + def pin_memory(self) -> ModelBatch: """Pin all tensors in this batch to CPU pinned memory""" # pin source samples @@ -245,15 +263,14 @@ def pin_memory(self): return self - def to_device(self, device): # -> ModelBatch + def to_device(self, device: str | torch.device) -> ModelBatch: """ Move batch to device """ + self.device = torch.device(device) - self.source_samples.to_device(device) - self.target_samples.to_device(device) - - self.device = device + self.source_samples.to_device(self.device) + self.target_samples.to_device(self.device) return self From 6d1f3598e7babc35bd81fc5eb34dd4f113f61fb6 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 14 Feb 2026 22:11:09 +0100 Subject: [PATCH 04/20] Improve error handling --- .../datasets/multi_stream_data_sampler.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 556f6942f..0a9a9de76 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -723,9 +723,17 @@ def _get_batch(self, idx: int, num_forecast_steps: int): batch.add_target_stream(tidx, student_indices, stream_name, sdata, target_metadata) source_input_steps: int = input_steps.max().item() - target_in_steps = np.array([tc.get("num_steps_input", 1) for _, tc in target_cfgs.items()]) - target_in_steps: int = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() - batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) + + try: + target_input_steps: int = max( + target_cfg.get("num_steps_input", 1) for target_cfg in target_cfgs.values() + ) + except ValueError: + # empty list produces value error when taking max + target_input_steps: int = 1 + + # only adds tokens lens as attribute to batch + batch = self._preprocess_model_batch(batch, source_input_steps, target_input_steps) return batch From bdf7bd6fac56f694e66886d2eb7b081b2ccb1eab Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 11:42:14 +0100 Subject: [PATCH 05/20] cleanup contructor of TargetAuxOutput --- .../train/target_and_aux_module_base.py | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index b4733a983..a1f5d5e14 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -9,37 +9,28 @@ from __future__ import annotations -import dataclasses - import torch from weathergen.datasets.batch import SampleMetaData +from weathergen.model.engines import LatentState type StreamName = str -@dataclasses.dataclass class TargetAuxOutput: """ A dataclass to encapsulate the TargetAndAuxCalculator output and give a clear API. """ - - output_idxs: list[int] - - physical: list[dict[StreamName, dict[str, torch.Tensor]]] - latent: list[ - dict[ - StreamName, - dict[str, list[torch.Tensor] | list[dict[str, SampleMetaData]] | list[bool]], - ] - ] - aux_outputs: dict[str, torch.Tensor] - def __init__(self, len_target: int, output_idxs: list) -> None: - self.output_idxs = output_idxs - self.physical = [{} for _ in range(len_target)] - self.latent = [{} for _ in range(len_target)] - self.aux_outputs = {} + self.output_idxs: list[int] = output_idxs + self.physical: list[ + dict[ + StreamName, + dict[str, list[torch.Tensor] | list[dict[str, SampleMetaData]] | list[bool]], + ] + ] = [{} for _ in range(len_target)] + self.latent: list[dict[str, torch.Tensor | LatentState]] = [{} for _ in range(len_target)] + self.aux_outputs: dict[str, torch.Tensor] = {} def add_physical_target( self, timestep_idx: int, stream_name: StreamName, pred: torch.Tensor From b3ee0f8d196f7fc800e8312f0be9687ced79f3c8 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 11:43:36 +0100 Subject: [PATCH 06/20] Make logic clearer by using simpler control flow --- src/weathergen/datasets/tokenizer_masking.py | 26 ++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 375143c21..a85b1a37d 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -60,21 +60,27 @@ def get_tokens_windows( Tokenize data (to amortize over the different views that are generated) """ - - tok_spacetime = stream_info.get("tokenize_spacetime", False) - tok = tokenize_spacetime if tok_spacetime else tokenize_space - hl = self.healpix_level - token_size = stream_info["token_size"] - - tokens = [] + tokens: list[tuple[list[list[torch.Tensor | None]], list[list[int]]]] = [] for rdata in data: # skip empty data if rdata.is_empty(): continue # tokenize data - idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens - ) + if stream_info.get("tokenize_spacetime", False): + idxs_cells, idxs_cells_lens = tokenize_spacetime( + readerdata_to_torch(rdata), + stream_info["token_size"], + self.healpix_level, + pad_tokens, + ) + else: + idxs_cells, idxs_cells_lens = tokenize_space( + readerdata_to_torch(rdata), + stream_info["token_size"], + self.healpix_level, + pad_tokens, + ) + tokens += [(idxs_cells, idxs_cells_lens)] return tokens From abd2b72a8fc0e55777375c0ad247e4e55be62b17 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 15:29:48 +0100 Subject: [PATCH 07/20] get path run with run_id only --- packages/common/src/weathergen/common/config.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index f0243a717..9f331af57 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -634,9 +634,14 @@ def load_streams(streams_directory: Path) -> list[Config]: return list(streams.values()) -def get_path_run(config: Config) -> Path: +def get_path_run(config: Config | None = None, run_id: str | None = None) -> Path: """Get the current runs results_path for storing run results and logs.""" - return _get_shared_wg_path() / "results" / get_run_id_from_config(config) + if config or run_id: + run_id = run_id if run_id else get_run_id_from_config(config) + else: + msg = f"Missing run_id and cannot infer it from config: {config}" + raise ValueError(msg) + return _get_shared_wg_path() / "results" / run_id def get_path_model(config: Config | None = None, run_id: str | None = None) -> Path: From 836a1884154f9bf27e6995fd6472b1add0c03e9b Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 5 Feb 2026 08:13:54 +0100 Subject: [PATCH 08/20] add geoinfo information to streams config during stream initialization --- src/weathergen/datasets/multi_stream_data_sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 0a9a9de76..e92e18d07 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -198,6 +198,7 @@ def __init__( if ds.target_channel_weights is not None else [1.0 for _ in ds.target_channels] ) + stream_info["geoinfo_channels"] = ds.geoinfo_channels self.streams_datasets[stream_info["name"]] += [ds] @@ -682,7 +683,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): tidx: int = source_to_target[sidx].item() sdata = self._build_stream_data( source_select, - tidx, + tidx, # Why use not idx here? num_forecast_steps, stream_info, source_masks.metadata[sidx].params.get("num_steps_input", 1), From 7cd13591db996f3c9c9e7cf8c01871e26d139f5e Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 5 Feb 2026 08:57:12 +0100 Subject: [PATCH 09/20] add property to sample to easily retrieve sample_idx --- src/weathergen/datasets/batch.py | 43 ++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index 55d150f87..a0de474b6 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -36,6 +36,24 @@ def __init__(self, streams: dict) -> None: for stream_info in streams: self.streams_data[stream_info["name"]] = None + @property + def sample_idx(self) -> int: + if not self.is_empty(): + # choose 'first' streams stream_data + stream_data = next(iter(self.streams_data.values())) + # TODO check: is reliable sinsce? + # since set with tidx and not idx in MultiStreamDataSampler + idx = stream_data.sample_idx + + assert all( + stream_data.sample_idx == idx for stream_data in self.streams_data.values() + ), "sample_idx should be identical for all streams within one sample." + + return idx + else: + msg = "Cannot infer sample_idx without any streams added." + raise RuntimeError(msg) + def pin_memory(self): """Pin all tensors in this Sample to CPU pinned memory""" @@ -129,6 +147,10 @@ def to_device(self, device): return self + @property + def sample_idxs(self) -> list[int]: + return [sample.sample_idx for sample in self.samples] + def get_samples(self) -> list[Sample]: return self.samples @@ -144,6 +166,7 @@ def get_subset(self, subset: list | None = None): bs.tokens_lens = torch.index_select(bs.tokens_lens, 1, torch_idxs) return bs + # unused def get_num_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams @@ -155,24 +178,28 @@ def get_num_steps(self) -> int: return min(lens) + # unused def get_output_idxs(self) -> int: """ Get forecast indices """ return self.output_idxs + # unused def get_output_len(self) -> int: """ Get length of output """ return self.output_steps + # unused def get_device(self) -> str | torch.device: """ Get device of tensors in the batch """ return self.device + # unused def pin_memory(self): """Pin all tensors in this batch to CPU pinned memory""" @@ -274,6 +301,7 @@ def to_device(self, device: str | torch.device) -> ModelBatch: return self + # used once in MultiStreamDataSampler def add_source_stream( self, source_sample_idx: int, @@ -293,6 +321,7 @@ def add_source_stream( assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" self.source2target_matching_idxs[source_sample_idx] = target_sample_idx + # used once in MultiStreamDataSampler def add_target_stream( self, target_sample_idx: int, @@ -319,6 +348,7 @@ def add_target_stream( ) self.target2source_matching_idxs[target_sample_idx] = source_sample_idx + # used once in MultiStreamDataSampler def is_empty(self): """ Check if batch is empty @@ -331,72 +361,84 @@ def is_empty(self): ) return source_empty or target_empty + # unused def len_sources(self) -> int: """ Number of source samples """ return len(self.source_samples) + # unused def len_targets(self) -> int: """ Number of target samples """ return len(self.target_samples) + # unused def get_source_sample(self, idx: int) -> Sample: """ Get a source sample """ return self.source_samples.samples[idx] + # used in validation_io only def get_source_samples(self, subset: list | None = None) -> BatchSamples: """ Get source samples """ return self.source_samples.get_subset(subset) + # unused def get_target_sample(self, idx: int) -> Sample: """ Get a target sample """ return self.target_samples.samples[idx] + # unused def get_target_samples(self, subset: list | None = None) -> BatchSamples: """ Get target samples """ return self.target_samples.get_subset(subset) + # unsused def get_source_idx_for_target(self, target_idx: int) -> int: """ Get index of source sample for a given target sample index """ return int(self.target2source_matching_idxs[target_idx]) + # unused def get_target_idx_for_source(self, source_idx: int) -> int: """ Get index of target sample for a given source sample index """ return int(self.source2target_matching_idxs[source_idx]) + # unused def get_output_idxs(self) -> int: """ Get valid output steps """ return self.output_idxs + # unused def get_output_len(self) -> int: """ Get length of output """ return self.output_steps + # unused def get_device(self) -> str | torch.device: """ Get device of tensors in the batch """ return self.device + # unused def get_num_source_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams @@ -409,6 +451,7 @@ def get_num_source_steps(self) -> int: return min(lens) + # unused def get_num_target_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams From 128b09004863f1fcd0936f2bbf1a16f4fce35686 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 13:32:27 +0100 Subject: [PATCH 10/20] notes --- integration_tests/test_output.py | 30 +++++++++++++++++++ .../datasets/multi_stream_data_sampler.py | 5 +++- src/weathergen/datasets/stream_data.py | 2 +- .../train/target_and_aux_module_base.py | 2 ++ 4 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 integration_tests/test_output.py diff --git a/integration_tests/test_output.py b/integration_tests/test_output.py new file mode 100644 index 000000000..8f4a65388 --- /dev/null +++ b/integration_tests/test_output.py @@ -0,0 +1,30 @@ +import pytest + +# test cases +# 1: num_input_steps > 1, num_steps > 1 (1527) +# => what sources get written (multiple sources without targets?) => test with offset=1 +# 2: inference on jepa (1759) (jepa_wael) +# 3: jepa pretraining on era5, finetuning/inference on synop (1736) +# 4: guarantee outputput item is never overwritten (1575) +# 5: allow subsetting (for io) channels used as targets (1705) +# forecast offset 0/1 +# predictions without targets/sources +# allow incremental writes +# allow non continouos fsteps +# always have source at fstep=0 + + + + +@pytest.fixture +def output_items(): + pass + +def test_target_source_identity(offset): + pass + +def test_time_coordinates(output_items): + pass + +def test_spatial_coordinates(output_items): + pass \ No newline at end of file diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index e92e18d07..c8b93192c 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -479,6 +479,8 @@ def _build_stream_data( modes : stream_data : base_idx: Time index for this sample + => no its tidx (index of n-th target mask, no info of the time index for this sample + was used in creating tidx) num_forecast_steps: Number of forecast steps stream_info: Stream configuration dict stream_ds: List of dataset readers for this stream @@ -683,7 +685,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): tidx: int = source_to_target[sidx].item() sdata = self._build_stream_data( source_select, - tidx, # Why use not idx here? + tidx, # Why use not idx here? num_forecast_steps, stream_info, source_masks.metadata[sidx].params.get("num_steps_input", 1), @@ -749,6 +751,7 @@ def __iter__(self) -> typing.Generator[ModelBatch, typing.Any, typing.Any]: logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}") # create new shuffeling + # should return permutations instead of relying on self.perms self.reset() # bidx is used to count the #batches that have been emitted diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 152c092b2..31549644f 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -196,7 +196,7 @@ def add_source( idx = torch.isnan(self.source_tokens_cells[step]) self.source_tokens_cells[step][idx] = self.mask_value - def add_target( + def add_target( # TODO what is the diffference to add_target_values??? self, fstep: int, targets: list, diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index a1f5d5e14..920f6e807 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -37,9 +37,11 @@ def add_physical_target( ) -> None: self.physical[timestep_idx][stream_name] = pred + # currently unused? def add_latent_target(self, timestep_idx: int, latent_name: str, pred: torch.Tensor) -> None: self.latent[timestep_idx][latent_name] = pred + # is broken => pred[sample_idx] must be pred["target"][sample_idx] def get_physical_target( self, timestep_idx: int, From 5b88e0d979194bab125574aa9b53fdd43160befe Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 13:34:00 +0100 Subject: [PATCH 11/20] Add methods to retrieve normalized targets/predictios via ItemKeys --- src/weathergen/model/model.py | 24 ++++++++++ .../train/target_and_aux_module_base.py | 47 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index e4271bcc3..f51b8f42a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -11,6 +11,7 @@ import logging import math +import typing import warnings import astropy_healpix as hp @@ -20,6 +21,7 @@ import torch.nn as nn from weathergen.common.config import Config +from weathergen.common.io import ItemKey from weathergen.datasets.batch import ModelBatch from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule @@ -76,6 +78,28 @@ def get_physical_prediction( pred = pred[sample_idx] return pred + def get_physical_prediction_normalized( + self, key: ItemKey, normalizer: typing.Callable + ) -> np.typing.NDArray: + try: + # TODO why is there a tuple => what index should be used + pred = self.physical[key.forecast_step][key.stream][0][key.sample] + except (KeyError, IndexError) as e: + msg = f"Cannot find prediction data for key: {key}" + raise ValueError(msg) from e + + # is it a performance issue if I dont convert/move the entire tensor at once? + pred = pred.to(torch.float32).detach().cpu().numpy() + pred = normalizer(key.stream, pred) + + assert isinstance(pred, np.ndarray), "Invalid buffer type." + # TODO What to do when preds are empty,when does it occur, + # how does it show (empty tensor or missing key?) + assert len(pred) > 0 + + # breakpoint() + return pred + def get_latent_prediction(self, fstep: int): return self.latent[fstep] diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index 920f6e807..e399ada07 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -9,18 +9,32 @@ from __future__ import annotations +import dataclasses +import typing + +import numpy as np import torch from weathergen.datasets.batch import SampleMetaData from weathergen.model.engines import LatentState type StreamName = str +BufferType = np.float32 +DateTimeType = np.datetime64 + + +@dataclasses.dataclass +class PhysicalTarget: + data: np.typing.NDArray[BufferType] + coords: np.typing.NDArray[BufferType] + datetimes: np.typing.NDArray[DateTimeType] class TargetAuxOutput: """ A dataclass to encapsulate the TargetAndAuxCalculator output and give a clear API. """ + def __init__(self, len_target: int, output_idxs: list) -> None: self.output_idxs: list[int] = output_idxs self.physical: list[ @@ -56,6 +70,39 @@ def get_physical_target( pred = pred[sample_idx] return pred + # TODO guarantee every buffer retrieved only once + def get_physical_target_normalized( + self, key: ItemKey, normalizer: typing.Callable + ) -> PhysicalTarget: + try: + stream_targets = self.physical[key.forecast_step][key.stream] + except (KeyError, IndexError) as e: + msg = f"Cannot find physical target data for key: {key}" + raise ValueError(msg) from e + + # is it a performance issue if I dont convert/move the entire tensor at once? + coords = ( + stream_targets["target_coords"][key.sample].to(torch.float32).detach().cpu().numpy() + ) + times = stream_targets["target_times"][key.sample] + + data = stream_targets["target"][key.sample].to(torch.float32).detach().cpu().numpy() + data = normalizer(key.stream, data) + + assert isinstance(data, np.ndarray), "Invalid data buffer type." + assert isinstance(coords, np.ndarray), "Invalid coords buffer type." + assert isinstance(times, np.ndarray), "Invalid datetimes buffer type." + + if len(coords) == 0: # TODO can this be removed? + coords = np.zeros((0, 2), dtype=np.float32) + assert len(coords.shape) >= 2 and coords.shape[-1] >= 2, ( + "invalid shape for coordinate buffer." + ) + assert data.shape[0] == coords.shape[0] == times.shape[0], "buffer shapes should align." + + # breakpoint() + return PhysicalTarget(data, coords, times) + def get_latent_target(self, timestep_idx: int): return self.latent[timestep_idx] From b13962c5a406dd51a9b82aab221d25e69f087510 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 13:40:46 +0100 Subject: [PATCH 12/20] implement new output.Writer class replacing io.OutputBatchData --- src/weathergen/utils/output.py | 210 +++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/weathergen/utils/output.py diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py new file mode 100644 index 000000000..1e39aeecc --- /dev/null +++ b/src/weathergen/utils/output.py @@ -0,0 +1,210 @@ +from __future__ import annotations # allow forward references in typehints + +import dataclasses +import itertools +import typing +from collections.abc import Callable + +import numpy as np +from omegaconf.errors import ConfigAttributeError + +from weathergen.common.config import Config, get_path_results +from weathergen.common.io import ( + ItemKey, + OutputDataset, + OutputItem, + TimeRange, + zarrio_writer, +) +from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.data_reader_base import TimeWindowHandler +from weathergen.model.model import ModelOutput +from weathergen.train.target_and_aux_module_base import PhysicalTarget, TargetAuxOutput + + +class Writer: + def __init__( + self, + config: Config, + val_cfg: Config, + streams: list[Config], + ): + streams = {stream.name: stream for stream in streams} + # TODO: nice: dont store all config + self._twh = TimeWindowHandler( + val_cfg.start_date, + val_cfg.end_date, + val_cfg.time_window_len, + val_cfg.time_window_step, + ) + + _all_streams = list(streams.keys()) + _output_streams = val_cfg.get("output", None).get("streams", None) + _output_streams = val_cfg.output.streams if val_cfg.output.streams else _all_streams + + self._streams = { + name: config for name, config in streams.items() if name in _output_streams + } + self._forecast_offset = val_cfg.forecast.offset + self._cf: Config = config # used for zarr output path lookup => improve + + def write_batch( + self, + mini_epoch: int, # TODO: nice: use iterstep for better consistency? + batch: ModelBatch, + targets: TargetAuxOutput, + predictions: ModelOutput, + normalizer: typing.Callable, + fsteps: range, + ) -> None: + data = _BatchOutputData(batch, targets, predictions, normalizer) + # TODO: nice: get result path differently + with zarrio_writer(get_path_results(self._cf, mini_epoch)) as zio: + for subset in self.itemize(data, fsteps): + zio.write_zarr(subset) + + def itemize( + self, data: _BatchOutputData, fstep_range: range + ) -> typing.Generator[OutputItem, None, None]: + """Iterate over possible output items""" + + # TODO: check: filter for empty items? + for key in self.keys(fstep_range, data.samples): + yield self.extract(data, key) + + def keys( + self, forecast_steps: range, samples: list[int] + ) -> typing.Generator[ItemKey, None, None]: + """Iterate over possible output items""" + streams: list[str] = self._streams.keys() + + # The order of iteration is important here: + # streams is the outermost and samples the innermost loop variable + # This is important since normalization is best done at a per stream/fstep basis + for stream, forecast_step, sample_idx in itertools.product( + streams, forecast_steps, samples + ): + yield ItemKey(sample_idx, forecast_step, stream) + + def extract(self, data: _BatchOutputData, key: ItemKey) -> OutputItem: + data_invariants = self._get_invariants(key) + + source, target, prediction = None, None, None + if key.with_source: + source = data.extract_source(key).as_dataset(key, data_invariants) + if key.with_target(self._forecast_offset): + target = data.extract_target(key).as_dataset(key, data_invariants) + prediction = data.extract_prediction(key).as_dataset(key, data_invariants) + + return OutputItem( + key, + self._forecast_offset, # TODO nice: maybe drop it? + target, + prediction, + source + ) + + def _get_invariants(self, key: ItemKey) -> _DataInvariants: + # TODO unify DTRange and TimeRange classes + window = self._twh.window(key.sample) + return _DataInvariants( + source_interval=TimeRange(window.start, window.end), + # val_source_channels are ListConfig[str] objects -> convert to list[str] + source_channels=list(self._streams[key.stream].val_source_channels), + target_channels=list(self._streams[key.stream].val_target_channels), + geoinfo_channels=list(self._streams[key.stream].geoinfo_channels), + ) + + +@dataclasses.dataclass +class _BatchOutputData: + _batch: ModelBatch + _targets: TargetAuxOutput + _predictions: ModelOutput + _normalizer: Callable + + def extract_source(self, key: ItemKey) -> _ExtractedData: + # TODO check this? + # breakpoint() + READER_DATA_INDEX_MYSTERY = 0 + source = ( + self._batch.source_samples.samples[key.forecast_step] + .streams_data[key.stream] + .source_raw[READER_DATA_INDEX_MYSTERY] + ) + + return _ExtractedData( + "prediction", + np.asarray(source.data), + np.asarray(source.datetimes), + np.asarray(source.coords), + np.asarray(source.geoinfos), + ) + + @property + def samples(self): + # TODO check: data._batch.source_samples.samples + return self._batch.source_samples.sample_idxs + + def extract_target(self, key: ItemKey) -> _ExtractedData: + target = self._target(key) + coords = self._target_coordinates(target) + + return _ExtractedData("target", target.data, coords.times, coords.coords, coords.geoinfos) + + def extract_prediction(self, key: ItemKey) -> _ExtractedData: + try: + data = self._predictions.get_physical_prediction_normalized(key, self._normalizer) + except Exception as e: + # TODO: if preds are empty so create copy of target and add ensemble dimension + # preds = [targets[0].clone().unsqueeze(0)] + raise ValueError("not handled yet") from e + data = self._target(key) + target = self._target(key) + coords = self._target_coordinates(target) + + return _ExtractedData("prediction", data, coords.times, coords.coords, coords.geoinfos) + + # TODO guarantee this method is only called once per OutputItem + # TODO try getting targets from batch directly + def _target(self, key: ItemKey) -> PhysicalTarget: + return self._targets.get_physical_target_normalized(key, self._normalizer) + + def _target_coordinates(self, target: PhysicalTarget) -> _ExtractedData: + coords = target.coords[..., :2] # first two columns are lat,lon + geoinfos = target.coords[..., 2:] # the rest is geoinfo => potentially empty + + return _ExtractedData("", None, target.datetimes, coords, geoinfos) + + +@dataclasses.dataclass +class _DataInvariants: + source_interval: TimeRange + source_channels: list[str] + target_channels: list[str] + geoinfo_channels: list[str] + + +@dataclasses.dataclass +class _ExtractedData: + name: str # TODO make enum + data: typing.Any + times: typing.Any + coords: typing.Any + geoinfos: typing.Any + + def as_dataset(self, key: ItemKey, invariants: _DataInvariants) -> OutputDataset: + if self.data is None or self.data.shape == (0, 0): + return None + else: + return OutputDataset( + name=self.name, + item_key=key, + data=self.data, + times=self.times, + coords=self.coords, + geoinfo=self.geoinfos, + source_interval=invariants.source_interval, + channels=invariants.target_channels, + geoinfo_channels=invariants.geoinfo_channels, + ) From 5c351e5da80d50b5532a949506177055aa196c4c Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 19 Feb 2026 14:04:08 +0100 Subject: [PATCH 13/20] Use new output.Writer for writing output --- src/weathergen/train/trainer.py | 41 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 888117fd8..9bdd2f86d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -48,6 +48,7 @@ get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root +from weathergen.utils.output import Writer from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output @@ -55,6 +56,7 @@ logger = logging.getLogger(__name__) # cfg_keys_to_filter = ["losses", "model_input", "target_input"] +PHYSICAL_LOSS_KEY = "physical" class Trainer(TrainerBase): @@ -67,7 +69,8 @@ def __init__(self, train_log_freq: Config): self.data_loader_validation: torch.utils.data.DataLoader | None = None self.dataset: MultiStreamDataSampler | None = None self.dataset_val: MultiStreamDataSampler | None = None - self.device: torch.device = None + self.output_writer: Writer | None = None + self.device: torch.device | None = None self.ema_model = None self.grad_scaler: torch.amp.GradScaler | None = None self.last_grad_norm = None @@ -93,7 +96,7 @@ def get_batch_size_total(self, batch_size_per_gpu) -> int: """ return self.world_size_original * batch_size_per_gpu - def init(self, cf: Config, devices): + def init(self, cf: Config, devices, use_test_config=False): # pylint: disable=attribute-defined-outside-init self.cf = OmegaConf.merge( OmegaConf.create( @@ -156,6 +159,10 @@ def init(self, cf: Config, devices): config.get_path_model(cf).mkdir(exist_ok=True, parents=True) self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + if use_test_config: + self.output_writer = Writer(cf, self.test_cfg, cf.streams) + else: + self.output_writer = Writer(cf, self.validation_cfg, cf.streams) # Initialize collapse monitor for SSL training collapse_config = cf.train_logging.get("collapse_monitoring", {}) @@ -181,7 +188,7 @@ def get_target_aux_calculators(self, mode_cfg): def inference(self, cf, devices, run_id_contd, mini_epoch_contd): # general initalization - self.init(cf, devices) + self.init(cf, devices, True) cf = self.cf device_type = torch.accelerator.current_accelerator() @@ -607,17 +614,23 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if mode_cfg.get("output", {}).get("normalized_samples", False) else self.dataset_val.denormalize_target_channels ) - # write output - write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - preds, - targets_and_auxs, + + fsteps = range(0, 1) # TODO + targets = [ + targets + for loss_name, targets in targets_and_auxs.items() + if loss_name == PHYSICAL_LOSS_KEY + ] + try: + # targets for all physical outputs should be the same + targets = targets[0] + except IndexError as e: + raise ValueError( + f"No physical outputs under key {PHYSICAL_LOSS_KEY} configured" + ) from e + + self.output_writer.write_batch( + mini_epoch, batch, targets, preds, denormalize_data_fct, fsteps ) pbar.update(batch_size) From 155c955ef93e4f4663136f82d569429bd2cbb004 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 25 Feb 2026 00:52:39 +0100 Subject: [PATCH 14/20] handle incremental sample writes --- src/weathergen/utils/output.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py index 1e39aeecc..65093ce4c 100644 --- a/src/weathergen/utils/output.py +++ b/src/weathergen/utils/output.py @@ -47,6 +47,7 @@ def __init__( } self._forecast_offset = val_cfg.forecast.offset self._cf: Config = config # used for zarr output path lookup => improve + self._sample_start = 0 def write_batch( self, @@ -77,6 +78,8 @@ def keys( ) -> typing.Generator[ItemKey, None, None]: """Iterate over possible output items""" streams: list[str] = self._streams.keys() + len_samples = len(samples) + samples = (self._sample_start + sample for sample in samples) # The order of iteration is important here: # streams is the outermost and samples the innermost loop variable @@ -85,16 +88,19 @@ def keys( streams, forecast_steps, samples ): yield ItemKey(sample_idx, forecast_step, stream) + + self._sample_start += len_samples def extract(self, data: _BatchOutputData, key: ItemKey) -> OutputItem: - data_invariants = self._get_invariants(key) + raw_key = ItemKey(key.sample-self._sample_start, key.forecast_step, key.stream) + data_invariants = self._get_invariants(raw_key) source, target, prediction = None, None, None if key.with_source: - source = data.extract_source(key).as_dataset(key, data_invariants) + source = data.extract_source(raw_key).as_dataset(key, data_invariants) if key.with_target(self._forecast_offset): - target = data.extract_target(key).as_dataset(key, data_invariants) - prediction = data.extract_prediction(key).as_dataset(key, data_invariants) + target = data.extract_target(raw_key).as_dataset(key, data_invariants) + prediction = data.extract_prediction(raw_key).as_dataset(key, data_invariants) return OutputItem( key, From 9ee68ac34e0f8e836198c4cdae13730605c6fd60 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 25 Feb 2026 00:51:06 +0100 Subject: [PATCH 15/20] handle incremental fstep writes --- src/weathergen/train/trainer.py | 4 ++-- src/weathergen/utils/output.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 9bdd2f86d..f58ed7dd1 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -615,7 +615,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): else self.dataset_val.denormalize_target_channels ) - fsteps = range(0, 1) # TODO + fstep_start = 0 targets = [ targets for loss_name, targets in targets_and_auxs.items() @@ -630,7 +630,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): ) from e self.output_writer.write_batch( - mini_epoch, batch, targets, preds, denormalize_data_fct, fsteps + mini_epoch, batch, targets, preds, denormalize_data_fct, fstep_start ) pbar.update(batch_size) diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py index 65093ce4c..40f401563 100644 --- a/src/weathergen/utils/output.py +++ b/src/weathergen/utils/output.py @@ -56,22 +56,24 @@ def write_batch( targets: TargetAuxOutput, predictions: ModelOutput, normalizer: typing.Callable, - fsteps: range, + fstep_start: int, ) -> None: data = _BatchOutputData(batch, targets, predictions, normalizer) + fstep_range = range(fstep_start, len(predictions.physical)) # TODO: nice: get result path differently with zarrio_writer(get_path_results(self._cf, mini_epoch)) as zio: - for subset in self.itemize(data, fsteps): + for subset in self.itemize(data, fstep_range): zio.write_zarr(subset) def itemize( self, data: _BatchOutputData, fstep_range: range ) -> typing.Generator[OutputItem, None, None]: """Iterate over possible output items""" + fstep_start = min(fstep_range) # TODO: check: filter for empty items? for key in self.keys(fstep_range, data.samples): - yield self.extract(data, key) + yield self.extract(data, key, fstep_start) def keys( self, forecast_steps: range, samples: list[int] @@ -91,8 +93,8 @@ def keys( self._sample_start += len_samples - def extract(self, data: _BatchOutputData, key: ItemKey) -> OutputItem: - raw_key = ItemKey(key.sample-self._sample_start, key.forecast_step, key.stream) + def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> OutputItem: + raw_key = ItemKey(key.sample-self._sample_start, key.forecast_step-fstep_start, key.stream) data_invariants = self._get_invariants(raw_key) source, target, prediction = None, None, None From 0e2098932fc6e3ac3b475a820b3744ede105fd59 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 25 Feb 2026 00:21:15 +0100 Subject: [PATCH 16/20] comment --- packages/common/src/weathergen/common/io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2243cdee8..f33724812 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -353,6 +353,7 @@ def __init__( self._append_dataset(self.source, "source") if self.key.with_target(forecast_offset): + # TODO requiring target data currently prevents predictions with unknowable target self._append_dataset(self.target, "target") self._append_dataset(self.prediction, "prediction") From 240a1961f44322ddb9bdb0f8894c27821d06b0e7 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 25 Feb 2026 00:57:34 +0100 Subject: [PATCH 17/20] formatting --- src/weathergen/datasets/batch.py | 3 +-- src/weathergen/utils/output.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index a0de474b6..c3b71902b 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -28,9 +28,8 @@ class SampleMetaData: class Sample: - def __init__(self, streams: dict) -> None: - self.meta_info: dict[str, SampleMetaData] = {} + self.meta_info: dict[str, SampleMetaData] = {} self.streams_data: dict[str, StreamData | None] = {} for stream_info in streams: diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py index 40f401563..874826fa4 100644 --- a/src/weathergen/utils/output.py +++ b/src/weathergen/utils/output.py @@ -90,11 +90,13 @@ def keys( streams, forecast_steps, samples ): yield ItemKey(sample_idx, forecast_step, stream) - + self._sample_start += len_samples def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> OutputItem: - raw_key = ItemKey(key.sample-self._sample_start, key.forecast_step-fstep_start, key.stream) + raw_key = ItemKey( + key.sample - self._sample_start, key.forecast_step - fstep_start, key.stream + ) data_invariants = self._get_invariants(raw_key) source, target, prediction = None, None, None @@ -109,8 +111,8 @@ def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> Out self._forecast_offset, # TODO nice: maybe drop it? target, prediction, - source - ) + source, + ) def _get_invariants(self, key: ItemKey) -> _DataInvariants: # TODO unify DTRange and TimeRange classes @@ -148,7 +150,7 @@ def extract_source(self, key: ItemKey) -> _ExtractedData: np.asarray(source.coords), np.asarray(source.geoinfos), ) - + @property def samples(self): # TODO check: data._batch.source_samples.samples From f9ecf5009df01cf91c4b8a554bb3f2d7911e6f07 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Fri, 27 Feb 2026 11:34:02 +0100 Subject: [PATCH 18/20] Add initial time window to batch. --- src/weathergen/datasets/batch.py | 5 +++++ src/weathergen/datasets/multi_stream_data_sampler.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index c3b71902b..c34521ee2 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -11,6 +11,7 @@ from dataclasses import dataclass import numpy as np +from src.weathergen.datasets.data_reader_base import DTRange import torch from weathergen.common.config import Config @@ -230,6 +231,7 @@ class ModelBatch: Attributes: + init_time: Initial sampling window used to costruct samples. source_samples: sources for model target_samples: targets for TargetAuxCalculator source2target_matching_idxs: index of corresponding target indices. @@ -241,6 +243,7 @@ class ModelBatch: def __init__( self, + init_time: DTRange, streams: dict, num_source_samples: int, num_target_samples: int, @@ -251,12 +254,14 @@ def __init__( Initialize new ModelBatch. Args: + init_time: Initial sampling window used to costruct samples. streams: global StreamConfig. num_source_samples: Number of differently masked source samples for one input window. num_target_samples: Number of differently masked target samples for one input window. output_offset: forecast offset for this batch. output_steps: number of forecast steps for this batch. """ + self.init_time: DTRange = init_time # define forecast indices self.output_offset: BatchSamples = output_offset diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c8b93192c..733113802 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -647,6 +647,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): num_output_steps = self._get_output_length(num_forecast_steps) batch = ModelBatch( + self.time_window_handler.window(idx), self.streams, num_source_samples, num_target_samples, From ea132bdef69f7d1692b2adc8e387fccc29e74a20 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Fri, 27 Feb 2026 12:06:37 +0100 Subject: [PATCH 19/20] get source window from batch instead of separate retriaval using TimeWindowHandler --- src/weathergen/utils/output.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py index 874826fa4..923bed1c5 100644 --- a/src/weathergen/utils/output.py +++ b/src/weathergen/utils/output.py @@ -17,7 +17,6 @@ zarrio_writer, ) from weathergen.datasets.batch import ModelBatch -from weathergen.datasets.data_reader_base import TimeWindowHandler from weathergen.model.model import ModelOutput from weathergen.train.target_and_aux_module_base import PhysicalTarget, TargetAuxOutput @@ -31,12 +30,6 @@ def __init__( ): streams = {stream.name: stream for stream in streams} # TODO: nice: dont store all config - self._twh = TimeWindowHandler( - val_cfg.start_date, - val_cfg.end_date, - val_cfg.time_window_len, - val_cfg.time_window_step, - ) _all_streams = list(streams.keys()) _output_streams = val_cfg.get("output", None).get("streams", None) @@ -98,6 +91,7 @@ def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> Out key.sample - self._sample_start, key.forecast_step - fstep_start, key.stream ) data_invariants = self._get_invariants(raw_key) + data_invariants.source_interval = data.source_interval source, target, prediction = None, None, None if key.with_source: @@ -116,9 +110,7 @@ def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> Out def _get_invariants(self, key: ItemKey) -> _DataInvariants: # TODO unify DTRange and TimeRange classes - window = self._twh.window(key.sample) return _DataInvariants( - source_interval=TimeRange(window.start, window.end), # val_source_channels are ListConfig[str] objects -> convert to list[str] source_channels=list(self._streams[key.stream].val_source_channels), target_channels=list(self._streams[key.stream].val_target_channels), @@ -150,6 +142,11 @@ def extract_source(self, key: ItemKey) -> _ExtractedData: np.asarray(source.coords), np.asarray(source.geoinfos), ) + + @property + def source_interval(self) -> TimeRange: + window = self._batch.init_time + return TimeRange(window.start, window.end) @property def samples(self): @@ -189,10 +186,10 @@ def _target_coordinates(self, target: PhysicalTarget) -> _ExtractedData: @dataclasses.dataclass class _DataInvariants: - source_interval: TimeRange source_channels: list[str] target_channels: list[str] geoinfo_channels: list[str] + source_interval: TimeRange | None = None @dataclasses.dataclass From 8ed6d54c991c6782a46be338d230f51861871256 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Fri, 27 Feb 2026 12:19:51 +0100 Subject: [PATCH 20/20] correct usage of sample_idx --- src/weathergen/utils/output.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py index 923bed1c5..e2c25a59d 100644 --- a/src/weathergen/utils/output.py +++ b/src/weathergen/utils/output.py @@ -73,7 +73,6 @@ def keys( ) -> typing.Generator[ItemKey, None, None]: """Iterate over possible output items""" streams: list[str] = self._streams.keys() - len_samples = len(samples) samples = (self._sample_start + sample for sample in samples) # The order of iteration is important here: @@ -84,7 +83,7 @@ def keys( ): yield ItemKey(sample_idx, forecast_step, stream) - self._sample_start += len_samples + self._sample_start += len(samples) def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> OutputItem: raw_key = ItemKey( @@ -128,11 +127,11 @@ class _BatchOutputData: def extract_source(self, key: ItemKey) -> _ExtractedData: # TODO check this? # breakpoint() - READER_DATA_INDEX_MYSTERY = 0 + source_sample_idx = self._batch.target2source_matching_idxs(key.sample) source = ( self._batch.source_samples.samples[key.forecast_step] .streams_data[key.stream] - .source_raw[READER_DATA_INDEX_MYSTERY] + .source_raw[source_sample_idx] ) return _ExtractedData( @@ -149,9 +148,13 @@ def source_interval(self) -> TimeRange: return TimeRange(window.start, window.end) @property - def samples(self): + def samples(self) -> list[int]: # TODO check: data._batch.source_samples.samples - return self._batch.source_samples.sample_idxs + sampels = list(self._batch.target_samples.sample_idxs) + + assert len(set(sampels)) == len(sampels), "samples are not unique" + + return self.samples def extract_target(self, key: ItemKey) -> _ExtractedData: target = self._target(key)