-
Notifications
You must be signed in to change notification settings - Fork 0
25 test two year data #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rogerkuou
wants to merge
24
commits into
main
Choose a base branch
from
25_test_two_year_data
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+346
−0
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
df1105f
initial example of two year data
rogerkuou fef01cf
update examples notebook
rogerkuou 2b3ac47
add example training scritps
rogerkuou 2bdf8b5
add example slurm file
rogerkuou 1159fbc
update fig dir
rogerkuou 3e2c4b4
add README
rogerkuou 994d36b
Merge branch 'main' into 25_test_two_year_data
rogerkuou 8cd1c8f
Apply suggestions from code review
rogerkuou fe8f024
fix conflicts
rogerkuou a3ba05d
separate training and inference
rogerkuou 3c99673
update model exportation with checkpoint
rogerkuou 2b4c7c5
add inference scripts
rogerkuou 9ed2e00
use logging to replace print
rogerkuou 1de1cfb
update example slurm scripts
rogerkuou efa17a1
force example notebook to be identical as main
rogerkuou 25297dc
Apply suggestions from code review
rogerkuou 6b04a27
revert changes in model file
rogerkuou ba0408f
remove inference script
rogerkuou 74a180c
maintain the same config in example script as example notebook
rogerkuou eeb29ff
update the training loop
rogerkuou 8481c08
update logger and log file
rogerkuou baba960
update the training script and slurm file
rogerkuou e89d4e4
document the efficiency calculation in README
rogerkuou 3a5dbd6
add an example slurm log
rogerkuou File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Execute training tasks on SLURM | ||
|
|
||
| 1. Make a working directory | ||
|
|
||
| ```sh | ||
| mkdir training | ||
| cd training | ||
| ``` | ||
|
|
||
| 2. Clone this repo | ||
| ```sh | ||
| git clone git@github.com:ESMValGroup/ClimaNet.git | ||
| ``` | ||
|
|
||
| 3. Install uv for dependency management. Se [uv doc](https://docs.astral.sh/uv/getting-started/installation/). | ||
|
|
||
| 4. Create a venv and install Python dependencies using uv | ||
| ```sh | ||
| cd ClimaNet | ||
| ``` | ||
|
|
||
| ``` | ||
| uv sync | ||
| ``` | ||
|
|
||
| A `.venv` dir will appear | ||
|
|
||
| 5. Copy the python script and slurm script into the working dir: | ||
|
|
||
| ```sh | ||
| cp ClimaNet/scripts/example* . | ||
| ``` | ||
|
|
||
| 6. Config `example.slurm`, in the `source ...` line, make sure the venv just created is activated. | ||
| Note that the account is the ESO4CLIMA project account, which is shared by multiple users. | ||
|
|
||
| 7. Config `example.py`, make sure the path of input data and land mask data is correct. | ||
|
|
||
| 8. Execute the SLURM job | ||
| ```sh | ||
| sbatch example.slurm | ||
| ``` | ||
|
|
||
| ## Check the efficiency of resource usage | ||
|
|
||
| In the SLURM job output, you can find the line like this: | ||
|
|
||
| ``` | ||
| ==== Slurm accounting summary 23743544 ==== | ||
| JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax | ||
| 23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K | ||
| 23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K | ||
| ``` | ||
|
|
||
| Which gives some information about the resource usage at the end of the job. | ||
|
|
||
| To have a better understanding of the efficiency of resource usage, you can run the following command after the job is finished: | ||
|
|
||
| ```sh | ||
| sacct -j <slurm_job_id> \ | ||
| --format=JobID,JobName%30,Partition,AllocCPUS,Elapsed,TotalCPU,MaxRSS,State,ExitCode \ | ||
| --parsable2 >> "eso4clima_<slurm_job_id>.out" | ||
|
|
||
| ``` | ||
|
|
||
| This will output the resource usage information and add it to the slurm job output file. After running this you can find the line like this in the output file: | ||
|
|
||
| ``` | ||
| JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode | ||
| 23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0 | ||
| 23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0 | ||
| 23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0 | ||
| ``` | ||
|
|
||
| The the efficiency of resource usage can be calculated as `TotalCPU / AllocCPUS * Elapsed Time`. In the example above, the CPU time is `04:21:01`, the allocated CPU is `256`, and the elapsed time is `00:02:44`, so the efficiency of resource usage is `4:21:01 / 256 * 00:02:44 = 0.37`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| daily_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| daily_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| daily_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:43: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| daily_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| monthly_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| monthly_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lat" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| monthly_data = xr.open_mfdataset( | ||
| /home/b/b383704/eso4clima/train_twoyears/example_training.py:57: UserWarning: The specified chunks separate the stored chunks along dimension "lon" starting at index 160. This could degrade performance. Instead, consider rechunking after loading. | ||
| monthly_data = xr.open_mfdataset( | ||
| 2026-03-26 17:02:31,368 - INFO - mean: [289.19693 289.3843 ], std: [10.575894 10.624186] | ||
| 2026-03-26 17:02:58,933 - INFO - Epoch 0: best_loss = inf | ||
| 2026-03-26 17:04:51,712 - INFO - No improvement for 10 epochs, stopping early at epoch 9. | ||
| 2026-03-26 17:04:51,712 - INFO - training done! | ||
| 2026-03-26 17:04:51,712 - INFO - Final loss: 4.909843444824219 | ||
| 2026-03-26 17:04:51,777 - INFO - Checkpoint saved to models/spatio_temporal_model.pth | ||
| ==== Slurm accounting summary 23743544 ==== | ||
| JobID|NTasks|AveCPU|AveRSS|MaxRSS|MaxVMSize|TRESUsageInAve|TRESUsageInMax | ||
| 23743544.extern|1|00:00:00|856K|3752K|641376K|cpu=00:00:00,energy=0,fs/disk=2332,mem=856K,pages=2,vmem=217160K|cpu=00:00:00,energy=0,fs/disk=2332,mem=3752K,pages=2,vmem=641376K | ||
| 23743544.batch|1|04:21:01|11964K|4102096K|37743716K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=11964K,pages=19,vmem=356724K|cpu=04:21:01,energy=0,fs/disk=22293117907,mem=4102096K,pages=7711,vmem=37743716K | ||
|
|
||
| ******************************************************************************** | ||
| * * | ||
| * This is the automated job summary provided by DKRZ. * | ||
| * If you encounter problems, need assistance or have any suggestion, please * | ||
| * write an email to * | ||
| * * | ||
| * -- support@dkrz.de -- * | ||
| * * | ||
| * We hope you enjoyed the DKRZ supercomputer LEVANTE ... * | ||
| * | ||
| * JobID : 23743544 | ||
| * JobName : eso4clima | ||
| * Account : bd0854 | ||
| * User : b383704 (202985), bd0854 (1473) | ||
| * Partition : compute | ||
| * QOS : normal | ||
| * Nodelist : l40338 (1) | ||
| * Submit date : 2026-03-26T17:02:02 | ||
| * Start time : 2026-03-26T17:02:10 | ||
| * End time : 2026-03-26T17:04:54 | ||
| * Elapsed time : 00:02:44 (Timelimit=00:30:00) | ||
| * Command : /home/b/b383704/eso4clima/train_twoyears/example.slurm | ||
| * WorkDir : /home/b/b383704/eso4clima/train_twoyears | ||
| * | ||
| * StepID | JobName NodeHours MaxRSS [Byte] (@task) | ||
| * ------------------------------------------------------------------------------ | ||
| * batch | batch 0.046 | ||
| * extern | extern 0.046 3752K (0) | ||
| * ------------------------------------------------------------------------------ | ||
|
|
||
| JobID|JobName|Partition|AllocCPUS|Elapsed|TotalCPU|MaxRSS|State|ExitCode | ||
| 23743544|eso4clima|compute|256|00:02:44|04:21:01||COMPLETED|0:0 | ||
| 23743544.batch|batch||256|00:02:44|04:21:01|4102096K|COMPLETED|0:0 | ||
| 23743544.extern|extern||256|00:02:44|00:00.001|3752K|COMPLETED|0:0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| #SBATCH --ntasks-per-node=128 | ||
| #SBATCH --time=02:00:00 | ||
| #SBATCH --account=bd0854 | ||
| #SBATCH --output=eso4clima_%j.out | ||
|
|
||
| # Activate the virtual environment | ||
| # Change the path to your virtual environment | ||
| source /home/b/b383704/eso4clima/ClimaNet/.venv/bin/activate | ||
|
|
||
| # Run the training script | ||
| # Change the path to your training script | ||
| python /home/b/b383704/eso4clima/train_twoyears/example_training.py | ||
|
|
||
| echo "==== Slurm accounting summary ${SLURM_JOB_ID} ====" | ||
| sstat --allsteps -j "$SLURM_JOB_ID" \ | ||
| --format=JobID,NTasks,AveCPU,AveRSS,MaxRSS,MaxVMSize,TresUsageInAve,TresUsageInMax \ | ||
| --parsable2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| #!/usr/bin/env python3 | ||
| """Example training script""" | ||
|
|
||
| from pathlib import Path | ||
| import torch | ||
| import torch.nn.functional | ||
| import xarray as xr | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from climanet import STDataset | ||
| from climanet.st_encoder_decoder import SpatioTemporalModel | ||
|
|
||
| import logging | ||
|
|
||
| logging.basicConfig( | ||
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | ||
| ) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def main(): | ||
| # Data files | ||
| data_folder = Path( | ||
| "/work/bd0854/b380103/eso4clima/output/v1.0/concatenated/" | ||
| ) # HPC | ||
| # data_folder = Path("../../data/output/") # local | ||
| daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) | ||
| monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) | ||
| daily_files.sort() | ||
| monthly_files.sort() | ||
|
|
||
| # Land surface | ||
| lsm_file = "/home/b/b383704/eso4clima/train_twoyears/era5_lsm_bool.nc" # HPC | ||
| # lsm_file = data_folder / "era5_lsm_bool.nc" # local | ||
|
|
||
| # Load full dataset | ||
| daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) | ||
| monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) | ||
| patch_size_training = 80 | ||
| daily_data = xr.open_mfdataset(daily_files) | ||
| monthly_data = xr.open_mfdataset(monthly_files) | ||
|
|
||
| daily_data = xr.open_mfdataset( | ||
| daily_files, | ||
| combine="by_coords", | ||
| chunks={ | ||
| "time": 1, | ||
| "lat": patch_size_training * 2, | ||
| "lon": patch_size_training * 2, | ||
| }, | ||
| data_vars="minimal", | ||
| coords="minimal", | ||
| compat="override", | ||
| parallel=False, | ||
| ) | ||
|
|
||
| monthly_data = xr.open_mfdataset( | ||
| monthly_files, | ||
| combine="by_coords", | ||
| chunks={ | ||
| "time": 1, | ||
| "lat": patch_size_training * 2, | ||
| "lon": patch_size_training * 2, | ||
| }, | ||
| data_vars="minimal", | ||
| coords="minimal", | ||
| compat="override", | ||
| parallel=False, | ||
| ) | ||
|
|
||
| lsm_mask = xr.open_dataset(lsm_file) | ||
|
|
||
| # Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field | ||
| monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True) | ||
| mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values | ||
| std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values | ||
| logger.info(f"mean: {mean}, std: {std}") | ||
|
|
||
| # Make a dataset | ||
| dataset = STDataset( | ||
| daily_da=daily_data["ts"], | ||
| monthly_da=monthly_data["ts"], | ||
| land_mask=lsm_mask["lsm"], | ||
| patch_size=(patch_size_training, patch_size_training), | ||
| ) | ||
|
|
||
| # Initialize training | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| patch_size = (1, 4, 4) | ||
| overlap = 1 | ||
| model = SpatioTemporalModel(patch_size=patch_size, overlap=overlap, num_months=2).to(device) | ||
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | ||
| decoder = model.decoder | ||
| with torch.no_grad(): | ||
| decoder.bias.copy_(torch.from_numpy(mean)) | ||
| decoder.scale.copy_(torch.from_numpy(std) + 1e-6) | ||
|
|
||
| # Create dataloader | ||
| dataloader = DataLoader( | ||
rogerkuou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dataset, | ||
| batch_size=2, | ||
| shuffle=True, | ||
| pin_memory=False, | ||
| ) | ||
|
|
||
| best_loss = float("inf") | ||
| patience = 10 # stop if no improvement for <patience> epochs | ||
| counter = 0 | ||
|
|
||
| # Effective batch size = batch_size (fit in memory) * accumulation_steps | ||
| accumulation_steps = 2 | ||
|
|
||
| # Training loop with DataLoader | ||
| model.train() | ||
| for epoch in range(501): | ||
| epoch_loss = 0.0 | ||
|
|
||
| optimizer.zero_grad() | ||
|
|
||
| for i, batch in enumerate(dataloader): | ||
| # Get batch data | ||
| daily_batch = batch["daily_patch"] | ||
| daily_mask = batch["daily_mask_patch"] | ||
| monthly_target = batch["monthly_patch"] | ||
| land_mask = batch["land_mask_patch"] | ||
| padded_days_mask = batch["padded_days_mask"] | ||
|
|
||
| # Batch prediction | ||
| pred = model(daily_batch, daily_mask, land_mask, padded_days_mask) # (B, M, H, W) | ||
|
|
||
| # Mask out land pixels | ||
| ocean = (~land_mask).to(pred.device).unsqueeze(1).float() # (B, M=1, H, W) bool | ||
| loss = torch.nn.functional.l1_loss(pred, monthly_target, reduction="none") | ||
| loss = loss * ocean | ||
|
|
||
| num = loss.sum(dim=(-2, -1)) # (B, M) | ||
| denom = ocean.sum(dim=(-2, -1)).clamp_min(1) # (B, 1) | ||
|
|
||
| loss_per_month = num / denom | ||
| loss = loss_per_month.mean() | ||
|
|
||
| # Scale loss for gradient accumulation | ||
| scaled_loss = loss / accumulation_steps | ||
| scaled_loss.backward() | ||
|
|
||
| # Track unscaled loss for logging | ||
| epoch_loss += loss.item() | ||
|
|
||
| # Update weights every accumulation_steps batches | ||
| if (i + 1) % accumulation_steps == 0: | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
|
|
||
| # Handle remaining gradients if num_batches is not divisible by accumulation_steps | ||
| if (i + 1) % accumulation_steps != 0: | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
|
|
||
| # Calculate average epoch loss | ||
| avg_epoch_loss = epoch_loss / (i + 1) | ||
|
|
||
| # Early stopping check | ||
| if avg_epoch_loss < best_loss: | ||
| best_loss = avg_epoch_loss | ||
| counter = 0 | ||
| else: | ||
| counter += 1 | ||
|
|
||
| if epoch % 20 == 0: | ||
| logger.info(f"Epoch {epoch}: best_loss = {best_loss:.6f}") | ||
|
|
||
| if counter >= patience: | ||
| logger.info(f"No improvement for {patience} epochs, stopping early at epoch {epoch}.") | ||
| break | ||
|
|
||
| logger.info("training done!") | ||
| logger.info(f"Final loss: {loss.item()}") | ||
|
|
||
| # Save the trained model with config | ||
| checkpoint = { | ||
| "model_state_dict": model.state_dict(), | ||
| "optimizer_state_dict": optimizer.state_dict(), | ||
| "epoch": epoch, | ||
| "loss": loss.item(), | ||
| } | ||
| model_save_path = Path("./models/spatio_temporal_model.pth") | ||
| model_save_path.parent.mkdir(parents=True, exist_ok=True) | ||
| torch.save(checkpoint, model_save_path) | ||
| logger.info(f"Checkpoint saved to {model_save_path}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is currently doing a lot: creating the model, training it, making predictions, and saving results. We should split these responsibilities.
If this is a "training script", it should only handle reading the data, creating the model with the correct arguments, and passing both to a separate training function (that will be added in #33).
Then, in another script (e.g. "inference script"), we can load the saved model and make predictions (see #32). This separation is needed because training and inference require different computing resources.
Any plotting or result inspection can be done in a separate script if needed.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have splited this into a training script and an inference script. plotting part has been removed