Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions comfy_extras/nodes_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import re
import math
import torch
import comfy.utils

Expand Down Expand Up @@ -682,6 +683,172 @@ def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
upscale = execute # TODO: remove


class SplitImageToTileList(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SplitImageToTileList",
category="image/batch",
search_aliases=["split image", "tile image", "slice image"],
display_name="Split Image into List of Tiles",
description="Splits an image into a batched list of tiles with a specified overlap.",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
IO.Int.Input("overlap", default=128, min=0, max=4096),
],
outputs=[
IO.Image.Output(is_output_list=True),
],
)

@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)

y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)

while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)

coords.append((x_start, y_start, x_end, y_end))

if x_end >= width:
break
x += stride_x

if y_end >= height:
break
y += stride_y

return coords

@classmethod
def execute(cls, image, tile_width, tile_height, overlap):
b, h, w, c = image.shape
coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)

output_list = []
for (x_start, y_start, x_end, y_end) in coords:
tile = image[:, y_start:y_end, x_start:x_end, :]
output_list.append(tile)

return IO.NodeOutput(output_list)


class ImageMergeTileList(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageMergeTileList",
display_name="Merge List of Tiles to Image",
category="image/batch",
search_aliases=["split image", "tile image", "slice image"],
is_input_list=True,
inputs=[
IO.Image.Input("image_list"),
IO.Int.Input("final_width", default=1024, min=64, max=32768),
IO.Int.Input("final_height", default=1024, min=64, max=32768),
IO.Int.Input("overlap", default=128, min=0, max=4096),
],
outputs=[
IO.Image.Output(is_output_list=False),
],
)

@staticmethod
def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = []
stride_x = max(1, tile_width - overlap)
stride_y = max(1, tile_height - overlap)

y = 0
while y < height:
x = 0
y_end = min(y + tile_height, height)
y_start = max(0, y_end - tile_height)

while x < width:
x_end = min(x + tile_width, width)
x_start = max(0, x_end - tile_width)

coords.append((x_start, y_start, x_end, y_end))

if x_end >= width:
break
x += stride_x

if y_end >= height:
break
y += stride_y

return coords

@classmethod
def execute(cls, image_list, final_width, final_height, overlap):
w = final_width[0]
h = final_height[0]
ovlp = overlap[0]
feather_str = 1.0

first_tile = image_list[0]
b, t_h, t_w, c = first_tile.shape
device = first_tile.device
dtype = first_tile.dtype

coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)

canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)

if ovlp > 0:
y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
y_w = torch.clamp(y_w, min=1e-5)
x_w = torch.clamp(x_w, min=1e-5)

sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
flat_mask = torch.ones_like(sine_mask)

weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
else:
weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)

for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
if i >= len(image_list):
break

tile = image_list[i]

region_h = y_end - y_start
region_w = x_end - x_start

real_h = min(region_h, tile.shape[1])
real_w = min(region_w, tile.shape[2])

y_end_actual = y_start + real_h
x_end_actual = x_start + real_w

tile_crop = tile[:, :real_h, :real_w, :]
mask_crop = weight_mask[:, :real_h, :real_w, :]

canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop

weights[weights == 0] = 1.0
merged_image = canvas / weights

return IO.NodeOutput(merged_image)


class ImagesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
Expand All @@ -701,6 +868,8 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
SplitImageToTileList,
ImageMergeTileList,
]


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.14
comfyui-workflow-templates==0.8.43
comfyui-workflow-templates==0.9.2
comfyui-embedded-docs==0.4.1
torch
torchsde
Expand Down
Loading