From 85fe7160379253d73396156415ae77f242e48a2a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:23:52 +0800 Subject: [PATCH 1/9] feat: Implement XAS (X-ray Absorption Spectroscopy) model, fitting, loss, and testing infrastructure. --- deepmd/__about__.py | 3 + deepmd/dpmodel/atomic_model/__init__.py | 4 + .../dpmodel/atomic_model/xas_atomic_model.py | 81 ++++++++ deepmd/dpmodel/fitting/__init__.py | 4 + deepmd/dpmodel/fitting/xas_fitting.py | 162 ++++++++++++++++ deepmd/dpmodel/model/model.py | 22 ++- deepmd/dpmodel/model/xas_model.py | 99 ++++++++++ deepmd/pt/loss/__init__.py | 4 + deepmd/pt/loss/xas.py | 175 ++++++++++++++++++ deepmd/pt/model/atomic_model/__init__.py | 4 + .../pt/model/atomic_model/xas_atomic_model.py | 59 ++++++ deepmd/pt/model/model/__init__.py | 27 ++- deepmd/pt/model/model/xas_model.py | 110 +++++++++++ deepmd/pt/model/task/__init__.py | 4 + deepmd/pt/model/task/xas.py | 167 +++++++++++++++++ deepmd/pt/train/training.py | 14 +- deepmd/utils/argcheck.py | 105 +++++++++++ 17 files changed, 1025 insertions(+), 19 deletions(-) create mode 100644 deepmd/__about__.py create mode 100644 deepmd/dpmodel/atomic_model/xas_atomic_model.py create mode 100644 deepmd/dpmodel/fitting/xas_fitting.py create mode 100644 deepmd/dpmodel/model/xas_model.py create mode 100644 deepmd/pt/loss/xas.py create mode 100644 deepmd/pt/model/atomic_model/xas_atomic_model.py create mode 100644 deepmd/pt/model/model/xas_model.py create mode 100644 deepmd/pt/model/task/xas.py diff --git a/deepmd/__about__.py b/deepmd/__about__.py new file mode 100644 index 0000000000..0d2f7d41b3 --- /dev/null +++ b/deepmd/__about__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# Auto-generated stub for development use +__version__ = "dev" diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index 4d882d5e4b..3c10834bf2 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -23,6 +23,9 @@ from .dos_atomic_model import ( DPDOSAtomicModel, ) +from .xas_atomic_model import ( + DPXASAtomicModel, +) from .dp_atomic_model import ( DPAtomicModel, ) @@ -50,6 +53,7 @@ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", + "DPXASAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", diff --git a/deepmd/dpmodel/atomic_model/xas_atomic_model.py b/deepmd/dpmodel/atomic_model/xas_atomic_model.py new file mode 100644 index 0000000000..63824ffdae --- /dev/null +++ b/deepmd/dpmodel/atomic_model/xas_atomic_model.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.dpmodel.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.dpmodel.fitting.xas_fitting import ( + XASFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPXASAtomicModel(DPAtomicModel): + """Atomic model for XAS spectrum fitting. + + Automatically sets ``atom_exclude_types`` to all non-absorbing atom types + so that the intensive mean reduction in ``fit_output_to_model_output`` + computes the mean XAS over absorbing atoms only. + + Parameters + ---------- + descriptor : BaseDescriptor + fitting : BaseFitting + Must be an instance of XASFittingNet. + type_map : list[str] + Mapping from type index to element symbol. + absorbing_type : str + Element symbol of the absorbing atom type (e.g. "Fe"). + **kwargs + Passed to DPAtomicModel. + """ + + def __init__( + self, + descriptor: BaseDescriptor, + fitting: BaseFitting, + type_map: list[str], + absorbing_type: str, + **kwargs: Any, + ) -> None: + if not isinstance(fitting, XASFittingNet): + raise TypeError( + "fitting must be an instance of XASFittingNet for DPXASAtomicModel" + ) + if absorbing_type not in type_map: + raise ValueError( + f"absorbing_type '{absorbing_type}' not found in type_map {type_map}" + ) + self.absorbing_type = absorbing_type + absorbing_idx = type_map.index(absorbing_type) + # Exclude all types except the absorbing type so the intensive mean + # reduction is computed only over absorbing atoms. + atom_exclude_types = [i for i in range(len(type_map)) if i != absorbing_idx] + kwargs["atom_exclude_types"] = atom_exclude_types + super().__init__(descriptor, fitting, type_map, **kwargs) + + def get_intensive(self) -> bool: + """XAS is an intensive property (mean over absorbing atoms).""" + return True + + def serialize(self) -> dict: + dd = super().serialize() + dd["absorbing_type"] = self.absorbing_type + return dd + + @classmethod + def deserialize(cls, data: dict) -> "DPXASAtomicModel": + data = data.copy() + absorbing_type = data.pop("absorbing_type") + # atom_exclude_types is already stored by base; rebuild absorbing_type param + obj = super().deserialize(data) + obj.absorbing_type = absorbing_type + return obj diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 5bdfff2571..9e7c7171af 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -5,6 +5,9 @@ from .dos_fitting import ( DOSFittingNet, ) +from .xas_fitting import ( + XASFittingNet, +) from .ener_fitting import ( EnergyFittingNet, ) @@ -23,6 +26,7 @@ __all__ = [ "DOSFittingNet", + "XASFittingNet", "DipoleFitting", "EnergyFittingNet", "InvarFitting", diff --git a/deepmd/dpmodel/fitting/xas_fitting.py b/deepmd/dpmodel/fitting/xas_fitting.py new file mode 100644 index 0000000000..514046b8bc --- /dev/null +++ b/deepmd/dpmodel/fitting/xas_fitting.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + TYPE_CHECKING, +) + +import numpy as np + +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.dpmodel.common import ( + DEFAULT_PRECISION, + to_numpy_array, +) +from deepmd.dpmodel.fitting.invar_fitting import ( + InvarFitting, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.fitting.general_fitting import ( + GeneralFitting, + ) + +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@InvarFitting.register("xas") +class XASFittingNet(InvarFitting): + """Fitting network for X-ray Absorption Spectroscopy (XAS) spectra. + + Predicts per-atom XAS contributions in a relative energy (ΔE) space. + The global XAS is the mean over all absorbing atoms, handled by the + XAS model via ``intensive=True`` and type-selective masking. + + Parameters + ---------- + ntypes : int + Number of atom types. + dim_descrpt : int + Dimension of the descriptor. + numb_xas : int + Number of XAS energy grid points. + neuron : list[int] + Hidden layer sizes of the fitting network. + resnet_dt : bool + Whether to use residual network with time step. + numb_fparam : int + Dimension of frame parameters (e.g. edge type encoding). + numb_aparam : int + Dimension of atomic parameters. + dim_case_embd : int + Dimension of case embedding. + bias_xas : Array or None + Initial bias for XAS output, shape (ntypes, numb_xas). + rcond : float or None + Cutoff for small singular values. + trainable : bool or list[bool] + Whether the fitting parameters are trainable. + activation_function : str + Activation function for hidden layers. + precision : str + Precision for the fitting parameters. + mixed_types : bool + Whether to use a shared network for all atom types. + exclude_types : list[int] + Atom types to exclude from fitting (set automatically by XASAtomicModel). + type_map : list[str] or None + Mapping from type index to element symbol. + seed : int, list[int], or None + Random seed. + default_fparam : list or None + Default frame parameter values. + """ + + def __init__( + self, + ntypes: int, + dim_descrpt: int, + numb_xas: int = 500, + neuron: list[int] = [120, 120, 120], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + bias_xas: Array | None = None, + rcond: float | None = None, + trainable: bool | list[bool] = True, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = False, + exclude_types: list[int] = [], + type_map: list[str] | None = None, + seed: int | list[int] | None = None, + default_fparam: list | None = None, + ) -> None: + if bias_xas is not None: + self.bias_xas = bias_xas + else: + self.bias_xas = np.zeros((ntypes, numb_xas), dtype=DEFAULT_PRECISION) + super().__init__( + var_name="xas", + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=numb_xas, + neuron=neuron, + resnet_dt=resnet_dt, + bias_atom=bias_xas, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + rcond=rcond, + trainable=trainable, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + exclude_types=exclude_types, + type_map=type_map, + seed=seed, + default_fparam=default_fparam, + ) + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [self.dim_out], + reducible=True, + intensive=True, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 4, 1) + data["numb_xas"] = data.pop("dim_out") + data.pop("tot_ener_zero", None) + data.pop("var_name", None) + data.pop("layer_name", None) + data.pop("use_aparam_as_mask", None) + data.pop("spin", None) + data.pop("atom_ener", None) + return super().deserialize(data) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + dd = { + **super().serialize(), + "type": "xas", + } + dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) + return dd diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 8f96e965b0..08a08d225b 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -28,6 +28,9 @@ from deepmd.dpmodel.model.dos_model import ( DOSModel, ) +from deepmd.dpmodel.model.xas_model import ( + XASModel, +) from deepmd.dpmodel.model.dp_zbl_model import ( DPZBLModel, ) @@ -97,6 +100,8 @@ def get_standard_model(data: dict) -> EnergyModel: modelcls = PolarModel elif fitting_net_type == "dos": modelcls = DOSModel + elif fitting_net_type == "xas": + modelcls = XASModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": @@ -104,13 +109,16 @@ def get_standard_model(data: dict) -> EnergyModel: else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") - model = modelcls( - descriptor=descriptor, - fitting=fitting, - type_map=data["type_map"], - atom_exclude_types=atom_exclude_types, - pair_exclude_types=pair_exclude_types, - ) + model_kwargs: dict = { + "descriptor": descriptor, + "fitting": fitting, + "type_map": data["type_map"], + "atom_exclude_types": atom_exclude_types, + "pair_exclude_types": pair_exclude_types, + } + if fitting_net_type == "xas": + model_kwargs["absorbing_type"] = data["absorbing_type"] + model = modelcls(**model_kwargs) return model diff --git a/deepmd/dpmodel/model/xas_model.py b/deepmd/dpmodel/model/xas_model.py new file mode 100644 index 0000000000..7700488049 --- /dev/null +++ b/deepmd/dpmodel/model/xas_model.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.dpmodel.atomic_model import ( + DPXASAtomicModel, +) +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPXASModel_ = make_model(DPXASAtomicModel, T_Bases=(NativeOP, BaseModel)) + + +@BaseModel.register("xas") +class XASModel(DPModelCommon, DPXASModel_): + model_type = "xas" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + DPXASModel_.__init__(self, *args, **kwargs) + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_xas"] = model_ret["xas"] + model_predict["xas"] = model_ret["xas_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_xas"] = model_ret["xas"] + model_predict["xas"] = model_ret["xas_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_xas": out_def_data["xas"], + "xas": out_def_data["xas_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..4d0058d83b 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -5,6 +5,9 @@ from .dos import ( DOSLoss, ) +from .xas import ( + XASLoss, +) from .ener import ( EnergyHessianStdLoss, EnergyStdLoss, @@ -24,6 +27,7 @@ __all__ = [ "DOSLoss", + "XASLoss", "DenoiseLoss", "EnergyHessianStdLoss", "EnergySpinLoss", diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py new file mode 100644 index 0000000000..6fc7a5c91a --- /dev/null +++ b/deepmd/pt/loss/xas.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +class XASLoss(TaskLoss): + """Loss for XAS spectrum fitting. + + Computes L2 loss on the reduced XAS spectrum (mean over absorbing atoms). + An optional CDF (cumulative) loss can be added to improve spectral shape. + + The labels expected in the dataset are: + + * ``xas.npy`` : shape ``[nframes, numb_xas]`` — the mean XAS spectrum over + absorbing atoms, in a relative energy (ΔE) grid. + + Parameters + ---------- + starter_learning_rate : float + Initial learning rate, used for prefactor scheduling. + numb_xas : int + Number of XAS energy grid points. + start_pref_xas : float + Starting prefactor for the XAS L2 loss. + limit_pref_xas : float + Limiting prefactor for the XAS L2 loss. + start_pref_cdf : float + Starting prefactor for the CDF L2 loss. + limit_pref_cdf : float + Limiting prefactor for the CDF L2 loss. + inference : bool + If True, output all losses regardless of prefactors. + """ + + def __init__( + self, + starter_learning_rate: float, + numb_xas: int, + start_pref_xas: float = 1.0, + limit_pref_xas: float = 1.0, + start_pref_cdf: float = 0.0, + limit_pref_cdf: float = 0.0, + inference: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + self.starter_learning_rate = starter_learning_rate + self.numb_xas = numb_xas + self.inference = inference + + self.start_pref_xas = start_pref_xas + self.limit_pref_xas = limit_pref_xas + self.start_pref_cdf = start_pref_cdf + self.limit_pref_cdf = limit_pref_cdf + + assert ( + self.start_pref_xas >= 0.0 + and self.limit_pref_xas >= 0.0 + and self.start_pref_cdf >= 0.0 + and self.limit_pref_cdf >= 0.0 + ), "Loss prefactors must be non-negative" + + self.has_xas = (start_pref_xas != 0.0 and limit_pref_xas != 0.0) or inference + self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference + + assert self.has_xas or self.has_cdf, ( + "At least one of start_pref_xas or start_pref_cdf must be non-zero" + ) + + def forward( + self, + input_dict: dict[str, torch.Tensor], + model: torch.nn.Module, + label: dict[str, torch.Tensor], + natoms: int, + learning_rate: float = 0.0, + mae: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]: + """Compute XAS loss. + + Parameters + ---------- + input_dict : dict + Model inputs. + model : torch.nn.Module + The model to evaluate. + label : dict + Label dict containing ``"xas"`` key. + natoms : int + Number of local atoms. + learning_rate : float + Current learning rate for prefactor scheduling. + mae : bool + Unused (kept for API compatibility). + + Returns + ------- + model_pred : dict + loss : torch.Tensor + more_loss : dict + """ + model_pred = model(**input_dict) + + coef = learning_rate / self.starter_learning_rate + pref_xas = ( + self.limit_pref_xas + (self.start_pref_xas - self.limit_pref_xas) * coef + ) + pref_cdf = ( + self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef + ) + + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] + more_loss: dict[str, torch.Tensor] = {} + + if self.has_xas and "xas" in model_pred and "xas" in label: + find_xas = label.get("find_xas", 0.0) + pref_xas = pref_xas * find_xas + pred = model_pred["xas"].reshape([-1, self.numb_xas]) + ref = label["xas"].reshape([-1, self.numb_xas]) + diff = pred - ref + l2_loss = torch.mean(torch.square(diff)) + if not self.inference: + more_loss["l2_xas_loss"] = self.display_if_exist( + l2_loss.detach(), find_xas + ) + loss += pref_xas * l2_loss + more_loss["rmse_xas"] = self.display_if_exist( + l2_loss.sqrt().detach(), find_xas + ) + + if self.has_cdf and "xas" in model_pred and "xas" in label: + find_xas = label.get("find_xas", 0.0) + pref_cdf = pref_cdf * find_xas + pred_cdf = torch.cumsum( + model_pred["xas"].reshape([-1, self.numb_xas]), dim=-1 + ) + ref_cdf = torch.cumsum(label["xas"].reshape([-1, self.numb_xas]), dim=-1) + diff_cdf = pred_cdf - ref_cdf + l2_cdf_loss = torch.mean(torch.square(diff_cdf)) + if not self.inference: + more_loss["l2_cdf_loss"] = self.display_if_exist( + l2_cdf_loss.detach(), find_xas + ) + loss += pref_cdf * l2_cdf_loss + more_loss["rmse_cdf"] = self.display_if_exist( + l2_cdf_loss.sqrt().detach(), find_xas + ) + + return model_pred, loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Return data label requirements for XAS training.""" + return [ + DataRequirementItem( + "xas", + ndof=self.numb_xas, + atomic=False, + must=False, + high_prec=False, + ) + ] diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..fd3dc50efb 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -23,6 +23,9 @@ from .dos_atomic_model import ( DPDOSAtomicModel, ) +from .xas_atomic_model import ( + DPXASAtomicModel, +) from .dp_atomic_model import ( DPAtomicModel, ) @@ -47,6 +50,7 @@ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", + "DPXASAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", diff --git a/deepmd/pt/model/atomic_model/xas_atomic_model.py b/deepmd/pt/model/atomic_model/xas_atomic_model.py new file mode 100644 index 0000000000..1a8c38e0cc --- /dev/null +++ b/deepmd/pt/model/atomic_model/xas_atomic_model.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.pt.model.task.xas import ( + XASFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPXASAtomicModel(DPAtomicModel): + """PyTorch atomic model for XAS spectrum fitting. + + Automatically excludes all non-absorbing atom types so that + the intensive mean reduction computes the mean XAS over absorbing + atoms only. + + Parameters + ---------- + descriptor : Any + fitting : Any + Must be an instance of XASFittingNet. + type_map : Any + Mapping from type index to element symbol. + absorbing_type : str + Element symbol of the absorbing atom type (e.g. "Fe"). + **kwargs + Passed to DPAtomicModel. + """ + + def __init__( + self, + descriptor: Any, + fitting: Any, + type_map: Any, + absorbing_type: str, + **kwargs: Any, + ) -> None: + if not isinstance(fitting, XASFittingNet): + raise TypeError( + "fitting must be an instance of XASFittingNet for DPXASAtomicModel" + ) + if absorbing_type not in type_map: + raise ValueError( + f"absorbing_type '{absorbing_type}' not found in type_map {type_map}" + ) + self.absorbing_type = absorbing_type + absorbing_idx = type_map.index(absorbing_type) + atom_exclude_types = [i for i in range(len(type_map)) if i != absorbing_idx] + kwargs["atom_exclude_types"] = atom_exclude_types + super().__init__(descriptor, fitting, type_map, **kwargs) + + def get_intensive(self) -> bool: + """XAS is an intensive property (mean over absorbing atoms).""" + return True diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..0322bf3089 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -39,6 +39,9 @@ from .dos_model import ( DOSModel, ) +from .xas_model import ( + XASModel, +) from .dp_linear_model import ( LinearEnergyModel, ) @@ -266,6 +269,8 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = PolarModel elif fitting_net_type == "dos": modelcls = DOSModel + elif fitting_net_type == "xas": + modelcls = XASModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": @@ -273,15 +278,18 @@ def get_standard_model(model_params: dict) -> BaseModel: else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") - model = modelcls( - descriptor=descriptor, - fitting=fitting, - type_map=model_params["type_map"], - atom_exclude_types=atom_exclude_types, - pair_exclude_types=pair_exclude_types, - preset_out_bias=preset_out_bias, - data_stat_protect=data_stat_protect, - ) + model_kwargs: dict[str, Any] = { + "descriptor": descriptor, + "fitting": fitting, + "type_map": model_params["type_map"], + "atom_exclude_types": atom_exclude_types, + "pair_exclude_types": pair_exclude_types, + "preset_out_bias": preset_out_bias, + "data_stat_protect": data_stat_protect, + } + if fitting_net_type == "xas": + model_kwargs["absorbing_type"] = model_params["absorbing_type"] + model = modelcls(**model_kwargs) if model_params.get("hessian_mode"): model.enable_hessian() model.model_def_script = json.dumps(model_params_old) @@ -306,6 +314,7 @@ def get_model(model_params: dict) -> Any: __all__ = [ "BaseModel", "DOSModel", + "XASModel", "DPModelCommon", "DPZBLModel", "DipoleModel", diff --git a/deepmd/pt/model/model/xas_model.py b/deepmd/pt/model/model/xas_model.py new file mode 100644 index 0000000000..45b568d09e --- /dev/null +++ b/deepmd/pt/model/model/xas_model.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + DPXASAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPXASModel_ = make_model(DPXASAtomicModel) + + +@BaseModel.register("xas") +class XASModel(DPModelCommon, DPXASModel_): + model_type = "xas" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + DPXASModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_xas": out_def_data["xas"], + "xas": out_def_data["xas_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_xas"] = model_ret["xas"] + model_predict["xas"] = model_ret["xas_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + return model_predict + + @torch.jit.export + def get_numb_xas(self) -> int: + """Get the number of XAS grid points.""" + return self.get_fitting_net().dim_out + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_xas"] = model_ret["xas"] + model_predict["xas"] = model_ret["xas_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..c10c2fcf1b 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -11,6 +11,9 @@ from .dos import ( DOSFittingNet, ) +from .xas import ( + XASFittingNet, +) from .ener import ( EnergyFittingNet, EnergyFittingNetDirect, @@ -31,6 +34,7 @@ __all__ = [ "BaseFitting", "DOSFittingNet", + "XASFittingNet", "DenoiseNet", "DipoleFittingNet", "EnergyFittingNet", diff --git a/deepmd/pt/model/task/xas.py b/deepmd/pt/model/task/xas.py new file mode 100644 index 0000000000..c8e9d752ea --- /dev/null +++ b/deepmd/pt/model/task/xas.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.task.ener import ( + InvarFitting, +) +from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) + + +@Fitting.register("xas") +class XASFittingNet(InvarFitting): + """PyTorch fitting network for XAS spectra. + + Parameters + ---------- + ntypes : int + Number of atom types. + dim_descrpt : int + Dimension of the descriptor. + numb_xas : int + Number of XAS energy grid points. + neuron : list[int] + Hidden layer sizes. + resnet_dt : bool + Whether to use ResNet time step. + numb_fparam : int + Dimension of frame parameters (e.g. edge type encoding). + numb_aparam : int + Dimension of atomic parameters. + dim_case_embd : int + Dimension of case embedding. + rcond : float or None + Cutoff for small singular values in bias init. + bias_xas : torch.Tensor or None + Initial bias, shape (ntypes, numb_xas). + trainable : bool or list[bool] + Whether parameters are trainable. + seed : int, list[int], or None + Random seed. + activation_function : str + Activation function. + precision : str + Float precision. + exclude_types : list[int] + Atom types to exclude (set by XASAtomicModel automatically). + mixed_types : bool + Whether to use a shared network across types. + type_map : list[str] or None + Mapping from type index to element symbol. + default_fparam : list or None + Default frame parameter values. + """ + + def __init__( + self, + ntypes: int, + dim_descrpt: int, + numb_xas: int = 500, + neuron: list[int] = [128, 128, 128], + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + rcond: float | None = None, + bias_xas: torch.Tensor | None = None, + trainable: bool | list[bool] = True, + seed: int | list[int] | None = None, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + exclude_types: list[int] = [], + mixed_types: bool = False, + type_map: list[str] | None = None, + default_fparam: list | None = None, + ) -> None: + if bias_xas is not None: + self.bias_xas = bias_xas + else: + self.bias_xas = torch.zeros( + (ntypes, numb_xas), dtype=dtype, device=env.DEVICE + ) + super().__init__( + var_name="xas", + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out=numb_xas, + neuron=neuron, + bias_atom_e=bias_xas, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + type_map=type_map, + default_fparam=default_fparam, + ) + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [self.dim_out], + reducible=True, + intensive=True, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @classmethod + def deserialize(cls, data: dict) -> "XASFittingNet": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 4, 1) + data.pop("@class", None) + data.pop("var_name", None) + data.pop("tot_ener_zero", None) + data.pop("layer_name", None) + data.pop("use_aparam_as_mask", None) + data.pop("spin", None) + data.pop("atom_ener", None) + data["numb_xas"] = data.pop("dim_out") + return super().deserialize(data) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + dd = { + **InvarFitting.serialize(self), + "type": "xas", + "dim_out": self.dim_out, + } + dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) + return dd + + # make jit happy with torch 2.0.0 + exclude_types: list[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8d16e1c7ea..ff573c1932 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -43,6 +43,7 @@ PropertyLoss, TaskLoss, TensorLoss, + XASLoss, ) from deepmd.pt.model.model import ( get_model, @@ -94,9 +95,12 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) -from torch.distributed.fsdp import ( - fully_shard, -) +try: + from torch.distributed.fsdp import ( + fully_shard, + ) +except ImportError: + fully_shard = None from torch.distributed.optim import ( ZeroRedundancyOptimizer, ) @@ -1727,6 +1731,10 @@ def get_loss( loss_params["starter_learning_rate"] = start_lr loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size return DOSLoss(**loss_params) + elif loss_type == "xas": + loss_params["starter_learning_rate"] = start_lr + loss_params["numb_xas"] = _model.model_output_def()["xas"].output_size + return XASLoss(**loss_params) elif loss_type == "ener_spin": loss_params["starter_learning_rate"] = start_lr return EnergySpinLoss(**loss_params) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..770f551b5f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1912,6 +1912,65 @@ def fitting_dos() -> list[Argument]: ] +@fitting_args_plugin.register("xas", doc=doc_only_pt_supported) +def fitting_xas() -> list[Argument]: + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams. Can be used to encode edge type information." + doc_numb_aparam = "The dimension of the atomic parameter." + doc_default_fparam = "The default frame parameter value." + doc_dim_case_embd = "The dimension of the case embedding." + doc_neuron = "The number of neurons in each hidden layer of the fitting net." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' + doc_trainable = "Whether the parameters in the fitting net are trainable." + doc_rcond = "The condition number for the initial bias fitting." + doc_seed = "Random seed for parameter initialization." + doc_numb_xas = "The number of grid points on the XAS energy axis (ΔE space)." + + return [ + Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), + Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_dim_case_embd, + ), + Argument( + "neuron", list[int], optional=True, default=[120, 120, 120], doc=doc_neuron + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="float64", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_trainable, + ), + Argument( + "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument("numb_xas", int, optional=True, default=500, doc=doc_numb_xas), + ] + + @fitting_args_plugin.register("property", doc=doc_only_pt_supported) def fitting_property() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -2231,6 +2290,7 @@ def model_args(exclude_hybrid: bool = False) -> list[Argument]: doc_compress_config = "Model compression configurations" doc_spin = "The settings for systems with spin." doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" + doc_absorbing_type = "The element symbol of the absorbing atom type for XAS fitting (e.g. 'Fe'). Only used when fitting_net type is 'xas'." doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." doc_preset_out_bias = "The preset bias of the atomic output. Note that the set_davg_zero should be set to true. The bias is provided as a dict. Taking the energy model that has three atom types for example, the `preset_out_bias` may be given as `{ 'energy': [null, 0., 1.] }`. In this case the energy bias of type 1 and 2 are set to 0. and 1., respectively. A dipole model with two atom types may set `preset_out_bias` as `{ 'dipole': [null, [0., 1., 2.]] }`" doc_finetune_head = ( @@ -2295,6 +2355,13 @@ def model_args(exclude_hybrid: bool = False) -> list[Argument]: default=[], doc=doc_only_pt_supported + doc_atom_exclude_types, ), + Argument( + "absorbing_type", + str, + optional=True, + default=None, + doc=doc_only_pt_supported + doc_absorbing_type, + ), Argument( "preset_out_bias", dict[str, list[float | list[float] | None]], @@ -3465,6 +3532,44 @@ def loss_dos() -> list[Argument]: ] +@loss_args_plugin.register("xas", doc=doc_only_pt_supported) +def loss_xas() -> list[Argument]: + doc_start_pref_xas = start_pref("XAS spectrum (mean over absorbing atoms)") + doc_limit_pref_xas = limit_pref("XAS spectrum (mean over absorbing atoms)") + doc_start_pref_cdf = start_pref("Cumulative Distribution Function of XAS") + doc_limit_pref_cdf = limit_pref("Cumulative Distribution Function of XAS") + return [ + Argument( + "start_pref_xas", + [float, int], + optional=True, + default=1.0, + doc=doc_start_pref_xas, + ), + Argument( + "limit_pref_xas", + [float, int], + optional=True, + default=1.0, + doc=doc_limit_pref_xas, + ), + Argument( + "start_pref_cdf", + [float, int], + optional=True, + default=0.0, + doc=doc_start_pref_cdf, + ), + Argument( + "limit_pref_cdf", + [float, int], + optional=True, + default=0.0, + doc=doc_limit_pref_cdf, + ), + ] + + @loss_args_plugin.register("property") def loss_property() -> list[Argument]: doc_loss_func = "The loss function to minimize, such as 'mae','smooth_mae'." From 9e9c6a32a3a8d28780ccb36ad15410c17553c1de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 05:30:34 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/atomic_model/__init__.py | 8 ++++---- deepmd/dpmodel/fitting/__init__.py | 8 ++++---- deepmd/dpmodel/model/model.py | 6 +++--- deepmd/pt/loss/__init__.py | 8 ++++---- deepmd/pt/model/atomic_model/__init__.py | 8 ++++---- deepmd/pt/model/model/__init__.py | 8 ++++---- deepmd/pt/model/task/__init__.py | 8 ++++---- deepmd/pt/train/training.py | 1 + deepmd/utils/argcheck.py | 2 +- 9 files changed, 29 insertions(+), 28 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index 3c10834bf2..ac349bce0a 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -23,9 +23,6 @@ from .dos_atomic_model import ( DPDOSAtomicModel, ) -from .xas_atomic_model import ( - DPXASAtomicModel, -) from .dp_atomic_model import ( DPAtomicModel, ) @@ -48,16 +45,19 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) +from .xas_atomic_model import ( + DPXASAtomicModel, +) __all__ = [ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", - "DPXASAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", + "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 9e7c7171af..336650184e 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -5,9 +5,6 @@ from .dos_fitting import ( DOSFittingNet, ) -from .xas_fitting import ( - XASFittingNet, -) from .ener_fitting import ( EnergyFittingNet, ) @@ -23,14 +20,17 @@ from .property_fitting import ( PropertyFittingNet, ) +from .xas_fitting import ( + XASFittingNet, +) __all__ = [ "DOSFittingNet", - "XASFittingNet", "DipoleFitting", "EnergyFittingNet", "InvarFitting", "PolarFitting", "PropertyFittingNet", + "XASFittingNet", "make_base_fitting", ] diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 08a08d225b..0211a0f1ba 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -28,9 +28,6 @@ from deepmd.dpmodel.model.dos_model import ( DOSModel, ) -from deepmd.dpmodel.model.xas_model import ( - XASModel, -) from deepmd.dpmodel.model.dp_zbl_model import ( DPZBLModel, ) @@ -46,6 +43,9 @@ from deepmd.dpmodel.model.spin_model import ( SpinModel, ) +from deepmd.dpmodel.model.xas_model import ( + XASModel, +) from deepmd.utils.spin import ( Spin, ) diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 4d0058d83b..17b2cd37c3 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -5,9 +5,6 @@ from .dos import ( DOSLoss, ) -from .xas import ( - XASLoss, -) from .ener import ( EnergyHessianStdLoss, EnergyStdLoss, @@ -24,10 +21,12 @@ from .tensor import ( TensorLoss, ) +from .xas import ( + XASLoss, +) __all__ = [ "DOSLoss", - "XASLoss", "DenoiseLoss", "EnergyHessianStdLoss", "EnergySpinLoss", @@ -35,4 +34,5 @@ "PropertyLoss", "TaskLoss", "TensorLoss", + "XASLoss", ] diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index fd3dc50efb..1270d5f720 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -23,9 +23,6 @@ from .dos_atomic_model import ( DPDOSAtomicModel, ) -from .xas_atomic_model import ( - DPXASAtomicModel, -) from .dp_atomic_model import ( DPAtomicModel, ) @@ -45,16 +42,19 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) +from .xas_atomic_model import ( + DPXASAtomicModel, +) __all__ = [ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", - "DPXASAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", + "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 0322bf3089..9bb2c76701 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -39,9 +39,6 @@ from .dos_model import ( DOSModel, ) -from .xas_model import ( - XASModel, -) from .dp_linear_model import ( LinearEnergyModel, ) @@ -76,6 +73,9 @@ SpinEnergyModel, SpinModel, ) +from .xas_model import ( + XASModel, +) def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: @@ -314,7 +314,6 @@ def get_model(model_params: dict) -> Any: __all__ = [ "BaseModel", "DOSModel", - "XASModel", "DPModelCommon", "DPZBLModel", "DipoleModel", @@ -324,6 +323,7 @@ def get_model(model_params: dict) -> Any: "PolarModel", "SpinEnergyModel", "SpinModel", + "XASModel", "get_model", "make_hessian_model", "make_model", diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index c10c2fcf1b..296935dc9e 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -11,9 +11,6 @@ from .dos import ( DOSFittingNet, ) -from .xas import ( - XASFittingNet, -) from .ener import ( EnergyFittingNet, EnergyFittingNetDirect, @@ -30,11 +27,13 @@ from .type_predict import ( TypePredictNet, ) +from .xas import ( + XASFittingNet, +) __all__ = [ "BaseFitting", "DOSFittingNet", - "XASFittingNet", "DenoiseNet", "DipoleFittingNet", "EnergyFittingNet", @@ -43,4 +42,5 @@ "PolarFittingNet", "PropertyFittingNet", "TypePredictNet", + "XASFittingNet", ] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index ff573c1932..d999b276f9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -95,6 +95,7 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) + try: from torch.distributed.fsdp import ( fully_shard, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 770f551b5f..25df224cf5 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1919,7 +1919,7 @@ def fitting_xas() -> list[Argument]: doc_default_fparam = "The default frame parameter value." doc_dim_case_embd = "The dimension of the case embedding." doc_neuron = "The number of neurons in each hidden layer of the fitting net." - doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}.' + doc_activation_function = f"The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' doc_trainable = "Whether the parameters in the fitting net are trainable." From 8fd99ad3ab47a92ad6e3e71b28f2f5bf557b1b84 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:31:02 +0800 Subject: [PATCH 3/9] feat: Reimplement XAS loss with per-atom property fitting, removing previous XAS model components and adding new tests. --- deepmd/dpmodel/atomic_model/__init__.py | 4 - .../dpmodel/atomic_model/xas_atomic_model.py | 81 ------- deepmd/dpmodel/fitting/__init__.py | 4 - deepmd/dpmodel/fitting/xas_fitting.py | 162 ------------- deepmd/dpmodel/model/model.py | 7 - deepmd/dpmodel/model/xas_model.py | 99 -------- deepmd/entrypoints/test.py | 29 ++- deepmd/pt/loss/xas.py | 214 +++++++----------- deepmd/pt/model/atomic_model/__init__.py | 4 - .../pt/model/atomic_model/xas_atomic_model.py | 59 ----- deepmd/pt/model/model/__init__.py | 9 - deepmd/pt/model/model/xas_model.py | 110 --------- deepmd/pt/model/task/__init__.py | 4 - deepmd/pt/model/task/xas.py | 167 -------------- deepmd/pt/train/training.py | 8 +- deepmd/utils/argcheck.py | 105 +-------- 16 files changed, 120 insertions(+), 946 deletions(-) delete mode 100644 deepmd/dpmodel/atomic_model/xas_atomic_model.py delete mode 100644 deepmd/dpmodel/fitting/xas_fitting.py delete mode 100644 deepmd/dpmodel/model/xas_model.py delete mode 100644 deepmd/pt/model/atomic_model/xas_atomic_model.py delete mode 100644 deepmd/pt/model/model/xas_model.py delete mode 100644 deepmd/pt/model/task/xas.py diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index ac349bce0a..4d882d5e4b 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -45,9 +45,6 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) -from .xas_atomic_model import ( - DPXASAtomicModel, -) __all__ = [ "BaseAtomicModel", @@ -57,7 +54,6 @@ "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", - "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/dpmodel/atomic_model/xas_atomic_model.py b/deepmd/dpmodel/atomic_model/xas_atomic_model.py deleted file mode 100644 index 63824ffdae..0000000000 --- a/deepmd/dpmodel/atomic_model/xas_atomic_model.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - -from deepmd.dpmodel.descriptor.base_descriptor import ( - BaseDescriptor, -) -from deepmd.dpmodel.fitting.base_fitting import ( - BaseFitting, -) -from deepmd.dpmodel.fitting.xas_fitting import ( - XASFittingNet, -) - -from .dp_atomic_model import ( - DPAtomicModel, -) - - -class DPXASAtomicModel(DPAtomicModel): - """Atomic model for XAS spectrum fitting. - - Automatically sets ``atom_exclude_types`` to all non-absorbing atom types - so that the intensive mean reduction in ``fit_output_to_model_output`` - computes the mean XAS over absorbing atoms only. - - Parameters - ---------- - descriptor : BaseDescriptor - fitting : BaseFitting - Must be an instance of XASFittingNet. - type_map : list[str] - Mapping from type index to element symbol. - absorbing_type : str - Element symbol of the absorbing atom type (e.g. "Fe"). - **kwargs - Passed to DPAtomicModel. - """ - - def __init__( - self, - descriptor: BaseDescriptor, - fitting: BaseFitting, - type_map: list[str], - absorbing_type: str, - **kwargs: Any, - ) -> None: - if not isinstance(fitting, XASFittingNet): - raise TypeError( - "fitting must be an instance of XASFittingNet for DPXASAtomicModel" - ) - if absorbing_type not in type_map: - raise ValueError( - f"absorbing_type '{absorbing_type}' not found in type_map {type_map}" - ) - self.absorbing_type = absorbing_type - absorbing_idx = type_map.index(absorbing_type) - # Exclude all types except the absorbing type so the intensive mean - # reduction is computed only over absorbing atoms. - atom_exclude_types = [i for i in range(len(type_map)) if i != absorbing_idx] - kwargs["atom_exclude_types"] = atom_exclude_types - super().__init__(descriptor, fitting, type_map, **kwargs) - - def get_intensive(self) -> bool: - """XAS is an intensive property (mean over absorbing atoms).""" - return True - - def serialize(self) -> dict: - dd = super().serialize() - dd["absorbing_type"] = self.absorbing_type - return dd - - @classmethod - def deserialize(cls, data: dict) -> "DPXASAtomicModel": - data = data.copy() - absorbing_type = data.pop("absorbing_type") - # atom_exclude_types is already stored by base; rebuild absorbing_type param - obj = super().deserialize(data) - obj.absorbing_type = absorbing_type - return obj diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 336650184e..5bdfff2571 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -20,9 +20,6 @@ from .property_fitting import ( PropertyFittingNet, ) -from .xas_fitting import ( - XASFittingNet, -) __all__ = [ "DOSFittingNet", @@ -31,6 +28,5 @@ "InvarFitting", "PolarFitting", "PropertyFittingNet", - "XASFittingNet", "make_base_fitting", ] diff --git a/deepmd/dpmodel/fitting/xas_fitting.py b/deepmd/dpmodel/fitting/xas_fitting.py deleted file mode 100644 index 514046b8bc..0000000000 --- a/deepmd/dpmodel/fitting/xas_fitting.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - TYPE_CHECKING, -) - -import numpy as np - -from deepmd.dpmodel.array_api import ( - Array, -) -from deepmd.dpmodel.common import ( - DEFAULT_PRECISION, - to_numpy_array, -) -from deepmd.dpmodel.fitting.invar_fitting import ( - InvarFitting, -) -from deepmd.dpmodel.output_def import ( - FittingOutputDef, - OutputVariableDef, -) - -if TYPE_CHECKING: - from deepmd.dpmodel.fitting.general_fitting import ( - GeneralFitting, - ) - -from deepmd.utils.version import ( - check_version_compatibility, -) - - -@InvarFitting.register("xas") -class XASFittingNet(InvarFitting): - """Fitting network for X-ray Absorption Spectroscopy (XAS) spectra. - - Predicts per-atom XAS contributions in a relative energy (ΔE) space. - The global XAS is the mean over all absorbing atoms, handled by the - XAS model via ``intensive=True`` and type-selective masking. - - Parameters - ---------- - ntypes : int - Number of atom types. - dim_descrpt : int - Dimension of the descriptor. - numb_xas : int - Number of XAS energy grid points. - neuron : list[int] - Hidden layer sizes of the fitting network. - resnet_dt : bool - Whether to use residual network with time step. - numb_fparam : int - Dimension of frame parameters (e.g. edge type encoding). - numb_aparam : int - Dimension of atomic parameters. - dim_case_embd : int - Dimension of case embedding. - bias_xas : Array or None - Initial bias for XAS output, shape (ntypes, numb_xas). - rcond : float or None - Cutoff for small singular values. - trainable : bool or list[bool] - Whether the fitting parameters are trainable. - activation_function : str - Activation function for hidden layers. - precision : str - Precision for the fitting parameters. - mixed_types : bool - Whether to use a shared network for all atom types. - exclude_types : list[int] - Atom types to exclude from fitting (set automatically by XASAtomicModel). - type_map : list[str] or None - Mapping from type index to element symbol. - seed : int, list[int], or None - Random seed. - default_fparam : list or None - Default frame parameter values. - """ - - def __init__( - self, - ntypes: int, - dim_descrpt: int, - numb_xas: int = 500, - neuron: list[int] = [120, 120, 120], - resnet_dt: bool = True, - numb_fparam: int = 0, - numb_aparam: int = 0, - dim_case_embd: int = 0, - bias_xas: Array | None = None, - rcond: float | None = None, - trainable: bool | list[bool] = True, - activation_function: str = "tanh", - precision: str = DEFAULT_PRECISION, - mixed_types: bool = False, - exclude_types: list[int] = [], - type_map: list[str] | None = None, - seed: int | list[int] | None = None, - default_fparam: list | None = None, - ) -> None: - if bias_xas is not None: - self.bias_xas = bias_xas - else: - self.bias_xas = np.zeros((ntypes, numb_xas), dtype=DEFAULT_PRECISION) - super().__init__( - var_name="xas", - ntypes=ntypes, - dim_descrpt=dim_descrpt, - dim_out=numb_xas, - neuron=neuron, - resnet_dt=resnet_dt, - bias_atom=bias_xas, - numb_fparam=numb_fparam, - numb_aparam=numb_aparam, - dim_case_embd=dim_case_embd, - rcond=rcond, - trainable=trainable, - activation_function=activation_function, - precision=precision, - mixed_types=mixed_types, - exclude_types=exclude_types, - type_map=type_map, - seed=seed, - default_fparam=default_fparam, - ) - - def output_def(self) -> FittingOutputDef: - return FittingOutputDef( - [ - OutputVariableDef( - self.var_name, - [self.dim_out], - reducible=True, - intensive=True, - r_differentiable=False, - c_differentiable=False, - ), - ] - ) - - @classmethod - def deserialize(cls, data: dict) -> "GeneralFitting": - data = data.copy() - check_version_compatibility(data.pop("@version", 1), 4, 1) - data["numb_xas"] = data.pop("dim_out") - data.pop("tot_ener_zero", None) - data.pop("var_name", None) - data.pop("layer_name", None) - data.pop("use_aparam_as_mask", None) - data.pop("spin", None) - data.pop("atom_ener", None) - return super().deserialize(data) - - def serialize(self) -> dict: - """Serialize the fitting to dict.""" - dd = { - **super().serialize(), - "type": "xas", - } - dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) - return dd diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 0211a0f1ba..220d1f4464 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -43,9 +43,6 @@ from deepmd.dpmodel.model.spin_model import ( SpinModel, ) -from deepmd.dpmodel.model.xas_model import ( - XASModel, -) from deepmd.utils.spin import ( Spin, ) @@ -100,8 +97,6 @@ def get_standard_model(data: dict) -> EnergyModel: modelcls = PolarModel elif fitting_net_type == "dos": modelcls = DOSModel - elif fitting_net_type == "xas": - modelcls = XASModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": @@ -116,8 +111,6 @@ def get_standard_model(data: dict) -> EnergyModel: "atom_exclude_types": atom_exclude_types, "pair_exclude_types": pair_exclude_types, } - if fitting_net_type == "xas": - model_kwargs["absorbing_type"] = data["absorbing_type"] model = modelcls(**model_kwargs) return model diff --git a/deepmd/dpmodel/model/xas_model.py b/deepmd/dpmodel/model/xas_model.py deleted file mode 100644 index 7700488049..0000000000 --- a/deepmd/dpmodel/model/xas_model.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - -from deepmd.dpmodel.array_api import ( - Array, -) -from deepmd.dpmodel.atomic_model import ( - DPXASAtomicModel, -) -from deepmd.dpmodel.common import ( - NativeOP, -) -from deepmd.dpmodel.model.base_model import ( - BaseModel, -) - -from .dp_model import ( - DPModelCommon, -) -from .make_model import ( - make_model, -) - -DPXASModel_ = make_model(DPXASAtomicModel, T_Bases=(NativeOP, BaseModel)) - - -@BaseModel.register("xas") -class XASModel(DPModelCommon, DPXASModel_): - model_type = "xas" - - def __init__( - self, - *args: Any, - **kwargs: Any, - ) -> None: - DPModelCommon.__init__(self) - DPXASModel_.__init__(self, *args, **kwargs) - - def call( - self, - coord: Array, - atype: Array, - box: Array | None = None, - fparam: Array | None = None, - aparam: Array | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, Array]: - model_ret = self.call_common( - coord, - atype, - box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - model_predict = {} - model_predict["atom_xas"] = model_ret["xas"] - model_predict["xas"] = model_ret["xas_redu"] - if "mask" in model_ret: - model_predict["mask"] = model_ret["mask"] - return model_predict - - def call_lower( - self, - extended_coord: Array, - extended_atype: Array, - nlist: Array, - mapping: Array | None = None, - fparam: Array | None = None, - aparam: Array | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, Array]: - model_ret = self.call_common_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - model_predict = {} - model_predict["atom_xas"] = model_ret["xas"] - model_predict["xas"] = model_ret["xas_redu"] - if "mask" in model_ret: - model_predict["mask"] = model_ret["mask"] - return model_predict - - def translated_output_def(self) -> dict[str, Any]: - out_def_data = self.model_output_def().get_data() - output_def = { - "atom_xas": out_def_data["xas"], - "xas": out_def_data["xas_redu"], - } - if "mask" in out_def_data: - output_def["mask"] = out_def_data["mask"] - return output_def diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 4a0cb27cb1..781123cb79 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -894,6 +894,9 @@ def test_property( if dp.get_dim_aparam() > 0: data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False) + # sel_type: optional per-frame type index for element-wise mean reduction (XAS) + data.add("sel_type", 1, atomic=False, must=False, high_prec=False, default=float(-1)) + test_data = data.get_test() mixed_type = data.mixed_type natoms = len(test_data["type"][0]) @@ -918,17 +921,39 @@ def test_property( else: aparam = None + # detect whether this system provides sel_type (XAS-style reduction) + sel_type_raw = test_data["sel_type"][:numb_test, 0] # [numb_test] + has_sel_type = bool((sel_type_raw >= 0).all()) + + # for sel_type reduction we need per-atom outputs + eval_atomic = has_atom_property or has_sel_type ret = dp.eval( coord, box, atype, fparam=fparam, aparam=aparam, - atomic=has_atom_property, + atomic=eval_atomic, mixed_type=mixed_type, ) - property = ret[0] + if has_sel_type: + # ret[1]: per-atom property [numb_test, natoms, task_dim] + atom_prop = ret[1].reshape([numb_test, natoms, dp.task_dim]) + # atype for all frames + if mixed_type: + atype_frames = atype # [numb_test, natoms] + else: + atype_frames = np.tile(atype, (numb_test, 1)) # [numb_test, natoms] + sel_type_int = sel_type_raw.astype(int) + property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype) + for i in range(numb_test): + t = sel_type_int[i] + mask = (atype_frames[i] == t) # [natoms] + count = max(mask.sum(), 1) + property[i] = atom_prop[i][mask].sum(axis=0) / count + else: + property = ret[0] property = property.reshape([numb_test, dp.task_dim]) diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py index 6fc7a5c91a..336425a6e0 100644 --- a/deepmd/pt/loss/xas.py +++ b/deepmd/pt/loss/xas.py @@ -1,84 +1,54 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) +import logging +from typing import Any import torch +import torch.nn.functional as F -from deepmd.pt.loss.loss import ( - TaskLoss, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.utils.data import ( - DataRequirementItem, -) +from deepmd.pt.loss.loss import TaskLoss +from deepmd.pt.utils import env +from deepmd.utils.data import DataRequirementItem +log = logging.getLogger(__name__) -class XASLoss(TaskLoss): - """Loss for XAS spectrum fitting. - - Computes L2 loss on the reduced XAS spectrum (mean over absorbing atoms). - An optional CDF (cumulative) loss can be added to improve spectral shape. - The labels expected in the dataset are: +class XASLoss(TaskLoss): + """Loss for XAS spectrum fitting via property fitting + sel_type reduction. - * ``xas.npy`` : shape ``[nframes, numb_xas]`` — the mean XAS spectrum over - absorbing atoms, in a relative energy (ΔE) grid. + The model outputs per-atom property vectors (atom_xas). For each frame + this loss selects the atoms of type ``sel_type`` (read from ``sel_type.npy`` + in each training system) and takes their mean, then computes a loss against + the per-frame XAS label. Parameters ---------- - starter_learning_rate : float - Initial learning rate, used for prefactor scheduling. - numb_xas : int - Number of XAS energy grid points. - start_pref_xas : float - Starting prefactor for the XAS L2 loss. - limit_pref_xas : float - Limiting prefactor for the XAS L2 loss. - start_pref_cdf : float - Starting prefactor for the CDF L2 loss. - limit_pref_cdf : float - Limiting prefactor for the CDF L2 loss. - inference : bool - If True, output all losses regardless of prefactors. + task_dim : int + Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts). + var_name : str + Property name, must match ``property_name`` in the fitting config. + loss_func : str + One of ``smooth_mae``, ``mae``, ``mse``, ``rmse``. + metric : list[str] + Metrics to display during training. + beta : float + Beta parameter for smooth_l1 loss. """ def __init__( self, - starter_learning_rate: float, - numb_xas: int, - start_pref_xas: float = 1.0, - limit_pref_xas: float = 1.0, - start_pref_cdf: float = 0.0, - limit_pref_cdf: float = 0.0, - inference: bool = False, + task_dim: int, + var_name: str = "xas", + loss_func: str = "smooth_mae", + metric: list[str] = ["mae"], + beta: float = 1.0, **kwargs: Any, ) -> None: super().__init__() - self.starter_learning_rate = starter_learning_rate - self.numb_xas = numb_xas - self.inference = inference - - self.start_pref_xas = start_pref_xas - self.limit_pref_xas = limit_pref_xas - self.start_pref_cdf = start_pref_cdf - self.limit_pref_cdf = limit_pref_cdf - - assert ( - self.start_pref_xas >= 0.0 - and self.limit_pref_xas >= 0.0 - and self.start_pref_cdf >= 0.0 - and self.limit_pref_cdf >= 0.0 - ), "Loss prefactors must be non-negative" - - self.has_xas = (start_pref_xas != 0.0 and limit_pref_xas != 0.0) or inference - self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference - - assert self.has_xas or self.has_cdf, ( - "At least one of start_pref_xas or start_pref_cdf must be non-zero" - ) + self.task_dim = task_dim + self.var_name = var_name + self.loss_func = loss_func + self.metric = metric + self.beta = beta def forward( self, @@ -89,87 +59,67 @@ def forward( learning_rate: float = 0.0, mae: bool = False, ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]: - """Compute XAS loss. - - Parameters - ---------- - input_dict : dict - Model inputs. - model : torch.nn.Module - The model to evaluate. - label : dict - Label dict containing ``"xas"`` key. - natoms : int - Number of local atoms. - learning_rate : float - Current learning rate for prefactor scheduling. - mae : bool - Unused (kept for API compatibility). - - Returns - ------- - model_pred : dict - loss : torch.Tensor - more_loss : dict - """ model_pred = model(**input_dict) - coef = learning_rate / self.starter_learning_rate - pref_xas = ( - self.limit_pref_xas + (self.start_pref_xas - self.limit_pref_xas) * coef - ) - pref_cdf = ( - self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef + # per-atom outputs: [nf, nloc, task_dim] + atom_prop = model_pred[f"atom_{self.var_name}"] + atype = input_dict["atype"] # [nf, nloc] + + # sel_type from label: [nf, 1] float → [nf] int + sel_type = label["sel_type"][:, 0].long() + + # element-wise mean: for each frame average over atoms of sel_type + nf, nloc, td = atom_prop.shape + pred = torch.zeros( + nf, td, dtype=atom_prop.dtype, device=atom_prop.device ) + for i in range(nf): + t = int(sel_type[i].item()) + mask = (atype[i] == t).unsqueeze(-1) # [nloc, 1] + count = mask.sum().clamp(min=1) + pred[i] = (atom_prop[i] * mask).sum(dim=0) / count - loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] - more_loss: dict[str, torch.Tensor] = {} + label_xas = label[self.var_name] # [nf, task_dim] - if self.has_xas and "xas" in model_pred and "xas" in label: - find_xas = label.get("find_xas", 0.0) - pref_xas = pref_xas * find_xas - pred = model_pred["xas"].reshape([-1, self.numb_xas]) - ref = label["xas"].reshape([-1, self.numb_xas]) - diff = pred - ref - l2_loss = torch.mean(torch.square(diff)) - if not self.inference: - more_loss["l2_xas_loss"] = self.display_if_exist( - l2_loss.detach(), find_xas - ) - loss += pref_xas * l2_loss - more_loss["rmse_xas"] = self.display_if_exist( - l2_loss.sqrt().detach(), find_xas - ) - - if self.has_cdf and "xas" in model_pred and "xas" in label: - find_xas = label.get("find_xas", 0.0) - pref_cdf = pref_cdf * find_xas - pred_cdf = torch.cumsum( - model_pred["xas"].reshape([-1, self.numb_xas]), dim=-1 - ) - ref_cdf = torch.cumsum(label["xas"].reshape([-1, self.numb_xas]), dim=-1) - diff_cdf = pred_cdf - ref_cdf - l2_cdf_loss = torch.mean(torch.square(diff_cdf)) - if not self.inference: - more_loss["l2_cdf_loss"] = self.display_if_exist( - l2_cdf_loss.detach(), find_xas - ) - loss += pref_cdf * l2_cdf_loss - more_loss["rmse_cdf"] = self.display_if_exist( - l2_cdf_loss.sqrt().detach(), find_xas - ) + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] + if self.loss_func == "smooth_mae": + loss += F.smooth_l1_loss(pred, label_xas, reduction="sum", beta=self.beta) + elif self.loss_func == "mae": + loss += F.l1_loss(pred, label_xas, reduction="sum") + elif self.loss_func == "mse": + loss += F.mse_loss(pred, label_xas, reduction="sum") + elif self.loss_func == "rmse": + loss += torch.sqrt(F.mse_loss(pred, label_xas, reduction="mean")) + else: + raise RuntimeError(f"Unknown loss function: {self.loss_func}") + more_loss: dict[str, torch.Tensor] = {} + if "mae" in self.metric: + more_loss["mae"] = F.l1_loss(pred, label_xas, reduction="mean").detach() + if "rmse" in self.metric: + more_loss["rmse"] = torch.sqrt( + F.mse_loss(pred, label_xas, reduction="mean") + ).detach() + + model_pred[self.var_name] = pred return model_pred, loss, more_loss @property def label_requirement(self) -> list[DataRequirementItem]: - """Return data label requirements for XAS training.""" + """Declare required data files: xas label + sel_type.""" return [ DataRequirementItem( - "xas", - ndof=self.numb_xas, + self.var_name, + ndof=self.task_dim, + atomic=False, + must=True, + high_prec=True, + ), + DataRequirementItem( + "sel_type", + ndof=1, atomic=False, - must=False, + must=True, high_prec=False, - ) + ), ] diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 1270d5f720..4da9bf781b 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -42,9 +42,6 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) -from .xas_atomic_model import ( - DPXASAtomicModel, -) __all__ = [ "BaseAtomicModel", @@ -54,7 +51,6 @@ "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", - "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/pt/model/atomic_model/xas_atomic_model.py b/deepmd/pt/model/atomic_model/xas_atomic_model.py deleted file mode 100644 index 1a8c38e0cc..0000000000 --- a/deepmd/pt/model/atomic_model/xas_atomic_model.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - -from deepmd.pt.model.task.xas import ( - XASFittingNet, -) - -from .dp_atomic_model import ( - DPAtomicModel, -) - - -class DPXASAtomicModel(DPAtomicModel): - """PyTorch atomic model for XAS spectrum fitting. - - Automatically excludes all non-absorbing atom types so that - the intensive mean reduction computes the mean XAS over absorbing - atoms only. - - Parameters - ---------- - descriptor : Any - fitting : Any - Must be an instance of XASFittingNet. - type_map : Any - Mapping from type index to element symbol. - absorbing_type : str - Element symbol of the absorbing atom type (e.g. "Fe"). - **kwargs - Passed to DPAtomicModel. - """ - - def __init__( - self, - descriptor: Any, - fitting: Any, - type_map: Any, - absorbing_type: str, - **kwargs: Any, - ) -> None: - if not isinstance(fitting, XASFittingNet): - raise TypeError( - "fitting must be an instance of XASFittingNet for DPXASAtomicModel" - ) - if absorbing_type not in type_map: - raise ValueError( - f"absorbing_type '{absorbing_type}' not found in type_map {type_map}" - ) - self.absorbing_type = absorbing_type - absorbing_idx = type_map.index(absorbing_type) - atom_exclude_types = [i for i in range(len(type_map)) if i != absorbing_idx] - kwargs["atom_exclude_types"] = atom_exclude_types - super().__init__(descriptor, fitting, type_map, **kwargs) - - def get_intensive(self) -> bool: - """XAS is an intensive property (mean over absorbing atoms).""" - return True diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 9bb2c76701..06f411d007 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -73,10 +73,6 @@ SpinEnergyModel, SpinModel, ) -from .xas_model import ( - XASModel, -) - def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: if "type_embedding" in model_params: @@ -269,8 +265,6 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = PolarModel elif fitting_net_type == "dos": modelcls = DOSModel - elif fitting_net_type == "xas": - modelcls = XASModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": @@ -287,8 +281,6 @@ def get_standard_model(model_params: dict) -> BaseModel: "preset_out_bias": preset_out_bias, "data_stat_protect": data_stat_protect, } - if fitting_net_type == "xas": - model_kwargs["absorbing_type"] = model_params["absorbing_type"] model = modelcls(**model_kwargs) if model_params.get("hessian_mode"): model.enable_hessian() @@ -323,7 +315,6 @@ def get_model(model_params: dict) -> Any: "PolarModel", "SpinEnergyModel", "SpinModel", - "XASModel", "get_model", "make_hessian_model", "make_model", diff --git a/deepmd/pt/model/model/xas_model.py b/deepmd/pt/model/model/xas_model.py deleted file mode 100644 index 45b568d09e..0000000000 --- a/deepmd/pt/model/model/xas_model.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) - -import torch - -from deepmd.pt.model.atomic_model import ( - DPXASAtomicModel, -) -from deepmd.pt.model.model.model import ( - BaseModel, -) - -from .dp_model import ( - DPModelCommon, -) -from .make_model import ( - make_model, -) - -DPXASModel_ = make_model(DPXASAtomicModel) - - -@BaseModel.register("xas") -class XASModel(DPModelCommon, DPXASModel_): - model_type = "xas" - - def __init__( - self, - *args: Any, - **kwargs: Any, - ) -> None: - DPModelCommon.__init__(self) - DPXASModel_.__init__(self, *args, **kwargs) - - def translated_output_def(self) -> dict[str, Any]: - out_def_data = self.model_output_def().get_data() - output_def = { - "atom_xas": out_def_data["xas"], - "xas": out_def_data["xas_redu"], - } - if "mask" in out_def_data: - output_def["mask"] = out_def_data["mask"] - return output_def - - def forward( - self, - coord: torch.Tensor, - atype: torch.Tensor, - box: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, torch.Tensor]: - model_ret = self.forward_common( - coord, - atype, - box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - if self.get_fitting_net() is not None: - model_predict = {} - model_predict["atom_xas"] = model_ret["xas"] - model_predict["xas"] = model_ret["xas_redu"] - if "mask" in model_ret: - model_predict["mask"] = model_ret["mask"] - else: - model_predict = model_ret - return model_predict - - @torch.jit.export - def get_numb_xas(self) -> int: - """Get the number of XAS grid points.""" - return self.get_fitting_net().dim_out - - @torch.jit.export - def forward_lower( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, - comm_dict: dict[str, torch.Tensor] | None = None, - ) -> dict[str, torch.Tensor]: - model_ret = self.forward_common_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - comm_dict=comm_dict, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - ) - if self.get_fitting_net() is not None: - model_predict = {} - model_predict["atom_xas"] = model_ret["xas"] - model_predict["xas"] = model_ret["xas_redu"] - if "mask" in model_ret: - model_predict["mask"] = model_ret["mask"] - else: - model_predict = model_ret - return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 296935dc9e..37ffec2725 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -27,9 +27,6 @@ from .type_predict import ( TypePredictNet, ) -from .xas import ( - XASFittingNet, -) __all__ = [ "BaseFitting", @@ -42,5 +39,4 @@ "PolarFittingNet", "PropertyFittingNet", "TypePredictNet", - "XASFittingNet", ] diff --git a/deepmd/pt/model/task/xas.py b/deepmd/pt/model/task/xas.py deleted file mode 100644 index c8e9d752ea..0000000000 --- a/deepmd/pt/model/task/xas.py +++ /dev/null @@ -1,167 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import logging - -import torch - -from deepmd.dpmodel import ( - FittingOutputDef, - OutputVariableDef, -) -from deepmd.pt.model.task.ener import ( - InvarFitting, -) -from deepmd.pt.model.task.fitting import ( - Fitting, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.env import ( - DEFAULT_PRECISION, -) -from deepmd.pt.utils.utils import ( - to_numpy_array, -) -from deepmd.utils.version import ( - check_version_compatibility, -) - -dtype = env.GLOBAL_PT_FLOAT_PRECISION -device = env.DEVICE - -log = logging.getLogger(__name__) - - -@Fitting.register("xas") -class XASFittingNet(InvarFitting): - """PyTorch fitting network for XAS spectra. - - Parameters - ---------- - ntypes : int - Number of atom types. - dim_descrpt : int - Dimension of the descriptor. - numb_xas : int - Number of XAS energy grid points. - neuron : list[int] - Hidden layer sizes. - resnet_dt : bool - Whether to use ResNet time step. - numb_fparam : int - Dimension of frame parameters (e.g. edge type encoding). - numb_aparam : int - Dimension of atomic parameters. - dim_case_embd : int - Dimension of case embedding. - rcond : float or None - Cutoff for small singular values in bias init. - bias_xas : torch.Tensor or None - Initial bias, shape (ntypes, numb_xas). - trainable : bool or list[bool] - Whether parameters are trainable. - seed : int, list[int], or None - Random seed. - activation_function : str - Activation function. - precision : str - Float precision. - exclude_types : list[int] - Atom types to exclude (set by XASAtomicModel automatically). - mixed_types : bool - Whether to use a shared network across types. - type_map : list[str] or None - Mapping from type index to element symbol. - default_fparam : list or None - Default frame parameter values. - """ - - def __init__( - self, - ntypes: int, - dim_descrpt: int, - numb_xas: int = 500, - neuron: list[int] = [128, 128, 128], - resnet_dt: bool = True, - numb_fparam: int = 0, - numb_aparam: int = 0, - dim_case_embd: int = 0, - rcond: float | None = None, - bias_xas: torch.Tensor | None = None, - trainable: bool | list[bool] = True, - seed: int | list[int] | None = None, - activation_function: str = "tanh", - precision: str = DEFAULT_PRECISION, - exclude_types: list[int] = [], - mixed_types: bool = False, - type_map: list[str] | None = None, - default_fparam: list | None = None, - ) -> None: - if bias_xas is not None: - self.bias_xas = bias_xas - else: - self.bias_xas = torch.zeros( - (ntypes, numb_xas), dtype=dtype, device=env.DEVICE - ) - super().__init__( - var_name="xas", - ntypes=ntypes, - dim_descrpt=dim_descrpt, - dim_out=numb_xas, - neuron=neuron, - bias_atom_e=bias_xas, - resnet_dt=resnet_dt, - numb_fparam=numb_fparam, - numb_aparam=numb_aparam, - dim_case_embd=dim_case_embd, - activation_function=activation_function, - precision=precision, - mixed_types=mixed_types, - rcond=rcond, - seed=seed, - exclude_types=exclude_types, - trainable=trainable, - type_map=type_map, - default_fparam=default_fparam, - ) - - def output_def(self) -> FittingOutputDef: - return FittingOutputDef( - [ - OutputVariableDef( - self.var_name, - [self.dim_out], - reducible=True, - intensive=True, - r_differentiable=False, - c_differentiable=False, - ), - ] - ) - - @classmethod - def deserialize(cls, data: dict) -> "XASFittingNet": - data = data.copy() - check_version_compatibility(data.pop("@version", 1), 4, 1) - data.pop("@class", None) - data.pop("var_name", None) - data.pop("tot_ener_zero", None) - data.pop("layer_name", None) - data.pop("use_aparam_as_mask", None) - data.pop("spin", None) - data.pop("atom_ener", None) - data["numb_xas"] = data.pop("dim_out") - return super().deserialize(data) - - def serialize(self) -> dict: - """Serialize the fitting to dict.""" - dd = { - **InvarFitting.serialize(self), - "type": "xas", - "dim_out": self.dim_out, - } - dd["@variables"]["bias_atom_e"] = to_numpy_array(self.bias_atom_e) - return dd - - # make jit happy with torch 2.0.0 - exclude_types: list[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d999b276f9..034bcdc015 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1732,10 +1732,6 @@ def get_loss( loss_params["starter_learning_rate"] = start_lr loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size return DOSLoss(**loss_params) - elif loss_type == "xas": - loss_params["starter_learning_rate"] = start_lr - loss_params["numb_xas"] = _model.model_output_def()["xas"].output_size - return XASLoss(**loss_params) elif loss_type == "ener_spin": loss_params["starter_learning_rate"] = start_lr return EnergySpinLoss(**loss_params) @@ -1761,6 +1757,10 @@ def get_loss( loss_params["var_name"] = var_name loss_params["intensive"] = intensive return PropertyLoss(**loss_params) + elif loss_type == "xas": + loss_params["task_dim"] = _model.get_task_dim() + loss_params["var_name"] = _model.get_var_name() + return XASLoss(**loss_params) else: loss_params["starter_learning_rate"] = start_lr return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 25df224cf5..76fec40aff 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1912,64 +1912,6 @@ def fitting_dos() -> list[Argument]: ] -@fitting_args_plugin.register("xas", doc=doc_only_pt_supported) -def fitting_xas() -> list[Argument]: - doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams. Can be used to encode edge type information." - doc_numb_aparam = "The dimension of the atomic parameter." - doc_default_fparam = "The default frame parameter value." - doc_dim_case_embd = "The dimension of the case embedding." - doc_neuron = "The number of neurons in each hidden layer of the fitting net." - doc_activation_function = f"The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." - doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." - doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' - doc_trainable = "Whether the parameters in the fitting net are trainable." - doc_rcond = "The condition number for the initial bias fitting." - doc_seed = "Random seed for parameter initialization." - doc_numb_xas = "The number of grid points on the XAS energy axis (ΔE space)." - - return [ - Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), - Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), - Argument( - "default_fparam", - list[float], - optional=True, - default=None, - doc=doc_default_fparam, - ), - Argument( - "dim_case_embd", - int, - optional=True, - default=0, - doc=doc_dim_case_embd, - ), - Argument( - "neuron", list[int], optional=True, default=[120, 120, 120], doc=doc_neuron - ), - Argument( - "activation_function", - str, - optional=True, - default="tanh", - doc=doc_activation_function, - ), - Argument("precision", str, optional=True, default="float64", doc=doc_precision), - Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), - Argument( - "trainable", - [list[bool], bool], - optional=True, - default=True, - doc=doc_trainable, - ), - Argument( - "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond - ), - Argument("seed", [int, None], optional=True, doc=doc_seed), - Argument("numb_xas", int, optional=True, default=500, doc=doc_numb_xas), - ] - @fitting_args_plugin.register("property", doc=doc_only_pt_supported) def fitting_property() -> list[Argument]: @@ -2290,7 +2232,6 @@ def model_args(exclude_hybrid: bool = False) -> list[Argument]: doc_compress_config = "Model compression configurations" doc_spin = "The settings for systems with spin." doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types" - doc_absorbing_type = "The element symbol of the absorbing atom type for XAS fitting (e.g. 'Fe'). Only used when fitting_net type is 'xas'." doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other." doc_preset_out_bias = "The preset bias of the atomic output. Note that the set_davg_zero should be set to true. The bias is provided as a dict. Taking the energy model that has three atom types for example, the `preset_out_bias` may be given as `{ 'energy': [null, 0., 1.] }`. In this case the energy bias of type 1 and 2 are set to 0. and 1., respectively. A dipole model with two atom types may set `preset_out_bias` as `{ 'dipole': [null, [0., 1., 2.]] }`" doc_finetune_head = ( @@ -2355,13 +2296,6 @@ def model_args(exclude_hybrid: bool = False) -> list[Argument]: default=[], doc=doc_only_pt_supported + doc_atom_exclude_types, ), - Argument( - "absorbing_type", - str, - optional=True, - default=None, - doc=doc_only_pt_supported + doc_absorbing_type, - ), Argument( "preset_out_bias", dict[str, list[float | list[float] | None]], @@ -3532,41 +3466,16 @@ def loss_dos() -> list[Argument]: ] + @loss_args_plugin.register("xas", doc=doc_only_pt_supported) def loss_xas() -> list[Argument]: - doc_start_pref_xas = start_pref("XAS spectrum (mean over absorbing atoms)") - doc_limit_pref_xas = limit_pref("XAS spectrum (mean over absorbing atoms)") - doc_start_pref_cdf = start_pref("Cumulative Distribution Function of XAS") - doc_limit_pref_cdf = limit_pref("Cumulative Distribution Function of XAS") + doc_loss_func = "The loss function to minimize: 'smooth_mae' (default), 'mae', 'mse', 'rmse'." + doc_metric = "Metrics to display during training. Supported: 'mae', 'rmse'." + doc_beta = "Beta parameter for smooth_l1 loss." return [ - Argument( - "start_pref_xas", - [float, int], - optional=True, - default=1.0, - doc=doc_start_pref_xas, - ), - Argument( - "limit_pref_xas", - [float, int], - optional=True, - default=1.0, - doc=doc_limit_pref_xas, - ), - Argument( - "start_pref_cdf", - [float, int], - optional=True, - default=0.0, - doc=doc_start_pref_cdf, - ), - Argument( - "limit_pref_cdf", - [float, int], - optional=True, - default=0.0, - doc=doc_limit_pref_cdf, - ), + Argument("loss_func", str, optional=True, default="smooth_mae", doc=doc_loss_func), + Argument("metric", list, optional=True, default=["mae"], doc=doc_metric), + Argument("beta", float, optional=True, default=1.0, doc=doc_beta), ] From 9352c4f6815f12b6176353f50d04adae30febe1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:38:21 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/entrypoints/test.py | 6 ++++-- deepmd/pt/loss/xas.py | 20 +++++++++++++------- deepmd/pt/model/model/__init__.py | 1 + deepmd/utils/argcheck.py | 10 ++++++---- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 781123cb79..9bcdf58d8c 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -895,7 +895,9 @@ def test_property( data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False) # sel_type: optional per-frame type index for element-wise mean reduction (XAS) - data.add("sel_type", 1, atomic=False, must=False, high_prec=False, default=float(-1)) + data.add( + "sel_type", 1, atomic=False, must=False, high_prec=False, default=float(-1) + ) test_data = data.get_test() mixed_type = data.mixed_type @@ -949,7 +951,7 @@ def test_property( property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype) for i in range(numb_test): t = sel_type_int[i] - mask = (atype_frames[i] == t) # [natoms] + mask = atype_frames[i] == t # [natoms] count = max(mask.sum(), 1) property[i] = atom_prop[i][mask].sum(axis=0) / count else: diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py index 336425a6e0..15e2089ab8 100644 --- a/deepmd/pt/loss/xas.py +++ b/deepmd/pt/loss/xas.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -from typing import Any +from typing import ( + Any, +) import torch import torch.nn.functional as F -from deepmd.pt.loss.loss import TaskLoss -from deepmd.pt.utils import env -from deepmd.utils.data import DataRequirementItem +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.data import ( + DataRequirementItem, +) log = logging.getLogger(__name__) @@ -70,9 +78,7 @@ def forward( # element-wise mean: for each frame average over atoms of sel_type nf, nloc, td = atom_prop.shape - pred = torch.zeros( - nf, td, dtype=atom_prop.dtype, device=atom_prop.device - ) + pred = torch.zeros(nf, td, dtype=atom_prop.dtype, device=atom_prop.device) for i in range(nf): t = int(sel_type[i].item()) mask = (atype[i] == t).unsqueeze(-1) # [nloc, 1] diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 06f411d007..1aa3732580 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -74,6 +74,7 @@ SpinModel, ) + def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: if "type_embedding" in model_params: raise ValueError( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 76fec40aff..674f565ee8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1912,7 +1912,6 @@ def fitting_dos() -> list[Argument]: ] - @fitting_args_plugin.register("property", doc=doc_only_pt_supported) def fitting_property() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -3466,14 +3465,17 @@ def loss_dos() -> list[Argument]: ] - @loss_args_plugin.register("xas", doc=doc_only_pt_supported) def loss_xas() -> list[Argument]: - doc_loss_func = "The loss function to minimize: 'smooth_mae' (default), 'mae', 'mse', 'rmse'." + doc_loss_func = ( + "The loss function to minimize: 'smooth_mae' (default), 'mae', 'mse', 'rmse'." + ) doc_metric = "Metrics to display during training. Supported: 'mae', 'rmse'." doc_beta = "Beta parameter for smooth_l1 loss." return [ - Argument("loss_func", str, optional=True, default="smooth_mae", doc=doc_loss_func), + Argument( + "loss_func", str, optional=True, default="smooth_mae", doc=doc_loss_func + ), Argument("metric", list, optional=True, default=["mae"], doc=doc_metric), Argument("beta", float, optional=True, default=1.0, doc=doc_beta), ] From 9bc38d7c5c9bce39d1c02c274071f765758ba389 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:45:32 +0800 Subject: [PATCH 5/9] feat: Add X-ray Absorption Spectroscopy (XAS) training examples --- examples/xas/train/README.md | 108 +++++++++++++++++++++++ examples/xas/train/gen_data.py | 154 +++++++++++++++++++++++++++++++++ examples/xas/train/input.json | 60 +++++++++++++ 3 files changed, 322 insertions(+) create mode 100644 examples/xas/train/README.md create mode 100644 examples/xas/train/gen_data.py create mode 100644 examples/xas/train/input.json diff --git a/examples/xas/train/README.md b/examples/xas/train/README.md new file mode 100644 index 0000000000..1ffb4d6ec3 --- /dev/null +++ b/examples/xas/train/README.md @@ -0,0 +1,108 @@ +# XAS Spectrum Fitting with DeePMD-kit + +This example shows how to train a model to predict X-ray absorption spectra (XAS) +from atomic structure using DeePMD-kit's `property` fitting net. + +## Concept + +- The model predicts a 102-dimensional output per atom: `[E_min, E_max, I_0, …, I_99]` +- During training, per-atom outputs are averaged over atoms of the **absorbing element** + (identified by `sel_type.npy` in each training system) +- The edge type (K, L1, L2, …) is provided as a frame-level parameter `fparam` +- One training system per `(element, edge)` pair + +## Quick Start + +**1. Generate example training data** + +```bash +python gen_data.py +``` + +This creates `data/Fe_K/` and `data/O_K/` with 50 frames each. + +**2. Train the model** + +```bash +dp train input.json +``` + +**3. Freeze the model** + +```bash +dp freeze -o model.pb +``` + +**4. Test the model** + +```bash +dp test -m model.pb -s data/Fe_K -n 10 +dp test -m model.pb -s data/O_K -n 10 +``` + +`dp test` automatically detects `sel_type.npy` and applies element-wise averaging +before computing the error metrics. + +## Data Format + +Each system directory must contain: + +``` +data/Fe_K/ +├── type.raw # atom type indices, one per line (int) +├── type_map.raw # element symbols, one per line +└── set.000/ + ├── coord.npy # [nframes, natoms*3] Cartesian coordinates (Å) + ├── box.npy # [nframes, 9] cell vectors (Å), row-major + ├── fparam.npy # [nframes, nfparam] edge one-hot encoding + ├── sel_type.npy # [nframes, 1] absorbing element type index (float64) + └── xas.npy # [nframes, 102] XAS label: [E_min, E_max, I_0..I_99] +``` + +### `sel_type.npy` + +The type index of the absorbing element, stored as float64, constant per system. + +``` +Fe is type 0 → sel_type.npy filled with 0.0 +O is type 1 → sel_type.npy filled with 1.0 +``` + +### `xas.npy` label layout (`task_dim = 102`) + +| Column | Meaning | +|-----------|---------------------------------------------| +| `xas[i,0]` | `E_min` (eV) — lower bound of energy grid | +| `xas[i,1]` | `E_max` (eV) — upper bound of energy grid | +| `xas[i,2:]`| `I[0..99]` — 100 intensity values on `linspace(E_min, E_max, 100)` | + +### `fparam.npy` edge encoding (`nfparam = 3`) + +| Edge | Encoding | +|------|-----------| +| K | `[1,0,0]` | +| L1 | `[0,1,0]` | +| L2 | `[0,0,1]` | + +Extend with more entries for additional edges and set `numb_fparam` accordingly. + +## Input Parameters + +Key fields in `input.json`: + +| Parameter | Description | +|-----------|-------------| +| `fitting_net.type` | Must be `"property"` | +| `fitting_net.task_dim` | `102` (2 energy bounds + 100 intensities) | +| `fitting_net.intensive` | `true` — per-atom outputs are **averaged**, not summed | +| `fitting_net.numb_fparam` | Number of edge-type features (3 for K/L1/L2) | +| `loss.type` | `"xas"` — uses `sel_type.npy` for element-selective averaging | +| `loss.loss_func` | `"smooth_mae"` (recommended) or `"mse"` | + +## Extending to More Elements / Edges + +- Add a new system directory per `(element, edge)` pair +- Set `sel_type.npy` to the type index of the absorbing element in that system +- Set `fparam.npy` to the one-hot vector for the corresponding edge +- List all system paths under `training.training_data.systems` +- Increase `numb_fparam` if adding new edge types diff --git a/examples/xas/train/gen_data.py b/examples/xas/train/gen_data.py new file mode 100644 index 0000000000..22c3864f7d --- /dev/null +++ b/examples/xas/train/gen_data.py @@ -0,0 +1,154 @@ +"""Generate example XAS training data for a Fe-O system. + +This script shows the required data format for XAS spectrum fitting. + +Data layout +----------- +One training system per (element, edge) pair: + + data/Fe_K/ — Fe K-edge XAS + data/O_K/ — O K-edge XAS + +Each system directory contains: + + type.raw — atom type indices (int, one per line) + type_map.raw — element symbols, one per line + set.000/ + coord.npy — [nframes, natoms*3] Cartesian coordinates (Å) + box.npy — [nframes, 9] cell vectors (Å), row-major + fparam.npy — [nframes, nfparam] edge encoding (one-hot or continuous) + sel_type.npy — [nframes, 1] type index of absorbing element (float) + xas.npy — [nframes, task_dim] XAS label: [E_min, E_max, I_0..I_99] + +Label format (task_dim = 102) +------------------------------ + xas[i, 0] = E_min (eV) — lower bound of the energy grid for frame i + xas[i, 1] = E_max (eV) — upper bound of the energy grid for frame i + xas[i, 2:] = I (arb. units) — 100 equally-spaced intensity values + on the grid linspace(E_min, E_max, 100) + +fparam encoding (nfparam = 3 for K/L1/L2 edges) +------------------------------------------------- + K-edge → [1, 0, 0] + L1-edge → [0, 1, 0] + L2-edge → [0, 0, 1] + (extend as needed; use numb_fparam in input.json accordingly) + +sel_type.npy +------------ + Integer type index of the absorbing element, stored as float64. + All frames in a system must share the same value (it is constant per system). + Example: Fe is type 0 → sel_type.npy filled with 0.0 + O is type 1 → sel_type.npy filled with 1.0 +""" + +import os +import numpy as np + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +nframes = 50 # number of frames per system +numb_pts = 100 # energy grid points +task_dim = numb_pts + 2 # E_min + E_max + 100 intensities +nfparam = 3 # K / L1 / L2 one-hot +natoms = 8 # 4 Fe (type 0) + 4 O (type 1) +box_size = 4.0 # Å + +rng = np.random.default_rng(42) + +# Equilibrium positions: simple rock-salt-like arrangement +base_pos = np.array([ + [0.0, 0.0, 0.0], [2.0, 2.0, 0.0], [2.0, 0.0, 2.0], [0.0, 2.0, 2.0], # Fe + [1.0, 1.0, 1.0], [3.0, 3.0, 1.0], [3.0, 1.0, 3.0], [1.0, 3.0, 3.0], # O +]) + +coords = base_pos[None] + rng.normal(0, 0.1, (nframes, natoms, 3)) +box = np.tile(np.diag([box_size] * 3).reshape(9), (nframes, 1)) + +type_arr = np.array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int) # Fe Fe Fe Fe O O O O + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def gaussian_spectrum(peak_eV, e_min, e_max, npts=100, width_frac=0.10): + grid = np.linspace(e_min, e_max, npts) + width = (e_max - e_min) * width_frac + return np.exp(-0.5 * ((grid - peak_eV) / width) ** 2) + + +def write_system( + path: str, + sel_type_idx: int, + atom_slice, # slice object selecting absorbing atoms + e_min: float, + e_max: float, + peak_center: float, + peak_shift_scale: float, + fparam_vec, # 1-D array of length nfparam (one-hot edge encoding) +): + os.makedirs(f"{path}/set.000", exist_ok=True) + + # --- structure --- + np.savetxt(f"{path}/type.raw", type_arr, fmt="%d") + with open(f"{path}/type_map.raw", "w") as f: + f.write("Fe\nO\n") + np.save(f"{path}/set.000/box.npy", box.astype(np.float64)) + np.save(f"{path}/set.000/coord.npy", + coords.reshape(nframes, natoms * 3).astype(np.float64)) + + # --- fparam: same edge for all frames --- + fparam = np.tile(fparam_vec, (nframes, 1)).astype(np.float64) + np.save(f"{path}/set.000/fparam.npy", fparam) + + # --- sel_type: constant per system --- + sel = np.full((nframes, 1), float(sel_type_idx), dtype=np.float64) + np.save(f"{path}/set.000/sel_type.npy", sel) + + # --- xas labels --- + labels = np.zeros((nframes, task_dim), dtype=np.float64) + for i in range(nframes): + # peak position shifts slightly with mean x-coordinate of absorbing atoms + mean_x = coords[i, atom_slice, 0].mean() + peak = peak_center + mean_x * peak_shift_scale + spectrum = gaussian_spectrum(peak, e_min, e_max) + labels[i, 0] = e_min + labels[i, 1] = e_max + labels[i, 2:] = spectrum + np.save(f"{path}/set.000/xas.npy", labels) + + print(f" {path}:") + print(f" sel_type = {sel_type_idx} fparam = {fparam_vec.tolist()}") + print(f" xas.npy shape = {labels.shape}") + + +# --------------------------------------------------------------------------- +# Generate Fe K-edge and O K-edge systems +# --------------------------------------------------------------------------- +print("Generating example XAS training data...") + +write_system( + path = "data/Fe_K", + sel_type_idx = 0, # Fe is type 0 + atom_slice = slice(0, 4), # first 4 atoms are Fe + e_min = 7100.0, # Fe K-edge region (eV) + e_max = 7250.0, + peak_center = 7112.0, # Fe K-edge energy + peak_shift_scale = 2.0, # chemical shift ∝ local environment + fparam_vec = np.array([1.0, 0.0, 0.0]), # K-edge one-hot +) + +write_system( + path = "data/O_K", + sel_type_idx = 1, # O is type 1 + atom_slice = slice(4, 8), # last 4 atoms are O + e_min = 525.0, # O K-edge region (eV) + e_max = 560.0, + peak_center = 535.0, # O K-edge energy + peak_shift_scale = 0.5, + fparam_vec = np.array([1.0, 0.0, 0.0]), # also K-edge +) + +print(f"\nDone. {nframes} frames per system, task_dim={task_dim}, nfparam={nfparam}") +print("Data written to ./data/Fe_K/ and ./data/O_K/") diff --git a/examples/xas/train/input.json b/examples/xas/train/input.json new file mode 100644 index 0000000000..5b4e063461 --- /dev/null +++ b/examples/xas/train/input.json @@ -0,0 +1,60 @@ +{ + "_comment": "XAS spectrum fitting example — Fe-O system, Fe K-edge + O K-edge", + + "model": { + "type_map": ["Fe", "O"], + "descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": [40, 40], + "neuron": [25, 50, 100], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1 + }, + "fitting_net": { + "_comment": "property fitting with task_dim=102: [E_min, E_max, I_0, ..., I_99]", + "type": "property", + "property_name": "xas", + "task_dim": 102, + "_comment_intensive": "intensive=true: per-atom outputs are averaged (not summed)", + "intensive": true, + "_comment_fparam": "fparam encodes edge type: 1-hot vector, e.g. [1,0,0]=K, [0,1,0]=L1, [0,0,1]=L2", + "numb_fparam": 3, + "neuron": [128, 128, 128], + "resnet_dt": true, + "seed": 1 + } + }, + "loss": { + "_comment": "xas loss: reads sel_type.npy to select which element to reduce over", + "type": "xas", + "loss_func": "smooth_mae", + "metric": ["mae", "rmse"] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_rate": 0.95 + }, + "training": { + "training_data": { + "_comment": "one system per (element, edge) pair", + "systems": [ + "./data/Fe_K/", + "./data/O_K/" + ], + "batch_size": "auto" + }, + "numb_steps": 200000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1000, + "save_freq": 10000, + "save_ckpt": "model.ckpt", + "stat_file": "stat_files" + } +} From c8a40051f74d31d2e27d86da8ae6d355c45f0ecb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:46:49 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/xas/train/README.md | 28 ++++---- examples/xas/train/gen_data.py | 86 ++++++++++++---------- examples/xas/train/input.json | 127 +++++++++++++++++++-------------- 3 files changed, 135 insertions(+), 106 deletions(-) diff --git a/examples/xas/train/README.md b/examples/xas/train/README.md index 1ffb4d6ec3..a4a96c757a 100644 --- a/examples/xas/train/README.md +++ b/examples/xas/train/README.md @@ -70,16 +70,16 @@ O is type 1 → sel_type.npy filled with 1.0 ### `xas.npy` label layout (`task_dim = 102`) -| Column | Meaning | -|-----------|---------------------------------------------| -| `xas[i,0]` | `E_min` (eV) — lower bound of energy grid | -| `xas[i,1]` | `E_max` (eV) — upper bound of energy grid | -| `xas[i,2:]`| `I[0..99]` — 100 intensity values on `linspace(E_min, E_max, 100)` | +| Column | Meaning | +| ----------- | ------------------------------------------------------------------ | +| `xas[i,0]` | `E_min` (eV) — lower bound of energy grid | +| `xas[i,1]` | `E_max` (eV) — upper bound of energy grid | +| `xas[i,2:]` | `I[0..99]` — 100 intensity values on `linspace(E_min, E_max, 100)` | ### `fparam.npy` edge encoding (`nfparam = 3`) | Edge | Encoding | -|------|-----------| +| ---- | --------- | | K | `[1,0,0]` | | L1 | `[0,1,0]` | | L2 | `[0,0,1]` | @@ -90,14 +90,14 @@ Extend with more entries for additional edges and set `numb_fparam` accordingly. Key fields in `input.json`: -| Parameter | Description | -|-----------|-------------| -| `fitting_net.type` | Must be `"property"` | -| `fitting_net.task_dim` | `102` (2 energy bounds + 100 intensities) | -| `fitting_net.intensive` | `true` — per-atom outputs are **averaged**, not summed | -| `fitting_net.numb_fparam` | Number of edge-type features (3 for K/L1/L2) | -| `loss.type` | `"xas"` — uses `sel_type.npy` for element-selective averaging | -| `loss.loss_func` | `"smooth_mae"` (recommended) or `"mse"` | +| Parameter | Description | +| ------------------------- | ------------------------------------------------------------- | +| `fitting_net.type` | Must be `"property"` | +| `fitting_net.task_dim` | `102` (2 energy bounds + 100 intensities) | +| `fitting_net.intensive` | `true` — per-atom outputs are **averaged**, not summed | +| `fitting_net.numb_fparam` | Number of edge-type features (3 for K/L1/L2) | +| `loss.type` | `"xas"` — uses `sel_type.npy` for element-selective averaging | +| `loss.loss_func` | `"smooth_mae"` (recommended) or `"mse"` | ## Extending to More Elements / Edges diff --git a/examples/xas/train/gen_data.py b/examples/xas/train/gen_data.py index 22c3864f7d..82c81319b8 100644 --- a/examples/xas/train/gen_data.py +++ b/examples/xas/train/gen_data.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later """Generate example XAS training data for a Fe-O system. This script shows the required data format for XAS spectrum fitting. @@ -43,28 +44,37 @@ """ import os + import numpy as np # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- -nframes = 50 # number of frames per system -numb_pts = 100 # energy grid points -task_dim = numb_pts + 2 # E_min + E_max + 100 intensities -nfparam = 3 # K / L1 / L2 one-hot -natoms = 8 # 4 Fe (type 0) + 4 O (type 1) -box_size = 4.0 # Å +nframes = 50 # number of frames per system +numb_pts = 100 # energy grid points +task_dim = numb_pts + 2 # E_min + E_max + 100 intensities +nfparam = 3 # K / L1 / L2 one-hot +natoms = 8 # 4 Fe (type 0) + 4 O (type 1) +box_size = 4.0 # Å rng = np.random.default_rng(42) # Equilibrium positions: simple rock-salt-like arrangement -base_pos = np.array([ - [0.0, 0.0, 0.0], [2.0, 2.0, 0.0], [2.0, 0.0, 2.0], [0.0, 2.0, 2.0], # Fe - [1.0, 1.0, 1.0], [3.0, 3.0, 1.0], [3.0, 1.0, 3.0], [1.0, 3.0, 3.0], # O -]) +base_pos = np.array( + [ + [0.0, 0.0, 0.0], + [2.0, 2.0, 0.0], + [2.0, 0.0, 2.0], + [0.0, 2.0, 2.0], # Fe + [1.0, 1.0, 1.0], + [3.0, 3.0, 1.0], + [3.0, 1.0, 3.0], + [1.0, 3.0, 3.0], # O + ] +) coords = base_pos[None] + rng.normal(0, 0.1, (nframes, natoms, 3)) -box = np.tile(np.diag([box_size] * 3).reshape(9), (nframes, 1)) +box = np.tile(np.diag([box_size] * 3).reshape(9), (nframes, 1)) type_arr = np.array([0, 0, 0, 0, 1, 1, 1, 1], dtype=int) # Fe Fe Fe Fe O O O O @@ -73,7 +83,7 @@ # Helpers # --------------------------------------------------------------------------- def gaussian_spectrum(peak_eV, e_min, e_max, npts=100, width_frac=0.10): - grid = np.linspace(e_min, e_max, npts) + grid = np.linspace(e_min, e_max, npts) width = (e_max - e_min) * width_frac return np.exp(-0.5 * ((grid - peak_eV) / width) ** 2) @@ -81,12 +91,12 @@ def gaussian_spectrum(peak_eV, e_min, e_max, npts=100, width_frac=0.10): def write_system( path: str, sel_type_idx: int, - atom_slice, # slice object selecting absorbing atoms + atom_slice, # slice object selecting absorbing atoms e_min: float, e_max: float, peak_center: float, peak_shift_scale: float, - fparam_vec, # 1-D array of length nfparam (one-hot edge encoding) + fparam_vec, # 1-D array of length nfparam (one-hot edge encoding) ): os.makedirs(f"{path}/set.000", exist_ok=True) @@ -94,9 +104,11 @@ def write_system( np.savetxt(f"{path}/type.raw", type_arr, fmt="%d") with open(f"{path}/type_map.raw", "w") as f: f.write("Fe\nO\n") - np.save(f"{path}/set.000/box.npy", box.astype(np.float64)) - np.save(f"{path}/set.000/coord.npy", - coords.reshape(nframes, natoms * 3).astype(np.float64)) + np.save(f"{path}/set.000/box.npy", box.astype(np.float64)) + np.save( + f"{path}/set.000/coord.npy", + coords.reshape(nframes, natoms * 3).astype(np.float64), + ) # --- fparam: same edge for all frames --- fparam = np.tile(fparam_vec, (nframes, 1)).astype(np.float64) @@ -110,11 +122,11 @@ def write_system( labels = np.zeros((nframes, task_dim), dtype=np.float64) for i in range(nframes): # peak position shifts slightly with mean x-coordinate of absorbing atoms - mean_x = coords[i, atom_slice, 0].mean() - peak = peak_center + mean_x * peak_shift_scale + mean_x = coords[i, atom_slice, 0].mean() + peak = peak_center + mean_x * peak_shift_scale spectrum = gaussian_spectrum(peak, e_min, e_max) - labels[i, 0] = e_min - labels[i, 1] = e_max + labels[i, 0] = e_min + labels[i, 1] = e_max labels[i, 2:] = spectrum np.save(f"{path}/set.000/xas.npy", labels) @@ -129,25 +141,25 @@ def write_system( print("Generating example XAS training data...") write_system( - path = "data/Fe_K", - sel_type_idx = 0, # Fe is type 0 - atom_slice = slice(0, 4), # first 4 atoms are Fe - e_min = 7100.0, # Fe K-edge region (eV) - e_max = 7250.0, - peak_center = 7112.0, # Fe K-edge energy - peak_shift_scale = 2.0, # chemical shift ∝ local environment - fparam_vec = np.array([1.0, 0.0, 0.0]), # K-edge one-hot + path="data/Fe_K", + sel_type_idx=0, # Fe is type 0 + atom_slice=slice(0, 4), # first 4 atoms are Fe + e_min=7100.0, # Fe K-edge region (eV) + e_max=7250.0, + peak_center=7112.0, # Fe K-edge energy + peak_shift_scale=2.0, # chemical shift ∝ local environment + fparam_vec=np.array([1.0, 0.0, 0.0]), # K-edge one-hot ) write_system( - path = "data/O_K", - sel_type_idx = 1, # O is type 1 - atom_slice = slice(4, 8), # last 4 atoms are O - e_min = 525.0, # O K-edge region (eV) - e_max = 560.0, - peak_center = 535.0, # O K-edge energy - peak_shift_scale = 0.5, - fparam_vec = np.array([1.0, 0.0, 0.0]), # also K-edge + path="data/O_K", + sel_type_idx=1, # O is type 1 + atom_slice=slice(4, 8), # last 4 atoms are O + e_min=525.0, # O K-edge region (eV) + e_max=560.0, + peak_center=535.0, # O K-edge energy + peak_shift_scale=0.5, + fparam_vec=np.array([1.0, 0.0, 0.0]), # also K-edge ) print(f"\nDone. {nframes} frames per system, task_dim={task_dim}, nfparam={nfparam}") diff --git a/examples/xas/train/input.json b/examples/xas/train/input.json index 5b4e063461..f58417b478 100644 --- a/examples/xas/train/input.json +++ b/examples/xas/train/input.json @@ -1,60 +1,77 @@ { - "_comment": "XAS spectrum fitting example — Fe-O system, Fe K-edge + O K-edge", + "_comment": "XAS spectrum fitting example — Fe-O system, Fe K-edge + O K-edge", - "model": { - "type_map": ["Fe", "O"], - "descriptor": { - "type": "se_e2_a", - "rcut": 6.0, - "rcut_smth": 0.5, - "sel": [40, 40], - "neuron": [25, 50, 100], - "resnet_dt": false, - "axis_neuron": 16, - "seed": 1 - }, - "fitting_net": { - "_comment": "property fitting with task_dim=102: [E_min, E_max, I_0, ..., I_99]", - "type": "property", - "property_name": "xas", - "task_dim": 102, - "_comment_intensive": "intensive=true: per-atom outputs are averaged (not summed)", - "intensive": true, - "_comment_fparam": "fparam encodes edge type: 1-hot vector, e.g. [1,0,0]=K, [0,1,0]=L1, [0,0,1]=L2", - "numb_fparam": 3, - "neuron": [128, 128, 128], - "resnet_dt": true, - "seed": 1 - } + "model": { + "type_map": [ + "Fe", + "O" + ], + "descriptor": { + "type": "se_e2_a", + "rcut": 6.0, + "rcut_smth": 0.5, + "sel": [ + 40, + 40 + ], + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1 }, - "loss": { - "_comment": "xas loss: reads sel_type.npy to select which element to reduce over", - "type": "xas", - "loss_func": "smooth_mae", - "metric": ["mae", "rmse"] - }, - "learning_rate": { - "type": "exp", - "decay_steps": 5000, - "start_lr": 0.001, - "stop_lr": 1e-8, - "decay_rate": 0.95 - }, - "training": { - "training_data": { - "_comment": "one system per (element, edge) pair", - "systems": [ - "./data/Fe_K/", - "./data/O_K/" - ], - "batch_size": "auto" - }, - "numb_steps": 200000, - "seed": 10, - "disp_file": "lcurve.out", - "disp_freq": 1000, - "save_freq": 10000, - "save_ckpt": "model.ckpt", - "stat_file": "stat_files" + "fitting_net": { + "_comment": "property fitting with task_dim=102: [E_min, E_max, I_0, ..., I_99]", + "type": "property", + "property_name": "xas", + "task_dim": 102, + "_comment_intensive": "intensive=true: per-atom outputs are averaged (not summed)", + "intensive": true, + "_comment_fparam": "fparam encodes edge type: 1-hot vector, e.g. [1,0,0]=K, [0,1,0]=L1, [0,0,1]=L2", + "numb_fparam": 3, + "neuron": [ + 128, + 128, + 128 + ], + "resnet_dt": true, + "seed": 1 } + }, + "loss": { + "_comment": "xas loss: reads sel_type.npy to select which element to reduce over", + "type": "xas", + "loss_func": "smooth_mae", + "metric": [ + "mae", + "rmse" + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_rate": 0.95 + }, + "training": { + "training_data": { + "_comment": "one system per (element, edge) pair", + "systems": [ + "./data/Fe_K/", + "./data/O_K/" + ], + "batch_size": "auto" + }, + "numb_steps": 200000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1000, + "save_freq": 10000, + "save_ckpt": "model.ckpt", + "stat_file": "stat_files" + } } From e157ed712b244152924c83128e24545cf6b44f6b Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:23:40 +0800 Subject: [PATCH 7/9] feat: Implement XAS energy normalization in the XAS loss function and introduce a dedicated XAS model. --- deepmd/entrypoints/test.py | 44 ++-- deepmd/pt/loss/xas.py | 220 ++++++++++++++++-- deepmd/pt/model/atomic_model/__init__.py | 2 + .../atomic_model/property_atomic_model.py | 25 ++ deepmd/pt/model/model/__init__.py | 9 +- deepmd/pt/model/model/xas_model.py | 42 ++++ deepmd/pt/train/training.py | 8 + 7 files changed, 315 insertions(+), 35 deletions(-) create mode 100644 deepmd/pt/model/model/xas_model.py diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 9bcdf58d8c..8f9cacdd42 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -887,6 +887,8 @@ def test_property( high_prec=True, ) + is_xas = var_name == "xas" + if dp.get_dim_fparam() > 0: data.add( "fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False @@ -894,10 +896,11 @@ def test_property( if dp.get_dim_aparam() > 0: data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False) - # sel_type: optional per-frame type index for element-wise mean reduction (XAS) - data.add( - "sel_type", 1, atomic=False, must=False, high_prec=False, default=float(-1) - ) + # XAS requires sel_type.npy (per-frame absorbing element type index) + if is_xas: + data.add( + "sel_type", 1, atomic=False, must=True, high_prec=False + ) test_data = data.get_test() mixed_type = data.mixed_type @@ -923,12 +926,8 @@ def test_property( else: aparam = None - # detect whether this system provides sel_type (XAS-style reduction) - sel_type_raw = test_data["sel_type"][:numb_test, 0] # [numb_test] - has_sel_type = bool((sel_type_raw >= 0).all()) - - # for sel_type reduction we need per-atom outputs - eval_atomic = has_atom_property or has_sel_type + # XAS: per-atom outputs are needed to average over absorbing-element atoms + eval_atomic = has_atom_property or is_xas ret = dp.eval( coord, box, @@ -939,27 +938,44 @@ def test_property( mixed_type=mixed_type, ) - if has_sel_type: + if is_xas: # ret[1]: per-atom property [numb_test, natoms, task_dim] atom_prop = ret[1].reshape([numb_test, natoms, dp.task_dim]) - # atype for all frames if mixed_type: atype_frames = atype # [numb_test, natoms] else: atype_frames = np.tile(atype, (numb_test, 1)) # [numb_test, natoms] - sel_type_int = sel_type_raw.astype(int) + sel_type_int = test_data["sel_type"][:numb_test, 0].astype(int) property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype) for i in range(numb_test): t = sel_type_int[i] mask = atype_frames[i] == t # [natoms] count = max(mask.sum(), 1) property[i] = atom_prop[i][mask].sum(axis=0) / count + + # Add back the per-(type, edge) energy reference so output is in + # absolute eV (matching label format). xas_e_ref is saved in the + # model checkpoint by XASLoss.compute_output_stats. + try: + xas_e_ref = dp.dp.model["Default"].atomic_model.xas_e_ref + except AttributeError: + xas_e_ref = None + if xas_e_ref is not None and fparam is not None: + import torch as _torch + edge_idx_all = _torch.tensor( + fparam.reshape(numb_test, -1) + ).argmax(dim=-1).numpy() + e_ref_np = xas_e_ref.cpu().numpy() # [ntypes, nfparam, 2] + for i in range(numb_test): + t = sel_type_int[i] + e = int(edge_idx_all[i]) + property[i, :2] += e_ref_np[t, e] else: property = ret[0] property = property.reshape([numb_test, dp.task_dim]) - if has_atom_property: + if has_atom_property and not is_xas: aproperty = ret[1] aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim]) diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py index 15e2089ab8..984c5e7487 100644 --- a/deepmd/pt/loss/xas.py +++ b/deepmd/pt/loss/xas.py @@ -1,21 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -from typing import ( - Any, -) +from collections import defaultdict +from typing import Any +import numpy as np import torch import torch.nn.functional as F -from deepmd.pt.loss.loss import ( - TaskLoss, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.utils.data import ( - DataRequirementItem, -) +from deepmd.pt.loss.loss import TaskLoss +from deepmd.pt.utils import env +from deepmd.utils.data import DataRequirementItem log = logging.getLogger(__name__) @@ -28,10 +22,31 @@ class XASLoss(TaskLoss): in each training system) and takes their mean, then computes a loss against the per-frame XAS label. + Energy normalization + -------------------- + XAS labels contain absolute edge energies (E_min, E_max in eV) that vary + enormously across element-edge pairs (H_K ~14 eV, Th_K ~110000 eV). + Training directly on absolute values causes gradient instability because + the energy dimensions dwarf the intensity dimensions. + + ``compute_output_stats`` computes a reference energy ``e_ref[t, e]`` for + every ``(absorbing_type t, edge_index e)`` combination from the training + data and stores it as a registered buffer. During training, ``forward`` + normalises labels and predictions by subtracting the per-frame reference + so that the loss is computed on chemical shifts (±few eV) and normalised + intensities—quantities of comparable magnitude. + + The buffer is saved in the model checkpoint, eliminating any need for + external normalisation files. + Parameters ---------- task_dim : int Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts). + ntypes : int + Number of atom types in the model. + nfparam : int + Length of the fparam one-hot vector (= number of edge types). var_name : str Property name, must match ``property_name`` in the fitting config. loss_func : str @@ -45,6 +60,8 @@ class XASLoss(TaskLoss): def __init__( self, task_dim: int, + ntypes: int, + nfparam: int, var_name: str = "xas", loss_func: str = "smooth_mae", metric: list[str] = ["mae"], @@ -53,11 +70,141 @@ def __init__( ) -> None: super().__init__() self.task_dim = task_dim + self.ntypes = ntypes + self.nfparam = nfparam self.var_name = var_name self.loss_func = loss_func self.metric = metric self.beta = beta + # e_ref[sel_type_idx, edge_idx, 0] = mean E_min (eV) + # e_ref[sel_type_idx, edge_idx, 1] = mean E_max (eV) + # Shape: [ntypes, nfparam, 2]. Filled by compute_output_stats; zero until then. + self.register_buffer( + "e_ref", + torch.zeros(ntypes, nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) + + # ------------------------------------------------------------------ + # Stat phase: compute per-(absorbing_type, edge) reference energies + # ------------------------------------------------------------------ + def compute_output_stats( + self, + sampled: list[dict], + model: "torch.nn.Module | None" = None, + ) -> None: + """Compute ``e_ref`` and fix model energy-dim bias/std. + + Called once before training starts. Requires ``xas``, ``sel_type``, + and ``fparam`` in at least some samples. + + Parameters + ---------- + sampled : list[dict] + List of data batches from ``make_stat_input``. + model : nn.Module, optional + The full DeePMD model. When given, the per-atom property model's + ``out_bias`` and ``out_std`` for the two energy dimensions (E_min, + E_max) are reset to 0 / 1 so the NN predicts *chemical shifts* + (±few eV) instead of absolute energies (~thousands of eV). + Without this reset the stat-initialised ``out_std ≈ 26 000 eV`` + amplifies weight-update steps by 26 000×, causing immediate + gradient explosion. + """ + accum: dict[tuple[int, int], list] = defaultdict(list) + + for frame in sampled: + if ( + self.var_name not in frame + or "sel_type" not in frame + or "fparam" not in frame + ): + continue + xas = frame[self.var_name] # tensor, various shapes + sel_type = frame["sel_type"] + fparam = frame["fparam"] + + # flatten to [nf, task_dim], [nf], [nf, nfparam] + xas = xas.reshape(-1, self.task_dim) + sel_type = sel_type.reshape(-1).long() + fparam = fparam.reshape(-1, self.nfparam) + edge_idx = fparam.argmax(dim=-1) + + nf = xas.shape[0] + for i in range(nf): + t = int(sel_type[i].item()) + e = int(edge_idx[i].item()) + if 0 <= t < self.ntypes and 0 <= e < self.nfparam: + accum[(t, e)].append(xas[i, :2].detach().cpu().numpy()) + + if not accum: + log.warning( + "XASLoss.compute_output_stats: no frames with xas+sel_type+fparam found; " + "e_ref remains zero. Training may be unstable." + ) + return + + e_ref = torch.zeros( + self.ntypes, self.nfparam, 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + for (t, e), vals in accum.items(): + e_ref[t, e] = torch.tensor( + np.mean(vals, axis=0), dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + log.info( + f"XASLoss e_ref: type={t}, edge={e} -> " + f"E_min_ref={float(e_ref[t,e,0]):.2f} eV, " + f"E_max_ref={float(e_ref[t,e,1]):.2f} eV " + f"(n={len(vals)})" + ) + + self.e_ref.copy_(e_ref) + log.info( + f"XASLoss: e_ref computed for {len(accum)} (sel_type, edge) combinations." + ) + + if model is not None: + try: + am = model.atomic_model + + # 1. Copy e_ref into the model's own buffer so it is saved + # in the checkpoint and available at inference time without + # any external reference file (analogous to out_bias). + if getattr(am, "xas_e_ref", None) is not None: + am.xas_e_ref.copy_(e_ref.to(am.xas_e_ref.dtype)) + log.info("XASLoss: copied e_ref → model.atomic_model.xas_e_ref.") + + # 2. Reset energy-dim out_bias/out_std so the NN predicts + # chemical shifts instead of absolute energies. + # + # Why this is necessary + # ---------------------- + # The model stat phase initialises + # out_bias[:, :2] ≈ global_mean(E_min, E_max) ≈ 19 000 eV + # out_std[:, :2] ≈ global_std(E_min, E_max) ≈ 26 000 eV + # so atom_xas[:, 0] = NN_raw[:, 0] * 26 000 + 19 000. + # A single Adam step changes NN_raw by ~lr, which changes + # the physical output by lr × 26 000 = 2.7 eV — the same + # instability as out_bias for energy fitting if the reference + # is wrong. With out_std=1 / out_bias=0, the NN output for + # energy dims is interpreted directly as a chemical shift + # (target ≈ label − e_ref ≈ ±few eV), keeping gradient + # magnitudes O(1) and training stable. + key_idx = am.bias_keys.index(self.var_name) + with torch.no_grad(): + am.out_bias[key_idx, :, :2] = 0.0 + am.out_std[key_idx, :, :2] = 1.0 + log.info( + "XASLoss: reset out_bias[:,:2]=0 and out_std[:,:2]=1 " + "for energy dims (model predicts chemical shifts; " + "xas_e_ref restores absolute energies at inference)." + ) + except Exception as exc: + log.warning(f"XASLoss: could not update model energy-dim stats: {exc}") + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ def forward( self, input_dict: dict[str, torch.Tensor], @@ -76,7 +223,7 @@ def forward( # sel_type from label: [nf, 1] float → [nf] int sel_type = label["sel_type"][:, 0].long() - # element-wise mean: for each frame average over atoms of sel_type + # element-wise mean: average atom_prop over atoms of sel_type per frame nf, nloc, td = atom_prop.shape pred = torch.zeros(nf, td, dtype=atom_prop.dtype, device=atom_prop.device) for i in range(nf): @@ -87,27 +234,60 @@ def forward( label_xas = label[self.var_name] # [nf, task_dim] + # --- per-frame reference energy lookup --- + # edge_idx = argmax of one-hot fparam + fparam = input_dict.get("fparam") + if fparam is not None and fparam.numel() > 0: + edge_idx = fparam.reshape(nf, -1).argmax(dim=-1).clamp(0, self.nfparam - 1) + else: + edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device) + + # e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame) + e_ref_frame = self.e_ref[sel_type, edge_idx] # [nf, 2] + + # Shift the energy-dim TARGETS only. + # + # After compute_output_stats has reset out_bias[:,:2]=0 / out_std[:,:2]=1, + # the model outputs raw NN values ≈ 0 for dims 0,1. We train those + # dims against (label − e_ref), i.e. the chemical shift (±few eV), + # keeping gradient magnitudes O(1). Intensity dims (2:) are trained + # against the original label values unchanged. + # + # At inference, we add e_ref back to get the absolute edge energy. + label_shifted = label_xas.clone() + label_shifted[:, :2] = label_xas[:, :2] - e_ref_frame + + # --- loss --- loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] if self.loss_func == "smooth_mae": - loss += F.smooth_l1_loss(pred, label_xas, reduction="sum", beta=self.beta) + loss += F.smooth_l1_loss( + pred, label_shifted, reduction="sum", beta=self.beta + ) elif self.loss_func == "mae": - loss += F.l1_loss(pred, label_xas, reduction="sum") + loss += F.l1_loss(pred, label_shifted, reduction="sum") elif self.loss_func == "mse": - loss += F.mse_loss(pred, label_xas, reduction="sum") + loss += F.mse_loss(pred, label_shifted, reduction="sum") elif self.loss_func == "rmse": - loss += torch.sqrt(F.mse_loss(pred, label_xas, reduction="mean")) + loss += torch.sqrt(F.mse_loss(pred, label_shifted, reduction="mean")) else: raise RuntimeError(f"Unknown loss function: {self.loss_func}") + # --- metrics --- more_loss: dict[str, torch.Tensor] = {} if "mae" in self.metric: - more_loss["mae"] = F.l1_loss(pred, label_xas, reduction="mean").detach() + more_loss["mae"] = F.l1_loss( + pred, label_shifted, reduction="mean" + ).detach() if "rmse" in self.metric: more_loss["rmse"] = torch.sqrt( - F.mse_loss(pred, label_xas, reduction="mean") + F.mse_loss(pred, label_shifted, reduction="mean") ).detach() - model_pred[self.var_name] = pred + # Absolute prediction: add e_ref back to energy dims for eval / output + pred_abs = pred.clone() + pred_abs[:, :2] = pred[:, :2] + e_ref_frame + model_pred[self.var_name] = pred_abs + return model_pred, loss, more_loss @property diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..fbf7478778 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -41,6 +41,7 @@ ) from .property_atomic_model import ( DPPropertyAtomicModel, + DPXASAtomicModel, ) __all__ = [ @@ -51,6 +52,7 @@ "DPEnergyAtomicModel", "DPPolarAtomicModel", "DPPropertyAtomicModel", + "DPXASAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", diff --git a/deepmd/pt/model/atomic_model/property_atomic_model.py b/deepmd/pt/model/atomic_model/property_atomic_model.py index baf9c5b7fc..ffe3cc5fe6 100644 --- a/deepmd/pt/model/atomic_model/property_atomic_model.py +++ b/deepmd/pt/model/atomic_model/property_atomic_model.py @@ -52,3 +52,28 @@ def apply_out_stat( for kk in self.bias_keys: ret[kk] = ret[kk] * out_std[kk][0] + out_bias[kk][0] return ret + + +class DPXASAtomicModel(DPPropertyAtomicModel): + """Atomic model for XAS spectrum fitting. + + Extends :class:`DPPropertyAtomicModel` with a per-(absorbing_type, edge) + energy reference buffer ``xas_e_ref`` [ntypes, nfparam, 2]. The buffer is + populated by :meth:`deepmd.pt.loss.xas.XASLoss.compute_output_stats` before + training starts and is saved in the model checkpoint so that absolute edge + energies can be reconstructed at inference time without any external files. + """ + + def __init__( + self, descriptor: Any, fitting: Any, type_map: Any, **kwargs: Any + ) -> None: + super().__init__(descriptor, fitting, type_map, **kwargs) + nfparam: int = getattr(fitting, "numb_fparam", 0) + if nfparam > 0: + ntypes: int = len(type_map) + self.register_buffer( + "xas_e_ref", + torch.zeros(ntypes, nfparam, 2, dtype=torch.float64), + ) + else: + self.xas_e_ref: torch.Tensor | None = None diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 1aa3732580..90637c653f 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -69,6 +69,9 @@ from .property_model import ( PropertyModel, ) +from .xas_model import ( + XASModel, +) from .spin_model import ( SpinEnergyModel, SpinModel, @@ -269,7 +272,10 @@ def get_standard_model(model_params: dict) -> BaseModel: elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel elif fitting_net_type == "property": - modelcls = PropertyModel + property_name = model_params.get("fitting_net", {}).get( + "property_name", model_params.get("fitting_net", {}).get("var_name", "") + ) + modelcls = XASModel if property_name == "xas" else PropertyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") @@ -316,6 +322,7 @@ def get_model(model_params: dict) -> Any: "PolarModel", "SpinEnergyModel", "SpinModel", + "XASModel", "get_model", "make_hessian_model", "make_model", diff --git a/deepmd/pt/model/model/xas_model.py b/deepmd/pt/model/model/xas_model.py new file mode 100644 index 0000000000..a18ce44c10 --- /dev/null +++ b/deepmd/pt/model/model/xas_model.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.pt.model.atomic_model import ( + DPXASAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .property_model import ( + PropertyModel, +) + + +@BaseModel.register("xas") +class XASModel(PropertyModel): + """Model for XAS spectrum fitting. + + Identical to :class:`PropertyModel` but uses :class:`DPXASAtomicModel` + as the underlying atomic model, which carries the per-(absorbing_type, + edge) energy reference buffer ``xas_e_ref`` in the checkpoint. This + buffer is populated by :meth:`deepmd.pt.loss.xas.XASLoss.compute_output_stats` + before training starts and restored at inference time so that absolute + edge energies are available without any external reference files. + """ + + model_type = "xas" + + def __init__( + self, + descriptor: Any, + fitting: Any, + type_map: Any, + **kwargs: Any, + ) -> None: + xas_atomic = DPXASAtomicModel(descriptor, fitting, type_map, **kwargs) + super().__init__( + descriptor, fitting, type_map, atomic_model_=xas_atomic, **kwargs + ) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 034bcdc015..f2963b8ca7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -379,6 +379,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: else False, preset_observed_type=model_params.get("info", {}).get("observed_type"), ) + # For XAS loss: compute per-(absorbing_type, edge) reference energies + # from training data and store as a registered buffer in the loss module. + if not resuming and self.rank == 0 and hasattr(self.loss, "compute_output_stats"): + self.loss.compute_output_stats( + self.get_sample_func(), model=self.model + ) # Persist observed_type from stat into model_params and model_def_script if not resuming and self.rank == 0: observed = self.model.atomic_model.observed_type @@ -1760,6 +1766,8 @@ def get_loss( elif loss_type == "xas": loss_params["task_dim"] = _model.get_task_dim() loss_params["var_name"] = _model.get_var_name() + loss_params["ntypes"] = _ntypes + loss_params["nfparam"] = _model.get_fitting_net().numb_fparam return XASLoss(**loss_params) else: loss_params["starter_learning_rate"] = start_lr From 8c216121a311f71c968f80285bfbff46a2432082 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Mar 2026 06:34:05 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/entrypoints/test.py | 11 +++++------ deepmd/pt/loss/xas.py | 28 ++++++++++++++++++---------- deepmd/pt/model/model/__init__.py | 6 +++--- deepmd/pt/train/training.py | 10 ++++++---- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 8f9cacdd42..0773023d61 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -898,9 +898,7 @@ def test_property( # XAS requires sel_type.npy (per-frame absorbing element type index) if is_xas: - data.add( - "sel_type", 1, atomic=False, must=True, high_prec=False - ) + data.add("sel_type", 1, atomic=False, must=True, high_prec=False) test_data = data.get_test() mixed_type = data.mixed_type @@ -962,9 +960,10 @@ def test_property( xas_e_ref = None if xas_e_ref is not None and fparam is not None: import torch as _torch - edge_idx_all = _torch.tensor( - fparam.reshape(numb_test, -1) - ).argmax(dim=-1).numpy() + + edge_idx_all = ( + _torch.tensor(fparam.reshape(numb_test, -1)).argmax(dim=-1).numpy() + ) e_ref_np = xas_e_ref.cpu().numpy() # [ntypes, nfparam, 2] for i in range(numb_test): t = sel_type_int[i] diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py index 984c5e7487..8905029dc0 100644 --- a/deepmd/pt/loss/xas.py +++ b/deepmd/pt/loss/xas.py @@ -1,15 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -from collections import defaultdict -from typing import Any +from collections import ( + defaultdict, +) +from typing import ( + Any, +) import numpy as np import torch import torch.nn.functional as F -from deepmd.pt.loss.loss import TaskLoss -from deepmd.pt.utils import env -from deepmd.utils.data import DataRequirementItem +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.data import ( + DataRequirementItem, +) log = logging.getLogger(__name__) @@ -153,8 +163,8 @@ def compute_output_stats( ) log.info( f"XASLoss e_ref: type={t}, edge={e} -> " - f"E_min_ref={float(e_ref[t,e,0]):.2f} eV, " - f"E_max_ref={float(e_ref[t,e,1]):.2f} eV " + f"E_min_ref={float(e_ref[t, e, 0]):.2f} eV, " + f"E_max_ref={float(e_ref[t, e, 1]):.2f} eV " f"(n={len(vals)})" ) @@ -275,9 +285,7 @@ def forward( # --- metrics --- more_loss: dict[str, torch.Tensor] = {} if "mae" in self.metric: - more_loss["mae"] = F.l1_loss( - pred, label_shifted, reduction="mean" - ).detach() + more_loss["mae"] = F.l1_loss(pred, label_shifted, reduction="mean").detach() if "rmse" in self.metric: more_loss["rmse"] = torch.sqrt( F.mse_loss(pred, label_shifted, reduction="mean") diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 90637c653f..f577e4f0cf 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -69,13 +69,13 @@ from .property_model import ( PropertyModel, ) -from .xas_model import ( - XASModel, -) from .spin_model import ( SpinEnergyModel, SpinModel, ) +from .xas_model import ( + XASModel, +) def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f2963b8ca7..790cbff80e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -381,10 +381,12 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) # For XAS loss: compute per-(absorbing_type, edge) reference energies # from training data and store as a registered buffer in the loss module. - if not resuming and self.rank == 0 and hasattr(self.loss, "compute_output_stats"): - self.loss.compute_output_stats( - self.get_sample_func(), model=self.model - ) + if ( + not resuming + and self.rank == 0 + and hasattr(self.loss, "compute_output_stats") + ): + self.loss.compute_output_stats(self.get_sample_func(), model=self.model) # Persist observed_type from stat into model_params and model_def_script if not resuming and self.rank == 0: observed = self.model.atomic_model.observed_type From 250168bbdc3bccea4d426b9d7a2ad217b430a4ef Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:42:37 +0800 Subject: [PATCH 9/9] fix:device --- deepmd/pt/loss/xas.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/loss/xas.py b/deepmd/pt/loss/xas.py index 8905029dc0..2091f5e04c 100644 --- a/deepmd/pt/loss/xas.py +++ b/deepmd/pt/loss/xas.py @@ -253,7 +253,9 @@ def forward( edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device) # e_ref_frame: [nf, 2] (E_min_ref, E_max_ref for each frame) - e_ref_frame = self.e_ref[sel_type, edge_idx] # [nf, 2] + # Indices must be on the same device as the buffer (handles CPU/GPU mismatch) + _dev = self.e_ref.device + e_ref_frame = self.e_ref[sel_type.to(_dev), edge_idx.to(_dev)].to(pred.device) # Shift the energy-dim TARGETS only. #