From d3bb861b79cfecef65436f394b28be2f3c3f3193 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 24 Mar 2026 10:23:43 +0800 Subject: [PATCH 1/3] feat: add full validation --- deepmd/pt/train/training.py | 63 +++ deepmd/pt/train/validation.py | 610 +++++++++++++++++++++++++++++ deepmd/utils/argcheck.py | 176 +++++++++ source/tests/pt/test_training.py | 102 +++++ source/tests/pt/test_validation.py | 139 +++++++ 5 files changed, 1090 insertions(+) create mode 100644 deepmd/pt/train/validation.py create mode 100644 source/tests/pt/test_validation.py diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8d16e1c7ea..ee4dfa87e8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -54,6 +54,10 @@ KFOptimizerWrapper, LKFOptimizer, ) +from deepmd.pt.train.validation import ( + FullValidator, + resolve_full_validation_start_step, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -857,6 +861,57 @@ def single_model_finetune( self.enable_profiler = training_params.get("enable_profiler", False) self.profiling = training_params.get("profiling", False) self.profiling_file = training_params.get("profiling_file", "timeline.json") + self.full_validator = None + + validating_params = config.get("validating") or {} + validation_start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.0), + self.num_steps, + ) + full_validation_requested = ( + bool(validating_params.get("full_validation", False)) + and validation_start_step is not None + and validation_start_step < self.num_steps + ) + if full_validation_requested: + if self.multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + has_spin = getattr(self.model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin or isinstance(self.loss, EnergySpinLoss): + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; spin-energy training is not supported." + ) + if not isinstance(self.loss, EnergyStdLoss): + raise ValueError( + "validating.full_validation only supports single-task energy " + "training." + ) + if validation_data is None: + raise ValueError( + "validating.full_validation requires `training.validation_data` " + "to be configured." + ) + if self.zero_stage >= 2: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training with training.zero_stage < 2." + ) + self.full_validator = FullValidator( + validating_params=validating_params, + validation_data=validation_data, + model=self.model, + train_infos=self._get_inner_module().train_infos, + num_steps=self.num_steps, + rank=self.rank, + zero_stage=self.zero_stage, + restart_training=self.restart_training, + ) # Log model parameter count if self.rank == 0: @@ -1363,6 +1418,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict: fout, display_step_id, cur_lr, train_results, valid_results ) + if self.full_validator is not None: + self.full_validator.run( + step_id=_step_id, + display_step=display_step_id, + lr=cur_lr, + save_checkpoint=self.save_model, + ) + if ( ( (display_step_id) % self.save_freq == 0 diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py new file mode 100644 index 0000000000..16e6ca41ab --- /dev/null +++ b/deepmd/pt/train/validation.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later + +from __future__ import ( + annotations, +) + +import logging +from dataclasses import ( + dataclass, +) +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np +import torch +import torch.distributed as dist + +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.argcheck import ( + normalize_full_validation_metric, +) +from deepmd.utils.weight_avg import ( + weighted_average, +) + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from deepmd.utils.data import ( + DeepmdData, + ) + +METRIC_KEY_MAP = { + "e:mae": "mae_e_per_atom", + "e:rmse": "rmse_e_per_atom", + "f:mae": "mae_f_vector", + "f:rmse": "rmse_f_vector", + "v:mae": "mae_v_per_atom", + "v:rmse": "rmse_v_per_atom", +} +LOG_COLUMN_ORDER = [ + ("E_MAE", "mae_e_per_atom"), + ("E_RMSE", "rmse_e_per_atom"), + ("F_MAE", "mae_f_vector"), + ("F_RMSE", "rmse_f_vector"), + ("V_MAE", "mae_v_per_atom"), + ("V_RMSE", "rmse_v_per_atom"), +] + +BEST_METRIC_INFO_KEY = "full_validation_best_metric" +BEST_STEP_INFO_KEY = "full_validation_best_step" +BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" +BEST_CKPT_GLOB = "best.ckpt-*.pt" +BEST_PATH_INFO_KEY_COMPAT = "full_validation_best_path" +BATCH_SIZE_LOGGER_NAME = "deepmd.utils.batch_size" +VAL_LOG_SIGNIFICANT_DIGITS = 5 +VAL_LOG_COLUMN_GAP = " " +VAL_LOG_HEADER_PREFIX = "# " +VAL_LOG_DATA_PREFIX = " " +METRIC_LOG_UNIT_MAP = { + "e": ("meV/atom", 1000.0), + "f": ("meV/Å", 1000.0), + "v": ("meV/atom", 1000.0), +} + + +@dataclass(frozen=True) +class FullValidationResult: + """Result of one full validation run.""" + + display_step: int + metrics: dict[str, float] + selected_metric_key: str + selected_metric_value: float + saved_best_path: str | None + + +def resolve_full_validation_start_step( + full_val_start: float, num_steps: int +) -> int | None: + """Resolve the first step at which full validation becomes active.""" + start_value = float(full_val_start) + if start_value == 1.0: + return None + if 0.0 <= start_value < 1.0: + return int(num_steps * start_value) + return int(start_value) + + +def parse_validation_metric(metric: str) -> tuple[str, str]: + """Parse the configured full validation metric.""" + normalized_metric = normalize_full_validation_metric(metric) + return normalized_metric, METRIC_KEY_MAP[normalized_metric] + + +def _compute_system_metrics( + prediction: dict[str, np.ndarray], + test_data: dict[str, np.ndarray], + natoms: int, + has_pbc: bool, +) -> dict[str, tuple[float, float]]: + """Compute per-system full validation metrics.""" + metrics: dict[str, tuple[float, float]] = {} + find_energy = bool(test_data.get("find_energy", 0.0)) + find_force = bool(test_data.get("find_force", 0.0)) + find_virial = bool(test_data.get("find_virial", 0.0)) + + if find_energy: + diff_e = prediction["energy"].reshape(-1, 1) - test_data["energy"].reshape( + -1, 1 + ) + mae_e_per_atom = float(np.mean(np.abs(diff_e)) / natoms) + rmse_e_per_atom = float(np.sqrt(np.mean(diff_e * diff_e)) / natoms) + metrics["mae_e_per_atom"] = (mae_e_per_atom, float(diff_e.size)) + metrics["rmse_e_per_atom"] = (rmse_e_per_atom, float(diff_e.size)) + + if find_force: + diff_f = prediction["force"].reshape(-1, 3) - test_data["force"].reshape(-1, 3) + diff_f_norm = np.linalg.vector_norm(diff_f, axis=1) + mae_f_vector = float(np.mean(diff_f_norm)) + rmse_f_vector = float(np.sqrt(np.mean(diff_f_norm * diff_f_norm))) + metrics["mae_f_vector"] = (mae_f_vector, float(diff_f_norm.size)) + metrics["rmse_f_vector"] = (rmse_f_vector, float(diff_f_norm.size)) + + if has_pbc and find_virial: + diff_v = prediction["virial"].reshape(-1, 9) - test_data["virial"].reshape( + -1, 9 + ) + mae_v_per_atom = float(np.mean(np.abs(diff_v)) / natoms) + rmse_v_per_atom = float(np.sqrt(np.mean(diff_v * diff_v)) / natoms) + metrics["mae_v_per_atom"] = (mae_v_per_atom, float(diff_v.size)) + metrics["rmse_v_per_atom"] = (rmse_v_per_atom, float(diff_v.size)) + + return metrics + + +def format_metric_for_log( + metric_name: str, metric_value: float +) -> tuple[str, float, str]: + """Format a full validation metric for user-facing logging.""" + metric_family, metric_kind = metric_name.split(":") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_label = f"{metric_family.upper()}:{metric_kind.upper()}" + return metric_label, metric_value * metric_scale, metric_unit + + +def format_metric_value_for_table( + metric_key: str, metric_value: float +) -> tuple[float, str]: + """Format one table metric value and its unit for `val.log`.""" + if metric_key.endswith("_e_per_atom"): + metric_family = "e" + elif metric_key.endswith("_f_vector"): + metric_family = "f" + elif metric_key.endswith("_v_per_atom"): + metric_family = "v" + else: + raise ValueError(f"Unknown full validation metric key: {metric_key}") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + return metric_value * metric_scale, metric_unit + + +def format_metric_number_for_log(metric_value: float) -> str: + """Format one metric value for `val.log` and best-save messages.""" + if np.isnan(metric_value): + return "nan" + if metric_value == 0.0: + return "0" + abs_value = abs(metric_value) + decimals = VAL_LOG_SIGNIFICANT_DIGITS - int(np.floor(np.log10(abs_value))) - 1 + rounded_value = round(metric_value, decimals) + if rounded_value == 0.0: + rounded_value = 0.0 + if decimals > 0: + return f"{rounded_value:.{decimals}f}" + return f"{rounded_value:.0f}" + + +class SilentAutoBatchSize(AutoBatchSize): + """Auto batch size that does not emit adjustment logs.""" + + def __init__( + self, + initial_batch_size: int = 1024, + factor: float = 2.0, + ) -> None: + batch_size_log = logging.getLogger(BATCH_SIZE_LOGGER_NAME) + old_disabled = batch_size_log.disabled + batch_size_log.disabled = True + try: + super().__init__( + initial_batch_size=initial_batch_size, + factor=factor, + ) + finally: + batch_size_log.disabled = old_disabled + + def _adjust_batch_size(self, factor: float) -> None: + self.current_batch_size = int(self.current_batch_size * factor) + + +class FullValidator: + """Run independent full validation during training.""" + + def __init__( + self, + *, + validating_params: dict[str, Any], + validation_data: Any, + model: torch.nn.Module, + train_infos: dict[str, Any], + num_steps: int, + rank: int, + zero_stage: int, + restart_training: bool, + ) -> None: + self.validation_data = validation_data + self.model = model + self.train_infos = train_infos + self.rank = rank + self.zero_stage = zero_stage + self.is_distributed = dist.is_available() and dist.is_initialized() + + self.full_validation = bool(validating_params.get("full_validation", False)) + self.validation_freq = int(validating_params.get("validation_freq", 5000)) + self.save_best = bool(validating_params.get("save_best", True)) + self.metric_name, self.metric_key = parse_validation_metric( + str(validating_params.get("validation_metric", "E:MAE")) + ) + self.full_val_file = Path(validating_params.get("full_val_file", "val.log")) + self.start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.0), + num_steps, + ) + self.enabled = ( + self.full_validation + and self.start_step is not None + and self.start_step < num_steps + ) + self.step_column_width = max(len("step"), len(str(num_steps))) + self._write_mode = "a" if restart_training else "w" + self._should_write_header = not ( + restart_training and self.full_val_file.exists() + ) + self.auto_batch_size = SilentAutoBatchSize() + self.table_column_specs = [] + for column_name, metric_key in LOG_COLUMN_ORDER: + _, metric_unit = format_metric_value_for_table(metric_key, 1.0) + header_label = f"{column_name}({metric_unit})" + self.table_column_specs.append( + (metric_key, header_label, max(len(header_label), 18)) + ) + + if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) == self.metric_name: + best_metric = self.train_infos.get(BEST_METRIC_INFO_KEY) + self.best_metric_value = ( + float(best_metric) if best_metric is not None else None + ) + self.best_step = self.train_infos.get(BEST_STEP_INFO_KEY) + else: + self.best_metric_value = None + self.best_step = None + self._sync_train_infos() + if self.rank == 0: + self._initialize_best_checkpoints(restart_training=restart_training) + + def should_run(self, display_step: int) -> bool: + """Check whether the current step should trigger full validation.""" + if not self.enabled or self.start_step is None: + return False + if display_step < self.start_step: + return False + return (display_step - self.start_step) % self.validation_freq == 0 + + def run( + self, + *, + step_id: int, + display_step: int, + lr: float, + save_checkpoint: Any, + ) -> FullValidationResult | None: + """Run full validation if the current step is due.""" + if not self.should_run(display_step): + return None + + if self.is_distributed: + dist.barrier() + + result: FullValidationResult | None = None + save_path = [None] + if self.rank == 0: + result = self._evaluate(display_step) + save_path[0] = result.saved_best_path + + if self.is_distributed: + dist.broadcast_object_list(save_path, src=0) + + if save_path[0] is not None: + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) + if self.rank == 0: + self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) + + if self.rank == 0: + self._log_result(result) + + if self.is_distributed: + dist.barrier() + + return result if self.rank == 0 else None + + def _evaluate(self, display_step: int) -> FullValidationResult: + """Evaluate all validation systems and update best state.""" + # === Step 1. Switch to Evaluation Mode === + was_training = bool(getattr(self.model, "training", True)) + self.model.eval() + try: + # === Step 2. Evaluate All Systems === + metrics = self.evaluate_all_systems() + finally: + self.model.train(was_training) + + if self.metric_key not in metrics or np.isnan(metrics[self.metric_key]): + raise RuntimeError( + "The selected full validation metric is unavailable on the " + f"validation dataset: {self.metric_name.upper()}." + ) + + # === Step 3. Update Best Tracking === + selected_metric_value = float(metrics[self.metric_key]) + saved_best_path = self._update_best_state( + display_step=display_step, + selected_metric_value=selected_metric_value, + ) + return FullValidationResult( + display_step=display_step, + metrics=metrics, + selected_metric_key=self.metric_key, + selected_metric_value=selected_metric_value, + saved_best_path=saved_best_path, + ) + + def evaluate_all_systems(self) -> dict[str, float]: + """Evaluate every validation system and aggregate metrics.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + system_metrics = [] + for dataset in self.validation_data.systems: + assert isinstance(dataset, DeepmdDataSetForLoader) + system_metrics.append(self._evaluate_system(dataset._data_system)) + + aggregated = weighted_average([metric for metric in system_metrics if metric]) + return { + metric_key: float(aggregated[metric_key]) + for _, metric_key in LOG_COLUMN_ORDER + if metric_key in aggregated + } + + def _evaluate_system( + self, data_system: DeepmdData + ) -> dict[str, tuple[float, float]]: + """Evaluate one validation system.""" + test_data = data_system.get_test() + natoms = int(test_data["type"].shape[1]) + nframes = int(test_data["coord"].shape[0]) + prediction = self._predict_outputs( + coord=test_data["coord"].reshape(nframes, -1), + atom_types=test_data["type"], + box=test_data["box"] if data_system.pbc else None, + fparam=test_data["fparam"] + if bool(test_data.get("find_fparam", 0.0)) + else None, + aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, + natoms=natoms, + nframes=nframes, + ) + return _compute_system_metrics( + prediction=prediction, + test_data=test_data, + natoms=natoms, + has_pbc=data_system.pbc, + ) + + def _predict_outputs( + self, + *, + coord: np.ndarray, + atom_types: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + natoms: int, + nframes: int, + ) -> dict[str, np.ndarray]: + """Predict energy, force, and virial for the full validation batch.""" + + def predict_batch( + coord_batch: np.ndarray, + atom_types_batch: np.ndarray, + box_batch: np.ndarray | None, + fparam_batch: np.ndarray | None, + aparam_batch: np.ndarray | None, + ) -> dict[str, np.ndarray]: + coord_input = torch.tensor( + coord_batch.reshape(-1, natoms, 3).astype( + NP_PRECISION_DICT[ + RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION] + ] + ), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor( + atom_types_batch.astype(np.int64), + dtype=torch.long, + device=DEVICE, + ) + if box_batch is not None: + box_input = torch.tensor( + box_batch.reshape(-1, 3, 3).astype( + NP_PRECISION_DICT[ + RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION] + ] + ), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + else: + box_input = None + if fparam_batch is not None: + fparam_input = to_torch_tensor( + fparam_batch.reshape(-1, self.model.get_dim_fparam()) + ) + else: + fparam_input = None + if aparam_batch is not None: + aparam_input = to_torch_tensor( + aparam_batch.reshape(-1, natoms, self.model.get_dim_aparam()) + ) + else: + aparam_input = None + + batch_output = self.model( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + + return { + "energy": batch_output["energy"].detach().cpu().numpy().reshape(-1, 1), + "force": batch_output["force"] + .detach() + .cpu() + .numpy() + .reshape(-1, natoms * 3), + "virial": batch_output["virial"].detach().cpu().numpy().reshape(-1, 9), + } + + prediction = self.auto_batch_size.execute_all( + predict_batch, + nframes, + natoms, + coord, + atom_types, + box, + fparam, + aparam, + ) + return { + "energy": prediction["energy"], + "force": prediction["force"], + "virial": prediction["virial"], + } + + def _update_best_state( + self, + *, + display_step: int, + selected_metric_value: float, + ) -> str | None: + """Update the best metric state and return the checkpoint path to save.""" + if ( + self.best_metric_value is not None + and selected_metric_value >= self.best_metric_value + ): + return None + + new_best_path = ( + self._best_checkpoint_name(display_step) if self.save_best else None + ) + self.best_metric_value = selected_metric_value + self.best_step = display_step + self._sync_train_infos() + return new_best_path + + def _sync_train_infos(self) -> None: + """Synchronize best validation state into train infos.""" + self.train_infos.pop(BEST_PATH_INFO_KEY_COMPAT, None) + self.train_infos[BEST_METRIC_NAME_INFO_KEY] = self.metric_name + self.train_infos[BEST_METRIC_INFO_KEY] = self.best_metric_value + self.train_infos[BEST_STEP_INFO_KEY] = self.best_step + + def _best_checkpoint_name(self, step: int) -> str: + """Build the best-checkpoint filename for one step.""" + return f"best.ckpt-{step}.pt" + + def _list_best_checkpoints(self) -> list[Path]: + """List all managed best checkpoints in the working directory.""" + best_checkpoints = [ + path + for path in Path(".").glob(BEST_CKPT_GLOB) + if path.is_file() and not path.is_symlink() + ] + best_checkpoints.sort(key=lambda path: path.stat().st_mtime) + return best_checkpoints + + def _prune_best_checkpoints(self, keep_names: set[str] | None = None) -> None: + """Delete managed best checkpoints except the requested ones.""" + keep_names = set() if keep_names is None else keep_names + for checkpoint_path in self._list_best_checkpoints(): + if checkpoint_path.name not in keep_names: + checkpoint_path.unlink(missing_ok=True) + + def _initialize_best_checkpoints(self, restart_training: bool) -> None: + """Align on-disk best checkpoints with the current training mode.""" + if restart_training and self.save_best and self.best_step is not None: + self._prune_best_checkpoints( + keep_names={self._best_checkpoint_name(int(self.best_step))} + ) + else: + self._prune_best_checkpoints() + + def _log_result(self, result: FullValidationResult | None) -> None: + """Log and persist full validation results on rank 0.""" + assert result is not None + self._write_log_file(result) + if result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + log.info( + f"Saved best model to {result.saved_best_path} " + f"with {metric_label} = {format_metric_number_for_log(metric_value)} " + f"{metric_unit}" + ) + + def _write_log_file(self, result: FullValidationResult) -> None: + """Append one full validation entry to the dedicated log file.""" + with self.full_val_file.open(self._write_mode, buffering=1) as fout: + if self._should_write_header: + header = VAL_LOG_HEADER_PREFIX + f"{'step':^{self.step_column_width}s}" + for _, header_label, column_width in self.table_column_specs: + header += VAL_LOG_COLUMN_GAP + f"{header_label:^{column_width}s}" + header += "\n" + header += ( + "# E uses per-atom energy, F uses per-atom force-vector L2 " + "norms, and V uses virial normalized by natoms.\n" + ) + fout.write(header) + self._should_write_header = False + self._write_mode = "a" + + line = ( + VAL_LOG_DATA_PREFIX + + f"{result.display_step:^{self.step_column_width}d}" + ) + for metric_key, _, column_width in self.table_column_specs: + metric_value = result.metrics.get(metric_key, float("nan")) + if not np.isnan(metric_value): + metric_value, _ = format_metric_value_for_table( + metric_key, metric_value + ) + metric_text = format_metric_number_for_log(metric_value) + line += VAL_LOG_COLUMN_GAP + f"{metric_text:^{column_width}s}" + line += "\n" + fout.write(line) + if result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + fout.write( + "# saved best checkpoint: " + f"{result.saved_best_path} ({metric_label} = " + f"{format_metric_number_for_log(metric_value)} {metric_unit})\n" + ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..0cf44c2093 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4037,6 +4037,179 @@ def training_extra_check(data: dict | None) -> bool: ) +FULL_VALIDATION_METRIC_PREFS = { + "e:mae": ("start_pref_e", "limit_pref_e"), + "e:rmse": ("start_pref_e", "limit_pref_e"), + "f:mae": ("start_pref_f", "limit_pref_f"), + "f:rmse": ("start_pref_f", "limit_pref_f"), + "v:mae": ("start_pref_v", "limit_pref_v"), + "v:rmse": ("start_pref_v", "limit_pref_v"), +} + + +def normalize_full_validation_metric(metric: str) -> str: + """Normalize the full validation metric string.""" + return metric.strip().lower() + + +def is_valid_full_validation_metric(metric: str) -> bool: + """Check whether a full validation metric is supported.""" + return normalize_full_validation_metric(metric) in FULL_VALIDATION_METRIC_PREFS + + +def get_full_validation_metric_prefactors(metric: str) -> tuple[str, str]: + """Get the prefactor keys required by a full validation metric.""" + normalized_metric = normalize_full_validation_metric(metric) + if normalized_metric not in FULL_VALIDATION_METRIC_PREFS: + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + raise ValueError( + f"validating.validation_metric must be one of {valid_metrics}, got {metric!r}." + ) + return FULL_VALIDATION_METRIC_PREFS[normalized_metric] + + +def validating_args() -> Argument: + """Generate full validation arguments.""" + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + doc_full_validation = ( + "Whether to run an additional full validation pass over the entire " + "validation dataset during training. This flow is independent from the " + "display-time validation controlled by `training.disp_freq`. Only " + "single-task energy training is supported. Multi-task, spin-energy, " + "and `training.zero_stage >= 2` are not supported." + ) + doc_validation_freq = ( + "The frequency, in training steps, of running the full validation pass." + ) + doc_save_best = ( + "Whether to save an extra checkpoint when the selected full validation " + "metric reaches a new best value." + ) + doc_validation_metric = ( + "Metric used to determine the best checkpoint during full validation. " + f"Supported values are {valid_metrics}. The string is case-insensitive. " + "`E` and `V` are per-atom metrics; `F` uses per-atom force-error " + "vector L2 norms. The corresponding loss prefactors must not both be 0." + ) + doc_full_val_file = ( + "The file for writing full validation results only. This file is " + "independent from `training.disp_file`." + ) + doc_full_val_start = ( + "The starting point of full validation. `0` means the feature is active " + "from the beginning and will trigger at every `validation_freq` steps. " + "A value in `(0, 1)` is interpreted as a ratio of `training.numb_steps`. " + "`1` disables the feature. A value larger than `1` is interpreted as the " + "starting step after integer conversion." + ) + args = [ + Argument( + "full_validation", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_full_validation, + ), + Argument( + "validation_freq", + int, + optional=True, + default=5000, + doc=doc_only_pt_supported + doc_validation_freq, + extra_check=lambda x: x > 0, + extra_check_errmsg="must be greater than 0", + ), + Argument( + "save_best", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_save_best, + ), + Argument( + "validation_metric", + str, + optional=True, + default="E:MAE", + doc=doc_only_pt_supported + doc_validation_metric, + extra_check=is_valid_full_validation_metric, + extra_check_errmsg=( + "must be one of " + + ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + ), + ), + Argument( + "full_val_file", + str, + optional=True, + default="val.log", + doc=doc_only_pt_supported + doc_full_val_file, + ), + Argument( + "full_val_start", + [int, float], + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_full_val_start, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be greater than or equal to 0", + ), + ] + return Argument( + "validating", + dict, + sub_fields=args, + sub_variants=[], + optional=True, + default={}, + doc=doc_only_pt_supported + + "Independent full validation options for single-task energy training.", + ) + + +def validate_full_validation_config( + data: dict[str, Any], multi_task: bool = False +) -> None: + """Validate cross-section constraints for full validation.""" + validating = data.get("validating") or {} + if not validating.get("full_validation", False): + return + if float(validating.get("full_val_start", 0.0)) == 1.0: + return + + if multi_task: + # Unsupported multi-task mode is rejected during trainer initialization. + return + + metric = validating["validation_metric"] + if not is_valid_full_validation_metric(metric): + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + raise ValueError( + "validating.validation_metric must be one of " + f"{valid_metrics}, got {metric!r}." + ) + + loss_params = data.get("loss", {}) + if loss_params.get("type", "ener") != "ener": + return + + if not data.get("training", {}).get("validation_data"): + raise ValueError( + "validating.full_validation requires `training.validation_data`. " + "It is only supported for single-task energy training." + ) + + pref_start_key, pref_limit_key = get_full_validation_metric_prefactors(metric) + pref_start = float(loss_params.get(pref_start_key, 0.0)) + pref_limit = float(loss_params.get(pref_limit_key, 0.0)) + if pref_start == 0.0 and pref_limit == 0.0: + raise ValueError( + f"validating.validation_metric={metric!r} requires " + f"`loss.{pref_start_key}` and `loss.{pref_limit_key}` to not both " + "be 0." + ) + + def multi_model_args() -> list[Argument]: model_dict = model_args() model_dict.name = "model_dict" @@ -4111,6 +4284,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: optimizer_args(), loss_args(), training_args(multi_task=multi_task), + validating_args(), nvnmd_args(), ] else: @@ -4120,6 +4294,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: optimizer_args(fold_subdoc=True), multi_loss_args(), training_args(multi_task=multi_task), + validating_args(), nvnmd_args(fold_subdoc=True), ] @@ -4155,6 +4330,7 @@ def normalize(data: dict[str, Any], multi_task: bool = False) -> dict[str, Any]: base = Argument("base", dict, gen_args(multi_task=multi_task)) data = base.normalize_value(data, trim_pattern="_*") base.check_value(data, strict=True) + validate_full_validation_config(data, multi_task=multi_task) return data diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index b4cc926844..1c171b5add 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -9,6 +9,9 @@ from pathlib import ( Path, ) +from unittest.mock import ( + patch, +) import numpy as np import torch @@ -20,8 +23,15 @@ from deepmd.pt.utils.finetune import ( get_finetune_rules, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) from deepmd.utils.compat import ( convert_optimizer_v31_to_v32, + update_deepmd_input, ) from .model.test_permutation import ( @@ -749,5 +759,97 @@ def test_fitting_stat_consistency(self) -> None: ) +class TestFullValidation(unittest.TestCase): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config = convert_optimizer_v31_to_v32(self.config, warning=False) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 2 + self.config["training"]["save_freq"] = 100 + self.config["training"]["disp_training"] = False + self.config["validating"] = { + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + + def tearDown(self) -> None: + for f in os.listdir("."): + if (f.startswith("model") or f.startswith("best")) and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "val.log", "checkpoint"]: + os.remove(f) + if f.startswith("stat_files"): + shutil.rmtree(f) + + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: + mocked_eval.side_effect = [ + {"mae_e_per_atom": 2.0}, + {"mae_e_per_atom": 1.0}, + ] + Path("best.ckpt-999.pt").touch() + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + self.assertFalse(Path("best.ckpt-999.pt").exists()) + self.assertFalse(Path("best.ckpt-1.pt").exists()) + self.assertTrue(Path("best.ckpt-2.pt").exists()) + train_infos = trainer._get_inner_module().train_infos + self.assertEqual(train_infos["full_validation_best_step"], 2) + self.assertEqual(train_infos["full_validation_best_metric"], 1.0) + self.assertNotIn("full_validation_best_path", train_infos) + with open("val.log") as fp: + val_lines = [line for line in fp.readlines() if not line.startswith("#")] + self.assertEqual(val_lines[0].split()[1], "2000.0") + self.assertEqual(val_lines[1].split()[1], "1000.0") + + def test_full_validation_rejects_spin_loss(self) -> None: + config = deepcopy(self.config) + config["loss"]["type"] = "ener_spin" + with self.assertRaisesRegex(ValueError, "spin-energy"): + get_trainer(config) + + def test_full_validation_rejects_multitask(self) -> None: + multitask_json = str(Path(__file__).parent / "water/multitask.json") + with open(multitask_json) as f: + config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + for model_key in config["training"]["data_dict"]: + config["training"]["data_dict"][model_key]["training_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["validation_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["stat_file"] = ( + f"stat_files_{model_key}" + ) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["validating"] = { + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + + with self.assertRaisesRegex(ValueError, "multi-task"): + get_trainer(config, shared_links=shared_links) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py new file mode 100644 index 0000000000..25d4d76c3a --- /dev/null +++ b/source/tests/pt/test_validation.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import tempfile +import unittest +from copy import ( + deepcopy, +) + +import torch + +from deepmd.pt.train.validation import ( + FullValidator, + resolve_full_validation_start_step, +) +from deepmd.utils.argcheck import ( + normalize, +) + +from .model.test_permutation import ( + model_se_e2_a, +) + + +class _DummyValidationData: + def __init__(self) -> None: + self.systems = [] + + +class _DummyModel(torch.nn.Module): + def forward(self, *args, **kwargs): + raise NotImplementedError + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + +def _make_single_task_config() -> dict: + return { + "model": deepcopy(model_se_e2_a), + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_steps": 10, + }, + "optimizer": { + "type": "Adam", + }, + "loss": { + "type": "ener", + "start_pref_e": 1.0, + "limit_pref_e": 1.0, + "start_pref_f": 1.0, + "limit_pref_f": 1.0, + "start_pref_v": 1.0, + "limit_pref_v": 1.0, + }, + "training": { + "training_data": {"systems": ["train_system"]}, + "validation_data": {"systems": ["valid_system"]}, + "numb_steps": 10, + }, + "validating": { + "full_validation": True, + "validation_freq": 2, + "save_best": True, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + } + + +class TestValidationHelpers(unittest.TestCase): + def test_resolve_full_validation_start_step(self) -> None: + self.assertEqual(resolve_full_validation_start_step(0, 2000000), 0) + self.assertEqual(resolve_full_validation_start_step(0.1, 2000000), 200000) + self.assertEqual(resolve_full_validation_start_step(5000, 2000000), 5000) + self.assertIsNone(resolve_full_validation_start_step(1, 2000000)) + + def test_full_validator_rotates_best_checkpoint(self) -> None: + train_infos = {} + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + train_infos=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + ) + new_best_path = validator._update_best_state( + display_step=2, + selected_metric_value=1.0, + ) + finally: + os.chdir(old_cwd) + + self.assertEqual(new_best_path, "best.ckpt-2.pt") + self.assertEqual(train_infos["full_validation_best_metric"], 1.0) + self.assertEqual(train_infos["full_validation_best_step"], 2) + self.assertNotIn("full_validation_best_path", train_infos) + + +class TestValidationArgcheck(unittest.TestCase): + def test_normalize_rejects_missing_validation_data(self) -> None: + config = _make_single_task_config() + del config["training"]["validation_data"] + with self.assertRaisesRegex(ValueError, "training.validation_data"): + normalize(config) + + def test_normalize_rejects_zero_prefactor_metric(self) -> None: + config = _make_single_task_config() + config["validating"]["validation_metric"] = "F:RMSE" + config["loss"]["start_pref_f"] = 0.0 + config["loss"]["limit_pref_f"] = 0.0 + with self.assertRaisesRegex(ValueError, "start_pref_f"): + normalize(config) + + def test_normalize_rejects_invalid_metric(self) -> None: + config = _make_single_task_config() + config["validating"]["validation_metric"] = "X:MAE" + with self.assertRaisesRegex(Exception, "validation_metric"): + normalize(config) From ccdfedde82322fae5a79648f4506bfbb83f416e0 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 26 Mar 2026 11:50:44 +0800 Subject: [PATCH 2/3] fix ai comment --- deepmd/pt/train/validation.py | 87 ++++++++++++++++++++++++++---- deepmd/pt/utils/dataset.py | 5 ++ deepmd/utils/argcheck.py | 18 ++++--- source/tests/pt/test_training.py | 19 +++---- source/tests/pt/test_validation.py | 5 +- 5 files changed, 107 insertions(+), 27 deletions(-) diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index 16e6ca41ab..74e56f4d13 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -6,6 +6,7 @@ ) import logging +import traceback from dataclasses import ( dataclass, ) @@ -110,6 +111,12 @@ def resolve_full_validation_start_step( def parse_validation_metric(metric: str) -> tuple[str, str]: """Parse the configured full validation metric.""" normalized_metric = normalize_full_validation_metric(metric) + if normalized_metric not in METRIC_KEY_MAP: + supported_metrics = ", ".join(item.upper() for item in METRIC_KEY_MAP) + raise ValueError( + "validating.validation_metric must be one of " + f"{supported_metrics}, got {metric!r}." + ) return normalized_metric, METRIC_KEY_MAP[normalized_metric] @@ -255,7 +262,7 @@ def __init__( self.enabled = ( self.full_validation and self.start_step is not None - and self.start_step < num_steps + and self.start_step <= num_steps ) self.step_column_width = max(len("step"), len(str(num_steps))) self._write_mode = "a" if restart_training else "w" @@ -308,21 +315,60 @@ def run( dist.barrier() result: FullValidationResult | None = None + caught_exception: Exception | None = None + error_message = None save_path = [None] if self.rank == 0: - result = self._evaluate(display_step) - save_path[0] = result.saved_best_path + try: + result = self._evaluate(display_step) + save_path[0] = result.saved_best_path + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed on rank 0 during evaluation:\n" + f"{traceback.format_exc()}" + ) + + self._raise_if_distributed_error(error_message, caught_exception) if self.is_distributed: dist.broadcast_object_list(save_path, src=0) if save_path[0] is not None: - save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) - if self.rank == 0: - self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) + try: + if not self.is_distributed or self.zero_stage == 0: + if self.rank == 0: + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) + else: + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) + if self.rank == 0: + self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while saving the best checkpoint:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_distributed_error(error_message, caught_exception) if self.rank == 0: - self._log_result(result) + try: + self._log_result(result) + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while writing logs:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_distributed_error(error_message, caught_exception) if self.is_distributed: dist.barrier() @@ -367,8 +413,12 @@ def evaluate_all_systems(self) -> dict[str, float]: system_metrics = [] for dataset in self.validation_data.systems: - assert isinstance(dataset, DeepmdDataSetForLoader) - system_metrics.append(self._evaluate_system(dataset._data_system)) + if not isinstance(dataset, DeepmdDataSetForLoader): + raise TypeError( + "Full validation expects each dataset in validation_data.systems " + f"to be DeepmdDataSetForLoader, got {type(dataset)!r}." + ) + system_metrics.append(self._evaluate_system(dataset.data_system)) aggregated = weighted_average([metric for metric in system_metrics if metric]) return { @@ -555,6 +605,25 @@ def _initialize_best_checkpoints(self, restart_training: bool) -> None: else: self._prune_best_checkpoints() + def _raise_if_distributed_error( + self, + local_error_message: str | None, + local_exception: Exception | None = None, + ) -> None: + """Propagate a local error to all ranks and raise consistently.""" + error_message = local_error_message + if self.is_distributed: + gathered_errors = [None] * dist.get_world_size() + dist.all_gather_object(gathered_errors, local_error_message) + error_message = next( + (message for message in gathered_errors if message is not None), None + ) + if error_message is None: + return + if local_exception is not None: + raise RuntimeError(error_message) from local_exception + raise RuntimeError(error_message) + def _log_result(self, result: FullValidationResult | None) -> None: """Log and persist full validation results on rank 0.""" assert result is not None diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index ce9a6c52c6..20a76a0e87 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -49,6 +49,11 @@ def __init__( def __len__(self) -> int: return self._data_system.nframes + @property + def data_system(self) -> DeepmdData: + """Expose the underlying DeePMD data system.""" + return self._data_system + def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" b_data = self._data_system.get_item_torch(index, max(1, NUM_WORKERS)) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0cf44c2093..16945589ae 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4177,10 +4177,6 @@ def validate_full_validation_config( if float(validating.get("full_val_start", 0.0)) == 1.0: return - if multi_task: - # Unsupported multi-task mode is rejected during trainer initialization. - return - metric = validating["validation_metric"] if not is_valid_full_validation_metric(metric): valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) @@ -4189,9 +4185,19 @@ def validate_full_validation_config( f"{valid_metrics}, got {metric!r}." ) + if multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + loss_params = data.get("loss", {}) - if loss_params.get("type", "ener") != "ener": - return + loss_type = loss_params.get("type", "ener") + if loss_type != "ener": + raise ValueError( + "validating.full_validation only supports single-task energy " + f"training with loss.type='ener'; got loss.type={loss_type!r}." + ) if not data.get("training", {}).get("validation_data"): raise ValueError( diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 1c171b5add..05434f8306 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -2,6 +2,7 @@ import json import os import shutil +import tempfile import unittest from copy import ( deepcopy, @@ -761,6 +762,9 @@ def test_fitting_stat_consistency(self) -> None: class TestFullValidation(unittest.TestCase): def setUp(self) -> None: + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -782,13 +786,8 @@ def setUp(self) -> None: } def tearDown(self) -> None: - for f in os.listdir("."): - if (f.startswith("model") or f.startswith("best")) and f.endswith(".pt"): - os.remove(f) - if f in ["lcurve.out", "val.log", "checkpoint"]: - os.remove(f) - if f.startswith("stat_files"): - shutil.rmtree(f) + os.chdir(self._cwd) + self._tmpdir.cleanup() @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: @@ -843,12 +842,10 @@ def test_full_validation_rejects_multitask(self) -> None: "full_val_file": "val.log", "full_val_start": 0.0, } - config["model"], shared_links = preprocess_shared_params(config["model"]) + config["model"], _ = preprocess_shared_params(config["model"]) config = update_deepmd_input(config, warning=False) - config = normalize(config, multi_task=True) - with self.assertRaisesRegex(ValueError, "multi-task"): - get_trainer(config, shared_links=shared_links) + normalize(config, multi_task=True) if __name__ == "__main__": diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index 25d4d76c3a..88e01f3844 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -7,6 +7,9 @@ ) import torch +from dargs.dargs import ( + ArgumentValueError, +) from deepmd.pt.train.validation import ( FullValidator, @@ -135,5 +138,5 @@ def test_normalize_rejects_zero_prefactor_metric(self) -> None: def test_normalize_rejects_invalid_metric(self) -> None: config = _make_single_task_config() config["validating"]["validation_metric"] = "X:MAE" - with self.assertRaisesRegex(Exception, "validation_metric"): + with self.assertRaisesRegex(ArgumentValueError, "validation_metric"): normalize(config) From ee2df8e3b603fc05547d52a4bbb8c64250339a95 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 26 Mar 2026 12:59:24 +0800 Subject: [PATCH 3/3] add topk logic --- deepmd/pt/train/validation.py | 166 +++++++++++++++++++++-------- deepmd/utils/argcheck.py | 16 ++- source/tests/pt/test_training.py | 32 ++++-- source/tests/pt/test_validation.py | 87 ++++++++++++++- 4 files changed, 242 insertions(+), 59 deletions(-) diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index 74e56f4d13..62d319c793 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -6,6 +6,7 @@ ) import logging +import re import traceback from dataclasses import ( dataclass, @@ -68,11 +69,16 @@ ("V_RMSE", "rmse_v_per_atom"), ] -BEST_METRIC_INFO_KEY = "full_validation_best_metric" -BEST_STEP_INFO_KEY = "full_validation_best_step" +TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" -BEST_CKPT_GLOB = "best.ckpt-*.pt" -BEST_PATH_INFO_KEY_COMPAT = "full_validation_best_path" +BEST_CKPT_GLOB = "best.ckpt-*.t-*.pt" +BEST_CKPT_PATTERN = re.compile(r"^best\.ckpt-(\d+)\.t-(\d+)\.pt$") +STALE_FULL_VALIDATION_INFO_KEYS = ( + "full_validation_best_metric", + "full_validation_best_step", + "full_validation_best_path", + "full_validation_best_records", +) BATCH_SIZE_LOGGER_NAME = "deepmd.utils.batch_size" VAL_LOG_SIGNIFICANT_DIGITS = 5 VAL_LOG_COLUMN_GAP = " " @@ -96,6 +102,14 @@ class FullValidationResult: saved_best_path: str | None +@dataclass(order=True, frozen=True) +class BestCheckpointRecord: + """One best-checkpoint record ordered by metric then step.""" + + metric: float + step: int + + def resolve_full_validation_start_step( full_val_start: float, num_steps: int ) -> int | None: @@ -251,6 +265,7 @@ def __init__( self.full_validation = bool(validating_params.get("full_validation", False)) self.validation_freq = int(validating_params.get("validation_freq", 5000)) self.save_best = bool(validating_params.get("save_best", True)) + self.max_best_ckpt = int(validating_params.get("max_best_ckpt", 1)) self.metric_name, self.metric_key = parse_validation_metric( str(validating_params.get("validation_metric", "E:MAE")) ) @@ -278,15 +293,7 @@ def __init__( (metric_key, header_label, max(len(header_label), 18)) ) - if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) == self.metric_name: - best_metric = self.train_infos.get(BEST_METRIC_INFO_KEY) - self.best_metric_value = ( - float(best_metric) if best_metric is not None else None - ) - self.best_step = self.train_infos.get(BEST_STEP_INFO_KEY) - else: - self.best_metric_value = None - self.best_step = None + self.topk_records = self._load_topk_records() self._sync_train_infos() if self.rank == 0: self._initialize_best_checkpoints(restart_training=restart_training) @@ -342,7 +349,7 @@ def run( else: save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) if self.rank == 0: - self._prune_best_checkpoints(keep_names={Path(save_path[0]).name}) + self._reconcile_best_checkpoints() except Exception as exc: caught_exception = exc error_message = ( @@ -553,31 +560,62 @@ def _update_best_state( display_step: int, selected_metric_value: float, ) -> str | None: - """Update the best metric state and return the checkpoint path to save.""" - if ( - self.best_metric_value is not None - and selected_metric_value >= self.best_metric_value - ): + """Update the top-K records and return the checkpoint path to save.""" + candidate = BestCheckpointRecord( + metric=selected_metric_value, + step=display_step, + ) + updated_records = [ + record for record in self.topk_records if record.step != display_step + ] + updated_records.append(candidate) + updated_records.sort() + updated_records = updated_records[: self.max_best_ckpt] + if candidate not in updated_records: return None - new_best_path = ( - self._best_checkpoint_name(display_step) if self.save_best else None - ) - self.best_metric_value = selected_metric_value - self.best_step = display_step + self.topk_records = updated_records self._sync_train_infos() - return new_best_path + if not self.save_best: + return None + candidate_rank = self.topk_records.index(candidate) + 1 + return self._best_checkpoint_name(display_step, candidate_rank) def _sync_train_infos(self) -> None: - """Synchronize best validation state into train infos.""" - self.train_infos.pop(BEST_PATH_INFO_KEY_COMPAT, None) + """Synchronize top-K validation state into train infos.""" + for key in STALE_FULL_VALIDATION_INFO_KEYS: + self.train_infos.pop(key, None) self.train_infos[BEST_METRIC_NAME_INFO_KEY] = self.metric_name - self.train_infos[BEST_METRIC_INFO_KEY] = self.best_metric_value - self.train_infos[BEST_STEP_INFO_KEY] = self.best_step + self.train_infos[TOPK_RECORDS_INFO_KEY] = [ + {"metric": record.metric, "step": record.step} + for record in self.topk_records + ] + + def _load_topk_records(self) -> list[BestCheckpointRecord]: + """Load top-K records from train infos for the current metric.""" + if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) != self.metric_name: + return [] + raw_records = self.train_infos.get(TOPK_RECORDS_INFO_KEY, []) + if not isinstance(raw_records, list): + return [] + records = [] + for raw_record in raw_records: + if not isinstance(raw_record, dict): + continue + if "metric" not in raw_record or "step" not in raw_record: + continue + records.append( + BestCheckpointRecord( + metric=float(raw_record["metric"]), + step=int(raw_record["step"]), + ) + ) + records.sort() + return records[: self.max_best_ckpt] - def _best_checkpoint_name(self, step: int) -> str: + def _best_checkpoint_name(self, step: int, rank: int) -> str: """Build the best-checkpoint filename for one step.""" - return f"best.ckpt-{step}.pt" + return f"best.ckpt-{step}.t-{rank}.pt" def _list_best_checkpoints(self) -> list[Path]: """List all managed best checkpoints in the working directory.""" @@ -589,21 +627,63 @@ def _list_best_checkpoints(self) -> list[Path]: best_checkpoints.sort(key=lambda path: path.stat().st_mtime) return best_checkpoints - def _prune_best_checkpoints(self, keep_names: set[str] | None = None) -> None: - """Delete managed best checkpoints except the requested ones.""" - keep_names = set() if keep_names is None else keep_names - for checkpoint_path in self._list_best_checkpoints(): - if checkpoint_path.name not in keep_names: - checkpoint_path.unlink(missing_ok=True) + def _expected_topk_checkpoint_names(self) -> dict[int, str]: + """Return the expected checkpoint filename for each retained step.""" + return { + record.step: self._best_checkpoint_name(record.step, rank) + for rank, record in enumerate(self.topk_records, start=1) + } + + def _reconcile_best_checkpoints(self) -> None: + """Rename retained best checkpoints to ranked names and delete stale ones.""" + expected_names = self._expected_topk_checkpoint_names() + current_files = self._list_best_checkpoints() + files_by_step: dict[int, list[Path]] = {} + stale_files: list[Path] = [] + for checkpoint_path in current_files: + match = BEST_CKPT_PATTERN.match(checkpoint_path.name) + if match is None: + stale_files.append(checkpoint_path) + continue + step = int(match.group(1)) + files_by_step.setdefault(step, []).append(checkpoint_path) + + temp_moves: list[tuple[Path, Path]] = [] + for step, checkpoint_paths in files_by_step.items(): + expected_name = expected_names.get(step) + if expected_name is None: + stale_files.extend(checkpoint_paths) + continue + + keep_path = next( + ( + checkpoint_path + for checkpoint_path in checkpoint_paths + if checkpoint_path.name == expected_name + ), + checkpoint_paths[0], + ) + for checkpoint_path in checkpoint_paths: + if checkpoint_path != keep_path: + stale_files.append(checkpoint_path) + if keep_path.name != expected_name: + temp_path = keep_path.with_name(f"{keep_path.name}.tmp") + keep_path.rename(temp_path) + temp_moves.append((temp_path, keep_path.with_name(expected_name))) + + for checkpoint_path in stale_files: + checkpoint_path.unlink(missing_ok=True) + for temp_path, final_path in temp_moves: + final_path.unlink(missing_ok=True) + temp_path.rename(final_path) def _initialize_best_checkpoints(self, restart_training: bool) -> None: """Align on-disk best checkpoints with the current training mode.""" - if restart_training and self.save_best and self.best_step is not None: - self._prune_best_checkpoints( - keep_names={self._best_checkpoint_name(int(self.best_step))} - ) - else: - self._prune_best_checkpoints() + if restart_training and self.save_best and self.topk_records: + self._reconcile_best_checkpoints() + return + for checkpoint_path in self._list_best_checkpoints(): + checkpoint_path.unlink(missing_ok=True) def _raise_if_distributed_error( self, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 16945589ae..d1fb756c2c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4085,6 +4085,11 @@ def validating_args() -> Argument: "Whether to save an extra checkpoint when the selected full validation " "metric reaches a new best value." ) + doc_max_best_ckpt = ( + "The maximum number of top-ranked best checkpoints to keep. The best " + "checkpoints are ranked by the selected validation metric in ascending " + "order. Default is 1." + ) doc_validation_metric = ( "Metric used to determine the best checkpoint during full validation. " f"Supported values are {valid_metrics}. The string is case-insensitive. " @@ -4126,6 +4131,15 @@ def validating_args() -> Argument: default=True, doc=doc_only_pt_supported + doc_save_best, ), + Argument( + "max_best_ckpt", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_max_best_ckpt, + extra_check=lambda x: x > 0, + extra_check_errmsg="must be greater than 0", + ), Argument( "validation_metric", str, @@ -4149,7 +4163,7 @@ def validating_args() -> Argument: "full_val_start", [int, float], optional=True, - default=0.0, + default=0.5, doc=doc_only_pt_supported + doc_full_val_start, extra_check=lambda x: x >= 0, extra_check_errmsg="must be greater than or equal to 0", diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 05434f8306..74ee873872 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -773,13 +773,14 @@ def setUp(self) -> None: self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_se_e2_a) - self.config["training"]["numb_steps"] = 2 + self.config["training"]["numb_steps"] = 4 self.config["training"]["save_freq"] = 100 self.config["training"]["disp_training"] = False self.config["validating"] = { "full_validation": True, "validation_freq": 1, "save_best": True, + "max_best_ckpt": 2, "validation_metric": "E:MAE", "full_val_file": "val.log", "full_val_start": 0.0, @@ -792,24 +793,33 @@ def tearDown(self) -> None: @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: mocked_eval.side_effect = [ - {"mae_e_per_atom": 2.0}, {"mae_e_per_atom": 1.0}, + {"mae_e_per_atom": 2.0}, + {"mae_e_per_atom": 0.5}, + {"mae_e_per_atom": 1.5}, ] - Path("best.ckpt-999.pt").touch() + Path("best.ckpt-999.t-1.pt").touch() trainer = get_trainer(deepcopy(self.config)) trainer.run() - self.assertFalse(Path("best.ckpt-999.pt").exists()) - self.assertFalse(Path("best.ckpt-1.pt").exists()) - self.assertTrue(Path("best.ckpt-2.pt").exists()) + self.assertFalse(Path("best.ckpt-999.t-1.pt").exists()) + self.assertFalse(Path("best.ckpt-1.t-1.pt").exists()) + self.assertFalse(Path("best.ckpt-2.t-1.pt").exists()) + self.assertTrue(Path("best.ckpt-3.t-1.pt").exists()) + self.assertTrue(Path("best.ckpt-1.t-2.pt").exists()) train_infos = trainer._get_inner_module().train_infos - self.assertEqual(train_infos["full_validation_best_step"], 2) - self.assertEqual(train_infos["full_validation_best_metric"], 1.0) - self.assertNotIn("full_validation_best_path", train_infos) + self.assertEqual( + train_infos["full_validation_topk_records"], + [ + {"metric": 0.5, "step": 3}, + {"metric": 1.0, "step": 1}, + ], + ) with open("val.log") as fp: val_lines = [line for line in fp.readlines() if not line.startswith("#")] - self.assertEqual(val_lines[0].split()[1], "2000.0") - self.assertEqual(val_lines[1].split()[1], "1000.0") + self.assertEqual(len(val_lines), 4) + self.assertEqual(val_lines[0].split()[1], "1000.0") + self.assertEqual(val_lines[1].split()[1], "2000.0") def test_full_validation_rejects_spin_loss(self) -> None: config = deepcopy(self.config) diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py index 88e01f3844..d09d78de18 100644 --- a/source/tests/pt/test_validation.py +++ b/source/tests/pt/test_validation.py @@ -5,6 +5,9 @@ from copy import ( deepcopy, ) +from pathlib import ( + Path, +) import torch from dargs.dargs import ( @@ -12,6 +15,8 @@ ) from deepmd.pt.train.validation import ( + BEST_METRIC_NAME_INFO_KEY, + TOPK_RECORDS_INFO_KEY, FullValidator, resolve_full_validation_start_step, ) @@ -70,6 +75,7 @@ def _make_single_task_config() -> dict: "full_validation": True, "validation_freq": 2, "save_best": True, + "max_best_ckpt": 1, "validation_metric": "E:MAE", "full_val_file": "val.log", "full_val_start": 0.0, @@ -95,6 +101,7 @@ def test_full_validator_rotates_best_checkpoint(self) -> None: "full_validation": True, "validation_freq": 1, "save_best": True, + "max_best_ckpt": 2, "validation_metric": "E:MAE", "full_val_file": "val.log", "full_val_start": 0.0, @@ -107,17 +114,83 @@ def test_full_validator_rotates_best_checkpoint(self) -> None: zero_stage=0, restart_training=False, ) + new_best_path = validator._update_best_state( + display_step=1, + selected_metric_value=2.0, + ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + new_best_path = validator._update_best_state( display_step=2, selected_metric_value=1.0, ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + + new_best_path = validator._update_best_state( + display_step=3, + selected_metric_value=1.5, + ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + finally: + os.chdir(old_cwd) + + self.assertEqual(new_best_path, "best.ckpt-3.t-2.pt") + self.assertEqual( + sorted(path.name for path in Path(tmpdir).glob("best.ckpt-*.pt")), + ["best.ckpt-2.t-1.pt", "best.ckpt-3.t-2.pt"], + ) + self.assertEqual( + train_infos[TOPK_RECORDS_INFO_KEY], + [ + {"metric": 1.0, "step": 2}, + {"metric": 1.5, "step": 3}, + ], + ) + self.assertEqual(train_infos[BEST_METRIC_NAME_INFO_KEY], "e:mae") + + def test_full_validator_restores_top_k_checkpoints(self) -> None: + train_infos = { + BEST_METRIC_NAME_INFO_KEY: "e:mae", + TOPK_RECORDS_INFO_KEY: [ + {"metric": 1.0, "step": 20}, + {"metric": 2.0, "step": 10}, + ], + } + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + Path("best.ckpt-20.t-9.pt").touch() + Path("best.ckpt-10.t-8.pt").touch() + Path("best.ckpt-999.t-1.pt").touch() + FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 2, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + train_infos=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=True, + ) finally: os.chdir(old_cwd) - self.assertEqual(new_best_path, "best.ckpt-2.pt") - self.assertEqual(train_infos["full_validation_best_metric"], 1.0) - self.assertEqual(train_infos["full_validation_best_step"], 2) - self.assertNotIn("full_validation_best_path", train_infos) + self.assertEqual( + sorted(path.name for path in Path(tmpdir).glob("best.ckpt-*.pt")), + ["best.ckpt-10.t-2.pt", "best.ckpt-20.t-1.pt"], + ) class TestValidationArgcheck(unittest.TestCase): @@ -140,3 +213,9 @@ def test_normalize_rejects_invalid_metric(self) -> None: config["validating"]["validation_metric"] = "X:MAE" with self.assertRaisesRegex(ArgumentValueError, "validation_metric"): normalize(config) + + def test_normalize_rejects_nonpositive_max_best_ckpt(self) -> None: + config = _make_single_task_config() + config["validating"]["max_best_ckpt"] = 0 + with self.assertRaisesRegex(ArgumentValueError, "max_best_ckpt"): + normalize(config)