Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ocf_data_sampler/numpy_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

from .convert import convert_to_numpy_sample
from .datetime_features import encode_datetimes, get_t0_embedding
from .collate import stack_np_samples_into_batch
from .common_types import NumpySample, NumpyBatch, TensorBatch
from .sun_position import make_sun_position_numpy_sample
66 changes: 0 additions & 66 deletions ocf_data_sampler/numpy_sample/collate.py

This file was deleted.

6 changes: 3 additions & 3 deletions ocf_data_sampler/numpy_sample/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
import numpy as np
import torch

NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
NumpySample: TypeAlias = dict[str, np.ndarray]
NumpyBatch: TypeAlias = dict[str, np.ndarray]
TensorBatch: TypeAlias = dict[str, torch.Tensor]
85 changes: 51 additions & 34 deletions ocf_data_sampler/numpy_sample/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,61 +7,78 @@


def convert_to_numpy_sample(
sample: dict[str, xr.DataArray | dict[str, xr.DataArray]],
datasets_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
Comment thread
dfulu marked this conversation as resolved.
t0_idx: int,
include_extra_metadata: bool = False,
) -> NumpySample:
"""Convert a dictionary of xarray objects to a NumpySample.

Args:
sample: Dictionary of xarray DataArrays, with same structure as used inside
PVNet Dataset classes. Expected keys are any of following:
datasets_dict: Dictionary of xarray DataArrays, with same structure as used inside
PVNetDataset classes. Expected keys are any of following:
- "generation": DataArray of generation data
- "sat": DataArray of satellite data
- "nwp": dict of DataArrays by provider name (e.g. {"ukv": da, "ecmwf": da})
t0_idx: Index of t0 within generation
include_extra_metadata: Whether to add additional non-essential metadata to the batch

Returns:
NumpySample dictionary with all modalities merged
"""
numpy_sample: NumpySample = {}

if "generation" in sample:
da = sample["generation"]
if "generation" in datasets_dict:
da = datasets_dict["generation"]

generation_values = da.sel(gen_param="generation_mw").values
capacity_value = da.sel(gen_param="capacity_mwp").values[0]

if capacity_value!=0:
generation_values = generation_values/capacity_value

numpy_sample.update({
"generation": generation_values,
"capacity_mwp": capacity_value,
"time_utc": da["time_utc"].values.astype(float),
"t0_idx": int(t0_idx),
"longitude": float(da.longitude.values),
"latitude": float(da.latitude.values),
})

if "sat" in sample:
da = sample["sat"]
numpy_sample.update({
"satellite_actual": da.values,
"satellite_time_utc": da.time_utc.values.astype(float),
"satellite_x_geostationary": da.x_geostationary.values,
"satellite_y_geostationary": da.y_geostationary.values,
})

if "nwp" in sample:
numpy_sample["nwp"] = {}
for provider, da in sample["nwp"].items():
target_time_utc = da.init_time_utc.values + da.step.values
numpy_sample["nwp"][provider] = {
"nwp": da.values,
"nwp_channel_names": da.channel.values,
"nwp_init_time_utc": da.init_time_utc.values.astype(float),
"nwp_step": (da.step.values / np.timedelta64(1, "h")).astype(int),
"nwp_target_time_utc": target_time_utc.astype(float),
}
numpy_sample.update(
{
"generation": generation_values,
"capacity_mwp": capacity_value,
"generation_t0_idx": int(t0_idx),
"generation_time_utc": da.time_utc.values.astype(float),
},
)

if include_extra_metadata:
numpy_sample.update(
{
"location_longitude": float(da.longitude.values),
"location_latitude": float(da.latitude.values),
},
)

if "sat" in datasets_dict:
da = datasets_dict["sat"]
numpy_sample.update({"satellite": da.values})

if include_extra_metadata:
numpy_sample.update(
{
"satellite_time_utc": da.time_utc.values.astype(float),
"satellite_x_geostationary": da.x_geostationary.values,
"satellite_y_geostationary": da.y_geostationary.values,
},
)

if "nwp" in datasets_dict:
for provider, da in datasets_dict["nwp"].items():
nwp_key = f"nwp_{provider}"
numpy_sample.update({nwp_key: da.values})

if include_extra_metadata:
step_hours = (da.step.values / np.timedelta64(1, "h")).astype(float)
target_times = (da.init_time_utc.values + da.step.values).astype(float)

numpy_sample.update({
f"{nwp_key}_init_time_utc": da.init_time_utc.values.astype(float),
f"{nwp_key}_step_hours": step_hours,
f"{nwp_key}_target_time_utc": target_times,
})

return numpy_sample
2 changes: 1 addition & 1 deletion ocf_data_sampler/numpy_sample/datetime_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def encode_datetimes(datetimes: NDArray[np.datetime64]) -> NumpySample:
def get_t0_embedding(
t0: np.datetime64,
embeddings: list[tuple[str, Literal["cyclic", "linear"]]],
) -> dict[str, np.ndarray]:
) -> NumpySample:
"""Creates dictionary of t0 time embeddings.

Args:
Expand Down
30 changes: 17 additions & 13 deletions ocf_data_sampler/torch_datasets/pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import xarray as xr
from numpy.typing import NDArray
from pydantic.warnings import UnsupportedFieldAttributeWarning
from torch.utils.data import Dataset
from torch.utils.data import Dataset, default_collate
from typing_extensions import override

from ocf_data_sampler.config import Configuration, load_yaml_configuration
Expand All @@ -22,8 +22,7 @@
get_t0_embedding,
make_sun_position_numpy_sample,
)
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch
from ocf_data_sampler.select import (
Location,
fill_time_periods,
Expand All @@ -35,7 +34,7 @@
add_alterate_coordinate_projections,
config_normalization_values_to_dicts,
diff_nwp_data,
fill_nans_in_arrays,
fill_nans_in_dataset_dicts,
find_valid_time_periods,
slice_datasets_by_space,
slice_datasets_by_time,
Expand Down Expand Up @@ -124,6 +123,7 @@ def __init__(
config_filename: str,
start_time: str | None = None,
end_time: str | None = None,
include_extra_metadata: bool = False,
use_xarray: bool = True,
) -> None:
"""A generic torch Dataset for creating PVNet samples.
Expand All @@ -136,6 +136,8 @@ def __init__(
config_filename: Path to the configuration file
start_time: Limit the init-times to be after this
end_time: Limit the init-times to be before this
include_extra_metadata: Whether to include non-essential metadata for each sample in the
sample dict.
use_xarray: Whether to use xarray.DataArray or LightDataArray as the underlying data
structure when sampling
"""
Expand Down Expand Up @@ -183,8 +185,8 @@ def __init__(

self.locations = add_alterate_coordinate_projections(locations, datasets_dict)

# Assign config and input data to self
self.config = config
self.include_extra_metadata = include_extra_metadata

if use_xarray:
self.datasets_dict = datasets_dict
Expand Down Expand Up @@ -228,8 +230,11 @@ def process_and_combine_datasets(
channel_stds = self.stds_dict["sat"]
dataset_dict["sat"].data = (dataset_dict["sat"].data - channel_means) / channel_stds

# Fill NaNs
dataset_dict = fill_nans_in_dataset_dicts(dataset_dict, config=self.config)

# Convert all xarray modalities to a single NumpySample
sample = convert_to_numpy_sample(dataset_dict, self.t0_idx)
sample = convert_to_numpy_sample(dataset_dict, self.t0_idx, self.include_extra_metadata)

# Add location metadata not present on the DataArray
if "generation" in dataset_dict:
Expand Down Expand Up @@ -268,8 +273,6 @@ def process_and_combine_datasets(

sample["t0"] = get_posix_timestamp(t0)

sample = fill_nans_in_arrays(sample, config=self.config)

return sample

@staticmethod
Expand Down Expand Up @@ -352,9 +355,10 @@ def __init__(
config_filename: str,
start_time: str | None = None,
end_time: str | None = None,
include_extra_metadata: bool = False,
use_xarray: bool = True,
) -> None:
super().__init__(config_filename, start_time, end_time, use_xarray)
super().__init__(config_filename, start_time, end_time, include_extra_metadata, use_xarray)

# Construct a lookup for locations - useful for users to construct sample by location ID
self.location_lookup = {loc.id: loc for loc in self.locations}
Expand Down Expand Up @@ -445,7 +449,7 @@ class PVNetConcurrentDataset(AbstractPVNetDataset):
def __len__(self) -> int:
return len(self.valid_t0_times)

def _get_sample(self, t0: np.datetime64) -> NumpyBatch:
def _get_sample(self, t0: np.datetime64) -> TensorBatch:
"""Generate a concurrent PVNet sample for given init-time.

Args:
Expand All @@ -469,13 +473,13 @@ def _get_sample(self, t0: np.datetime64) -> NumpyBatch:
samples.append(numpy_sample)

# Stack samples
return stack_np_samples_into_batch(samples)
return default_collate(samples)

@override
def __getitem__(self, idx: int) -> NumpyBatch:
def __getitem__(self, idx: int) -> TensorBatch:
return self._get_sample(self.valid_t0_times[idx])

def get_sample(self, t0: np.datetime64) -> NumpyBatch:
def get_sample(self, t0: np.datetime64) -> TensorBatch:
"""Generate a sample for the given init-time.

Useful for users to generate specific samples.
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/torch_datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
from .merge_and_fill_utils import fill_nans_in_arrays
from .fill_nans import fill_nans_in_dataset_dicts
from .valid_time_periods import find_valid_time_periods
from .spatial_slice_for_dataset import slice_datasets_by_space
from .time_slice_for_dataset import slice_datasets_by_time
Expand Down
34 changes: 34 additions & 0 deletions ocf_data_sampler/torch_datasets/utils/fill_nans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Utility function for filling NaNs in DataArrays."""

import numpy as np
import xarray as xr

from ocf_data_sampler.config.model import Configuration, DropoutMixin


def fill_nans(da: xr.DataArray, source_config: DropoutMixin) -> xr.DataArray:
"""Fill NaNs in a DataArray in-place."""
if np.isnan(da.data).any():
da.data = np.nan_to_num(da.data, copy=False, nan=source_config.dropout_value)
return da


def fill_nans_in_dataset_dicts(datasets_dict: dict, config: Configuration) -> dict:
"""Fills all NaN values in the dataarrays in-place.

Args:
datasets_dict: Dictionary of the input data sources
config: Configuration object.
"""
conf_in = config.input_data
if "generation" in datasets_dict:
datasets_dict["generation"] = fill_nans(datasets_dict["generation"], conf_in.generation)

if "sat" in datasets_dict:
datasets_dict["sat"] = fill_nans(datasets_dict["sat"], conf_in.satellite)

if "nwp" in datasets_dict:
for nwp_key, nwp_config in config.input_data.nwp.items():
datasets_dict["nwp"][nwp_key] = fill_nans(datasets_dict["nwp"][nwp_key], nwp_config)

return datasets_dict
35 changes: 0 additions & 35 deletions ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py

This file was deleted.

Loading
Loading