-
Notifications
You must be signed in to change notification settings - Fork 601
feat(pd): add add_chg_spin_ebd parameter to DescrptDPA3 #5333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -732,6 +732,7 @@ def forward( | |||||
| nlist: paddle.Tensor, | ||||||
| mapping: paddle.Tensor | None = None, | ||||||
| comm_dict: list[paddle.Tensor] | None = None, | ||||||
| fparam: paddle.Tensor | None = None, | ||||||
| ) -> paddle.Tensor: | ||||||
|
||||||
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |||||||||||||||||||||||||||||||||
| UpdateSel, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| from deepmd.pd.utils.utils import ( | ||||||||||||||||||||||||||||||||||
| ActivationFn, | ||||||||||||||||||||||||||||||||||
| to_numpy_array, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| from deepmd.utils.data_system import ( | ||||||||||||||||||||||||||||||||||
|
|
@@ -120,6 +121,7 @@ def __init__( | |||||||||||||||||||||||||||||||||
| use_tebd_bias: bool = False, | ||||||||||||||||||||||||||||||||||
| use_loc_mapping: bool = True, | ||||||||||||||||||||||||||||||||||
| type_map: list[str] | None = None, | ||||||||||||||||||||||||||||||||||
| add_chg_spin_ebd: bool = False, | ||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -174,6 +176,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: | |||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.use_econf_tebd = use_econf_tebd | ||||||||||||||||||||||||||||||||||
| self.add_chg_spin_ebd = add_chg_spin_ebd | ||||||||||||||||||||||||||||||||||
| self.use_loc_mapping = use_loc_mapping | ||||||||||||||||||||||||||||||||||
| self.use_tebd_bias = use_tebd_bias | ||||||||||||||||||||||||||||||||||
| self.type_map = type_map | ||||||||||||||||||||||||||||||||||
|
|
@@ -196,6 +199,34 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: | |||||||||||||||||||||||||||||||||
| self.concat_output_tebd = concat_output_tebd | ||||||||||||||||||||||||||||||||||
| self.precision = precision | ||||||||||||||||||||||||||||||||||
| self.prec = PRECISION_DICT[self.precision] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if self.add_chg_spin_ebd: | ||||||||||||||||||||||||||||||||||
| self.act = ActivationFn(activation_function) | ||||||||||||||||||||||||||||||||||
| # -100 ~ 100 is a conservative bound | ||||||||||||||||||||||||||||||||||
| self.chg_embedding = TypeEmbedNet( | ||||||||||||||||||||||||||||||||||
| 200, | ||||||||||||||||||||||||||||||||||
| self.tebd_dim, | ||||||||||||||||||||||||||||||||||
| precision=precision, | ||||||||||||||||||||||||||||||||||
| seed=child_seed(seed, 3), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| # 100 is a conservative upper bound | ||||||||||||||||||||||||||||||||||
| self.spin_embedding = TypeEmbedNet( | ||||||||||||||||||||||||||||||||||
| 100, | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+207
to
+214
|
||||||||||||||||||||||||||||||||||
| 200, | |
| self.tebd_dim, | |
| precision=precision, | |
| seed=child_seed(seed, 3), | |
| ) | |
| # 100 is a conservative upper bound | |
| self.spin_embedding = TypeEmbedNet( | |
| 100, | |
| 201, | |
| self.tebd_dim, | |
| precision=precision, | |
| seed=child_seed(seed, 3), | |
| ) | |
| # 100 is a conservative upper bound | |
| self.spin_embedding = TypeEmbedNet( | |
| 101, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump the descriptor schema version for the new serialized fields.
This payload now includes a new constructor kwarg plus three extra parameter blobs, but it still advertises @version: 2. Older DescrptDPA3.deserialize() implementations that accept v2 won't reject these checkpoints early; they'll fail later on unexpected fields instead.
Suggested version bump
- "@version": 2,
+ "@version": 3,- check_version_compatibility(version, 2, 1)
+ check_version_compatibility(version, 3, 1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pd/model/descriptor/dpa3.py` around lines 467 - 474, The serialized
payload now includes the new constructor kwarg add_chg_spin_ebd and three extra
blobs (chg_embedding, spin_embedding, mix_cs_mlp) but the descriptor still
advertises `@version`: 2; update the schema version in the class's serialize
output to a new integer (e.g., 3) so DescrptDPA3.deserialize() can reject older
v2 checkpoints; locate the serialize method that builds the dict with
"add_chg_spin_ebd", "type_map", and "type_embedding" (and the conditional
chg/spin/mix_cs_mlp entries) and increment the version field, and ensure
DescrptDPA3.deserialize() handles or rejects the new version accordingly.
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deserialize() only guards on chg_embedding is not None before deserializing all three add_chg_spin components. If the serialized dict is partially missing spin_embedding or mix_cs_mlp, this will fail later or leave the object in an inconsistent state. Consider checking that all required keys are present when add_chg_spin_ebd is true, and raising a clear error if any are missing.
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert for required runtime inputs (fparam and embedding modules) is fragile because asserts can be disabled with Python optimizations and they don’t provide actionable error messages. Prefer explicit validation (e.g., raise ValueError with expected shape nf x 2 and valid ranges) and also validate mix_cs_mlp is present when add_chg_spin_ebd is enabled.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -288,6 +288,7 @@ def forward( | |||||
| nlist: paddle.Tensor, | ||||||
| mapping: paddle.Tensor | None = None, | ||||||
| comm_dict: list[paddle.Tensor] | None = None, | ||||||
| fparam: paddle.Tensor | None = None, | ||||||
| ) -> paddle.Tensor: | ||||||
|
||||||
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None, paddle.Tensor | None, paddle.Tensor]: |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -436,6 +436,7 @@ def forward( | |||||
| nlist: paddle.Tensor, | ||||||
| mapping: paddle.Tensor | None = None, | ||||||
| comm_dict: list[paddle.Tensor] | None = None, | ||||||
| fparam: paddle.Tensor | None = None, | ||||||
| ) -> paddle.Tensor: | ||||||
|
||||||
| ) -> paddle.Tensor: | |
| ) -> tuple[paddle.Tensor, None, None, None, paddle.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward()is annotated as returningpaddle.Tensor, but it actually returns a 5-tuple. Since this method signature was updated, please fix the return type annotation to reflect the tuple return type.