diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 858bd4cc782b..0b7da2852aaa 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered): return dest_views -aimdo_allocator = None +aimdo_enabled = False diff --git a/comfy/model_management.py b/comfy/model_management.py index 38c3e482b834..1fe56a62b3a6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: return torch_dev else: return cpu_dev @@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref): synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer - #FIXME: This doesn't work in Aimdo because mempool cant clear cache soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) diff --git a/comfy/utils.py b/comfy/utils.py index 17443b4ccd6d..5fe66ecdb34e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) def model_trange(*args, **kwargs): - if comfy.memory_management.aimdo_allocator is None: + if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) pbar = trange(*args, **kwargs, smoothing=1.0) diff --git a/cuda_malloc.py b/cuda_malloc.py index b2182df374cb..f7651981c126 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,10 +1,8 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +from comfy.cli_args import args, PerformanceFeature import subprocess -import comfy_aimdo.control - #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -87,10 +85,6 @@ def cuda_malloc_supported(): except: pass -if enables_dynamic_vram() and comfy_aimdo.control.init(): - args.cuda_malloc = False - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" - if args.disable_cuda_malloc: args.cuda_malloc = False diff --git a/execution.py b/execution.py index f549a2f0ff2d..75b021892de0 100644 --- a/execution.py +++ b/execution.py @@ -9,7 +9,6 @@ from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio -from contextlib import nullcontext import torch @@ -521,19 +520,14 @@ def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows - #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc - #that we just want to cull out each model run. - allocator = comfy.memory_management.aimdo_allocator - with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): - try: - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - finally: - if allocator is not None: - if args.verbose == "DEBUG": - comfy_aimdo.model_vbar.vbars_analyze() - comfy.model_management.reset_cast_buffers() - comfy_aimdo.model_vbar.vbars_reset_watermark_limits() + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if comfy.memory_management.aimdo_enabled: + if args.verbose == "DEBUG": + comfy_aimdo.control.analyze() + comfy.model_management.reset_cast_buffers() + comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: pending_async_nodes[unique_id] = output_data diff --git a/main.py b/main.py index 92d705b4d548..39e605debe27 100644 --- a/main.py +++ b/main.py @@ -173,6 +173,10 @@ def execute_script(script_path): if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() import comfy.utils @@ -188,13 +192,9 @@ def execute_script(script_path): import comfy.memory_management import comfy.model_patcher -import comfy_aimdo.control -import comfy_aimdo.torch - if enables_dynamic_vram(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - comfy.memory_management.aimdo_allocator = None elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() @@ -208,11 +208,10 @@ def execute_script(script_path): comfy_aimdo.control.set_log_info() comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic - comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + comfy.memory_management.aimdo_enabled = True logging.info("DynamicVRAM support detected and enabled") else: logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - comfy.memory_management.aimdo_allocator = None def cuda_malloc_warning(): diff --git a/requirements.txt b/requirements.txt index 3a9bfde46987..8fbb0dbd6692 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.8 +comfy-aimdo>=0.2.0 requests #non essential dependencies: