From 9847bf845cac15ee2875ccb5a6d78e9d1936a749 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 15:06:39 +0800 Subject: [PATCH 01/10] feat: add numpy tabulate math in deepmd/utils/tabulate_math.py Backend-agnostic DPTabulate with numpy implementations of: - Activation derivatives (grad, grad_grad) - Chain-rule propagation (unaggregated_dy_dx_s, etc.) - Embedding net forward pass (_make_data, _layer_0, _layer_1) - Network variable extraction (_get_network_variable) - Descriptor type detection from serialized data --- deepmd/utils/tabulate_math.py | 582 ++++++++++++++++++++++++++++++++++ 1 file changed, 582 insertions(+) create mode 100644 deepmd/utils/tabulate_math.py diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py new file mode 100644 index 0000000000..a279ee67e0 --- /dev/null +++ b/deepmd/utils/tabulate_math.py @@ -0,0 +1,582 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Backend-agnostic tabulation math using numpy. + +Provides the pure-math functions for model compression tabulation: +activation derivatives, chain-rule derivative propagation, and +embedding-net forward pass. Used by both pt and pt_expt backends. +""" + +import logging +from functools import ( + cached_property, +) +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) +from deepmd.utils.tabulate import ( + BaseTabulate, +) + +log = logging.getLogger(__name__) + +SQRT_2_PI = np.sqrt(2 / np.pi) +GGELU = 0.044715 + +# Mapping from activation function name to integer functype +# used by grad/grad_grad for derivative computation. +ACTIVATION_TO_FUNCTYPE: dict[str, int] = { + "tanh": 1, + "gelu": 2, + "gelu_tf": 2, + "relu": 3, + "relu6": 4, + "softplus": 5, + "sigmoid": 6, + "silu": 7, +} + + +# ---- Activation derivatives (numpy) ---- + + +def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: + """First derivative of the activation function.""" + if functype == 1: + return 1 - y * y + elif functype == 2: + var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + return ( + 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) + + 0.5 * var + + 0.5 + ) + elif functype == 3: + return np.where(xbar > 0, np.ones_like(xbar), np.zeros_like(xbar)) + elif functype == 4: + return np.where( + (xbar > 0) & (xbar < 6), np.ones_like(xbar), np.zeros_like(xbar) + ) + elif functype == 5: + return 1.0 - 1.0 / (1.0 + np.exp(xbar)) + elif functype == 6: + return y * (1 - y) + elif functype == 7: + sig = 1.0 / (1.0 + np.exp(-xbar)) + return sig + xbar * sig * (1 - sig) + else: + raise ValueError(f"Unsupported function type: {functype}") + + +def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: + """Second derivative of the activation function.""" + if functype == 1: + return -2 * y * (1 - y * y) + elif functype == 2: + var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) + return ( + 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) + - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1 + + var2 + ) + elif functype in [3, 4]: + return np.zeros_like(xbar) + elif functype == 5: + exp_xbar = np.exp(xbar) + return exp_xbar / ((1 + exp_xbar) * (1 + exp_xbar)) + elif functype == 6: + return y * (1 - y) * (1 - 2 * y) + elif functype == 7: + sig = 1.0 / (1.0 + np.exp(-xbar)) + d_sig = sig * (1 - sig) + return 2 * d_sig + xbar * d_sig * (1 - 2 * sig) + else: + return -np.ones_like(xbar) + + +# ---- Chain-rule derivative propagation (numpy) ---- + + +def unaggregated_dy_dx_s( + y: np.ndarray, w: np.ndarray, xbar: np.ndarray, functype: int +) -> np.ndarray: + """First derivative for the first layer (scalar input).""" + if y.ndim != 2: + raise ValueError("Dim of input y should be 2") + if w.ndim != 2: + raise ValueError("Dim of input w should be 2") + if xbar.ndim != 2: + raise ValueError("Dim of input xbar should be 2") + + grad_xbar_y = grad(xbar, y, functype) + w_flat = np.ravel(w)[: y.shape[1]] + w_rep = np.tile(w_flat, (y.shape[0], 1)) + return grad_xbar_y * w_rep + + +def unaggregated_dy2_dx_s( + y: np.ndarray, + dy: np.ndarray, + w: np.ndarray, + xbar: np.ndarray, + functype: int, +) -> np.ndarray: + """Second derivative for the first layer (scalar input).""" + if y.ndim != 2: + raise ValueError("Dim of input y should be 2") + if dy.ndim != 2: + raise ValueError("Dim of input dy should be 2") + if w.ndim != 2: + raise ValueError("Dim of input w should be 2") + if xbar.ndim != 2: + raise ValueError("Dim of input xbar should be 2") + + gg = grad_grad(xbar, y, functype) + w_flat = np.ravel(w)[: y.shape[1]] + w_rep = np.tile(w_flat, (y.shape[0], 1)) + return gg * w_rep * w_rep + + +def unaggregated_dy_dx( + z: np.ndarray, + w: np.ndarray, + dy_dx: np.ndarray, + ybar: np.ndarray, + functype: int, +) -> np.ndarray: + """First derivative for subsequent layers.""" + if z.ndim != 2: + raise ValueError("z must have 2 dimensions") + if w.ndim != 2: + raise ValueError("w must have 2 dimensions") + if dy_dx.ndim != 2: + raise ValueError("dy_dx must have 2 dimensions") + if ybar.ndim != 2: + raise ValueError("ybar must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + + grad_ybar_z = grad(ybar, z, functype) + dy_dx = np.ravel(dy_dx)[: length * size].reshape(length, size) + accumulator = dy_dx @ w + dz_drou = grad_ybar_z * accumulator + + if width == size: + dz_drou += dy_dx + if width == 2 * size: + dy_dx = np.concatenate((dy_dx, dy_dx), axis=1) + dz_drou += dy_dx + + return dz_drou + + +def unaggregated_dy2_dx( + z: np.ndarray, + w: np.ndarray, + dy_dx: np.ndarray, + dy2_dx: np.ndarray, + ybar: np.ndarray, + functype: int, +) -> np.ndarray: + """Second derivative for subsequent layers.""" + if z.ndim != 2: + raise ValueError("z must have 2 dimensions") + if w.ndim != 2: + raise ValueError("w must have 2 dimensions") + if dy_dx.ndim != 2: + raise ValueError("dy_dx must have 2 dimensions") + if dy2_dx.ndim != 2: + raise ValueError("dy2_dx must have 2 dimensions") + if ybar.ndim != 2: + raise ValueError("ybar must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + + grad_ybar_z = grad(ybar, z, functype) + gg = grad_grad(ybar, z, functype) + + dy2_dx = np.ravel(dy2_dx)[: length * size].reshape(length, size) + dy_dx = np.ravel(dy_dx)[: length * size].reshape(length, size) + + acc1 = dy2_dx @ w + acc2 = dy_dx @ w + + dz_drou = grad_ybar_z * acc1 + gg * acc2 * acc2 + + if width == size: + dz_drou += dy2_dx + if width == 2 * size: + dy2_dx = np.concatenate((dy2_dx, dy2_dx), axis=1) + dz_drou += dy2_dx + + return dz_drou + + +# ---- DPTabulate with numpy math ---- + + +class DPTabulate(BaseTabulate): + r"""Backend-agnostic tabulation using numpy. + + Compress a model by tabulating the embedding-net. The table is composed + of fifth-order polynomial coefficients assembled from two sub-tables. + + Parameters + ---------- + descrpt + Descriptor of the original model. + neuron + Number of neurons in each hidden layer of the embedding net. + type_one_side + Try to build N_types tables. Otherwise, building N_types^2 tables. + exclude_types + Excluded type pairs with no interaction. + activation_fn_name + Name of the activation function (e.g. "tanh", "gelu", "relu"). + """ + + def __init__( + self, + descrpt: Any, + neuron: list[int], + type_one_side: bool = False, + exclude_types: list[list[int]] = [], + activation_fn_name: str = "tanh", + ) -> None: + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + True, # is_pt flag (for _build_lower numpy int conversion) + ) + self._activation_fn = get_activation_fn(activation_fn_name) + activation_fn_name = activation_fn_name.lower() + if activation_fn_name not in ACTIVATION_TO_FUNCTYPE: + raise RuntimeError(f"Unknown activation function: {activation_fn_name}") + self.functype = ACTIVATION_TO_FUNCTYPE[activation_fn_name] + + self.descrpt_type = self._get_descrpt_type() + + supported_descrpt_type = ("Atten", "A", "T", "T_TEBD", "R") + if self.descrpt_type in supported_descrpt_type: + self.sel_a = self.descrpt.get_sel() + self.rcut = self.descrpt.get_rcut() + self.rcut_smth = self.descrpt.get_rcut_smth() + else: + raise RuntimeError("Unsupported descriptor") + + serialized = self.descrpt.serialize() + # For DPA2, use the repinit sub-block's serialized data + if self.descrpt_type == "Atten" and "repinit_variable" in serialized: + serialized = serialized["repinit_variable"] + self.davg = serialized["@variables"]["davg"] + self.dstd = serialized["@variables"]["dstd"] + self.embedding_net_nodes = serialized["embeddings"]["networks"] + + self.ntypes = self.descrpt.get_ntypes() + + self.layer_size = self._get_layer_size() + self.table_size = self._get_table_size() + + self.bias = self._get_bias() + self.matrix = self._get_matrix() + + self.data_type = self._get_data_type() + self.last_layer_size = self._get_last_layer_size() + + def _make_data(self, xx: np.ndarray, idx: int) -> Any: + """Forward pass through embedding net with derivative computation.""" + xx = xx.reshape(-1, 1) + for layer in range(self.layer_size): + if layer == 0: + xbar = ( + np.matmul(xx, self.matrix["layer_" + str(layer + 1)][idx]) + + self.bias["layer_" + str(layer + 1)][idx] + ) + if self.neuron[0] == 1: + yy = ( + self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + xx + ) + dy = unaggregated_dy_dx_s( + yy - xx, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + np.ones((1, 1), dtype=yy.dtype) + dy2 = unaggregated_dy2_dx_s( + yy - xx, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + elif self.neuron[0] == 2: + tt, yy = self._layer_1( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy - tt, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + np.ones((1, 2), dtype=yy.dtype) + dy2 = unaggregated_dy2_dx_s( + yy - tt, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + yy = self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + dy2 = unaggregated_dy2_dx_s( + yy, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + ybar = ( + np.matmul(yy, self.matrix["layer_" + str(layer + 1)][idx]) + + self.bias["layer_" + str(layer + 1)][idx] + ) + if self.neuron[layer] == self.neuron[layer - 1]: + zz = ( + self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + yy + ) + dz = unaggregated_dy_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + elif self.neuron[layer] == 2 * self.neuron[layer - 1]: + tt, zz = self._layer_1( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + else: + zz = self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + dy = dz + yy = zz + + vv = zz.astype(self.data_type) + dd = dy.astype(self.data_type) + d2 = dy2.astype(self.data_type) + return vv, dd, d2 + + def _layer_0(self, x: np.ndarray, w: np.ndarray, b: np.ndarray) -> np.ndarray: + return self._activation_fn(np.matmul(x, w) + b) + + def _layer_1( + self, x: np.ndarray, w: np.ndarray, b: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: + t = np.concatenate([x, x], axis=1) + return t, self._activation_fn(np.matmul(x, w) + b) + t + + def _get_descrpt_type(self) -> str: + """Determine descriptor type from serialized data.""" + data = self.descrpt.serialize() + type_str = data.get("type", "") + type_map = { + "se_e2_a": "A", + "se_r": "R", + "se_e3": "T", + "se_e3_tebd": "T_TEBD", + "dpa1": "Atten", + "se_atten_v2": "Atten", + } + descrpt_type = type_map.get(type_str) + if descrpt_type is None: + raise RuntimeError(f"Unsupported descriptor type: {type_str}") + return descrpt_type + + def _get_layer_size(self) -> int: + layer_size = 0 + basic_size = 0 + if self.type_one_side: + basic_size = len(self.embedding_net_nodes) * len(self.neuron) + else: + basic_size = ( + len(self.embedding_net_nodes) + * len(self.embedding_net_nodes[0]) + * len(self.neuron) + ) + if self.descrpt_type in ("Atten", "T_TEBD"): + layer_size = len(self.embedding_net_nodes[0]["layers"]) + elif self.descrpt_type == "A": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + elif self.descrpt_type == "T": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + elif self.descrpt_type == "R": + layer_size = basic_size // ( + self.ntypes * self.ntypes - len(self.exclude_types) + ) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + else: + raise RuntimeError("Unsupported descriptor") + return layer_size + + def _get_network_variable(self, var_name: str) -> dict: + """Get network variables (weights or biases) for all layers.""" + result = {} + for layer in range(1, self.layer_size + 1): + result["layer_" + str(layer)] = [] + if self.descrpt_type == "Atten": + node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ + var_name + ] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "A": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + elif self.descrpt_type == "T": + for ii in range(self.ntypes): + for jj in range(ii, self.ntypes): + node = self.embedding_net_nodes[jj * self.ntypes + ii][ + "layers" + ][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "T_TEBD": + node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ + var_name + ] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "R": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + raise RuntimeError("Unsupported descriptor") + return result + + def _get_bias(self) -> Any: + return self._get_network_variable("b") + + def _get_matrix(self) -> Any: + return self._get_network_variable("w") + + def _convert_numpy_to_tensor(self) -> None: + """No-op: data stays as numpy arrays.""" + pass + + @cached_property + def _n_all_excluded(self) -> int: + """The number of types excluding all types.""" + return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) From f52dca3d96f3dc03e2b6c12cf3d30ab7f16fab35 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 15:15:06 +0800 Subject: [PATCH 02/10] refactor(pt): DPTabulate inherits from shared numpy tabulate_math Remove ~500 lines of duplicated math. PT DPTabulate now only overrides _get_descrpt_type (isinstance checks) and _convert_numpy_to_tensor (torch tensor conversion). --- deepmd/pt/utils/tabulate.py | 604 ++---------------------------------- 1 file changed, 22 insertions(+), 582 deletions(-) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index a308f2d36b..10f3938380 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import logging -from functools import ( - cached_property, -) +"""PyTorch-specific DPTabulate wrapper. + +Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate`` +and adds torch-specific ``_convert_numpy_to_tensor`` and +``_get_descrpt_type`` (isinstance checks against PT descriptor classes). +""" + from typing import ( Any, ) -import numpy as np import torch import deepmd @@ -17,37 +19,28 @@ from deepmd.pt.utils.utils import ( ActivationFn, ) -from deepmd.utils.tabulate import ( - BaseTabulate, -) - -log = logging.getLogger(__name__) +from deepmd.utils.tabulate_math import DPTabulate as DPTabulateBase -SQRT_2_PI = np.sqrt(2 / np.pi) -GGELU = 0.044715 +class DPTabulate(DPTabulateBase): + r"""PyTorch tabulation wrapper. -class DPTabulate(BaseTabulate): - r"""Class for tabulation. - - Compress a model, which including tabulating the embedding-net. - The table is composed of fifth-order polynomial coefficients and is assembled from two sub-tables. The first table takes the stride(parameter) as it's uniform stride, while the second table takes 10 * stride as it's uniform stride - The range of the first table is automatically detected by deepmd-kit, while the second table ranges from the first table's upper boundary(upper) to the extrapolate(parameter) * upper. + Accepts a PT ``ActivationFn`` module and delegates all math to the + numpy base class. Only overrides tensor conversion and descriptor + type detection. Parameters ---------- descrpt - Descriptor of the original model + Descriptor of the original model. neuron - Number of neurons in each hidden layers of the embedding net :math:`\\mathcal{N}` + Number of neurons in each hidden layer of the embedding net. type_one_side - Try to build N_types tables. Otherwise, building N_types^2 tables - exclude_types : List[List[int]] - The excluded pairs of types which have no interaction with each other. - For example, `[[0, 1]]` means no interaction between type 0 and type 1. - activation_function - The activation function in the embedding net. See :class:`ActivationFn` - for supported options (e.g. "tanh", "gelu", "relu", "silu"). + Try to build N_types tables. + exclude_types + Excluded type pairs. + activation_fn + The activation function (PT ActivationFn module). """ def __init__( @@ -63,237 +56,11 @@ def __init__( neuron, type_one_side, exclude_types, - True, + activation_fn_name=activation_fn.activation, ) - self.descrpt_type = self._get_descrpt_type() - - supported_descrpt_type = ("Atten", "A", "T", "T_TEBD", "R") - - if self.descrpt_type in supported_descrpt_type: - self.sel_a = self.descrpt.get_sel() - self.rcut = self.descrpt.get_rcut() - self.rcut_smth = self.descrpt.get_rcut_smth() - else: - raise RuntimeError("Unsupported descriptor") - - # functype - activation_map = { - "tanh": 1, - "gelu": 2, - "gelu_tf": 2, - "relu": 3, - "relu6": 4, - "softplus": 5, - "sigmoid": 6, - "silu": 7, - } - - activation = activation_fn.activation - if activation in activation_map: - self.functype = activation_map[activation] - else: - raise RuntimeError("Unknown activation function type!") - - self.activation_fn = activation_fn - serialized = self.descrpt.serialize() - if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2): - serialized = serialized["repinit_variable"] - self.davg = serialized["@variables"]["davg"] - self.dstd = serialized["@variables"]["dstd"] - self.embedding_net_nodes = serialized["embeddings"]["networks"] - - self.ntypes = self.descrpt.get_ntypes() - - self.layer_size = self._get_layer_size() - self.table_size = self._get_table_size() - - self.bias = self._get_bias() - self.matrix = self._get_matrix() - - self.data_type = self._get_data_type() - self.last_layer_size = self._get_last_layer_size() - - def _make_data(self, xx: np.ndarray, idx: int) -> Any: - """Generate tabulation data for the given input. - - Parameters - ---------- - xx : np.ndarray - Input values to tabulate - idx : int - Index for accessing the correct network parameters - - Returns - ------- - tuple[np.ndarray, np.ndarray, np.ndarray] - Values, first derivatives, and second derivatives - """ - xx = torch.from_numpy(xx).view(-1, 1).to(env.DEVICE) - for layer in range(self.layer_size): - if layer == 0: - xbar = torch.matmul( - xx, - torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( - env.DEVICE - ), - ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( - env.DEVICE - ) - if self.neuron[0] == 1: - yy = ( - self._layer_0( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - + xx - ) - dy = unaggregated_dy_dx_s( - yy - xx, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) + torch.ones((1, 1), dtype=yy.dtype, device=yy.device) - dy2 = unaggregated_dy2_dx_s( - yy - xx, - dy, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) - elif self.neuron[0] == 2: - tt, yy = self._layer_1( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - dy = unaggregated_dy_dx_s( - yy - tt, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) + torch.ones((1, 2), dtype=yy.dtype, device=yy.device) - dy2 = unaggregated_dy2_dx_s( - yy - tt, - dy, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) - else: - yy = self._layer_0( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - dy = unaggregated_dy_dx_s( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) - dy2 = unaggregated_dy2_dx_s( - yy, - dy, - self.matrix["layer_" + str(layer + 1)][idx], - xbar, - self.functype, - ) - else: - ybar = torch.matmul( - yy, - torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( - env.DEVICE - ), - ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( - env.DEVICE - ) - if self.neuron[layer] == self.neuron[layer - 1]: - zz = ( - self._layer_0( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - + yy - ) - dz = unaggregated_dy_dx( - zz - yy, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - ybar, - self.functype, - ) - dy2 = unaggregated_dy2_dx( - zz - yy, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - dy2, - ybar, - self.functype, - ) - elif self.neuron[layer] == 2 * self.neuron[layer - 1]: - tt, zz = self._layer_1( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - dz = unaggregated_dy_dx( - zz - tt, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - ybar, - self.functype, - ) - dy2 = unaggregated_dy2_dx( - zz - tt, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - dy2, - ybar, - self.functype, - ) - else: - zz = self._layer_0( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - dz = unaggregated_dy_dx( - zz, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - ybar, - self.functype, - ) - dy2 = unaggregated_dy2_dx( - zz, - self.matrix["layer_" + str(layer + 1)][idx], - dy, - dy2, - ybar, - self.functype, - ) - dy = dz - yy = zz - - vv = zz.detach().cpu().numpy().astype(self.data_type) - dd = dy.detach().cpu().numpy().astype(self.data_type) - d2 = dy2.detach().cpu().numpy().astype(self.data_type) - return vv, dd, d2 - - def _layer_0(self, x: torch.Tensor, w: np.ndarray, b: np.ndarray) -> torch.Tensor: - w = torch.from_numpy(w).to(env.DEVICE) - b = torch.from_numpy(b).to(env.DEVICE) - return self.activation_fn(torch.matmul(x, w) + b) - - def _layer_1(self, x: torch.Tensor, w: np.ndarray, b: np.ndarray) -> torch.Tensor: - w = torch.from_numpy(w).to(env.DEVICE) - b = torch.from_numpy(b).to(env.DEVICE) - t = torch.cat([x, x], dim=1) - return t, self.activation_fn(torch.matmul(x, w) + b) + t def _get_descrpt_type(self) -> str: + """Detect descriptor type via isinstance checks against PT classes.""" if isinstance( self.descrpt, ( @@ -312,334 +79,7 @@ def _get_descrpt_type(self) -> str: return "T_TEBD" raise RuntimeError(f"Unsupported descriptor {self.descrpt}") - def _get_layer_size(self) -> int: - # get the number of layers in EmbeddingNet - layer_size = 0 - basic_size = 0 - if self.type_one_side: - basic_size = len(self.embedding_net_nodes) * len(self.neuron) - else: - basic_size = ( - len(self.embedding_net_nodes) - * len(self.embedding_net_nodes[0]) - * len(self.neuron) - ) - if self.descrpt_type in ("Atten", "T_TEBD"): - layer_size = len(self.embedding_net_nodes[0]["layers"]) - elif self.descrpt_type == "A": - layer_size = len(self.embedding_net_nodes[0]["layers"]) - if self.type_one_side: - layer_size = basic_size // (self.ntypes - self._n_all_excluded) - elif self.descrpt_type == "T": - layer_size = len(self.embedding_net_nodes[0]["layers"]) - # layer_size = basic_size // int(comb(self.ntypes + 1, 2)) - elif self.descrpt_type == "R": - layer_size = basic_size // ( - self.ntypes * self.ntypes - len(self.exclude_types) - ) - if self.type_one_side: - layer_size = basic_size // (self.ntypes - self._n_all_excluded) - else: - raise RuntimeError("Unsupported descriptor") - return layer_size - - def _get_network_variable(self, var_name: str) -> dict: - """Get network variables (weights or biases) for all layers. - - Parameters - ---------- - var_name : str - Name of the variable to get ('w' for weights, 'b' for biases) - - Returns - ------- - dict - Dictionary mapping layer names to their variables - """ - result = {} - for layer in range(1, self.layer_size + 1): - result["layer_" + str(layer)] = [] - if self.descrpt_type == "Atten": - node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ - var_name - ] - result["layer_" + str(layer)].append(node) - elif self.descrpt_type == "A": - if self.type_one_side: - for ii in range(0, self.ntypes): - if not self._all_excluded(ii): - node = self.embedding_net_nodes[ii]["layers"][layer - 1][ - "@variables" - ][var_name] - result["layer_" + str(layer)].append(node) - else: - result["layer_" + str(layer)].append(np.array([])) - else: - for ii in range(0, self.ntypes * self.ntypes): - if ( - ii // self.ntypes, - ii % self.ntypes, - ) not in self.exclude_types: - node = self.embedding_net_nodes[ - (ii % self.ntypes) * self.ntypes + ii // self.ntypes - ]["layers"][layer - 1]["@variables"][var_name] - result["layer_" + str(layer)].append(node) - else: - result["layer_" + str(layer)].append(np.array([])) - elif self.descrpt_type == "T": - for ii in range(self.ntypes): - for jj in range(ii, self.ntypes): - node = self.embedding_net_nodes[jj * self.ntypes + ii][ - "layers" - ][layer - 1]["@variables"][var_name] - result["layer_" + str(layer)].append(node) - elif self.descrpt_type == "T_TEBD": - # For the se_e3_tebd descriptor, a single, - # shared embedding network is used for all type pairs - node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ - var_name - ] - result["layer_" + str(layer)].append(node) - elif self.descrpt_type == "R": - if self.type_one_side: - for ii in range(0, self.ntypes): - if not self._all_excluded(ii): - node = self.embedding_net_nodes[ii]["layers"][layer - 1][ - "@variables" - ][var_name] - result["layer_" + str(layer)].append(node) - else: - result["layer_" + str(layer)].append(np.array([])) - else: - for ii in range(0, self.ntypes * self.ntypes): - if ( - ii // self.ntypes, - ii % self.ntypes, - ) not in self.exclude_types: - node = self.embedding_net_nodes[ - (ii % self.ntypes) * self.ntypes + ii // self.ntypes - ]["layers"][layer - 1]["@variables"][var_name] - result["layer_" + str(layer)].append(node) - else: - result["layer_" + str(layer)].append(np.array([])) - else: - raise RuntimeError("Unsupported descriptor") - return result - - def _get_bias(self) -> Any: - return self._get_network_variable("b") - - def _get_matrix(self) -> Any: - return self._get_network_variable("w") - def _convert_numpy_to_tensor(self) -> None: """Convert self.data from np.ndarray to torch.Tensor.""" for ii in self.data: self.data[ii] = torch.tensor(self.data[ii], device=env.DEVICE) # pylint: disable=no-explicit-dtype - - @cached_property - def _n_all_excluded(self) -> int: - """Then number of types excluding all types.""" - return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) - - -# customized op -def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tensor: - if functype == 1: - return 1 - y * y - - elif functype == 2: - var = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) - return ( - 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) - + 0.5 * var - + 0.5 - ) - - elif functype == 3: - return torch.where(xbar > 0, torch.ones_like(xbar), torch.zeros_like(xbar)) - - elif functype == 4: - return torch.where( - (xbar > 0) & (xbar < 6), torch.ones_like(xbar), torch.zeros_like(xbar) - ) - - elif functype == 5: - return 1.0 - 1.0 / (1.0 + torch.exp(xbar)) - - elif functype == 6: - return y * (1 - y) - - elif functype == 7: - # silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) - sig = torch.sigmoid(xbar) - return sig + xbar * sig * (1 - sig) - - else: - raise ValueError(f"Unsupported function type: {functype}") - - -def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tensor: - if functype == 1: - return -2 * y * (1 - y * y) - - elif functype == 2: - var1 = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) - var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) - return ( - 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) - - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1 - + var2 - ) - - elif functype in [3, 4]: - return torch.zeros_like(xbar) - - elif functype == 5: - exp_xbar = torch.exp(xbar) - return exp_xbar / ((1 + exp_xbar) * (1 + exp_xbar)) - - elif functype == 6: - return y * (1 - y) * (1 - 2 * y) - - elif functype == 7: - sig = torch.sigmoid(xbar) - d_sig = sig * (1 - sig) - # silu''(x) = 2 * d_sig + x * d_sig * (1 - 2 * sig) - return 2 * d_sig + xbar * d_sig * (1 - 2 * sig) - - else: - return -torch.ones_like(xbar) - - -def unaggregated_dy_dx_s( - y: torch.Tensor, w_np: np.ndarray, xbar: torch.Tensor, functype: int -) -> torch.Tensor: - w = torch.from_numpy(w_np).to(env.DEVICE) - y = y.to(env.DEVICE) - xbar = xbar.to(env.DEVICE) - if y.dim() != 2: - raise ValueError("Dim of input y should be 2") - if w.dim() != 2: - raise ValueError("Dim of input w should be 2") - if xbar.dim() != 2: - raise ValueError("Dim of input xbar should be 2") - - grad_xbar_y = grad(xbar, y, functype) - - w = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1) - - dy_dx = grad_xbar_y * w - - return dy_dx - - -def unaggregated_dy2_dx_s( - y: torch.Tensor, - dy: torch.Tensor, - w_np: np.ndarray, - xbar: torch.Tensor, - functype: int, -) -> torch.Tensor: - w = torch.from_numpy(w_np).to(env.DEVICE) - y = y.to(env.DEVICE) - dy = dy.to(env.DEVICE) - xbar = xbar.to(env.DEVICE) - if y.dim() != 2: - raise ValueError("Dim of input y should be 2") - if dy.dim() != 2: - raise ValueError("Dim of input dy should be 2") - if w.dim() != 2: - raise ValueError("Dim of input w should be 2") - if xbar.dim() != 2: - raise ValueError("Dim of input xbar should be 2") - - grad_grad_result = grad_grad(xbar, y, functype) - - w_flattened = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1) - - dy2_dx = grad_grad_result * w_flattened * w_flattened - - return dy2_dx - - -def unaggregated_dy_dx( - z: torch.Tensor, - w_np: np.ndarray, - dy_dx: torch.Tensor, - ybar: torch.Tensor, - functype: int, -) -> torch.Tensor: - w = torch.from_numpy(w_np).to(env.DEVICE) - if z.dim() != 2: - raise ValueError("z tensor must have 2 dimensions") - if w.dim() != 2: - raise ValueError("w tensor must have 2 dimensions") - if dy_dx.dim() != 2: - raise ValueError("dy_dx tensor must have 2 dimensions") - if ybar.dim() != 2: - raise ValueError("ybar tensor must have 2 dimensions") - - length, width = z.shape - size = w.shape[0] - - grad_ybar_z = grad(ybar, z, functype) - - dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size) - - accumulator = dy_dx @ w - - dz_drou = grad_ybar_z * accumulator - - if width == size: - dz_drou += dy_dx - if width == 2 * size: - dy_dx = torch.cat((dy_dx, dy_dx), dim=1) - dz_drou += dy_dx - - return dz_drou - - -def unaggregated_dy2_dx( - z: torch.Tensor, - w_np: np.ndarray, - dy_dx: torch.Tensor, - dy2_dx: torch.Tensor, - ybar: torch.Tensor, - functype: int, -) -> torch.Tensor: - w = torch.from_numpy(w_np).to(env.DEVICE) - if z.dim() != 2: - raise ValueError("z tensor must have 2 dimensions") - if w.dim() != 2: - raise ValueError("w tensor must have 2 dimensions") - if dy_dx.dim() != 2: - raise ValueError("dy_dx tensor must have 2 dimensions") - if dy2_dx.dim() != 2: - raise ValueError("dy2_dx tensor must have 2 dimensions") - if ybar.dim() != 2: - raise ValueError("ybar tensor must have 2 dimensions") - - length, width = z.shape - size = w.shape[0] - - grad_ybar_z = grad(ybar, z, functype) - grad_grad_ybar_z = grad_grad(ybar, z, functype) - - dy2_dx = dy2_dx.view(-1)[: (length * size)].view(length, size) - dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size) - - accumulator1 = dy2_dx @ w - accumulator2 = dy_dx @ w - - dz_drou = ( - grad_ybar_z * accumulator1 + grad_grad_ybar_z * accumulator2 * accumulator2 - ) - - if width == size: - dz_drou += dy2_dx - if width == 2 * size: - dy2_dx = torch.cat((dy2_dx, dy2_dx), dim=1) - dz_drou += dy2_dx - - return dz_drou From 5bd68f51625953fca842745bf882ef065b655802 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 15:19:48 +0800 Subject: [PATCH 03/10] refactor(pt_expt): inherit from shared tabulate_math, remove pt dependency DPTabulate now inherits from deepmd.utils.tabulate_math.DPTabulate instead of deepmd.pt.utils.tabulate.DPTabulate. Removed all imports of ActivationFn and DPTabulatePT from deepmd.pt. Descriptors now pass activation_fn_name (string) instead of ActivationFn (torch module). --- deepmd/pt_expt/descriptor/dpa1.py | 5 +-- deepmd/pt_expt/descriptor/dpa2.py | 5 +-- deepmd/pt_expt/descriptor/se_e2_a.py | 5 +-- deepmd/pt_expt/descriptor/se_r.py | 5 +-- deepmd/pt_expt/descriptor/se_t.py | 5 +-- deepmd/pt_expt/descriptor/se_t_tebd.py | 5 +-- deepmd/pt_expt/utils/tabulate.py | 51 ++++++++++++++------------ 7 files changed, 34 insertions(+), 47 deletions(-) diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py index 7e08f99c12..d72a12267a 100644 --- a/deepmd/pt_expt/descriptor/dpa1.py +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -49,9 +49,6 @@ def enable_compression( check_frequency The overflow check frequency """ - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -71,7 +68,7 @@ def enable_compression( data["neuron"], data["type_one_side"], data["exclude_types"], - ActivationFn(data["activation_function"]), + data["activation_function"], ) self.table_config = [ table_extrapolate, diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index 15f8c9e53f..0d389af070 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -57,9 +57,6 @@ def enable_compression( check_frequency The overflow check frequency """ - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -105,7 +102,7 @@ def enable_compression( repinit_data["neuron"], repinit_data.get("type_one_side", False), repinit_data.get("exclude_types", []), - ActivationFn(repinit_data["activation_function"]), + repinit_data["activation_function"], ) self.table_config = [ table_extrapolate, diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 2c6c506162..38be83c46c 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -34,9 +34,6 @@ def enable_compression( table_stride_2: float = 0.1, check_frequency: int = -1, ) -> None: - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -49,7 +46,7 @@ def enable_compression( data["neuron"], data["type_one_side"], data["exclude_types"], - ActivationFn(data["activation_function"]), + data["activation_function"], ) self.table_config = [ table_extrapolate, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index c92e9d2645..c2fd34e6b5 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -34,9 +34,6 @@ def enable_compression( table_stride_2: float = 0.1, check_frequency: int = -1, ) -> None: - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -49,7 +46,7 @@ def enable_compression( data["neuron"], data["type_one_side"], data["exclude_types"], - ActivationFn(data["activation_function"]), + data["activation_function"], ) self.table_config = [ table_extrapolate, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index b5e923394d..806d5eca7a 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -35,9 +35,6 @@ def enable_compression( table_stride_2: float = 0.1, check_frequency: int = -1, ) -> None: - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -49,7 +46,7 @@ def enable_compression( self, data["neuron"], exclude_types=data["exclude_types"], - activation_fn=ActivationFn(data["activation_function"]), + activation_fn_name=data["activation_function"], ) # SE_T scales strides by 10 stride_1_scaled = table_stride_1 * 10 diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index a619ae64d5..385bf0dfb6 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -48,9 +48,6 @@ def enable_compression( check_frequency The overflow check frequency """ - from deepmd.pt.utils.utils import ( - ActivationFn, - ) from deepmd.pt_expt.utils.tabulate import ( DPTabulate, ) @@ -65,7 +62,7 @@ def enable_compression( self, data["neuron"], exclude_types=data["exclude_types"], - activation_fn=ActivationFn(data["activation_function"]), + activation_fn_name=data["activation_function"], ) # SE_T scales strides by 10 stride_1_scaled = table_stride_1 * 10 diff --git a/deepmd/pt_expt/utils/tabulate.py b/deepmd/pt_expt/utils/tabulate.py index e6c3d7ebe0..37ea0b3d26 100644 --- a/deepmd/pt_expt/utils/tabulate.py +++ b/deepmd/pt_expt/utils/tabulate.py @@ -1,22 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """DPTabulate for the pt_expt backend. -Subclasses the pt backend's DPTabulate, overriding _get_descrpt_type() to -detect descriptor types via serialized data rather than isinstance checks -against pt-specific classes. +Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate`` +and overrides ``_convert_numpy_to_tensor`` for torch tensor conversion +and ``_get_descrpt_type`` for serialization-based type detection. +No dependency on the pt backend. """ from typing import ( Any, ) -from deepmd.pt.utils.tabulate import DPTabulate as DPTabulatePT -from deepmd.pt.utils.utils import ( - ActivationFn, -) +from deepmd.utils.tabulate_math import DPTabulate as DPTabulateBase -class DPTabulate(DPTabulatePT): +class DPTabulate(DPTabulateBase): """Tabulation helper for pt_expt descriptors. The descriptor passed to this class must serialize to a dict with @@ -34,8 +32,8 @@ class DPTabulate(DPTabulatePT): Whether to use one-side type embedding. exclude_types Excluded type pairs. - activation_fn - The activation function used in the embedding net. + activation_fn_name + Name of the activation function (e.g. "tanh", "gelu"). """ def __init__( @@ -44,23 +42,20 @@ def __init__( neuron: list[int], type_one_side: bool = False, exclude_types: list[list[int]] = [], - activation_fn: ActivationFn = ActivationFn("tanh"), + activation_fn_name: str = "tanh", ) -> None: - # DPTabulatePT.__init__ works here because: - # 1. _get_descrpt_type is overridden to use serialized data (not isinstance) - # 2. The isinstance(descrpt, DescrptDPA2) check in parent just returns False - # for pt_expt descriptors — callers pass the repinit block directly. - super().__init__(descrpt, neuron, type_one_side, exclude_types, activation_fn) + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + activation_fn_name=activation_fn_name, + ) def _get_descrpt_type(self) -> str: - """Determine descriptor type from serialized data. - - Instead of isinstance checks against pt classes, use the "type" key - from the serialized descriptor dict. - """ + """Determine descriptor type from serialized data.""" data = self.descrpt.serialize() type_str = data.get("type", "") - type_map = { "se_e2_a": "A", "se_r": "R", @@ -69,8 +64,18 @@ def _get_descrpt_type(self) -> str: "dpa1": "Atten", "se_atten_v2": "Atten", } - descrpt_type = type_map.get(type_str) if descrpt_type is None: raise RuntimeError(f"Unsupported descriptor type: {type_str}") return descrpt_type + + def _convert_numpy_to_tensor(self) -> None: + """Convert self.data from np.ndarray to torch.Tensor.""" + import torch + + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + for ii in self.data: + self.data[ii] = torch.tensor(self.data[ii], device=DEVICE) # pylint: disable=no-explicit-dtype From 56bdcf77a96e9b6bccd7f321f807cb397b31f920 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 15:40:16 +0800 Subject: [PATCH 04/10] refactor: remove is_pt flag from BaseTabulate Normalize lower/upper to per-type scalars after _get_env_mat_range so build() no longer needs is_pt branching. Move _convert_numpy_float_to_int into subclass _convert_numpy_to_tensor. Remove is_pt parameter from BaseTabulate and all callers. --- deepmd/pt/utils/tabulate.py | 1 + deepmd/pt_expt/utils/tabulate.py | 1 + deepmd/tf/utils/tabulate.py | 1 - deepmd/utils/tabulate.py | 36 +++++++++++--------------------- deepmd/utils/tabulate_math.py | 1 - 5 files changed, 14 insertions(+), 26 deletions(-) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 10f3938380..796764b1d1 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -81,5 +81,6 @@ def _get_descrpt_type(self) -> str: def _convert_numpy_to_tensor(self) -> None: """Convert self.data from np.ndarray to torch.Tensor.""" + self._convert_numpy_float_to_int() for ii in self.data: self.data[ii] = torch.tensor(self.data[ii], device=env.DEVICE) # pylint: disable=no-explicit-dtype diff --git a/deepmd/pt_expt/utils/tabulate.py b/deepmd/pt_expt/utils/tabulate.py index 37ea0b3d26..49a3f4c222 100644 --- a/deepmd/pt_expt/utils/tabulate.py +++ b/deepmd/pt_expt/utils/tabulate.py @@ -77,5 +77,6 @@ def _convert_numpy_to_tensor(self) -> None: DEVICE, ) + self._convert_numpy_float_to_int() for ii in self.data: self.data[ii] = torch.tensor(self.data[ii], device=DEVICE) # pylint: disable=no-explicit-dtype diff --git a/deepmd/tf/utils/tabulate.py b/deepmd/tf/utils/tabulate.py index 98c023299f..614d18d9d8 100644 --- a/deepmd/tf/utils/tabulate.py +++ b/deepmd/tf/utils/tabulate.py @@ -78,7 +78,6 @@ def __init__( neuron, type_one_side, exclude_types, - False, ) self.descrpt_type = self._get_descrpt_type() diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index fb40f798e2..92497a52f7 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -28,7 +28,6 @@ def __init__( neuron: list[int], type_one_side: bool, exclude_types: list[list[int]], - is_pt: bool, ) -> None: """Constructor.""" super().__init__() @@ -38,7 +37,6 @@ def __init__( self.neuron = neuron self.type_one_side = type_one_side self.exclude_types = exclude_types - self.is_pt = is_pt """Need to be initialized in the subclass.""" self.descrpt_type = "Base" @@ -91,6 +89,11 @@ def build( """ # tabulate range [lower, upper] with stride0 'stride0' lower, upper = self._get_env_mat_range(min_nbor_dist) + # Normalize to per-type scalars: PT serialized data produces + # multi-dimensional arrays (ntypes, nnei) while TF produces 1D. + if lower.ndim > 1: + lower = np.min(lower, axis=tuple(range(1, lower.ndim))) + upper = np.max(upper, axis=tuple(range(1, upper.ndim))) if self.descrpt_type in ("Atten", "AEbdV2"): uu = np.max(upper) ll = np.min(lower) @@ -127,12 +130,8 @@ def build( net = ( "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) ) - if self.is_pt: - uu = np.max(upper[ielement]) - ll = np.min(lower[ielement]) - else: - uu = upper[ielement] - ll = lower[ielement] + uu = upper[ielement] + ll = lower[ielement] xx = np.arange(ll, uu, stride0, dtype=self.data_type) xx = np.append( xx, @@ -150,13 +149,8 @@ def build( elif self.descrpt_type == "T": xx_all = [] for ii in range(self.ntypes): - """Pt and tf is different here. Pt version is a two-dimensional array.""" - if self.is_pt: - uu = np.max(upper[ii]) - ll = np.min(lower[ii]) - else: - ll = lower[ii] - uu = upper[ii] + ll = lower[ii] + uu = upper[ii] xx = np.arange(extrapolate * ll, ll, stride1, dtype=self.data_type) xx = np.append(xx, np.arange(ll, uu, stride0, dtype=self.data_type)) xx = np.append( @@ -176,12 +170,8 @@ def build( ).astype(int) idx = 0 for ii in range(self.ntypes): - if self.is_pt: - uu = np.max(upper[ii]) - ll = np.min(lower[ii]) - else: - ll = lower[ii] - uu = upper[ii] + ll = lower[ii] + uu = upper[ii] for jj in range(ii, self.ntypes): net = "filter_" + str(ii) + "_net_" + str(jj) self._build_lower( @@ -193,7 +183,7 @@ def build( stride0, stride1, extrapolate, - nspline[ii][0] if self.is_pt else nspline[ii], + nspline[ii], ) idx += 1 elif self.descrpt_type == "T_TEBD": @@ -279,8 +269,6 @@ def build( raise RuntimeError("Unsupported descriptor") self._convert_numpy_to_tensor() - if self.is_pt: - self._convert_numpy_float_to_int() return self.lower, self.upper # generate_spline_table diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py index a279ee67e0..580e9b8a5f 100644 --- a/deepmd/utils/tabulate_math.py +++ b/deepmd/utils/tabulate_math.py @@ -256,7 +256,6 @@ def __init__( neuron, type_one_side, exclude_types, - True, # is_pt flag (for _build_lower numpy int conversion) ) self._activation_fn = get_activation_fn(activation_fn_name) activation_fn_name = activation_fn_name.lower() From cb4a12280e050d099af74113f14aaa5530f956e9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 18:24:21 +0800 Subject: [PATCH 05/10] fix: re-export derivative functions from pt/tabulate for test compatibility --- deepmd/pt/utils/tabulate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 796764b1d1..f9add53f59 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -20,6 +20,12 @@ ActivationFn, ) from deepmd.utils.tabulate_math import DPTabulate as DPTabulateBase +from deepmd.utils.tabulate_math import ( # noqa: F401 — re-export for test compatibility + unaggregated_dy2_dx, + unaggregated_dy2_dx_s, + unaggregated_dy_dx, + unaggregated_dy_dx_s, +) class DPTabulate(DPTabulateBase): From 1067ff068910a2a0a9293722e4cfabe57ade10d8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 18:34:38 +0800 Subject: [PATCH 06/10] fix(tests): update test_tabulate to use numpy arrays instead of torch tensors The derivative functions now return numpy arrays from tabulate_math. Update test inputs/outputs accordingly. Remove unused torch/env imports. --- source/tests/pt/test_tabulate.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index d6075c3a74..364de166fa 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -2,14 +2,10 @@ import unittest import numpy as np -import torch from deepmd.dpmodel.utils.network import ( get_activation_fn, ) -from deepmd.pt.utils import ( - env, -) from deepmd.pt.utils.tabulate import ( unaggregated_dy2_dx, unaggregated_dy2_dx_s, @@ -90,14 +86,14 @@ def _test_single_activation( ) dy_pt = unaggregated_dy_dx_s( - torch.from_numpy(y), + y, self.w, - torch.from_numpy(self.xbar), + self.xbar, functype, ) dy_tf_numpy = dy_tf.numpy() - dy_pt_numpy = dy_pt.detach().cpu().numpy() + dy_pt_numpy = np.asarray(dy_pt) np.testing.assert_almost_equal( dy_tf_numpy, @@ -116,15 +112,15 @@ def _test_single_activation( ) dy2_pt = unaggregated_dy2_dx_s( - torch.from_numpy(y), + y, dy_pt, self.w, - torch.from_numpy(self.xbar), + self.xbar, functype, ) dy2_tf_numpy = dy2_tf.numpy() - dy2_pt_numpy = dy2_pt.detach().cpu().numpy() + dy2_pt_numpy = np.asarray(dy2_pt) np.testing.assert_almost_equal( dy2_tf_numpy, @@ -143,15 +139,15 @@ def _test_single_activation( ) dz_pt = unaggregated_dy_dx( - torch.from_numpy(y).to(env.DEVICE), + y, self.w, dy_pt, - torch.from_numpy(self.xbar).to(env.DEVICE), + self.xbar, functype, ) dz_tf_numpy = dz_tf.numpy() - dz_pt_numpy = dz_pt.detach().cpu().numpy() + dz_pt_numpy = np.asarray(dz_pt) np.testing.assert_almost_equal( dz_tf_numpy, @@ -171,16 +167,16 @@ def _test_single_activation( ) dy2_pt = unaggregated_dy2_dx( - torch.from_numpy(y).to(env.DEVICE), + y, self.w, dy_pt, dy2_pt, - torch.from_numpy(self.xbar).to(env.DEVICE), + self.xbar, functype, ) dy2_tf_numpy = dy2_tf.numpy() - dy2_pt_numpy = dy2_pt.detach().cpu().numpy() + dy2_pt_numpy = np.asarray(dy2_pt) np.testing.assert_almost_equal( dy2_tf_numpy, From 28cd4a2078c766e284df113bac77324a65287a67 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 21:57:38 +0800 Subject: [PATCH 07/10] fix: address reviewer comments on tabulate_math - Fix UnboundLocalError: use yy instead of zz for single-layer nets - Add linear/none (functype=0) to ACTIVATION_TO_FUNCTYPE and grad/grad_grad (identity: f'=1, f''=0) - All 8 activation derivatives verified against numerical differentiation --- deepmd/utils/tabulate_math.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py index 580e9b8a5f..a1d1393850 100644 --- a/deepmd/utils/tabulate_math.py +++ b/deepmd/utils/tabulate_math.py @@ -39,6 +39,8 @@ "softplus": 5, "sigmoid": 6, "silu": 7, + "linear": 0, + "none": 0, } @@ -47,7 +49,9 @@ def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: """First derivative of the activation function.""" - if functype == 1: + if functype == 0: + return np.ones_like(xbar) + elif functype == 1: return 1 - y * y elif functype == 2: var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) @@ -75,7 +79,9 @@ def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: """Second derivative of the activation function.""" - if functype == 1: + if functype == 0: + return np.zeros_like(xbar) + elif functype == 1: return -2 * y * (1 - y * y) elif functype == 2: var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) @@ -435,7 +441,7 @@ def _make_data(self, xx: np.ndarray, idx: int) -> Any: dy = dz yy = zz - vv = zz.astype(self.data_type) + vv = yy.astype(self.data_type) dd = dy.astype(self.data_type) d2 = dy2.astype(self.data_type) return vv, dd, d2 From 9b3661a3300d172e2abc3c71b40c9dee0a213955 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 25 Mar 2026 23:16:01 +0800 Subject: [PATCH 08/10] test: add functype=0 (linear) test for tabulate derivatives --- source/tests/pt/test_tabulate.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index 364de166fa..70e41d0e13 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -27,6 +27,12 @@ 7: "silu", } +# functype=0 (linear/none) is not supported by TF custom ops, +# so we test it separately against numerical derivatives. +ACTIVATION_NAMES_NUMPY_ONLY = { + 0: "linear", +} + def get_activation_function(functype: int): """Get activation function corresponding to functype.""" @@ -185,6 +191,37 @@ def _test_single_activation( err_msg=f"unaggregated_dy2_dx failed for {activation_name}", ) + def test_linear_activation(self) -> None: + """Test functype=0 (linear/none) against numerical derivatives. + + TF custom ops don't support functype=0, so we validate against + finite-difference derivatives instead. + """ + from deepmd.utils.tabulate_math import ( + grad, + grad_grad, + ) + + fn = get_activation_fn("linear") + y = fn(self.xbar) + h = 1e-7 + + # grad: f'(x) = 1 for identity + dy_ana = grad(self.xbar, y, 0) + np.testing.assert_allclose(dy_ana, np.ones_like(self.xbar), atol=1e-12) + + # grad_grad: f''(x) = 0 for identity + dy2_ana = grad_grad(self.xbar, y, 0) + np.testing.assert_allclose(dy2_ana, np.zeros_like(self.xbar), atol=1e-12) + + # Also verify unaggregated functions work with functype=0 + dy = unaggregated_dy_dx_s(y, self.w, self.xbar, 0) + self.assertEqual(dy.shape, (4, 4)) + + dy2 = unaggregated_dy2_dx_s(y, dy, self.w, self.xbar, 0) + # Second derivative of identity is zero everywhere + np.testing.assert_allclose(dy2, np.zeros_like(dy2), atol=1e-12) + if __name__ == "__main__": unittest.main() From 3686686c915ee26ba92352c141aeee036a705280 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 26 Mar 2026 00:19:10 +0800 Subject: [PATCH 09/10] fix(tabulate): address verified reviewer feedback --- deepmd/pt/utils/tabulate.py | 14 ++++++++++---- deepmd/pt_expt/utils/tabulate.py | 3 ++- deepmd/utils/tabulate_math.py | 5 +++-- source/tests/pt/test_tabulate.py | 13 +++---------- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index f9add53f59..930292db58 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -46,7 +46,7 @@ class DPTabulate(DPTabulateBase): exclude_types Excluded type pairs. activation_fn - The activation function (PT ActivationFn module). + The activation function name or PT ``ActivationFn`` module. """ def __init__( @@ -54,15 +54,21 @@ def __init__( descrpt: Any, neuron: list[int], type_one_side: bool = False, - exclude_types: list[list[int]] = [], - activation_fn: ActivationFn = ActivationFn("tanh"), + exclude_types: list[list[int]] | None = None, + activation_fn: str | ActivationFn = "tanh", ) -> None: + exclude_types = [] if exclude_types is None else exclude_types + activation_fn_name = ( + activation_fn.activation + if isinstance(activation_fn, ActivationFn) + else str(activation_fn) + ) super().__init__( descrpt, neuron, type_one_side, exclude_types, - activation_fn_name=activation_fn.activation, + activation_fn_name=activation_fn_name, ) def _get_descrpt_type(self) -> str: diff --git a/deepmd/pt_expt/utils/tabulate.py b/deepmd/pt_expt/utils/tabulate.py index 49a3f4c222..4cda06d9b9 100644 --- a/deepmd/pt_expt/utils/tabulate.py +++ b/deepmd/pt_expt/utils/tabulate.py @@ -41,9 +41,10 @@ def __init__( descrpt: Any, neuron: list[int], type_one_side: bool = False, - exclude_types: list[list[int]] = [], + exclude_types: list[list[int]] | None = None, activation_fn_name: str = "tanh", ) -> None: + exclude_types = [] if exclude_types is None else exclude_types super().__init__( descrpt, neuron, diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py index a1d1393850..555c9a244b 100644 --- a/deepmd/utils/tabulate_math.py +++ b/deepmd/utils/tabulate_math.py @@ -103,7 +103,7 @@ def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: d_sig = sig * (1 - sig) return 2 * d_sig + xbar * d_sig * (1 - 2 * sig) else: - return -np.ones_like(xbar) + raise ValueError(f"Unsupported function type: {functype}") # ---- Chain-rule derivative propagation (numpy) ---- @@ -254,9 +254,10 @@ def __init__( descrpt: Any, neuron: list[int], type_one_side: bool = False, - exclude_types: list[list[int]] = [], + exclude_types: list[list[int]] | None = None, activation_fn_name: str = "tanh", ) -> None: + exclude_types = [] if exclude_types is None else exclude_types super().__init__( descrpt, neuron, diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index 70e41d0e13..185d8cd04c 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -27,12 +27,6 @@ 7: "silu", } -# functype=0 (linear/none) is not supported by TF custom ops, -# so we test it separately against numerical derivatives. -ACTIVATION_NAMES_NUMPY_ONLY = { - 0: "linear", -} - def get_activation_function(functype: int): """Get activation function corresponding to functype.""" @@ -192,10 +186,10 @@ def _test_single_activation( ) def test_linear_activation(self) -> None: - """Test functype=0 (linear/none) against numerical derivatives. + """Test functype=0 (linear/none) with direct numpy expectations. - TF custom ops don't support functype=0, so we validate against - finite-difference derivatives instead. + TF custom ops don't support functype=0, so we validate the numpy + derivative helpers and unaggregated tabulate ops directly. """ from deepmd.utils.tabulate_math import ( grad, @@ -204,7 +198,6 @@ def test_linear_activation(self) -> None: fn = get_activation_fn("linear") y = fn(self.xbar) - h = 1e-7 # grad: f'(x) = 1 for identity dy_ana = grad(self.xbar, y, 0) From 35f1777605eca75b1f85f0f7792cf13c8e2118d8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 26 Mar 2026 00:37:19 +0800 Subject: [PATCH 10/10] fix(tabulate): harden shared softplus tabulation numerics --- deepmd/dpmodel/utils/network.py | 16 +++++++++--- deepmd/utils/tabulate_math.py | 17 +++++++++--- source/tests/pt/test_tabulate.py | 45 ++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index da6a305b9b..ff46c5787d 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -44,6 +44,18 @@ def sigmoid_t(x): # noqa: ANN001, ANN201 return xp_sigmoid(x) +def softplus_t(x): # noqa: ANN001, ANN201 + """Numerically stable softplus.""" + xp = array_api_compat.array_namespace(x) + positive = x > 0 + exp_neg_abs = xp.exp(xp.where(positive, -x, x)) + return xp.where( + positive, + x + xp.log(1 + exp_neg_abs), + xp.log(1 + exp_neg_abs), + ) + + class Identity(NativeOP): def __init__(self) -> None: super().__init__() @@ -330,9 +342,7 @@ def fn(x): # noqa: ANN001, ANN202 elif activation_function == "softplus": def fn(x): # noqa: ANN001, ANN202 - xp = array_api_compat.array_namespace(x) - # generated by GitHub Copilot - return xp.log(1 + xp.exp(x)) + return softplus_t(x) return fn elif activation_function == "sigmoid": diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py index 555c9a244b..67b56b311e 100644 --- a/deepmd/utils/tabulate_math.py +++ b/deepmd/utils/tabulate_math.py @@ -47,6 +47,17 @@ # ---- Activation derivatives (numpy) ---- +def _stable_sigmoid(xbar: np.ndarray) -> np.ndarray: + """Compute sigmoid without overflow for large-magnitude inputs.""" + positive = xbar >= 0 + exp_neg_abs = np.exp(np.where(positive, -xbar, xbar)) + return np.where( + positive, + 1.0 / (1.0 + exp_neg_abs), + exp_neg_abs / (1.0 + exp_neg_abs), + ) + + def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: """First derivative of the activation function.""" if functype == 0: @@ -67,7 +78,7 @@ def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: (xbar > 0) & (xbar < 6), np.ones_like(xbar), np.zeros_like(xbar) ) elif functype == 5: - return 1.0 - 1.0 / (1.0 + np.exp(xbar)) + return _stable_sigmoid(xbar) elif functype == 6: return y * (1 - y) elif functype == 7: @@ -94,8 +105,8 @@ def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: elif functype in [3, 4]: return np.zeros_like(xbar) elif functype == 5: - exp_xbar = np.exp(xbar) - return exp_xbar / ((1 + exp_xbar) * (1 + exp_xbar)) + sig = _stable_sigmoid(xbar) + return sig * (1 - sig) elif functype == 6: return y * (1 - y) * (1 - 2 * y) elif functype == 7: diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index 185d8cd04c..1608d363a9 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -215,6 +215,51 @@ def test_linear_activation(self) -> None: # Second derivative of identity is zero everywhere np.testing.assert_allclose(dy2, np.zeros_like(dy2), atol=1e-12) + def test_softplus_activation_is_numerically_stable(self) -> None: + """Test softplus tabulation helpers on large extrapolated inputs.""" + from deepmd.utils.tabulate_math import ( + grad, + grad_grad, + ) + + xbar = np.array([[100.0, 500.0, 1000.0]], dtype=np.float64) + + with np.errstate(over="raise", invalid="raise"): + y = get_activation_fn("softplus")(xbar) + dy = grad(xbar, y, 5) + dy2 = grad_grad(xbar, y, 5) + + np.testing.assert_allclose(y, xbar, atol=1e-12) + np.testing.assert_allclose(dy, np.ones_like(xbar), atol=1e-12) + np.testing.assert_allclose(dy2, np.zeros_like(xbar), atol=1e-12) + + def test_softplus_derivatives_match_finite_differences(self) -> None: + """Test softplus derivatives against finite differences on both branches.""" + from deepmd.utils.tabulate_math import ( + grad, + grad_grad, + ) + + fn = get_activation_fn("softplus") + xbar = np.array([[-5.0, -0.5, 0.0, 0.5, 5.0]], dtype=np.float64) + y = fn(xbar) + + dy = grad(xbar, y, 5) + dy2 = grad_grad(xbar, y, 5) + + h_grad = 3e-5 + y_plus = fn(xbar + h_grad) + y_minus = fn(xbar - h_grad) + dy_fd = (y_plus - y_minus) / (2 * h_grad) + + h_grad2 = 3e-4 + y_plus = fn(xbar + h_grad2) + y_minus = fn(xbar - h_grad2) + dy2_fd = (y_plus - 2 * y + y_minus) / (h_grad2**2) + + np.testing.assert_allclose(dy, dy_fd, rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(dy2, dy2_fd, rtol=1e-6, atol=1e-8) + if __name__ == "__main__": unittest.main()