Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions integration_tests/small1.yaml

This file was deleted.

219 changes: 219 additions & 0 deletions integration_tests/small1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

embed_orientation: "channels"
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

ae_local_dim_embed: 512 #1024
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 512 #1024 #2048
ae_global_num_blocks: 2
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
# See https://github.com/ecmwf/WeatherGenerator/issues/1050
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
ae_global_trailing_layer_norm: False

ae_aggregation_num_blocks: 2
ae_aggregation_num_heads: 32
ae_aggregation_dropout_rate: 0.1
ae_aggregation_with_qk_lnorm: True
ae_aggregation_att_dense_rate: 1.0
ae_aggregation_block_factor: 64
ae_aggregation_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
num_class_tokens: 1
num_register_tokens: 7

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
fe_num_blocks: 2
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer
fe_impute_latent_noise_std: 0.0 # 1e-4
# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

healpix_level: 4

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mixed_precision_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

freeze_modules: ""

norm_type: "LayerNorm"

# type of zarr_store
zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore


################

streams_directory: "./integration_tests/streams/"
model_path: "./models"
results_path: "./results"

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

# local_rank,
# with_ddp,
# data_path_*,
# model_path,
# run_path,
# path_shared_

multiprocessing_method: "fork"

desc: ""
run_id: ???
run_history: []

train_logging:
terminal: 10
metrics: 20
checkpoint: 250

# parameters for data loading
data_loading :

num_workers: 2
rng_seed: ???
repeat_data_in_mini_epoch : False

# pin GPU memory for faster transfer; it is possible that enabling memory_pinning with
# FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error.
# If this happens, you can disable the flag, but performance will drop on GH200.
memory_pinning: True

# config for training
training_config:

training_mode: ["masking"]

num_mini_epochs: 1
samples_per_mini_epoch: 48
shuffle: True

start_date: 2014-01-01T00:00
end_date: 2020-12-31T00:00

time_window_step: 06:00:00
time_window_len: 06:00:00

window_offset_prediction : 1

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 0.00005
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 4
num_steps_cooldown: 4
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 1.0
weight_decay: 0.1
log_grad_norms: False
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.975
beta2 : 0.9875
eps : 2e-08

losses : {
"physical": {
type: LossPhysical,
loss_fcts: { "mse": { }, },
},
}

model_input: {
"forecasting" : {
masking_strategy: "forecast",
}
}

forecast :
time_step: 06:00:00
num_steps: 2
policy: "fixed"
offset: 1


# validation config; full validation config is merge of training and validation config
validation_config:

samples_per_mini_epoch: 32
shuffle: False

start_date: 2021-10-10T00:00
end_date: 2022-10-11T00:00

output:
streams: ["ERA5"]

validate_with_ema:
enabled : True
ema_ramp_up_ratio: 0.09
ema_halflife_in_thousands: 1e-3

test_config:
output:
num_samples: 2


# TODO: read latent from here
inference_config:
output:
streams: ["ERA5"]
37 changes: 8 additions & 29 deletions integration_tests/small1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_train(setup, test_run_id):

main(
[
"inference",
f"--config={WEATHERGEN_HOME}/integration_tests/small1.yaml",
"train",
f"--base-config={WEATHERGEN_HOME}/integration_tests/small1.yml",
"--run-id",
test_run_id,
]
)

infer_with_missing(test_run_id)
infer(test_run_id)
evaluate_results(test_run_id)
assert_missing_metrics_file(test_run_id)
assert_train_loss_below_threshold(test_run_id)
Expand All @@ -70,28 +70,7 @@ def infer(run_id):
logger.info("run inference")
main(
[
"-start",
"2022-10-10",
"-end",
"2022-10-11",
"--samples",
"10",
"--mini-epoch",
"0",
"--from-run-id",
run_id,
"--run-id",
run_id,
"--config",
f"{WEATHERGEN_HOME}/integration_tests/small1.yaml",
]
)


def infer_with_missing(run_id):
logger.info("run inference")
main(
[
"inference",
"-start",
"2021-10-10",
"-end",
Expand All @@ -105,7 +84,7 @@ def infer_with_missing(run_id):
"--run-id",
run_id,
"--config",
f"{WEATHERGEN_HOME}/integration_tests/small1.yaml",
f"{WEATHERGEN_HOME}/integration_tests/small1.yml",
]
)

Expand Down Expand Up @@ -155,7 +134,7 @@ def evaluate_results(run_id):

def load_metrics(run_id):
"""Helper function to load metrics"""
file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id)
file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results" / run_id, run_id=run_id)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}")
with open(file_path) as f:
Expand All @@ -165,7 +144,7 @@ def load_metrics(run_id):

def assert_missing_metrics_file(run_id):
"""Test that a missing metrics file raises FileNotFoundError."""
file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id)
file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results"/ run_id, run_id=run_id)
assert os.path.exists(file_path), f"Metrics file does not exist for run_id: {run_id}"
metrics = load_metrics(run_id)
logger.info(f"Loaded metrics for run_id: {run_id}: {metrics}")
Expand Down Expand Up @@ -208,4 +187,4 @@ def assert_val_loss_below_threshold(run_id):
assert loss_metric is not None, f"'{loss_avg_name}' metric is missing in metrics file"
# Check that the loss does not explode in a single mini_epoch
# This is meant to be a quick test, not a convergence test
assert loss_metric < 0.25, f"'{loss_avg_name}' is {loss_metric}, expected to be below 0.25"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to 0.1

assert loss_metric < 0.2, f"'{loss_avg_name}' is {loss_metric}, expected to be below 0.2"