Conversation
There was a problem hiding this comment.
Pull request overview
Adds “full validation” support to PyTorch training, including config argcheck validation and checkpoint rotation on best metric.
Changes:
- Introduces
FullValidatorto run periodic full-dataset validation, logval.log, and optionally save/rotate best checkpoints. - Extends
deepmd.utils.argcheck.normalize()to validate full-validation configs and supported metrics/prefactors. - Adds unit/integration tests covering metric parsing/start-step resolution, argcheck failures, and best-checkpoint rotation.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_validation.py | Adds unit tests for helper functions and argcheck validation for full validation. |
| source/tests/pt/test_training.py | Adds trainer-level tests for full validation behavior and rejection paths (spin/multi-task). |
| deepmd/utils/argcheck.py | Adds validating config schema and cross-field validation for full validation. |
| deepmd/pt/train/validation.py | Implements FullValidator and full-validation metric/logging utilities. |
| deepmd/pt/train/training.py | Wires FullValidator into the training loop and enforces runtime constraints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) | ||
| if self.rank == 0: |
There was a problem hiding this comment.
In distributed mode, every rank will call save_checkpoint(...) for the same best.ckpt-*.pt path, which can cause concurrent writes and corrupted checkpoints (even with barriers). Make the save operation rank-0 only (and keep the broadcast only as a signal), or ensure save_checkpoint is explicitly rank-gated here.
| save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) | |
| if self.rank == 0: | |
| if self.rank == 0: | |
| save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) |
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
|
|
There was a problem hiding this comment.
torch.cuda.empty_cache() on every full validation run can significantly slow training and increase allocator churn/fragmentation. Consider removing it, or making it conditional (e.g., only after catching an OOM during validation, or behind an explicit config flag).
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() |
|
|
||
| def parse_validation_metric(metric: str) -> tuple[str, str]: | ||
| """Parse the configured full validation metric.""" | ||
| normalized_metric = normalize_full_validation_metric(metric) |
There was a problem hiding this comment.
METRIC_KEY_MAP[normalized_metric] will raise a KeyError for unsupported metrics, which is harder to interpret than a ValueError and bypasses the more user-friendly messaging you already implemented in argcheck. Consider validating membership here and raising ValueError with the supported values so FullValidator remains robust when used outside the normalize() path (e.g., unit tests or internal callers).
| normalized_metric = normalize_full_validation_metric(metric) | |
| normalized_metric = normalize_full_validation_metric(metric) | |
| if normalized_metric not in METRIC_KEY_MAP: | |
| supported = ", ".join(sorted(METRIC_KEY_MAP)) | |
| raise ValueError( | |
| f"Unsupported full validation metric {normalized_metric!r}. " | |
| f"Supported metrics are: {supported}" | |
| ) |
| assert isinstance(dataset, DeepmdDataSetForLoader) | ||
| system_metrics.append(self._evaluate_system(dataset._data_system)) |
There was a problem hiding this comment.
This relies on a private attribute (dataset._data_system) and uses assert for runtime type enforcement. Both make this brittle (private API changes, and asserts can be stripped with -O). Prefer a public accessor on DeepmdDataSetForLoader (or add one) and raise a typed exception if an unexpected dataset is encountered.
| assert isinstance(dataset, DeepmdDataSetForLoader) | |
| system_metrics.append(self._evaluate_system(dataset._data_system)) | |
| if not isinstance(dataset, DeepmdDataSetForLoader): | |
| raise TypeError( | |
| f"Expected validation dataset of type DeepmdDataSetForLoader, " | |
| f"got {type(dataset)!r} instead." | |
| ) | |
| try: | |
| data_system = dataset.data_system | |
| except AttributeError as exc: | |
| raise AttributeError( | |
| "Validation dataset does not expose a public 'data_system' " | |
| "attribute. Please provide a public accessor on " | |
| "DeepmdDataSetForLoader instead of relying on private " | |
| "attributes." | |
| ) from exc | |
| system_metrics.append(self._evaluate_system(data_system)) |
| def tearDown(self) -> None: | ||
| for f in os.listdir("."): | ||
| if (f.startswith("model") or f.startswith("best")) and f.endswith(".pt"): | ||
| os.remove(f) | ||
| if f in ["lcurve.out", "val.log", "checkpoint"]: | ||
| os.remove(f) | ||
| if f.startswith("stat_files"): | ||
| shutil.rmtree(f) |
There was a problem hiding this comment.
These tests manipulate and delete files in the current working directory, which can be risky/flaky when the test runner’s CWD isn’t isolated (and checkpoint may be a directory in some setups, making os.remove fail). Prefer running the test in a TemporaryDirectory() (chdir like in test_validation.py) and use Path checks (is_file()/is_dir()) to delete safely.
📝 WalkthroughWalkthroughThis pull request introduces a full validation system during training in DeepMD PT. It adds configuration parameters for validation control, implements a new Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant FullValidator
participant Model as Model (eval)
participant ValidData as Validation Data
participant Checkpoint as Checkpoint Storage
participant ValLog as val.log
Trainer->>FullValidator: __init__(validation_data, model, ...)
FullValidator->>FullValidator: Initialize best_step=-1, best_metric=inf
loop Training steps
Trainer->>Trainer: Training iterations
alt Display moment (logging step)
Trainer->>FullValidator: run(step_id, display_step, lr, save_checkpoint)
FullValidator->>FullValidator: should_run(display_step)?
alt Should run validation
FullValidator->>Model: switch to eval mode
FullValidator->>ValidData: iterate systems
loop Per system
FullValidator->>Model: predict(atoms, coords, ...)
Model-->>FullValidator: energy, force, virial
FullValidator->>FullValidator: compute MAE/RMSE metrics
end
FullValidator->>FullValidator: aggregate metrics (weighted avg)
FullValidator->>FullValidator: check if new best metric
alt New best metric found
FullValidator->>Checkpoint: save best.ckpt-<step>.pt
FullValidator->>Checkpoint: prune older best checkpoints
FullValidator->>FullValidator: update best_step, best_metric
end
FullValidator->>ValLog: write/append entry with metrics
FullValidator-->>Trainer: return FullValidationResult
else Skip validation
FullValidator-->>Trainer: return None
end
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
deepmd/pt/train/validation.py (1)
464-470: Disable autograd for validation forwards.
eval()changes module behavior, but gradients are still tracked here. On a full-dataset validation pass that is unnecessary memory and latency overhead.Proposed refactor
- batch_output = self.model( - coord_input, - type_input, - box=box_input, - fparam=fparam_input, - aparam=aparam_input, - ) + with torch.inference_mode(): + batch_output = self.model( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 464 - 470, The validation forward is running with gradients enabled; wrap the model inference call that produces batch_output = self.model(...) in a no-grad context (preferably with torch.inference_mode() or with torch.no_grad()) inside the validation routine (the method where self.model is called for validation) so autograd is disabled during validation forwards and reduces memory/latency overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1421-1427: FullValidator.run currently receives self.save_model
directly, causing every stage-0 worker to build model_state and deepcopy
optimizer.state_dict() when a new best checkpoint is broadcast; change the call
so save_checkpoint is True only on the rank that should serialize (e.g. pass
save_checkpoint=(self.save_model and (self.global_rank == 0)) or use
torch.distributed.get_rank()==0 when self.stage == 0), i.e. gate the
save_checkpoint argument before calling self.full_validator.run to skip
serialization on nonzero ranks and avoid unnecessary deep copies.
- Around line 900-914: FullValidator.run() can deadlock other ranks if rank 0
raises during _evaluate()/save_checkpoint() because those ranks wait on
broadcast_object_list(); wrap the rank-0 evaluation/checkpoint block in a
try/except that captures any Exception, set a serializable error payload (e.g.,
tuple with True and the exception string), and immediately broadcast that
payload with broadcast_object_list() so all ranks are unblocked; on non-zero
ranks receive the payload, detect the error flag, and raise a matching exception
(or handle/clean up) so the failure is propagated instead of leaving ranks
blocked—modify deepmd/pt/train/validation.py FullValidator.run() to implement
this pattern around _evaluate() and save_checkpoint().
In `@deepmd/pt/train/validation.py`:
- Around line 255-259: The validator is disabled when start_step equals
num_steps due to a strict '<' check; update the initialization of self.enabled
(which uses self.full_validation, self.start_step, and num_steps) to allow
equality (use '<=' semantics) so full validation can run on the final training
step, and ensure the should_run() logic remains consistent with this change.
- Around line 307-328: The current code only calls self._evaluate() on rank 0
which deadlocks when self.zero_stage >= 2 because forward passes require all
ranks; change the control flow so that when self.zero_stage >= 2 you call
self._evaluate() on every rank (remove the rank==0-only guard for that case) and
still use save_path = [None] + dist.broadcast_object_list(save_path, src=0) to
propagate the chosen checkpoint; keep the existing rank-0-only actions (calling
self._prune_best_checkpoints and self._log_result) but ensure
save_checkpoint(Path(save_path[0]), ...) and the broadcast happen after every
rank has produced or received save_path; update the branches around
self._evaluate, save_path, dist.broadcast_object_list, save_checkpoint,
_prune_best_checkpoints and _log_result accordingly so distributed stage-2/3
training doesn't hang.
In `@deepmd/utils/argcheck.py`:
- Around line 4180-4194: The code currently returns early on multi_task or
non-'ener' losses which lets validating.full_validation silently pass; instead,
check validating.get("full_validation") first and if true reject unsupported
modes: if multi_task is True or loss_params.get("type","ener") != "ener" raise a
ValueError explaining that full_validation is unsupported with multi-task or
non-'ener' losses. Also only run the validation_metric check (using
validating["validation_metric"], is_valid_full_validation_metric and
FULL_VALIDATION_METRIC_PREFS) when full_validation is enabled so invalid metrics
are rejected rather than silently ignored.
In `@source/tests/pt/test_validation.py`:
- Around line 135-139: The test test_normalize_rejects_invalid_metric currently
catches the broad Exception; replace this with the concrete validation error
type that normalize() raises (e.g., ValidationError or the project-specific
ValidationError class) and update imports accordingly so the assertion uses
assertRaisesRegex(ValidationError, "validation_metric") against
normalize(config); keep the same regex and test flow but narrow the exception to
the specific validation error class.
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 464-470: The validation forward is running with gradients enabled;
wrap the model inference call that produces batch_output = self.model(...) in a
no-grad context (preferably with torch.inference_mode() or with torch.no_grad())
inside the validation routine (the method where self.model is called for
validation) so autograd is disabled during validation forwards and reduces
memory/latency overhead.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 98de0e8f-e132-4770-ae44-25a71d16e73e
📒 Files selected for processing (5)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
| if self.zero_stage >= 2: | ||
| raise ValueError( | ||
| "validating.full_validation only supports single-task energy " | ||
| "training with training.zero_stage < 2." | ||
| ) | ||
| self.full_validator = FullValidator( | ||
| validating_params=validating_params, | ||
| validation_data=validation_data, | ||
| model=self.model, | ||
| train_infos=self._get_inner_module().train_infos, | ||
| num_steps=self.num_steps, | ||
| rank=self.rank, | ||
| zero_stage=self.zero_stage, | ||
| restart_training=self.restart_training, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n -C4 'def run\(|_evaluate\(|broadcast_object_list|dist\.barrier|save_checkpoint' deepmd/pt/train/validation.pyRepository: deepmodeling/deepmd-kit
Length of output: 1737
Distributed full validation has a critical deadlock risk that must be addressed before enabling for zero_stage < 2.
When distributed training is active (any rank > 0), FullValidator.run() synchronizes all ranks at line 308, then rank 0 alone executes _evaluate() at line 313. If rank 0 raises during evaluation, it never reaches the broadcast_object_list() call at line 317, leaving all other ranks indefinitely blocked waiting for that broadcast. This wedges the entire training job.
Harden deepmd/pt/train/validation.py::FullValidator.run() to propagate rank-0 exceptions (e.g., by broadcasting an error flag before save_checkpoint()), or restrict this feature to single-rank training until the distributed path is safe.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/train/training.py` around lines 900 - 914, FullValidator.run() can
deadlock other ranks if rank 0 raises during _evaluate()/save_checkpoint()
because those ranks wait on broadcast_object_list(); wrap the rank-0
evaluation/checkpoint block in a try/except that captures any Exception, set a
serializable error payload (e.g., tuple with True and the exception string), and
immediately broadcast that payload with broadcast_object_list() so all ranks are
unblocked; on non-zero ranks receive the payload, detect the error flag, and
raise a matching exception (or handle/clean up) so the failure is propagated
instead of leaving ranks blocked—modify deepmd/pt/train/validation.py
FullValidator.run() to implement this pattern around _evaluate() and
save_checkpoint().
| if self.full_validator is not None: | ||
| self.full_validator.run( | ||
| step_id=_step_id, | ||
| display_step=display_step_id, | ||
| lr=cur_lr, | ||
| save_checkpoint=self.save_model, | ||
| ) |
There was a problem hiding this comment.
Skip best-checkpoint serialization on nonzero ranks for stage 0.
FullValidator.run() calls save_checkpoint() on every rank once a new best path is broadcast. Passing self.save_model directly means stage-0 workers still build model_state and deepcopy(self.optimizer.state_dict()) before the rank guard returns at Line 1602. On large jobs, every best-checkpoint update now clones the full training state on every worker for no benefit.
♻️ One way to gate the callback
+ def _save_full_validation_checkpoint(
+ self, save_path: Path, lr: float, step: int
+ ) -> None:
+ if self.zero_stage == 0 and self.rank != 0:
+ return
+ self.save_model(save_path, lr=lr, step=step)
...
- save_checkpoint=self.save_model,
+ save_checkpoint=self._save_full_validation_checkpoint,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/train/training.py` around lines 1421 - 1427, FullValidator.run
currently receives self.save_model directly, causing every stage-0 worker to
build model_state and deepcopy optimizer.state_dict() when a new best checkpoint
is broadcast; change the call so save_checkpoint is True only on the rank that
should serialize (e.g. pass save_checkpoint=(self.save_model and
(self.global_rank == 0)) or use torch.distributed.get_rank()==0 when self.stage
== 0), i.e. gate the save_checkpoint argument before calling
self.full_validator.run to skip serialization on nonzero ranks and avoid
unnecessary deep copies.
| self.enabled = ( | ||
| self.full_validation | ||
| and self.start_step is not None | ||
| and self.start_step < num_steps | ||
| ) |
There was a problem hiding this comment.
Allow full validation to start on the final training step.
If full_val_start resolves to exactly num_steps, should_run() could still fire once on the last display step, but this strict < disables the validator up front. That makes the documented “final-step only” configuration a no-op.
Proposed fix
self.enabled = (
self.full_validation
and self.start_step is not None
- and self.start_step < num_steps
+ and self.start_step <= num_steps
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/train/validation.py` around lines 255 - 259, The validator is
disabled when start_step equals num_steps due to a strict '<' check; update the
initialization of self.enabled (which uses self.full_validation,
self.start_step, and num_steps) to allow equality (use '<=' semantics) so full
validation can run on the final training step, and ensure the should_run() logic
remains consistent with this change.
| if self.is_distributed: | ||
| dist.barrier() | ||
|
|
||
| result: FullValidationResult | None = None | ||
| save_path = [None] | ||
| if self.rank == 0: | ||
| result = self._evaluate(display_step) | ||
| save_path[0] = result.saved_best_path | ||
|
|
||
| if self.is_distributed: | ||
| dist.broadcast_object_list(save_path, src=0) | ||
|
|
||
| if save_path[0] is not None: | ||
| save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) | ||
| if self.rank == 0: | ||
| self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) | ||
|
|
||
| if self.rank == 0: | ||
| self._log_result(result) | ||
|
|
||
| if self.is_distributed: | ||
| dist.barrier() |
There was a problem hiding this comment.
Guard zero_stage >= 2 before rank-0-only evaluation.
This path sends only rank 0 into _evaluate() while the other ranks block on broadcast_object_list. With sharded stage-2/3 execution, forward requires collective participation from every rank, so the first full-validation pass can hang here.
Proposed fix
if not self.should_run(display_step):
return None
+ if self.is_distributed and self.zero_stage >= 2:
+ raise ValueError(
+ "validating.full_validation does not support training.zero_stage >= 2."
+ )
+
if self.is_distributed:
dist.barrier()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/train/validation.py` around lines 307 - 328, The current code only
calls self._evaluate() on rank 0 which deadlocks when self.zero_stage >= 2
because forward passes require all ranks; change the control flow so that when
self.zero_stage >= 2 you call self._evaluate() on every rank (remove the
rank==0-only guard for that case) and still use save_path = [None] +
dist.broadcast_object_list(save_path, src=0) to propagate the chosen checkpoint;
keep the existing rank-0-only actions (calling self._prune_best_checkpoints and
self._log_result) but ensure save_checkpoint(Path(save_path[0]), ...) and the
broadcast happen after every rank has produced or received save_path; update the
branches around self._evaluate, save_path, dist.broadcast_object_list,
save_checkpoint, _prune_best_checkpoints and _log_result accordingly so
distributed stage-2/3 training doesn't hang.
| if multi_task: | ||
| # Unsupported multi-task mode is rejected during trainer initialization. | ||
| return | ||
|
|
||
| metric = validating["validation_metric"] | ||
| if not is_valid_full_validation_metric(metric): | ||
| valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) | ||
| raise ValueError( | ||
| "validating.validation_metric must be one of " | ||
| f"{valid_metrics}, got {metric!r}." | ||
| ) | ||
|
|
||
| loss_params = data.get("loss", {}) | ||
| if loss_params.get("type", "ener") != "ener": | ||
| return |
There was a problem hiding this comment.
Reject unsupported full-validation modes instead of returning early.
These branches currently let validating.full_validation: true normalize successfully for multi-task configs and non-ener losses like ener_spin, even though this file documents those combinations as unsupported. That defers the failure to runtime or leaves the setting silently ineffective.
Proposed fix
- if multi_task:
- # Unsupported multi-task mode is rejected during trainer initialization.
- return
+ if multi_task:
+ raise ValueError(
+ "validating.full_validation is only supported in single-task mode."
+ )
metric = validating["validation_metric"]
if not is_valid_full_validation_metric(metric):
valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS)
raise ValueError(
"validating.validation_metric must be one of "
f"{valid_metrics}, got {metric!r}."
)
loss_params = data.get("loss", {})
- if loss_params.get("type", "ener") != "ener":
- return
+ if loss_params.get("type", "ener") != "ener":
+ raise ValueError(
+ "validating.full_validation is only supported when loss.type == 'ener'."
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/utils/argcheck.py` around lines 4180 - 4194, The code currently
returns early on multi_task or non-'ener' losses which lets
validating.full_validation silently pass; instead, check
validating.get("full_validation") first and if true reject unsupported modes: if
multi_task is True or loss_params.get("type","ener") != "ener" raise a
ValueError explaining that full_validation is unsupported with multi-task or
non-'ener' losses. Also only run the validation_metric check (using
validating["validation_metric"], is_valid_full_validation_metric and
FULL_VALIDATION_METRIC_PREFS) when full_validation is enabled so invalid metrics
are rejected rather than silently ignored.
| def test_normalize_rejects_invalid_metric(self) -> None: | ||
| config = _make_single_task_config() | ||
| config["validating"]["validation_metric"] = "X:MAE" | ||
| with self.assertRaisesRegex(Exception, "validation_metric"): | ||
| normalize(config) |
There was a problem hiding this comment.
Use the concrete exception type in this assertion.
Catching Exception here will also pass on unrelated regressions inside normalize(), so this test can stay green for the wrong reason. Please narrow it to the validation error invalid metrics are actually expected to raise.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt/test_validation.py` around lines 135 - 139, The test
test_normalize_rejects_invalid_metric currently catches the broad Exception;
replace this with the concrete validation error type that normalize() raises
(e.g., ValidationError or the project-specific ValidationError class) and update
imports accordingly so the assertion uses assertRaisesRegex(ValidationError,
"validation_metric") against normalize(config); keep the same regex and test
flow but narrow the exception to the specific validation error class.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5336 +/- ##
==========================================
- Coverage 82.42% 82.39% -0.03%
==========================================
Files 784 785 +1
Lines 79125 79461 +336
Branches 3676 3676
==========================================
+ Hits 65220 65474 +254
- Misses 12732 12814 +82
Partials 1173 1173 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit