feat(pd): add add_chg_spin_ebd parameter to DescrptDPA3#5333
Conversation
Co-authored-by: HydrogenSulfate <23737287+HydrogenSulfate@users.noreply.github.com> Agent-Logs-Url: https://github.com/HydrogenSulfate/deepmd-kit/sessions/730a0b97-f969-4779-8394-1758329031b6
There was a problem hiding this comment.
Pull request overview
Adds an optional charge/spin embedding pathway to the Paddle DPA3 descriptor (matching the dpmodel/pt variants) and threads fparam through the PD atomic model + descriptor stack, with a consistency test update.
Changes:
- Add
add_chg_spin_ebdinitialization, (de)serialization, andfparamhandling toDescrptDPA3. - Plumb an optional
fparamargument through PD descriptors andDPAtomicModel.forward_atomic. - Extend the PD DPA3 consistency test to cover the new
add_chg_spin_ebdmode and passfparamto PD + dpmodel implementations.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pd/model/test_dpa3.py | Expands DPA3 consistency test matrix and supplies fparam when charge/spin embedding is enabled. |
| deepmd/pd/model/descriptor/dpa3.py | Implements add_chg_spin_ebd embeddings + mixing MLP, adds fparam to forward, and serializes the new parameters. |
| deepmd/pd/model/descriptor/dpa1.py | Adds optional fparam parameter to forward for signature alignment. |
| deepmd/pd/model/descriptor/dpa2.py | Adds optional fparam parameter to forward for signature alignment. |
| deepmd/pd/model/descriptor/se_a.py | Adds optional fparam parameter to forward for signature alignment. |
| deepmd/pd/model/descriptor/se_t_tebd.py | Adds optional fparam parameter to forward for signature alignment. |
| deepmd/pd/model/atomic_model/dp_atomic_model.py | Detects descriptor support for charge/spin embedding and conditionally forwards fparam into the descriptor call. |
Comments suppressed due to low confidence (1)
deepmd/pd/model/descriptor/dpa3.py:566
- The updated
forward()signature addsfparamand now returns a tuple with optional elements, but the docstring doesn’t document the newfparamparameter or that some returned values may beNone. Updating the docstring here would prevent misuse and keep the API contract clear.
def forward(
self,
extended_coord: paddle.Tensor,
extended_atype: paddle.Tensor,
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
fparam: paddle.Tensor | None = None,
) -> tuple[
paddle.Tensor,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
paddle.Tensor | None,
]:
"""Compute the descriptor.
Parameters
----------
extended_coord
The extended coordinates of atoms. shape: nf x (nallx3)
extended_atype
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.
comm_dict
The data needed for communication for parallel inference.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| 200, | ||
| self.tebd_dim, | ||
| precision=precision, | ||
| seed=child_seed(seed, 3), | ||
| ) | ||
| # 100 is a conservative upper bound | ||
| self.spin_embedding = TypeEmbedNet( | ||
| 100, |
There was a problem hiding this comment.
The charge/spin embedding table sizes look off-by-one given TypeEmbedNet uses a dedicated last index for padding. With TypeEmbedNet(200, ...) and charge = fparam[:,0].int64 + 100, charge==100 maps to index 200 (padding), not a learnable embedding, contradicting the “-100 ~ 100” comment. Consider either (a) resizing to cover the intended inclusive bounds (e.g., 201/101) or (b) tightening the documented/validated allowed ranges and raising a clear error when out of range.
| 200, | |
| self.tebd_dim, | |
| precision=precision, | |
| seed=child_seed(seed, 3), | |
| ) | |
| # 100 is a conservative upper bound | |
| self.spin_embedding = TypeEmbedNet( | |
| 100, | |
| 201, | |
| self.tebd_dim, | |
| precision=precision, | |
| seed=child_seed(seed, 3), | |
| ) | |
| # 100 is a conservative upper bound | |
| self.spin_embedding = TypeEmbedNet( | |
| 101, |
| if self.add_chg_spin_ebd: | ||
| assert fparam is not None | ||
| assert self.chg_embedding is not None | ||
| assert self.spin_embedding is not None | ||
| charge = fparam[:, 0].to(dtype=paddle.int64) + 100 | ||
| spin = fparam[:, 1].to(dtype=paddle.int64) | ||
| chg_ebd = self.chg_embedding(charge) | ||
| spin_ebd = self.spin_embedding(spin) | ||
| sys_cs_embd = self.act( | ||
| self.mix_cs_mlp(paddle.concat([chg_ebd, spin_ebd], axis=-1)) | ||
| ) |
There was a problem hiding this comment.
Using assert for required runtime inputs (fparam and embedding modules) is fragile because asserts can be disabled with Python optimizations and they don’t provide actionable error messages. Prefer explicit validation (e.g., raise ValueError with expected shape nf x 2 and valid ranges) and also validate mix_cs_mlp is present when add_chg_spin_ebd is enabled.
| chg_embedding = data.pop("chg_embedding", None) | ||
| spin_embedding = data.pop("spin_embedding", None) | ||
| mix_cs_mlp = data.pop("mix_cs_mlp", None) | ||
| data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) | ||
| obj = cls(**data) | ||
| obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( | ||
| type_embedding | ||
| ) | ||
|
|
||
| if obj.add_chg_spin_ebd and chg_embedding is not None: | ||
| obj.chg_embedding.embedding = TypeEmbedNetConsistent.deserialize( | ||
| chg_embedding | ||
| ) | ||
| obj.spin_embedding.embedding = TypeEmbedNetConsistent.deserialize( | ||
| spin_embedding | ||
| ) | ||
| obj.mix_cs_mlp = MLPLayer.deserialize(mix_cs_mlp) |
There was a problem hiding this comment.
deserialize() only guards on chg_embedding is not None before deserializing all three add_chg_spin components. If the serialized dict is partially missing spin_embedding or mix_cs_mlp, this will fail later or leave the object in an inconsistent state. Consider checking that all required keys are present when add_chg_spin_ebd is true, and raising a clear error if any are missing.
| mapping: paddle.Tensor | None = None, | ||
| comm_dict: list[paddle.Tensor] | None = None, | ||
| fparam: paddle.Tensor | None = None, | ||
| ) -> paddle.Tensor: |
There was a problem hiding this comment.
forward() is annotated as returning paddle.Tensor, but it actually returns a 5-tuple (descriptor, rot_mat, g2, h2, sw). Since this method signature was updated, please fix the return type annotation to match the actual return value (see se_atten.forward for an example).
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None, paddle.Tensor | None, paddle.Tensor]: |
| mapping: paddle.Tensor | None = None, | ||
| comm_dict: list[paddle.Tensor] | None = None, | ||
| fparam: paddle.Tensor | None = None, | ||
| ) -> paddle.Tensor: |
There was a problem hiding this comment.
forward() is annotated as returning paddle.Tensor, but it actually returns a 5-tuple (g1, None, None, None, sw). Since this method signature was updated, please fix the return type annotation to match the actual return value for better static checking and consistency with other descriptors.
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, None, None, None, paddle.Tensor]: |
| mapping: paddle.Tensor | None = None, | ||
| comm_dict: list[paddle.Tensor] | None = None, | ||
| fparam: paddle.Tensor | None = None, | ||
| ) -> paddle.Tensor: |
There was a problem hiding this comment.
forward() is annotated as returning paddle.Tensor, but it actually returns a 5-tuple. Since this method signature was updated, please fix the return type annotation to reflect the tuple return type.
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: |
| mapping: paddle.Tensor | None = None, | ||
| comm_dict: list[paddle.Tensor] | None = None, | ||
| fparam: paddle.Tensor | None = None, | ||
| ) -> paddle.Tensor: |
There was a problem hiding this comment.
forward() is annotated as returning paddle.Tensor, but it actually returns a 5-tuple (g1, rot_mat, g2, h2, sw). Since this method signature was updated, please fix the return type annotation to reflect the tuple return type.
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: |
📝 WalkthroughWalkthroughThis PR introduces charge and spin embedding support in the DPA3 descriptor. Key changes include adding an Changes
Sequence DiagramsequenceDiagram
actor Client
participant AtomicModel as DP Atomic Model
participant Descriptor as DPA3 Descriptor
participant ChgEmbed as Charge Embedding
participant SpinEmbed as Spin Embedding
participant MixMLP as Mix CS MLP
Client->>AtomicModel: forward(coord, atype, nlist, fparam)
Note over AtomicModel: add_chg_spin_ebd = True
AtomicModel->>Descriptor: forward(..., fparam=fparam)
Note over Descriptor: Compute node_ebd_ext
alt add_chg_spin_ebd enabled
Descriptor->>ChgEmbed: forward(fparam[:, 0] + 100)
ChgEmbed-->>Descriptor: charge_embedding
Descriptor->>SpinEmbed: forward(fparam[:, 1])
SpinEmbed-->>Descriptor: spin_embedding
Descriptor->>MixMLP: forward([charge_emb, spin_emb])
MixMLP-->>Descriptor: mixed_embedding
Note over Descriptor: Add mixed_embedding to node_ebd_ext
end
Descriptor-->>AtomicModel: (node_ebd, rot_mat, edge_ebd, h2, sw)
AtomicModel-->>Client: model outputs
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pd/model/descriptor/dpa3.py (1)
179-228:⚠️ Potential issue | 🟠 MajorShare the new charge/spin sublayers in
share_params().When
add_chg_spin_ebdis enabled, this descriptor gains three extra trainable sublayers, butshare_params()still aliases onlytype_embeddingandrepflows.shared_level == 0therefore no longer shares the full descriptor, so multitask runs can diverge on the new path.Possible follow-up in
share_params()if self.add_chg_spin_ebd != base_class.add_chg_spin_ebd: raise ValueError( "Descriptors with different add_chg_spin_ebd settings cannot share params." ) if shared_level == 0 and self.add_chg_spin_ebd: self._sub_layers["chg_embedding"] = base_class._sub_layers["chg_embedding"] self._sub_layers["spin_embedding"] = base_class._sub_layers["spin_embedding"] self._sub_layers["mix_cs_mlp"] = base_class._sub_layers["mix_cs_mlp"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pd/model/descriptor/dpa3.py` around lines 179 - 228, share_params() currently only aliases type_embedding and repflows, but when add_chg_spin_ebd is true the descriptor creates three extra trainable sublayers (chg_embedding, spin_embedding, mix_cs_mlp) that must also be shared to avoid divergence; update share_params() to check that base_class.add_chg_spin_ebd matches self.add_chg_spin_ebd and when shared_level == 0 and add_chg_spin_ebd is true assign self._sub_layers["chg_embedding"] = base_class._sub_layers["chg_embedding"], self._sub_layers["spin_embedding"] = base_class._sub_layers["spin_embedding"], and self._sub_layers["mix_cs_mlp"] = base_class._sub_layers["mix_cs_mlp"] (and raise a ValueError if the add_chg_spin_ebd flags differ).
🧹 Nitpick comments (1)
source/tests/pd/model/test_dpa3.py (1)
123-153: Add one model-level coverage case for the newfparamplumbing.These assertions call
DescrptDPA3/DPDescrptDPA3directly, so the newDPAtomicModel.forward_atomic()branch that conditionally forwardsfparamis still untested. A regression in the production call path would leave this test green.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pd/model/test_dpa3.py` around lines 123 - 153, The test exercises DescrptDPA3.deserialize/DPDescrptDPA3.call directly so the DPAtomicModel.forward_atomic branch that conditionally forwards fparam is not covered; add one model-level test that invokes the production call path (the model entrypoint that uses DPAtomicModel.forward_atomic) with fparam present and with fparam absent to hit both branches, verifying outputs match the direct DescrptDPA3/DPDescrptDPA3 results; reference DescrptDPA3, DPDescrptDPA3, DPAtomicModel.forward_atomic and the fparam/fparam_np variables to locate where to wire the call in the existing test harness.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pd/model/descriptor/dpa3.py`:
- Around line 467-474: The serialized payload now includes the new constructor
kwarg add_chg_spin_ebd and three extra blobs (chg_embedding, spin_embedding,
mix_cs_mlp) but the descriptor still advertises `@version`: 2; update the schema
version in the class's serialize output to a new integer (e.g., 3) so
DescrptDPA3.deserialize() can reject older v2 checkpoints; locate the serialize
method that builds the dict with "add_chg_spin_ebd", "type_map", and
"type_embedding" (and the conditional chg/spin/mix_cs_mlp entries) and increment
the version field, and ensure DescrptDPA3.deserialize() handles or rejects the
new version accordingly.
---
Outside diff comments:
In `@deepmd/pd/model/descriptor/dpa3.py`:
- Around line 179-228: share_params() currently only aliases type_embedding and
repflows, but when add_chg_spin_ebd is true the descriptor creates three extra
trainable sublayers (chg_embedding, spin_embedding, mix_cs_mlp) that must also
be shared to avoid divergence; update share_params() to check that
base_class.add_chg_spin_ebd matches self.add_chg_spin_ebd and when shared_level
== 0 and add_chg_spin_ebd is true assign self._sub_layers["chg_embedding"] =
base_class._sub_layers["chg_embedding"], self._sub_layers["spin_embedding"] =
base_class._sub_layers["spin_embedding"], and self._sub_layers["mix_cs_mlp"] =
base_class._sub_layers["mix_cs_mlp"] (and raise a ValueError if the
add_chg_spin_ebd flags differ).
---
Nitpick comments:
In `@source/tests/pd/model/test_dpa3.py`:
- Around line 123-153: The test exercises
DescrptDPA3.deserialize/DPDescrptDPA3.call directly so the
DPAtomicModel.forward_atomic branch that conditionally forwards fparam is not
covered; add one model-level test that invokes the production call path (the
model entrypoint that uses DPAtomicModel.forward_atomic) with fparam present and
with fparam absent to hit both branches, verifying outputs match the direct
DescrptDPA3/DPDescrptDPA3 results; reference DescrptDPA3, DPDescrptDPA3,
DPAtomicModel.forward_atomic and the fparam/fparam_np variables to locate where
to wire the call in the existing test harness.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a2cd380d-d050-49ff-b382-221d67b60fcd
📒 Files selected for processing (7)
deepmd/pd/model/atomic_model/dp_atomic_model.pydeepmd/pd/model/descriptor/dpa1.pydeepmd/pd/model/descriptor/dpa2.pydeepmd/pd/model/descriptor/dpa3.pydeepmd/pd/model/descriptor/se_a.pydeepmd/pd/model/descriptor/se_t_tebd.pysource/tests/pd/model/test_dpa3.py
| "add_chg_spin_ebd": self.add_chg_spin_ebd, | ||
| "type_map": self.type_map, | ||
| "type_embedding": self.type_embedding.embedding.serialize(), | ||
| } | ||
| if self.add_chg_spin_ebd: | ||
| data["chg_embedding"] = self.chg_embedding.embedding.serialize() | ||
| data["spin_embedding"] = self.spin_embedding.embedding.serialize() | ||
| data["mix_cs_mlp"] = self.mix_cs_mlp.serialize() |
There was a problem hiding this comment.
Bump the descriptor schema version for the new serialized fields.
This payload now includes a new constructor kwarg plus three extra parameter blobs, but it still advertises @version: 2. Older DescrptDPA3.deserialize() implementations that accept v2 won't reject these checkpoints early; they'll fail later on unexpected fields instead.
Suggested version bump
- "@version": 2,
+ "@version": 3,- check_version_compatibility(version, 2, 1)
+ check_version_compatibility(version, 3, 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pd/model/descriptor/dpa3.py` around lines 467 - 474, The serialized
payload now includes the new constructor kwarg add_chg_spin_ebd and three extra
blobs (chg_embedding, spin_embedding, mix_cs_mlp) but the descriptor still
advertises `@version`: 2; update the schema version in the class's serialize
output to a new integer (e.g., 3) so DescrptDPA3.deserialize() can reject older
v2 checkpoints; locate the serialize method that builds the dict with
"add_chg_spin_ebd", "type_map", and "type_embedding" (and the conditional
chg/spin/mix_cs_mlp entries) and increment the version field, and ensure
DescrptDPA3.deserialize() handles or rejects the new version accordingly.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5333 +/- ##
==========================================
+ Coverage 82.30% 82.40% +0.10%
==========================================
Files 775 783 +8
Lines 77628 79062 +1434
Branches 3675 3675
==========================================
+ Hits 63888 65153 +1265
- Misses 12568 12736 +168
- Partials 1172 1173 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Agent-Logs-Url: https://github.com/HydrogenSulfate/deepmd-kit/sessions/730a0b97-f969-4779-8394-1758329031b6
Summary by CodeRabbit
New Features
Tests