Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,180 @@ def _decay_value(self, step: int | Array) -> Array:
return step_lr


@BaseLR.register("wsd")
class LearningRateWSD(BaseLR):
r"""
Warmup-stable-decay learning rate schedule with configurable decay rules.

The schedule uses the shared warmup implementation from :class:`BaseLR`,
then keeps the learning rate at ``start_lr`` during the stable phase, and
finally applies one of the supported decay rules.

Let :math:`\tau \in [0, 1]` denote the normalized progress within the
decay phase.

**Inverse-linear mode (``decay_type="inverse_linear"``):**

.. math::

lr(t) = \frac{1}{
\tau / lr_{\text{stop}} + (1 - \tau) / lr_0
}

**Cosine mode (``decay_type="cosine"``):**

.. math::

lr(t) = lr_{\text{stop}} +
\frac{lr_0 - lr_{\text{stop}}}{2}
\left(1 + \cos(\pi \tau)\right)

**Linear mode (``decay_type="linear"``):**

.. math::

lr(t) = lr_0 + \left(lr_{\text{stop}} - lr_0\right)\tau
"""

def __init__(
self,
start_lr: float,
num_steps: int,
stop_lr: float | None = None,
stop_lr_ratio: float | None = None,
warmup_steps: int = 0,
warmup_ratio: float | None = None,
warmup_start_factor: float = 0.0,
decay_phase_ratio: float = 0.1,
decay_type: str = "inverse_linear",
**kwargs: Any,
) -> None:
"""
Construct a warmup-stable-decay learning rate schedule.

Parameters
----------
start_lr : float
The learning rate at the start of the stable phase.
num_steps : int
The total training steps (including warmup).
stop_lr : float, optional
The final learning rate at the end of training.
Mutually exclusive with stop_lr_ratio.
stop_lr_ratio : float, optional
The ratio of stop_lr to start_lr.
Mutually exclusive with stop_lr.
warmup_steps : int, optional
The number of warmup steps.
Mutually exclusive with warmup_ratio. Default is 0.
warmup_ratio : float, optional
The ratio of warmup steps to total training steps.
Mutually exclusive with warmup_steps.
warmup_start_factor : float, optional
The factor of start_lr for the initial warmup learning rate.
Default is 0.0.
decay_phase_ratio : float, optional
The ratio of the decay phase to total training steps.
Default is 0.1.
decay_type : str, optional
The decay rule used in the decay phase.
Supported values are ``inverse_linear``, ``cosine`` and ``linear``.
Default is ``inverse_linear``.

Raises
------
ValueError
If the learning rates are non-positive.
If decay_phase_ratio is not in (0, 1].
If decay_type is invalid.
If the derived decay phase is empty or exceeds post-warmup steps.
"""
super().__init__(
start_lr=start_lr,
stop_lr=stop_lr,
stop_lr_ratio=stop_lr_ratio,
num_steps=num_steps,
warmup_steps=warmup_steps,
warmup_ratio=warmup_ratio,
warmup_start_factor=warmup_start_factor,
**kwargs,
)

# === Validate WSD-specific invariants ===
if self._start_lr <= 0:
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
if self.stop_lr <= 0:
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:
raise ValueError(
f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1]."
)
if decay_type not in ("inverse_linear", "cosine", "linear"):
raise ValueError(
"decay_type must be one of "
f"{('inverse_linear', 'cosine', 'linear')}. "
f"Got decay_type={decay_type}."
)

# === Derive stable and decay phase lengths ===
self.decay_phase_ratio = decay_phase_ratio
self.decay_type = decay_type
# Clamp decay_phase_steps to valid range [1, decay_num_steps]
self.decay_phase_steps = max(
1, min(int(self.decay_phase_ratio * self.num_steps), self.decay_num_steps)
)
self.stable_steps = self.decay_num_steps - self.decay_phase_steps
Comment on lines +514 to +518
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle num_steps == 0 before forcing a one-step decay.

Lines 515-517 turn decay_num_steps == 0 into decay_phase_steps == 1, which makes stable_steps negative and causes Lines 565-566 to return stop_lr immediately for value(0). LearningRateExp and LearningRateCosine already special-case zero-decay runs, so WSD should do the same or reject num_steps == 0 explicitly.

💡 Possible fix
         # === Derive stable and decay phase lengths ===
         self.decay_phase_ratio = decay_phase_ratio
         self.decay_type = decay_type
+        if self.decay_num_steps == 0:
+            self.decay_phase_steps = 0
+            self.stable_steps = 0
+            return
         # Clamp decay_phase_steps to valid range [1, decay_num_steps]
         self.decay_phase_steps = max(
             1, min(int(self.decay_phase_ratio * self.num_steps), self.decay_num_steps)
         )
         self.stable_steps = self.decay_num_steps - self.decay_phase_steps
@@
         step_dtype = (
             step.dtype
             if xp.isdtype(step.dtype, "real floating")
             else get_xp_precision(xp, "global")
         )
+        if self.decay_num_steps == 0:
+            return xp.full_like(step, self._start_lr, dtype=step_dtype)

Also applies to: 565-566

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/utils/learning_rate.py` around lines 514 - 518, The WSD
scheduler currently forces decay_phase_steps to at least 1 even when
decay_num_steps == 0, producing negative stable_steps and making value(0) return
stop_lr; modify the LearningRateWSD initialization to special-case zero-decay
runs (or explicitly reject num_steps == 0): if self.decay_num_steps == 0 then
set self.decay_phase_steps = 0 and self.stable_steps = 0 (and ensure
value(index) treats all indices as pre-decay/start_lr), or raise a ValueError
when self.num_steps == 0; update references to decay_phase_steps,
decay_num_steps, stable_steps and the value(...) method so zero-decay behavior
matches LearningRateExp/LearningRateCosine.


def _decay_value(self, step: int | Array) -> Array:
"""
Get the warmup-stable-decay learning rate at the given step.

Parameters
----------
step : int or Array
The step index relative to the end of warmup.

Returns
-------
Array
The learning rate (absolute value).
"""
if not array_api_compat.is_array_api_obj(step):
step = np.asarray(step)
xp = array_api_compat.array_namespace(step)
step_dtype = (
step.dtype
if xp.isdtype(step.dtype, "real floating")
else get_xp_precision(xp, "global")
)

# === Step 1. Build typed scalar constants ===
typed_step = xp.astype(step, step_dtype)
zero = xp.asarray(0.0, dtype=step_dtype)
one = xp.asarray(1.0, dtype=step_dtype)
start_lr = xp.asarray(self._start_lr, dtype=step_dtype)
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)

# === Step 2. Keep a constant learning rate in the stable phase ===
decay_progress = (typed_step - stable_steps) / decay_phase_steps
tau = xp.clip(decay_progress, zero, one)

# === Step 3. Apply the selected interpolation in the decay phase ===
if self.decay_type == "inverse_linear":
decay_lr = one / (tau / stop_lr + (one - tau) / start_lr)
elif self.decay_type == "cosine":
decay_lr = stop_lr + (start_lr - stop_lr) * 0.5 * (
one + xp.cos(xp.asarray(xp.pi * tau, dtype=step_dtype))
)
else:
decay_lr = start_lr + (stop_lr - start_lr) * tau
step_lr = xp.where(step < self.stable_steps, start_lr, decay_lr)
step_lr = xp.where(step >= self.decay_num_steps, stop_lr, step_lr)
return step_lr


@BaseLR.register("cosine")
class LearningRateCosine(BaseLR):
r"""
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pd/utils/learning_rate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.utils.learning_rate import (
LearningRateExp,
LearningRateWSD,
)

__all__ = [
"LearningRateExp",
"LearningRateWSD",
]
2 changes: 2 additions & 0 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
BaseLR,
LearningRateCosine,
LearningRateExp,
LearningRateWSD,
)

__all__ = [
"BaseLR",
"LearningRateCosine",
"LearningRateExp",
"LearningRateWSD",
]
93 changes: 93 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,6 +2594,61 @@ def _check_decay_steps_args(data: dict[str, Any]) -> bool:
return True


def _check_wsd_args(data: dict[str, Any]) -> bool:
"""
Check WSD-specific learning rate arguments.

Parameters
----------
data : dict[str, Any]
The learning rate configuration dictionary.

Returns
-------
bool
True if validation passes.

Raises
------
ValueError
If the WSD-specific arguments are invalid.
"""
lr_type = data.get("type", "exp")
if lr_type != "wsd":
return True

start_lr = data.get("start_lr")
if start_lr is not None and start_lr <= 0:
raise ValueError(f"start_lr ({start_lr}) must be positive for WSD.")

stop_lr = data.get("stop_lr")
if stop_lr is not None and stop_lr <= 0:
raise ValueError(f"stop_lr ({stop_lr}) must be positive for WSD.")

stop_lr_ratio = data.get("stop_lr_ratio")
if stop_lr_ratio is not None and stop_lr_ratio <= 0:
raise ValueError(f"stop_lr_ratio ({stop_lr_ratio}) must be positive for WSD.")

decay_phase_ratio = data.get("decay_phase_ratio")
if decay_phase_ratio is not None and (
decay_phase_ratio <= 0 or decay_phase_ratio > 1
):
raise ValueError(f"decay_phase_ratio ({decay_phase_ratio}) must be in (0, 1].")

decay_type = data.get("decay_type")
if decay_type is not None and decay_type not in (
"inverse_linear",
"cosine",
"linear",
):
raise ValueError(
"decay_type must be one of "
f"{('inverse_linear', 'cosine', 'linear')}. "
f"Got decay_type={decay_type}."
)
return True


@lr_args_plugin.register("exp")
def learning_rate_exp() -> list[Argument]:
"""
Expand Down Expand Up @@ -2645,6 +2700,42 @@ def learning_rate_cosine() -> list[Argument]:
return []


@lr_args_plugin.register("wsd")
def learning_rate_wsd() -> list[Argument]:
"""
Defines a warmup-stable-decay learning rate schedule with configurable
decay rules.

The learning rate stays at `start_lr` during the stable phase and then
decays to `stop_lr` with the selected decay rule.
"""
doc_decay_phase_ratio = (
"The ratio of the decay phase to total training steps. "
"The remaining post-warmup steps are used as the stable phase. "
"Default is 0.1."
)
doc_decay_type = (
"The decay rule used in the decay phase. "
"Supported values are `inverse_linear` (default), `cosine`, and `linear`."
)
return [
Argument(
"decay_phase_ratio",
float,
optional=True,
default=0.1,
doc=doc_decay_phase_ratio,
),
Argument(
"decay_type",
str,
optional=True,
default="inverse_linear",
doc=doc_decay_type,
),
]


def learning_rate_variant_type_args() -> Variant:
doc_lr = "The type of the learning rate."

Expand Down Expand Up @@ -2694,6 +2785,8 @@ def _check_lr_args(data: dict[str, Any]) -> bool:
_check_warmup_args(data)
# Check decay_steps and decay_rate
_check_decay_steps_args(data)
# Check WSD-specific arguments
_check_wsd_args(data)
return True

# Common arguments for all learning rate types (outside Variant)
Expand Down
Loading
Loading