Skip to content

Commit 0ff0457

Browse files
authored
mm: wrap the raw stream in context manager (Comfy-Org#10958)
The documentation of torch.foo.Stream being usable with with: suggests it starts at version 2.7. Use the old API for backwards compatibility.
1 parent 6484ac8 commit 0ff0457

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

comfy/model_management.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,15 +1055,19 @@ def get_offload_stream(device):
10551055
elif is_device_cuda(device):
10561056
ss = []
10571057
for k in range(NUM_STREAMS):
1058-
ss.append(torch.cuda.Stream(device=device, priority=0))
1058+
s1 = torch.cuda.Stream(device=device, priority=0)
1059+
s1.as_context = torch.cuda.stream
1060+
ss.append(s1)
10591061
STREAMS[device] = ss
10601062
s = ss[stream_counter]
10611063
stream_counters[device] = stream_counter
10621064
return s
10631065
elif is_device_xpu(device):
10641066
ss = []
10651067
for k in range(NUM_STREAMS):
1066-
ss.append(torch.xpu.Stream(device=device, priority=0))
1068+
s1 = torch.xpu.Stream(device=device, priority=0)
1069+
s1.as_context = torch.xpu.stream
1070+
ss.append(s1)
10671071
STREAMS[device] = ss
10681072
s = ss[stream_counter]
10691073
stream_counters[device] = stream_counter
@@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
10811085
if dtype is None or weight.dtype == dtype:
10821086
return weight
10831087
if stream is not None:
1084-
with stream:
1088+
wf_context = stream
1089+
if hasattr(wf_context, "as_context"):
1090+
wf_context = wf_context.as_context(stream)
1091+
with wf_context:
10851092
return weight.to(dtype=dtype, copy=copy)
10861093
return weight.to(dtype=dtype, copy=copy)
10871094

1095+
10881096
if stream is not None:
1089-
with stream:
1097+
wf_context = stream
1098+
if hasattr(wf_context, "as_context"):
1099+
wf_context = wf_context.as_context(stream)
1100+
with wf_context:
10901101
r = torch.empty_like(weight, dtype=dtype, device=device)
10911102
r.copy_(weight, non_blocking=non_blocking)
10921103
else:

comfy/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
9595

9696
if offload_stream is not None:
9797
wf_context = offload_stream
98+
if hasattr(wf_context, "as_context"):
99+
wf_context = wf_context.as_context(offload_stream)
98100
else:
99101
wf_context = contextlib.nullcontext()
100102

0 commit comments

Comments
 (0)