diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..87484a4 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,75 @@ +# Execute training tasks on SLURM + +1. Make a working directory + +```sh +mkdir training +cd training +``` + +2. Clone this repo +```sh +git clone git@github.com:ESMValGroup/ClimaNet.git +``` + +3. Install uv for dependency management. Se [uv doc](https://docs.astral.sh/uv/getting-started/installation/). + +4. Create a venv and install Python dependencies using uv +```sh +cd ClimaNet +``` + +``` +uv sync +``` + +A `.venv` dir will appear + +5. Copy the python script and slurm script into the working dir: + +```sh +cp ClimaNet/scripts/example* . +``` + +6. Config `example.slurm`, in the `source ...` line, make sure the venv just created is activated. + Note that the account is the ESO4CLIMA project account, which is shared by multiple users. + +7. Config `example.py`, make sure the path of input data and land mask data is correct. + +8. Execute the SLURM job +```sh +sbatch example.slurm +``` + +## Check the efficiency of resource usage + +In the SLURM job output, you can find the line like this: + +``` +==== Slurm accounting summary 23743544 ==== +JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax +23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K +23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K +``` + +Which gives some information about the resource usage at the end of the job. + +To have a better understanding of the efficiency of resource usage, you can run the following command after the job is finished: + +```sh +sacct -j \ + --format=JobID,JobName%30,Partition,AllocCPUS,Elapsed,TotalCPU,MaxRSS,State,ExitCode \ + --parsable2 >> "eso4clima_.out" + +``` + +This will output the resource usage information and add it to the slurm job output file. After running this you can find the line like this in the output file: + +``` +JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode +23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0 +23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0 +23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0 +``` + +The the efficiency of resource usage can be calculated as `TotalCPU / AllocCPUS * Elapsed Time`. In the example above, the CPU time is `04:21:01`, the allocated CPU is `256`, and the elapsed time is `00:02:44`, so the efficiency of resource usage is `4:21:01 / 256 * 00:02:44 = 0.37`. \ No newline at end of file diff --git a/scripts/eso4clima_23743544.out b/scripts/eso4clima_23743544.out new file mode 100644 index 0000000..1ca4c85 --- /dev/null +++ b/scripts/eso4clima_23743544.out @@ -0,0 +1,61 @@ +/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + daily_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + daily_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + daily_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + daily_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + monthly_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + monthly_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + monthly_data = xr.open_mfdataset( +/home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. + monthly_data = xr.open_mfdataset( +2026-03-26 17:02:31,368 - INFO - mean: [289.19693 289.3843 ], std: [10.575894 10.624186] +2026-03-26 17:02:58,933 - INFO - Epoch 0: best_loss = inf +2026-03-26 17:04:51,712 - INFO - No improvement for 10 epochs, stopping early at epoch 9. +2026-03-26 17:04:51,712 - INFO - training done! +2026-03-26 17:04:51,712 - INFO - Final loss: 4.909843444824219 +2026-03-26 17:04:51,777 - INFO - Checkpoint saved to models/spatio_temporal_model.pth +==== Slurm accounting summary 23743544 ==== +JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax +23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K +23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K + +******************************************************************************** +* * +* This is the automated job summary provided by DKRZ. * +* If you encounter problems, need assistance or have any suggestion, please * +* write an email to * +* * +* -- support@dkrz.de -- * +* * +* We hope you enjoyed the DKRZ supercomputer LEVANTE ... * +* +* JobID : 23743544 +* JobName : eso4clima +* Account : bd0854 +* User : b383704 (202985), bd0854 (1473) +* Partition : compute +* QOS : normal +* Nodelist : l40338 (1) +* Submit date : 2026-03-26T17:02:02 +* Start time : 2026-03-26T17:02:10 +* End time : 2026-03-26T17:04:54 +* Elapsed time : 00:02:44 (Timelimit=00:30:00) +* Command : /home/b/b383704/eso4clima/train_twoyears/example.slurm +* WorkDir : /home/b/b383704/eso4clima/train_twoyears +* +* StepID | JobName NodeHours MaxRSS [Byte] (@task) +* ------------------------------------------------------------------------------ +* batch | batch 0.046 +* extern | extern 0.046 3752K (0) +* ------------------------------------------------------------------------------ + +JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode +23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0 +23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0 +23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0 diff --git a/scripts/example.slurm b/scripts/example.slurm new file mode 100644 index 0000000..11c0ad8 --- /dev/null +++ b/scripts/example.slurm @@ -0,0 +1,17 @@ +#SBATCH --ntasks-per-node=128 +#SBATCH --time=02:00:00 +#SBATCH --account=bd0854 +#SBATCH --output=eso4clima_%j.out + +# Activate the virtual environment +# Change the path to your virtual environment +source /home/b/b383704/eso4clima/ClimaNet/.venv/bin/activate + +# Run the training script +# Change the path to your training script +python /home/b/b383704/eso4clima/train_twoyears/example_training.py + +echo "==== Slurm accounting summary ${SLURM_JOB_ID} ====" +sstat --allsteps -j "$SLURM_JOB_ID" \ + --format=JobID,NTasks,AveCPU,AveRSS,MaxRSS,MaxVMSize,TresUsageInAve,TresUsageInMax \ + --parsable2 \ No newline at end of file diff --git a/scripts/example_training.py b/scripts/example_training.py new file mode 100644 index 0000000..9db4741 --- /dev/null +++ b/scripts/example_training.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +"""Example training script""" + +from pathlib import Path +import torch +import torch.nn.functional +import xarray as xr +from torch.utils.data import DataLoader + +from climanet import STDataset +from climanet.st_encoder_decoder import SpatioTemporalModel + +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def main(): + # Data files + data_folder = Path( + "/work/bd0854/b380103/eso4clima/output/v1.0/concatenated/" + ) # HPC + # data_folder = Path("../../data/output/") # local + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + daily_files.sort() + monthly_files.sort() + + # Land surface + lsm_file = "/home/b/b383704/eso4clima/train_twoyears/era5_lsm_bool.nc" # HPC + # lsm_file = data_folder / "era5_lsm_bool.nc" # local + + # Load full dataset + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + patch_size_training = 80 + daily_data = xr.open_mfdataset(daily_files) + monthly_data = xr.open_mfdataset(monthly_files) + + daily_data = xr.open_mfdataset( + daily_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + monthly_data = xr.open_mfdataset( + monthly_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + lsm_mask = xr.open_dataset(lsm_file) + + # Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field + monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True) + mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values + std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values + logger.info(f"mean: {mean}, std: {std}") + + # Make a dataset + dataset = STDataset( + daily_da=daily_data["ts"], + monthly_da=monthly_data["ts"], + land_mask=lsm_mask["lsm"], + patch_size=(patch_size_training, patch_size_training), + ) + + # Initialize training + device = "cuda" if torch.cuda.is_available() else "cpu" + patch_size = (1, 4, 4) + overlap = 1 + model = SpatioTemporalModel(patch_size=patch_size, overlap=overlap, num_months=2).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + decoder = model.decoder + with torch.no_grad(): + decoder.bias.copy_(torch.from_numpy(mean)) + decoder.scale.copy_(torch.from_numpy(std) + 1e-6) + + # Create dataloader + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + pin_memory=False, + ) + + best_loss = float("inf") + patience = 10 # stop if no improvement for epochs + counter = 0 + + # Effective batch size = batch_size (fit in memory) * accumulation_steps + accumulation_steps = 2 + + # Training loop with DataLoader + model.train() + for epoch in range(501): + epoch_loss = 0.0 + + optimizer.zero_grad() + + for i, batch in enumerate(dataloader): + # Get batch data + daily_batch = batch["daily_patch"] + daily_mask = batch["daily_mask_patch"] + monthly_target = batch["monthly_patch"] + land_mask = batch["land_mask_patch"] + padded_days_mask = batch["padded_days_mask"] + + # Batch prediction + pred = model(daily_batch, daily_mask, land_mask, padded_days_mask) # (B, M, H, W) + + # Mask out land pixels + ocean = (~land_mask).to(pred.device).unsqueeze(1).float() # (B, M=1, H, W) bool + loss = torch.nn.functional.l1_loss(pred, monthly_target, reduction="none") + loss = loss * ocean + + num = loss.sum(dim=(-2, -1)) # (B, M) + denom = ocean.sum(dim=(-2, -1)).clamp_min(1) # (B, 1) + + loss_per_month = num / denom + loss = loss_per_month.mean() + + # Scale loss for gradient accumulation + scaled_loss = loss / accumulation_steps + scaled_loss.backward() + + # Track unscaled loss for logging + epoch_loss += loss.item() + + # Update weights every accumulation_steps batches + if (i + 1) % accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() + + # Handle remaining gradients if num_batches is not divisible by accumulation_steps + if (i + 1) % accumulation_steps != 0: + optimizer.step() + optimizer.zero_grad() + + # Calculate average epoch loss + avg_epoch_loss = epoch_loss / (i + 1) + + # Early stopping check + if avg_epoch_loss < best_loss: + best_loss = avg_epoch_loss + counter = 0 + else: + counter += 1 + + if epoch % 20 == 0: + logger.info(f"Epoch {epoch}: best_loss = {best_loss:.6f}") + + if counter >= patience: + logger.info(f"No improvement for {patience} epochs, stopping early at epoch {epoch}.") + break + + logger.info("training done!") + logger.info(f"Final loss: {loss.item()}") + + # Save the trained model with config + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + "loss": loss.item(), + } + model_save_path = Path("./models/spatio_temporal_model.pth") + model_save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint, model_save_path) + logger.info(f"Checkpoint saved to {model_save_path}") + + +if __name__ == "__main__": + main()