Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/memory_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):

return dest_views

aimdo_allocator = None
aimdo_enabled = False
3 changes: 1 addition & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):

mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
return torch_dev
else:
return cpu_dev
Expand Down Expand Up @@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
Expand Down
2 changes: 1 addition & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)

def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs)

pbar = trange(*args, **kwargs, smoothing=1.0)
Expand Down
8 changes: 1 addition & 7 deletions cuda_malloc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
import importlib.util
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
from comfy.cli_args import args, PerformanceFeature
import subprocess

import comfy_aimdo.control

#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
Expand Down Expand Up @@ -87,10 +85,6 @@ def cuda_malloc_supported():
except:
pass

if enables_dynamic_vram() and comfy_aimdo.control.init():
args.cuda_malloc = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""

if args.disable_cuda_malloc:
args.cuda_malloc = False

Expand Down
22 changes: 8 additions & 14 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
from contextlib import nullcontext

import torch

Expand Down Expand Up @@ -521,19 +520,14 @@ def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)

#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if comfy.memory_management.aimdo_enabled:
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()

if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
Expand Down
11 changes: 5 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def execute_script(script_path):
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")

import comfy_aimdo.control

if enables_dynamic_vram():
comfy_aimdo.control.init()

import comfy.utils

Expand All @@ -188,13 +192,9 @@ def execute_script(script_path):
import comfy.memory_management
import comfy.model_patcher

import comfy_aimdo.control
import comfy_aimdo.torch

if enables_dynamic_vram():
if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
Expand All @@ -208,11 +208,10 @@ def execute_script(script_path):
comfy_aimdo.control.set_log_info()

comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
comfy.memory_management.aimdo_enabled = True
logging.info("DynamicVRAM support detected and enabled")
else:
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None


def cuda_malloc_warning():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.8
comfy-aimdo>=0.2.0
requests

#non essential dependencies:
Expand Down
Loading