Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Execute training tasks on SLURM

1. Make a working directory

```sh
mkdir training
cd training
```

2. Clone this repo
```sh
git clone git@github.com:ESMValGroup/ClimaNet.git
```

3. Install uv for dependency management. Se [uv doc](https://docs.astral.sh/uv/getting-started/installation/).

4. Create a venv and install Python dependencies using uv
```sh
cd ClimaNet
```

```
uv sync
```

A `.venv` dir will appear

5. Copy the python script and slurm script into the working dir:

```sh
cp ClimaNet/scripts/example* .
```

6. Config `example.slurm`, in the `source ...` line, make sure the venv just created is activated.
Note that the account is the ESO4CLIMA project account, which is shared by multiple users.

7. Config `example.py`, make sure the path of input data and land mask data is correct.

8. Execute the SLURM job
```sh
sbatch example.slurm
```

## Check the efficiency of resource usage

In the SLURM job output, you can find the line like this:

```
==== Slurm accounting summary 23743544 ====
JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax
23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K
23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K
```

Which gives some information about the resource usage at the end of the job.

To have a better understanding of the efficiency of resource usage, you can run the following command after the job is finished:

```sh
sacct -j <slurm_job_id> \
--format=JobID,JobName%30,Partition,AllocCPUS,Elapsed,TotalCPU,MaxRSS,State,ExitCode \
--parsable2 >> "eso4clima_<slurm_job_id>.out"

```

This will output the resource usage information and add it to the slurm job output file. After running this you can find the line like this in the output file:

```
JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode
23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0
23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0
23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0
```

The the efficiency of resource usage can be calculated as `TotalCPU / AllocCPUS * Elapsed Time`. In the example above, the CPU time is `04:21:01`, the allocated CPU is `256`, and the elapsed time is `00:02:44`, so the efficiency of resource usage is `4:21:01 / 256 * 00:02:44 = 0.37`.
61 changes: 61 additions & 0 deletions scripts/eso4clima_23743544.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
daily_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
daily_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
daily_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
daily_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
monthly_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
monthly_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
monthly_data = xr.open_mfdataset(
/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading.
monthly_data = xr.open_mfdataset(
2026-03-26 17:02:31,368 - INFO - mean: [289.19693 289.3843 ], std: [10.575894 10.624186]
2026-03-26 17:02:58,933 - INFO - Epoch 0: best_loss = inf
2026-03-26 17:04:51,712 - INFO - No improvement for 10 epochs, stopping early at epoch 9.
2026-03-26 17:04:51,712 - INFO - training done!
2026-03-26 17:04:51,712 - INFO - Final loss: 4.909843444824219
2026-03-26 17:04:51,777 - INFO - Checkpoint saved to models/spatio_temporal_model.pth
==== Slurm accounting summary 23743544 ====
JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax
23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K
23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K

********************************************************************************
* *
* This is the automated job summary provided by DKRZ. *
* If you encounter problems, need assistance or have any suggestion, please *
* write an email to *
* *
* -- support@dkrz.de -- *
* *
* We hope you enjoyed the DKRZ supercomputer LEVANTE ... *
*
* JobID : 23743544
* JobName : eso4clima
* Account : bd0854
* User : b383704 (202985), bd0854 (1473)
* Partition : compute
* QOS : normal
* Nodelist : l40338 (1)
* Submit date : 2026-03-26T17:02:02
* Start time : 2026-03-26T17:02:10
* End time : 2026-03-26T17:04:54
* Elapsed time : 00:02:44 (Timelimit=00:30:00)
* Command : /home/b/b383704/eso4clima/train_twoyears/example.slurm
* WorkDir : /home/b/b383704/eso4clima/train_twoyears
*
* StepID | JobName NodeHours MaxRSS [Byte] (@task)
* ------------------------------------------------------------------------------
* batch | batch 0.046
* extern | extern 0.046 3752K (0)
* ------------------------------------------------------------------------------

JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode
23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0
23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0
23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0
17 changes: 17 additions & 0 deletions scripts/example.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#SBATCH --ntasks-per-node=128
#SBATCH --time=02:00:00
#SBATCH --account=bd0854
#SBATCH --output=eso4clima_%j.out

# Activate the virtual environment
# Change the path to your virtual environment
source /home/b/b383704/eso4clima/ClimaNet/.venv/bin/activate

# Run the training script
# Change the path to your training script
python /home/b/b383704/eso4clima/train_twoyears/example_training.py

echo "==== Slurm accounting summary ${SLURM_JOB_ID} ===="
sstat --allsteps -j "$SLURM_JOB_ID" \
--format=JobID,NTasks,AveCPU,AveRSS,MaxRSS,MaxVMSize,TresUsageInAve,TresUsageInMax \
--parsable2
193 changes: 193 additions & 0 deletions scripts/example_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""Example training script"""

from pathlib import Path
import torch
import torch.nn.functional
import xarray as xr
from torch.utils.data import DataLoader

from climanet import STDataset
from climanet.st_encoder_decoder import SpatioTemporalModel

import logging

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def main():
Copy link
Member

@SarahAlidoost SarahAlidoost Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is currently doing a lot: creating the model, training it, making predictions, and saving results. We should split these responsibilities.
If this is a "training script", it should only handle reading the data, creating the model with the correct arguments, and passing both to a separate training function (that will be added in #33).

Then, in another script (e.g. "inference script"), we can load the saved model and make predictions (see #32). This separation is needed because training and inference require different computing resources.

Any plotting or result inspection can be done in a separate script if needed.

Copy link
Collaborator Author

@rogerkuou rogerkuou Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have splited this into a training script and an inference script. plotting part has been removed

# Data files
data_folder = Path(
"/work/bd0854/b380103/eso4clima/output/v1.0/concatenated/"
) # HPC
# data_folder = Path("../../data/output/") # local
daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc"))
monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc"))
daily_files.sort()
monthly_files.sort()

# Land surface
lsm_file = "/home/b/b383704/eso4clima/train_twoyears/era5_lsm_bool.nc" # HPC
# lsm_file = data_folder / "era5_lsm_bool.nc" # local

# Load full dataset
daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc"))
monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc"))
patch_size_training = 80
daily_data = xr.open_mfdataset(daily_files)
monthly_data = xr.open_mfdataset(monthly_files)

daily_data = xr.open_mfdataset(
daily_files,
combine="by_coords",
chunks={
"time": 1,
"lat": patch_size_training * 2,
"lon": patch_size_training * 2,
},
data_vars="minimal",
coords="minimal",
compat="override",
parallel=False,
)

monthly_data = xr.open_mfdataset(
monthly_files,
combine="by_coords",
chunks={
"time": 1,
"lat": patch_size_training * 2,
"lon": patch_size_training * 2,
},
data_vars="minimal",
coords="minimal",
compat="override",
parallel=False,
)

lsm_mask = xr.open_dataset(lsm_file)

# Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field
monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True)
mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values
std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values
logger.info(f"mean: {mean}, std: {std}")

# Make a dataset
dataset = STDataset(
daily_da=daily_data["ts"],
monthly_da=monthly_data["ts"],
land_mask=lsm_mask["lsm"],
patch_size=(patch_size_training, patch_size_training),
)

# Initialize training
device = "cuda" if torch.cuda.is_available() else "cpu"
patch_size = (1, 4, 4)
overlap = 1
model = SpatioTemporalModel(patch_size=patch_size, overlap=overlap, num_months=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
pin_memory=False,
)

best_loss = float("inf")
patience = 10 # stop if no improvement for <patience> epochs
counter = 0

# Effective batch size = batch_size (fit in memory) * accumulation_steps
accumulation_steps = 2

# Training loop with DataLoader
model.train()
for epoch in range(501):
epoch_loss = 0.0

optimizer.zero_grad()

for i, batch in enumerate(dataloader):
# Get batch data
daily_batch = batch["daily_patch"]
daily_mask = batch["daily_mask_patch"]
monthly_target = batch["monthly_patch"]
land_mask = batch["land_mask_patch"]
padded_days_mask = batch["padded_days_mask"]

# Batch prediction
pred = model(daily_batch, daily_mask, land_mask, padded_days_mask) # (B, M, H, W)

# Mask out land pixels
ocean = (~land_mask).to(pred.device).unsqueeze(1).float() # (B, M=1, H, W) bool
loss = torch.nn.functional.l1_loss(pred, monthly_target, reduction="none")
loss = loss * ocean

num = loss.sum(dim=(-2, -1)) # (B, M)
denom = ocean.sum(dim=(-2, -1)).clamp_min(1) # (B, 1)

loss_per_month = num / denom
loss = loss_per_month.mean()

# Scale loss for gradient accumulation
scaled_loss = loss / accumulation_steps
scaled_loss.backward()

# Track unscaled loss for logging
epoch_loss += loss.item()

# Update weights every accumulation_steps batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

# Handle remaining gradients if num_batches is not divisible by accumulation_steps
if (i + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()

# Calculate average epoch loss
avg_epoch_loss = epoch_loss / (i + 1)

# Early stopping check
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
counter = 0
else:
counter += 1

if epoch % 20 == 0:
logger.info(f"Epoch {epoch}: best_loss = {best_loss:.6f}")

if counter >= patience:
logger.info(f"No improvement for {patience} epochs, stopping early at epoch {epoch}.")
break

logger.info("training done!")
logger.info(f"Final loss: {loss.item()}")

# Save the trained model with config
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"loss": loss.item(),
}
model_save_path = Path("./models/spatio_temporal_model.pth")
model_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, model_save_path)
logger.info(f"Checkpoint saved to {model_save_path}")


if __name__ == "__main__":
main()