diff --git a/integration_tests/test_output.py b/integration_tests/test_output.py new file mode 100644 index 000000000..8f4a65388 --- /dev/null +++ b/integration_tests/test_output.py @@ -0,0 +1,30 @@ +import pytest + +# test cases +# 1: num_input_steps > 1, num_steps > 1 (1527) +# => what sources get written (multiple sources without targets?) => test with offset=1 +# 2: inference on jepa (1759) (jepa_wael) +# 3: jepa pretraining on era5, finetuning/inference on synop (1736) +# 4: guarantee outputput item is never overwritten (1575) +# 5: allow subsetting (for io) channels used as targets (1705) +# forecast offset 0/1 +# predictions without targets/sources +# allow incremental writes +# allow non continouos fsteps +# always have source at fstep=0 + + + + +@pytest.fixture +def output_items(): + pass + +def test_target_source_identity(offset): + pass + +def test_time_coordinates(output_items): + pass + +def test_spatial_coordinates(output_items): + pass \ No newline at end of file diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index f0243a717..9f331af57 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -634,9 +634,14 @@ def load_streams(streams_directory: Path) -> list[Config]: return list(streams.values()) -def get_path_run(config: Config) -> Path: +def get_path_run(config: Config | None = None, run_id: str | None = None) -> Path: """Get the current runs results_path for storing run results and logs.""" - return _get_shared_wg_path() / "results" / get_run_id_from_config(config) + if config or run_id: + run_id = run_id if run_id else get_run_id_from_config(config) + else: + msg = f"Missing run_id and cannot infer it from config: {config}" + raise ValueError(msg) + return _get_shared_wg_path() / "results" / run_id def get_path_model(config: Config | None = None, run_id: str | None = None) -> Path: diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2243cdee8..f33724812 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -353,6 +353,7 @@ def __init__( self._append_dataset(self.source, "source") if self.key.with_target(forecast_offset): + # TODO requiring target data currently prevents predictions with unknowable target self._append_dataset(self.target, "target") self._append_dataset(self.prediction, "prediction") diff --git a/src/weathergen/datasets/batch.py b/src/weathergen/datasets/batch.py index d106feb08..c34521ee2 100644 --- a/src/weathergen/datasets/batch.py +++ b/src/weathergen/datasets/batch.py @@ -5,11 +5,13 @@ - Model data (StreamData objects containing tensors) - View metadata (spatial masks, strategies, relationships) """ +from __future__ import annotations # allow forward references in typehints import copy from dataclasses import dataclass import numpy as np +from src.weathergen.datasets.data_reader_base import DTRange import torch from weathergen.common.config import Config @@ -27,12 +29,30 @@ class SampleMetaData: class Sample: - # keys: stream name, values: SampleMetaData - meta_info: dict[str | SampleMetaData] + def __init__(self, streams: dict) -> None: + self.meta_info: dict[str, SampleMetaData] = {} + self.streams_data: dict[str, StreamData | None] = {} + + for stream_info in streams: + self.streams_data[stream_info["name"]] = None - # data for all streams - # keys: stream_name, values: StreamData - streams_data: dict[str, StreamData | None] + @property + def sample_idx(self) -> int: + if not self.is_empty(): + # choose 'first' streams stream_data + stream_data = next(iter(self.streams_data.values())) + # TODO check: is reliable sinsce? + # since set with tidx and not idx in MultiStreamDataSampler + idx = stream_data.sample_idx + + assert all( + stream_data.sample_idx == idx for stream_data in self.streams_data.values() + ), "sample_idx should be identical for all streams within one sample." + + return idx + else: + msg = "Cannot infer sample_idx without any streams added." + raise RuntimeError(msg) def pin_memory(self): """Pin all tensors in this Sample to CPU pinned memory""" @@ -53,13 +73,6 @@ def pin_memory(self): return self - def __init__(self, streams: dict) -> None: - self.meta_info = {} - - self.streams_data = {} - for stream_info in streams: - self.streams_data[stream_info["name"]] = None - def to_device(self, device) -> None: for key in self.meta_info.keys(): self.meta_info[key].mask = ( @@ -134,6 +147,10 @@ def to_device(self, device): return self + @property + def sample_idxs(self) -> list[int]: + return [sample.sample_idx for sample in self.samples] + def get_samples(self) -> list[Sample]: return self.samples @@ -149,6 +166,7 @@ def get_subset(self, subset: list | None = None): bs.tokens_lens = torch.index_select(bs.tokens_lens, 1, torch_idxs) return bs + # unused def get_num_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams @@ -160,24 +178,28 @@ def get_num_steps(self) -> int: return min(lens) + # unused def get_output_idxs(self) -> int: """ Get forecast indices """ return self.output_idxs + # unused def get_output_len(self) -> int: """ Get length of output """ return self.output_steps + # unused def get_device(self) -> str | torch.device: """ Get device of tensors in the batch """ return self.device + # unused def pin_memory(self): """Pin all tensors in this batch to CPU pinned memory""" @@ -195,40 +217,56 @@ def pin_memory(self): class ModelBatch: """ Container for all data and metadata for one training batch. - """ - - # source samples (for model) - source_samples: BatchSamples - # target samples (for TargetAuxCalculator) - target_samples: BatchSamples - - # index of corresponding target (for source samples) or source (for target samples) - # these are in 1-to-1 corresponding for classical training modes (e.g. MTM, forecasting) but - # can be more complex for strategies like student-teacher training - source2target_matching_idxs: np.typing.NDArray[np.int32] - target2source_matching_idxs: np.typing.NDArray[np.int32] - - # indices of valid outputs - output_idxs: list[int] - - # device of the tensors in the batch - device: str | torch.device + The data a instance contains is associated with one particular inital sampling window. + From this initial data multiple forecast windows are derived by offsetting the initial window. + Multiple samples for one forecast window can be generated + by sampling differently masked versions of the data. + Note that `output_offset` is the difference in forecast steps + between corresponding source and target samples. + Source and target samples are in 1-to-1 correspondance for classical training modes + (e.g. MTM, forecasting), but can be more complex for strategies like student-teacher training. + This relationship is expressed in `source2target_matching_idxs` + and `target2source_matching_idxs`. + + + Attributes: + init_time: Initial sampling window used to costruct samples. + source_samples: sources for model + target_samples: targets for TargetAuxCalculator + source2target_matching_idxs: index of corresponding target indices. + target2source_matching_idxs: index of corresponding source indices. + output_idxs: Forecast step indices (including offset) for this batch. The number of + forecast steps is constant per batch, this amortized through data parallel training. + device: device of the tensors in this batch. + """ def __init__( self, + init_time: DTRange, streams: dict, num_source_samples: int, num_target_samples: int, - output_offset, - output_steps, + output_offset: int, + output_steps: int, ) -> None: - """ """ + """ + Initialize new ModelBatch. + + Args: + init_time: Initial sampling window used to costruct samples. + streams: global StreamConfig. + num_source_samples: Number of differently masked source samples for one input window. + num_target_samples: Number of differently masked target samples for one input window. + output_offset: forecast offset for this batch. + output_steps: number of forecast steps for this batch. + """ + self.init_time: DTRange = init_time # define forecast indices - self.output_offset = output_offset - self.output_steps = output_steps - self.output_idxs = list(range(output_offset, output_steps)) + self.output_offset: BatchSamples = output_offset + self.output_steps: BatchSamples = output_steps + self.output_idxs: list[int] = list(range(output_offset, output_steps)) self.source_samples = BatchSamples( streams, num_source_samples, output_steps, self.output_idxs @@ -237,10 +275,15 @@ def __init__( streams, num_target_samples, output_steps, self.output_idxs ) - self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32) - self.target2source_matching_idxs = [[] for _ in range(num_target_samples)] + self.source2target_matching_idxs: np.typing.NDArray[np.int32] = np.full( + num_source_samples, -1, dtype=np.int32 + ) + self.target2source_matching_idxs: np.typing.NDArray[np.int32] = [ + [] for _ in range(num_target_samples) + ] + self.device: torch.device | None = None - def pin_memory(self): + def pin_memory(self) -> ModelBatch: """Pin all tensors in this batch to CPU pinned memory""" # pin source samples @@ -251,18 +294,18 @@ def pin_memory(self): return self - def to_device(self, device): # -> ModelBatch + def to_device(self, device: str | torch.device) -> ModelBatch: """ Move batch to device """ + self.device = torch.device(device) - self.source_samples.to_device(device) - self.target_samples.to_device(device) - - self.device = device + self.source_samples.to_device(self.device) + self.target_samples.to_device(self.device) return self + # used once in MultiStreamDataSampler def add_source_stream( self, source_sample_idx: int, @@ -282,6 +325,7 @@ def add_source_stream( assert target_sample_idx < len(self.target_samples), "invalid value for target_sample_idx" self.source2target_matching_idxs[source_sample_idx] = target_sample_idx + # used once in MultiStreamDataSampler def add_target_stream( self, target_sample_idx: int, @@ -308,6 +352,7 @@ def add_target_stream( ) self.target2source_matching_idxs[target_sample_idx] = source_sample_idx + # used once in MultiStreamDataSampler def is_empty(self): """ Check if batch is empty @@ -320,72 +365,84 @@ def is_empty(self): ) return source_empty or target_empty + # unused def len_sources(self) -> int: """ Number of source samples """ return len(self.source_samples) + # unused def len_targets(self) -> int: """ Number of target samples """ return len(self.target_samples) + # unused def get_source_sample(self, idx: int) -> Sample: """ Get a source sample """ return self.source_samples.samples[idx] + # used in validation_io only def get_source_samples(self, subset: list | None = None) -> BatchSamples: """ Get source samples """ return self.source_samples.get_subset(subset) + # unused def get_target_sample(self, idx: int) -> Sample: """ Get a target sample """ return self.target_samples.samples[idx] + # unused def get_target_samples(self, subset: list | None = None) -> BatchSamples: """ Get target samples """ return self.target_samples.get_subset(subset) + # unsused def get_source_idx_for_target(self, target_idx: int) -> int: """ Get index of source sample for a given target sample index """ return int(self.target2source_matching_idxs[target_idx]) + # unused def get_target_idx_for_source(self, source_idx: int) -> int: """ Get index of target sample for a given source sample index """ return int(self.source2target_matching_idxs[source_idx]) + # unused def get_output_idxs(self) -> int: """ Get valid output steps """ return self.output_idxs + # unused def get_output_len(self) -> int: """ Get length of output """ return self.output_steps + # unused def get_device(self) -> str | torch.device: """ Get device of tensors in the batch """ return self.device + # unused def get_num_source_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams @@ -398,6 +455,7 @@ def get_num_source_steps(self) -> int: return min(lens) + # unused def get_num_target_steps(self) -> int: """ Get number of input/source steps from smallest of all available streams diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index f84111541..fb8615670 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -259,7 +259,7 @@ def build_samples_for_stream( num_cells: int, stage_cfg: dict, stream_cfg: dict, - ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: + ) -> tuple[MaskData, MaskData, np.typing.NDArray[np.int32]]: """ Construct teacher/student keep masks for a stream. SampleMetaData is currently just a dict with the masking params used. @@ -355,7 +355,7 @@ def build_samples_for_stream( source_target_mapping += [target_idx] i_source += 1 - source_target_mapping = np.array(source_target_mapping, dtype=np.int32) + source_target_mapping: NDArray[np.int32] = np.array(source_target_mapping, dtype=np.int32) return (target_masks, source_masks, source_target_mapping) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 179333b8e..733113802 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -9,6 +9,7 @@ import logging import pathlib +import typing import numpy as np import torch @@ -24,7 +25,7 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs -from weathergen.datasets.masking import Masker +from weathergen.datasets.masking import MaskData, Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_masking import TokenizerMasking from weathergen.datasets.utils import ( @@ -197,6 +198,7 @@ def __init__( if ds.target_channel_weights is not None else [1.0 for _ in ds.target_channels] ) + stream_info["geoinfo_channels"] = ds.geoinfo_channels self.streams_datasets[stream_info["name"]] += [ds] @@ -477,6 +479,8 @@ def _build_stream_data( modes : stream_data : base_idx: Time index for this sample + => no its tidx (index of n-th target mask, no info of the time index for this sample + was used in creating tidx) num_forecast_steps: Number of forecast steps stream_info: Stream configuration dict stream_ds: List of dataset readers for this stream @@ -521,7 +525,9 @@ def _build_stream_data( return stream_data - def _get_data_windows(self, base_idx, num_forecast_steps, num_steps_input_max, stream_ds): + def _get_data_windows( + self, base_idx, num_forecast_steps, num_steps_input_max, stream_ds + ) -> tuple[list[IOReaderData], list[IOReaderData]]: """ Collect all data needed for current stream to potentially amortize costs by generating multiple samples @@ -578,7 +584,7 @@ def _get_source_target_masks(self, training_mode): Generate source and target masks for all streams """ - masks = {} + masks: dict[str, tuple[MaskData, MaskData, np.NDArray[np.int32]]] = {} for stream_info in self.streams: # Build source and target sample masks masks[stream_info["name"]] = self.tokenizer.build_samples_for_stream( @@ -587,7 +593,7 @@ def _get_source_target_masks(self, training_mode): self.mode_cfg, stream_info, ) - # identical for all streams + # shape identical for all streams: #target/source config x num_samples num_target_samples = len(masks[stream_info["name"]][0]) num_source_samples = len(masks[stream_info["name"]][1]) @@ -619,12 +625,15 @@ def _get_batch(self, idx: int, num_forecast_steps: int): mode = self.mode_cfg.get("training_mode") source_cfgs = self.mode_cfg.get("model_input") - target_cfgs = self.mode_cfg.get("target_input", {}) + target_cfgs: typing.Mapping = self.mode_cfg.get("target_input", {}) # get/coordinate masks masks_streams, num_source_samples, num_target_samples = self._get_source_target_masks(mode) - source_select, target_select = [], [] + # contain string flags + source_select: list[str] = [] + target_select: list[str] = [] + if "masking" in mode: source_select += ["network_input", "target_coords"] target_select += ["target_values"] @@ -638,6 +647,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): num_output_steps = self._get_output_length(num_forecast_steps) batch = ModelBatch( + self.time_window_handler.window(idx), self.streams, num_source_samples, num_target_samples, @@ -649,6 +659,7 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for stream_info, (stream_name, stream_ds) in zip( self.streams, self.streams_datasets.items(), strict=True ): + stream_info: Config (target_masks, source_masks, source_to_target) = masks_streams[stream_name] # max number of input steps @@ -672,10 +683,10 @@ def _get_batch(self, idx: int, num_forecast_steps: int): for sidx, source_mask in enumerate(source_masks.masks): # Map each source to its target - tidx = source_to_target[sidx].item() + tidx: int = source_to_target[sidx].item() sdata = self._build_stream_data( source_select, - tidx, + tidx, # Why use not idx here? num_forecast_steps, stream_info, source_masks.metadata[sidx].params.get("num_steps_input", 1), @@ -715,14 +726,22 @@ def _get_batch(self, idx: int, num_forecast_steps: int): ] batch.add_target_stream(tidx, student_indices, stream_name, sdata, target_metadata) - source_in_steps = input_steps.max().item() - target_in_steps = np.array([tc.get("num_steps_input", 1) for _, tc in target_cfgs.items()]) - target_in_steps = 1 if len(target_in_steps) == 0 else target_in_steps.max().item() - batch = self._preprocess_model_batch(batch, source_in_steps, target_in_steps) + source_input_steps: int = input_steps.max().item() + + try: + target_input_steps: int = max( + target_cfg.get("num_steps_input", 1) for target_cfg in target_cfgs.values() + ) + except ValueError: + # empty list produces value error when taking max + target_input_steps: int = 1 + + # only adds tokens lens as attribute to batch + batch = self._preprocess_model_batch(batch, source_input_steps, target_input_steps) return batch - def __iter__(self) -> ModelBatch: + def __iter__(self) -> typing.Generator[ModelBatch, typing.Any, typing.Any]: """ Return one batch of data @@ -733,6 +752,7 @@ def __iter__(self) -> ModelBatch: logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}") # create new shuffeling + # should return permutations instead of relying on self.perms self.reset() # bidx is used to count the #batches that have been emitted diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 152c092b2..31549644f 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -196,7 +196,7 @@ def add_source( idx = torch.isnan(self.source_tokens_cells[step]) self.source_tokens_cells[step][idx] = self.mask_value - def add_target( + def add_target( # TODO what is the diffference to add_target_values??? self, fstep: int, targets: list, diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 6dfe71c89..a85b1a37d 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -11,8 +11,8 @@ import numpy as np import torch +from weathergen.common.config import Config from weathergen.common.io import IOReaderData -from weathergen.datasets.batch import SampleMetaData from weathergen.datasets.masking import Masker from weathergen.datasets.tokenizer import Tokenizer from weathergen.datasets.tokenizer_utils import ( @@ -53,26 +53,34 @@ def reset_rng(self, rng) -> None: self.masker.reset_rng(rng) self.rng = rng - def get_tokens_windows(self, stream_info, data, pad_tokens): + def get_tokens_windows( + self, stream_info: Config, data: list[IOReaderData], pad_tokens: bool + ) -> list[tuple[list[list[torch.Tensor | None]], list[list[int]]]]: """ Tokenize data (to amortize over the different views that are generated) """ - - tok_spacetime = stream_info.get("tokenize_spacetime", False) - tok = tokenize_spacetime if tok_spacetime else tokenize_space - hl = self.healpix_level - token_size = stream_info["token_size"] - - tokens = [] + tokens: list[tuple[list[list[torch.Tensor | None]], list[list[int]]]] = [] for rdata in data: # skip empty data if rdata.is_empty(): continue # tokenize data - idxs_cells, idxs_cells_lens = tok( - readerdata_to_torch(rdata), token_size, hl, pad_tokens - ) + if stream_info.get("tokenize_spacetime", False): + idxs_cells, idxs_cells_lens = tokenize_spacetime( + readerdata_to_torch(rdata), + stream_info["token_size"], + self.healpix_level, + pad_tokens, + ) + else: + idxs_cells, idxs_cells_lens = tokenize_space( + readerdata_to_torch(rdata), + stream_info["token_size"], + self.healpix_level, + pad_tokens, + ) + tokens += [(idxs_cells, idxs_cells_lens)] return tokens @@ -83,7 +91,7 @@ def build_samples_for_stream( num_cells: int, stage_cfg: dict, stream_cfg: dict, - ) -> tuple[np.typing.NDArray, list[np.typing.NDArray], list[SampleMetaData]]: + ): """ Create masks for samples """ diff --git a/src/weathergen/datasets/tokenizer_utils.py b/src/weathergen/datasets/tokenizer_utils.py index 3f8f11b9d..1de41ae5e 100644 --- a/src/weathergen/datasets/tokenizer_utils.py +++ b/src/weathergen/datasets/tokenizer_utils.py @@ -1,3 +1,5 @@ +import typing + import numpy as np import pandas as pd import torch @@ -5,6 +7,7 @@ from torch import Tensor from weathergen.common.io import IOReaderData +from weathergen.datasets.data_reader_base import NPDT64 from weathergen.datasets.utils import ( locs_to_cell_coords_ctrs, locs_to_ctr_coords, @@ -117,7 +120,7 @@ def hpy_cell_splits(coords: torch.tensor, hl: int): def hpy_splits( coords: torch.Tensor, hl: int, token_size: int, pad_tokens: bool -) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor]: +) -> tuple[list[list[torch.Tensor | None]], list[list[int]]]: """Compute healpix cell for each data point and splitting information per cell; when the token_size is exceeded then splitting based on lat is used; tokens can be padded @@ -168,10 +171,10 @@ def hpy_splits( def tokenize_space( - rdata, - token_size, - hl, - pad_tokens=True, + rdata: IOReaderData, + token_size: int, + hl: int, + pad_tokens: bool = True, ): """Process one window into tokens""" @@ -182,7 +185,7 @@ def tokenize_space( def tokenize_spacetime( - rdata, + rdata: IOReaderData, token_size, hl, pad_tokens=True, @@ -192,11 +195,12 @@ def tokenize_spacetime( """ num_healpix_cells = 12 * 4**hl - idxs_cells = [[] for _ in range(num_healpix_cells)] - idxs_cells_lens = [[] for _ in range(num_healpix_cells)] + idxs_cells: list[list[Tensor | None]] = [[] for _ in range(num_healpix_cells)] + idxs_cells_lens: list[list[int]] = [[] for _ in range(num_healpix_cells)] - t_unique = np.unique(rdata.datetimes) + t_unique: typing.iterable[NPDT64] = np.unique(rdata.datetimes) for _, t in enumerate(t_unique): + t: NPDT64 # data for current time step mask = t == rdata.datetimes rdata_cur = IOReaderData( @@ -204,9 +208,12 @@ def tokenize_spacetime( ) idxs_cur, idxs_cur_lens = tokenize_space(rdata_cur, token_size, hl, pad_tokens) - # collect data for all time steps - idxs_cells = [t + tc for t, tc in zip(idxs_cells, idxs_cur, strict=True)] - idxs_cells_lens = [t + tc_l for t, tc_l in zip(idxs_cells_lens, idxs_cur_lens, strict=True)] + # append tokens/n_tokens for current time step for each healpix cell to existing tokens + for cell_steps_tokens, cell_steps_tokens_cur, cell_steps_lens, cell_steps_lens_cur in zip( + idxs_cells, idxs_cur, idxs_cells_lens, idxs_cur_lens, strict=True + ): + cell_steps_tokens.extend(cell_steps_tokens_cur) + cell_steps_lens.extend(cell_steps_lens_cur) return idxs_cells, idxs_cells_lens @@ -382,6 +389,7 @@ def return_empty(rdata, idxs_cells_lens): else: coords_local = torch.tensor([]) + # geoinfos information is empedded into coords_local here return data, datetimes, coords, coords_local, masked_points_per_cell diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 273c28838..f51b8f42a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -11,6 +11,7 @@ import logging import math +import typing import warnings import astropy_healpix as hp @@ -20,6 +21,7 @@ import torch.nn as nn from weathergen.common.config import Config +from weathergen.common.io import ItemKey from weathergen.datasets.batch import ModelBatch from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule @@ -50,7 +52,7 @@ class ModelOutput: Representation of model output """ - physical: list[dict[StreamName, torch.Tensor]] + physical: list[dict[StreamName, tuple[torch.Tensor]]] latent: list[dict[str, torch.Tensor | LatentState]] def __init__(self, len_output: int) -> None: @@ -76,6 +78,28 @@ def get_physical_prediction( pred = pred[sample_idx] return pred + def get_physical_prediction_normalized( + self, key: ItemKey, normalizer: typing.Callable + ) -> np.typing.NDArray: + try: + # TODO why is there a tuple => what index should be used + pred = self.physical[key.forecast_step][key.stream][0][key.sample] + except (KeyError, IndexError) as e: + msg = f"Cannot find prediction data for key: {key}" + raise ValueError(msg) from e + + # is it a performance issue if I dont convert/move the entire tensor at once? + pred = pred.to(torch.float32).detach().cpu().numpy() + pred = normalizer(key.stream, pred) + + assert isinstance(pred, np.ndarray), "Invalid buffer type." + # TODO What to do when preds are empty,when does it occur, + # how does it show (empty tensor or missing key?) + assert len(pred) > 0 + + # breakpoint() + return pred + def get_latent_prediction(self, fstep: int): return self.latent[fstep] diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2f81940a3..5c9a38409 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -80,7 +80,7 @@ def __init__( def compute_loss( self, preds: ModelOutput, - targets_and_aux: TargetAuxOutput, + targets_and_aux: dict[str, TargetAuxOutput], metadata: dict, ): losses_all = defaultdict(dict) diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py index bb39d1b17..e399ada07 100644 --- a/src/weathergen/train/target_and_aux_module_base.py +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -10,40 +10,52 @@ from __future__ import annotations import dataclasses +import typing +import numpy as np import torch +from weathergen.datasets.batch import SampleMetaData from weathergen.model.engines import LatentState type StreamName = str +BufferType = np.float32 +DateTimeType = np.datetime64 @dataclasses.dataclass +class PhysicalTarget: + data: np.typing.NDArray[BufferType] + coords: np.typing.NDArray[BufferType] + datetimes: np.typing.NDArray[DateTimeType] + + class TargetAuxOutput: """ A dataclass to encapsulate the TargetAndAuxCalculator output and give a clear API. """ - output_idxs: list[int] - - physical: list[dict[StreamName, torch.Tensor]] - latent: list[dict[str, torch.Tensor | LatentState]] - aux_outputs: dict[str, torch.Tensor] - def __init__(self, len_target: int, output_idxs: list) -> None: - self.output_idxs = output_idxs - self.physical = [{} for _ in range(len_target)] - self.latent = [{} for _ in range(len_target)] - self.aux_outputs = {} + self.output_idxs: list[int] = output_idxs + self.physical: list[ + dict[ + StreamName, + dict[str, list[torch.Tensor] | list[dict[str, SampleMetaData]] | list[bool]], + ] + ] = [{} for _ in range(len_target)] + self.latent: list[dict[str, torch.Tensor | LatentState]] = [{} for _ in range(len_target)] + self.aux_outputs: dict[str, torch.Tensor] = {} def add_physical_target( self, timestep_idx: int, stream_name: StreamName, pred: torch.Tensor ) -> None: self.physical[timestep_idx][stream_name] = pred + # currently unused? def add_latent_target(self, timestep_idx: int, latent_name: str, pred: torch.Tensor) -> None: self.latent[timestep_idx][latent_name] = pred + # is broken => pred[sample_idx] must be pred["target"][sample_idx] def get_physical_target( self, timestep_idx: int, @@ -58,6 +70,39 @@ def get_physical_target( pred = pred[sample_idx] return pred + # TODO guarantee every buffer retrieved only once + def get_physical_target_normalized( + self, key: ItemKey, normalizer: typing.Callable + ) -> PhysicalTarget: + try: + stream_targets = self.physical[key.forecast_step][key.stream] + except (KeyError, IndexError) as e: + msg = f"Cannot find physical target data for key: {key}" + raise ValueError(msg) from e + + # is it a performance issue if I dont convert/move the entire tensor at once? + coords = ( + stream_targets["target_coords"][key.sample].to(torch.float32).detach().cpu().numpy() + ) + times = stream_targets["target_times"][key.sample] + + data = stream_targets["target"][key.sample].to(torch.float32).detach().cpu().numpy() + data = normalizer(key.stream, data) + + assert isinstance(data, np.ndarray), "Invalid data buffer type." + assert isinstance(coords, np.ndarray), "Invalid coords buffer type." + assert isinstance(times, np.ndarray), "Invalid datetimes buffer type." + + if len(coords) == 0: # TODO can this be removed? + coords = np.zeros((0, 2), dtype=np.float32) + assert len(coords.shape) >= 2 and coords.shape[-1] >= 2, ( + "invalid shape for coordinate buffer." + ) + assert data.shape[0] == coords.shape[0] == times.shape[0], "buffer shapes should align." + + # breakpoint() + return PhysicalTarget(data, coords, times) + def get_latent_target(self, timestep_idx: int): return self.latent[timestep_idx] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index bfa289c7b..f58ed7dd1 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -11,6 +11,7 @@ import copy import logging import time +from collections.abc import Iterable import numpy as np import torch @@ -22,6 +23,7 @@ import weathergen.common.config as config from weathergen.common.config import Config +from weathergen.datasets.batch import ModelBatch from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( @@ -32,6 +34,7 @@ from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( TRAIN, @@ -45,6 +48,7 @@ get_target_idxs_from_cfg, ) from weathergen.utils.distributed import is_root +from weathergen.utils.output import Writer from weathergen.utils.train_logger import TrainLogger, prepare_losses_for_logging from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output @@ -52,6 +56,7 @@ logger = logging.getLogger(__name__) # cfg_keys_to_filter = ["losses", "model_input", "target_input"] +PHYSICAL_LOSS_KEY = "physical" class Trainer(TrainerBase): @@ -64,7 +69,8 @@ def __init__(self, train_log_freq: Config): self.data_loader_validation: torch.utils.data.DataLoader | None = None self.dataset: MultiStreamDataSampler | None = None self.dataset_val: MultiStreamDataSampler | None = None - self.device: torch.device = None + self.output_writer: Writer | None = None + self.device: torch.device | None = None self.ema_model = None self.grad_scaler: torch.amp.GradScaler | None = None self.last_grad_norm = None @@ -75,8 +81,8 @@ def __init__(self, train_log_freq: Config): self.model_params = None self.optimizer: torch.optim.Optimizer | None = None self.t_start: float = 0 - self.target_and_aux_calculators = None - self.target_and_aux_calculators_val = None + self.target_and_aux_calculators: dict[str, TargetAndAuxModuleBase] = None + self.target_and_aux_calculators_val: dict[str, TargetAndAuxModuleBase] = None self.validate_with_ema_cfg = None self.validate_with_ema: bool = False self.batch_size_per_gpu = -1 @@ -90,7 +96,7 @@ def get_batch_size_total(self, batch_size_per_gpu) -> int: """ return self.world_size_original * batch_size_per_gpu - def init(self, cf: Config, devices): + def init(self, cf: Config, devices, use_test_config=False): # pylint: disable=attribute-defined-outside-init self.cf = OmegaConf.merge( OmegaConf.create( @@ -153,6 +159,10 @@ def init(self, cf: Config, devices): config.get_path_model(cf).mkdir(exist_ok=True, parents=True) self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + if use_test_config: + self.output_writer = Writer(cf, self.test_cfg, cf.streams) + else: + self.output_writer = Writer(cf, self.validation_cfg, cf.streams) # Initialize collapse monitor for SSL training collapse_config = cf.train_logging.get("collapse_monitoring", {}) @@ -168,7 +178,7 @@ def get_target_aux_calculators(self, mode_cfg): # get target_aux calculators for different loss terms # del self.cf.training_config.losses["student-teacher"]["loss_fcts"]["JEPA"] # del mode_cfg.losses["student-teacher"]["loss_fcts"]["JEPA"] - target_and_aux_calculators = {} + target_and_aux_calculators: dict[str, TargetAndAuxModuleBase] = {} for loss_name, loss_cfg in mode_cfg.losses.items(): target_and_aux_calculators[loss_name] = get_target_aux_calculator( self.cf, loss_cfg, self.dataset, self.model, self.device, batch_size @@ -178,7 +188,7 @@ def get_target_aux_calculators(self, mode_cfg): def inference(self, cf, devices, run_id_contd, mini_epoch_contd): # general initalization - self.init(cf, devices) + self.init(cf, devices, True) cf = self.cf device_type = torch.accelerator.current_accelerator() @@ -549,7 +559,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): cf = self.cf self.model.eval() - dataset_val_iter = iter(self.data_loader_validation) + dataset_val_iter: Iterable[ModelBatch] = iter(self.data_loader_validation) num_samples_write = mode_cfg.get("output", {}).get("num_samples", 0) * batch_size @@ -580,7 +590,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.get_source_samples(), ) - targets_and_auxs = {} + targets_and_auxs: dict[str, TargetAuxOutput] = {} for loss_name, target_aux in self.target_and_aux_calculators_val.items(): target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) targets_and_auxs[loss_name] = target_aux.compute( @@ -604,17 +614,23 @@ def validate(self, mini_epoch, mode_cfg, batch_size): if mode_cfg.get("output", {}).get("normalized_samples", False) else self.dataset_val.denormalize_target_channels ) - # write output - write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - preds, - targets_and_auxs, + + fstep_start = 0 + targets = [ + targets + for loss_name, targets in targets_and_auxs.items() + if loss_name == PHYSICAL_LOSS_KEY + ] + try: + # targets for all physical outputs should be the same + targets = targets[0] + except IndexError as e: + raise ValueError( + f"No physical outputs under key {PHYSICAL_LOSS_KEY} configured" + ) from e + + self.output_writer.write_batch( + mini_epoch, batch, targets, preds, denormalize_data_fct, fstep_start ) pbar.update(batch_size) diff --git a/src/weathergen/utils/output.py b/src/weathergen/utils/output.py new file mode 100644 index 000000000..e2c25a59d --- /dev/null +++ b/src/weathergen/utils/output.py @@ -0,0 +1,220 @@ +from __future__ import annotations # allow forward references in typehints + +import dataclasses +import itertools +import typing +from collections.abc import Callable + +import numpy as np +from omegaconf.errors import ConfigAttributeError + +from weathergen.common.config import Config, get_path_results +from weathergen.common.io import ( + ItemKey, + OutputDataset, + OutputItem, + TimeRange, + zarrio_writer, +) +from weathergen.datasets.batch import ModelBatch +from weathergen.model.model import ModelOutput +from weathergen.train.target_and_aux_module_base import PhysicalTarget, TargetAuxOutput + + +class Writer: + def __init__( + self, + config: Config, + val_cfg: Config, + streams: list[Config], + ): + streams = {stream.name: stream for stream in streams} + # TODO: nice: dont store all config + + _all_streams = list(streams.keys()) + _output_streams = val_cfg.get("output", None).get("streams", None) + _output_streams = val_cfg.output.streams if val_cfg.output.streams else _all_streams + + self._streams = { + name: config for name, config in streams.items() if name in _output_streams + } + self._forecast_offset = val_cfg.forecast.offset + self._cf: Config = config # used for zarr output path lookup => improve + self._sample_start = 0 + + def write_batch( + self, + mini_epoch: int, # TODO: nice: use iterstep for better consistency? + batch: ModelBatch, + targets: TargetAuxOutput, + predictions: ModelOutput, + normalizer: typing.Callable, + fstep_start: int, + ) -> None: + data = _BatchOutputData(batch, targets, predictions, normalizer) + fstep_range = range(fstep_start, len(predictions.physical)) + # TODO: nice: get result path differently + with zarrio_writer(get_path_results(self._cf, mini_epoch)) as zio: + for subset in self.itemize(data, fstep_range): + zio.write_zarr(subset) + + def itemize( + self, data: _BatchOutputData, fstep_range: range + ) -> typing.Generator[OutputItem, None, None]: + """Iterate over possible output items""" + fstep_start = min(fstep_range) + + # TODO: check: filter for empty items? + for key in self.keys(fstep_range, data.samples): + yield self.extract(data, key, fstep_start) + + def keys( + self, forecast_steps: range, samples: list[int] + ) -> typing.Generator[ItemKey, None, None]: + """Iterate over possible output items""" + streams: list[str] = self._streams.keys() + samples = (self._sample_start + sample for sample in samples) + + # The order of iteration is important here: + # streams is the outermost and samples the innermost loop variable + # This is important since normalization is best done at a per stream/fstep basis + for stream, forecast_step, sample_idx in itertools.product( + streams, forecast_steps, samples + ): + yield ItemKey(sample_idx, forecast_step, stream) + + self._sample_start += len(samples) + + def extract(self, data: _BatchOutputData, key: ItemKey, fstep_start: int) -> OutputItem: + raw_key = ItemKey( + key.sample - self._sample_start, key.forecast_step - fstep_start, key.stream + ) + data_invariants = self._get_invariants(raw_key) + data_invariants.source_interval = data.source_interval + + source, target, prediction = None, None, None + if key.with_source: + source = data.extract_source(raw_key).as_dataset(key, data_invariants) + if key.with_target(self._forecast_offset): + target = data.extract_target(raw_key).as_dataset(key, data_invariants) + prediction = data.extract_prediction(raw_key).as_dataset(key, data_invariants) + + return OutputItem( + key, + self._forecast_offset, # TODO nice: maybe drop it? + target, + prediction, + source, + ) + + def _get_invariants(self, key: ItemKey) -> _DataInvariants: + # TODO unify DTRange and TimeRange classes + return _DataInvariants( + # val_source_channels are ListConfig[str] objects -> convert to list[str] + source_channels=list(self._streams[key.stream].val_source_channels), + target_channels=list(self._streams[key.stream].val_target_channels), + geoinfo_channels=list(self._streams[key.stream].geoinfo_channels), + ) + + +@dataclasses.dataclass +class _BatchOutputData: + _batch: ModelBatch + _targets: TargetAuxOutput + _predictions: ModelOutput + _normalizer: Callable + + def extract_source(self, key: ItemKey) -> _ExtractedData: + # TODO check this? + # breakpoint() + source_sample_idx = self._batch.target2source_matching_idxs(key.sample) + source = ( + self._batch.source_samples.samples[key.forecast_step] + .streams_data[key.stream] + .source_raw[source_sample_idx] + ) + + return _ExtractedData( + "prediction", + np.asarray(source.data), + np.asarray(source.datetimes), + np.asarray(source.coords), + np.asarray(source.geoinfos), + ) + + @property + def source_interval(self) -> TimeRange: + window = self._batch.init_time + return TimeRange(window.start, window.end) + + @property + def samples(self) -> list[int]: + # TODO check: data._batch.source_samples.samples + sampels = list(self._batch.target_samples.sample_idxs) + + assert len(set(sampels)) == len(sampels), "samples are not unique" + + return self.samples + + def extract_target(self, key: ItemKey) -> _ExtractedData: + target = self._target(key) + coords = self._target_coordinates(target) + + return _ExtractedData("target", target.data, coords.times, coords.coords, coords.geoinfos) + + def extract_prediction(self, key: ItemKey) -> _ExtractedData: + try: + data = self._predictions.get_physical_prediction_normalized(key, self._normalizer) + except Exception as e: + # TODO: if preds are empty so create copy of target and add ensemble dimension + # preds = [targets[0].clone().unsqueeze(0)] + raise ValueError("not handled yet") from e + data = self._target(key) + target = self._target(key) + coords = self._target_coordinates(target) + + return _ExtractedData("prediction", data, coords.times, coords.coords, coords.geoinfos) + + # TODO guarantee this method is only called once per OutputItem + # TODO try getting targets from batch directly + def _target(self, key: ItemKey) -> PhysicalTarget: + return self._targets.get_physical_target_normalized(key, self._normalizer) + + def _target_coordinates(self, target: PhysicalTarget) -> _ExtractedData: + coords = target.coords[..., :2] # first two columns are lat,lon + geoinfos = target.coords[..., 2:] # the rest is geoinfo => potentially empty + + return _ExtractedData("", None, target.datetimes, coords, geoinfos) + + +@dataclasses.dataclass +class _DataInvariants: + source_channels: list[str] + target_channels: list[str] + geoinfo_channels: list[str] + source_interval: TimeRange | None = None + + +@dataclasses.dataclass +class _ExtractedData: + name: str # TODO make enum + data: typing.Any + times: typing.Any + coords: typing.Any + geoinfos: typing.Any + + def as_dataset(self, key: ItemKey, invariants: _DataInvariants) -> OutputDataset: + if self.data is None or self.data.shape == (0, 0): + return None + else: + return OutputDataset( + name=self.name, + item_key=key, + data=self.data, + times=self.times, + coords=self.coords, + geoinfo=self.geoinfos, + source_interval=invariants.source_interval, + channels=invariants.target_channels, + geoinfo_channels=invariants.geoinfo_channels, + ) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 0e09fd38d..7150954bf 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import logging +import typing import numpy as np import torch @@ -15,13 +16,26 @@ import weathergen.common.config as config import weathergen.common.io as io from weathergen.common.io import TimeRange, zarrio_writer +from weathergen.datasets.batch import ModelBatch from weathergen.datasets.data_reader_base import TimeWindowHandler +from weathergen.model.model import ModelOutput +from weathergen.train.target_and_aux_module_base import TargetAuxOutput _logger = logging.getLogger(__name__) def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out + cf: config.Config, # only streams const + val_cfg: config.Config, # const + batch_size: int, # const + mini_epoch: int, # get filename + batch_idx: int, # calculate sample_start + dn_data: typing.Callable, # const + batch: ModelBatch, + # contains physical/latent predictions => can ignore latent? + model_output: ModelOutput, + # contains physical/latent targets => can ignore latent? + target_aux_out: dict[str, TargetAuxOutput], ): """ Interface for writing model output