From 3ebe1ac22e090c10ecf4c478fe6f89dc8b398fa0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:13:46 -0800 Subject: [PATCH] Disable dynamic_vram when using torch compiler (#12612) * mp: attach re-construction arguments to model patcher When making a model-patcher from a unet or ckpt, attach a callable function that can be called to replay the model construction. This can be used to deep clone model patcher WRT the actual model. Originally written by Kosinkadink https://github.com/Comfy-Org/ComfyUI/commit/f4b99bc62389af315013dda85f24f2bbd262b686 * mp: Add disable_dynamic clone argument Add a clone argument that lets a caller clone a ModelPatcher but disable dynamic to demote the clone to regular MP. This is useful for legacy features where dynamic_vram support is missing or TBD. * torch_compile: disable dynamic_vram This is a bigger feature. Disable for the interim to preserve functionality. --- comfy/model_patcher.py | 14 ++++++++++++-- comfy/sd.py | 29 +++++++++++++++++++++-------- comfy_extras/nodes_torch_compile.py | 2 +- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 21b4ce53e3fd..1c9ba8096b30 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -271,6 +271,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.cached_patcher_init: tuple[Callable, tuple] | None = None if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -307,8 +308,15 @@ def lowvram_patch_counter(self): def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + def clone(self, disable_dynamic=False): + class_ = self.__class__ + model = self.model + if self.is_dynamic() and disable_dynamic: + class_ = ModelPatcher + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model = temp_model_patcher.model + + n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -362,6 +370,8 @@ def clone(self): n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.cached_patcher_init = self.cached_patcher_init + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n diff --git a/comfy/sd.py b/comfy/sd.py index ce6ca5d17c27..69d4531e31ee 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1530,14 +1530,24 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) + if output_model: + out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): +def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, + embedding_directory=embedding_directory, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return model + +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None vae = None @@ -1586,7 +1596,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: @@ -1637,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1721,7 +1732,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) if not model_management.is_device_cpu(offload_device): model.to(offload_device) model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) @@ -1730,12 +1742,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): logging.info("left over keys in diffusion model: {}".format(left_over)) return model_patcher -def load_diffusion_model(unet_path, model_options={}): +def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) - model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) + model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model def load_unet(unet_path, dtype=None): diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 00e9f8b1fd7e..c9e2e0026f4a 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -25,7 +25,7 @@ def define_schema(cls) -> io.Schema: @classmethod def execute(cls, model, backend) -> io.NodeOutput: - m = model.clone() + m = model.clone(disable_dynamic=True) set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict}) return io.NodeOutput(m)