Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,187 @@ def freeze(
log.info("Saved frozen model to %s", output)


def change_bias(
input_file: str,
mode: str = "change",
bias_value: list | None = None,
datafile: str | None = None,
system: str = ".",
numb_batch: int = 0,
model_branch: str | None = None,
output: str | None = None,
) -> None:
"""Change the output bias of a pt_expt model.

Parameters
----------
input_file : str
Path to the model file (.pt checkpoint or .pte frozen model).
mode : str
``"change"`` or ``"set"``.
bias_value : list or None
User-defined bias values (one per type).
datafile : str or None
File listing data system paths.
system : str
Data system path (used when *datafile* is None).
numb_batch : int
Number of batches for statistics (0 = all).
model_branch : str or None
Branch name for multi-task models.
output : str or None
Output file path.
"""
import torch

from deepmd.common import (
expand_sys_str,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.pt_expt.model.get_model import (
get_model,
)
from deepmd.pt_expt.train.training import (
get_additional_data_requirement,
get_loss,
model_change_out_bias,
)
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)
from deepmd.pt_expt.utils.serialization import (
deserialize_to_file,
serialize_from_file,
)
from deepmd.pt_expt.utils.stat import (
make_stat_input,
)

if input_file.endswith(".pt"):
old_state_dict = torch.load(input_file, map_location=DEVICE, weights_only=True)
if "model" in old_state_dict:
model_state_dict = old_state_dict["model"]
else:
model_state_dict = old_state_dict
extra_state = model_state_dict.get("_extra_state")
if not isinstance(extra_state, dict) or "model_params" not in extra_state:
raise ValueError(
f"Unsupported checkpoint format at '{input_file}': missing "
"'_extra_state.model_params' in model state dict."
)
model_params = extra_state["model_params"]
elif input_file.endswith((".pte", ".pt2")):
pte_data = serialize_from_file(input_file)
from deepmd.pt_expt.model.model import (
BaseModel,
)

model_to_change = BaseModel.deserialize(pte_data["model"])
model_params = None
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pte/.pt2 extension"
)

if mode == "change":
bias_adjust_mode = "change-by-statistic"
elif mode == "set":
bias_adjust_mode = "set-by-statistic"
else:
raise ValueError(f"Unsupported mode '{mode}'. Expected 'change' or 'set'.")

if input_file.endswith(".pt"):
multi_task = "model_dict" in model_params
if multi_task:
raise NotImplementedError(
"Multi-task change-bias is not yet supported for the pt_expt backend."
)
type_map = model_params["type_map"]
model = get_model(model_params)
wrapper = ModelWrapper(model)
wrapper.load_state_dict(model_state_dict)
model_to_change = model

if input_file.endswith((".pte", ".pt2")):
type_map = model_to_change.get_type_map()

if bias_value is not None:
if "energy" not in model_to_change.model_output_type():
raise ValueError("User-defined bias is only available for energy models!")
if len(bias_value) != len(type_map):
raise ValueError(
f"The number of elements in the bias ({len(bias_value)}) must match "
f"the number of types in type_map ({len(type_map)}): {type_map}."
)
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
f"Change output bias of {type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
)
else:
if datafile is not None:
with open(datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
all_sys = expand_sys_str(system)
data_systems = process_systems(all_sys)
data = DeepmdDataSystem(
systems=data_systems,
batch_size=1,
test_size=1,
rcut=model_to_change.get_rcut(),
type_map=type_map,
)
mock_loss = get_loss({"inference": True}, 1.0, len(type_map), model_to_change)
data.add_data_requirements(mock_loss.label_requirement)
data.add_data_requirements(get_additional_data_requirement(model_to_change))
if numb_batch != 0:
nbatches = numb_batch
else:
# Cap at the minimum across systems so no system wraps and
# overweights short systems (matching PT behavior).
nbatches = min(data.get_nbatches())
sampled_data = make_stat_input(data, nbatches)
model_to_change = model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
)

if input_file.endswith(".pt"):
output_path = (
output if output is not None else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model_to_change)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = extra_state
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = extra_state
torch.save(old_state_dict, output_path)
elif input_file.endswith((".pte", ".pt2")):
output_path = (
output
if output is not None
else input_file.replace(".pte", "_updated.pte").replace(
".pt2", "_updated.pt2"
)
)
model_dict = model_to_change.serialize()
deserialize_to_file(output_path, {"model": model_dict})
log.info(f"Saved model to {output_path}")


def main(args: list[str] | argparse.Namespace | None = None) -> None:
"""Entry point for the pt_expt backend CLI.

Expand Down Expand Up @@ -323,6 +504,17 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
if not FLAGS.output.endswith((".pte", ".pt2")):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
elif FLAGS.command == "change-bias":
change_bias(
input_file=FLAGS.INPUT,
mode=FLAGS.mode,
bias_value=FLAGS.bias_value,
datafile=FLAGS.datafile,
system=FLAGS.system,
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
elif FLAGS.command == "compress":
from deepmd.pt_expt.entrypoints.compress import (
enable_compression,
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,14 @@ def model_change_out_bias(
bias_adjust_mode=_bias_adjust_mode,
)
new_bias = deepcopy(_model.get_out_bias())

from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)

if isinstance(_model, DPModelCommon) and _bias_adjust_mode == "set-by-statistic":
_model.get_fitting_net().compute_input_stats(_sample_func)

model_type_map = _model.get_type_map()
log.info(
f"Change output bias of {model_type_map!s} "
Expand Down
23 changes: 20 additions & 3 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
)


@pytest.fixture(autouse=True)
def _clear_leaked_device_context():
"""Pop any stale ``DeviceContext`` before each test, restore after."""
def _pop_device_contexts() -> list:
"""Pop all stale DeviceContext modes from the torch function mode stack."""
popped = []
while True:
modes = _get_current_function_mode_stack()
Expand All @@ -46,6 +45,24 @@ def _clear_leaked_device_context():
popped.append(top)
else:
break
return popped


@pytest.fixture(autouse=True, scope="session")
def _clear_leaked_device_context_session():
"""Pop any stale DeviceContext once at session start.

This runs before any setUpClass, preventing CUDA init errors
in tests that call trainer.run() during class setup.
"""
_pop_device_contexts()
yield


@pytest.fixture(autouse=True)
def _clear_leaked_device_context():
"""Pop any stale ``DeviceContext`` before each test, restore after."""
popped = _pop_device_contexts()
yield
# Restore in reverse order so the stack is back to its original state.
for ctx in reversed(popped):
Expand Down
Loading
Loading