diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1a24e6d95f5f..d5674dea676b 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -171,7 +171,10 @@ def forward_orig( pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -215,7 +218,10 @@ def block_wrap(args): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80091..9c403d580a79 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,7 +1012,7 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 +NUM_STREAMS = 0 if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) @@ -1030,7 +1030,7 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: return None if device in STREAMS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f706d9..3eac772758b5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ def __call__(self, weight): else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ def _load_list(self): skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):