Skip to content

Commit 645ee18

Browse files
Inpainting for z image fun control. Use the ZImageFunControlnet node. (Comfy-Org#11346)
image -> control image ex: pose inpaint_image -> image for inpainting mask -> inpaint mask
1 parent 3d082c3 commit 645ee18

1 file changed

Lines changed: 61 additions & 16 deletions

File tree

comfy_extras/nodes_model_patch.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -313,22 +313,46 @@ def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=N
313313
self.inpaint_image = inpaint_image
314314
self.mask = mask
315315
self.strength = strength
316-
self.encoded_image = self.encode_latent_cond(image)
317-
self.encoded_image_size = (image.shape[1], image.shape[2])
318-
self.temp_data = None
316+
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
319317

320-
def encode_latent_cond(self, control_image, inpaint_image=None):
321-
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
322-
if self.model_patch.model.additional_in_dim > 0:
323-
if self.mask is None:
324-
mask_ = torch.zeros_like(latent_image)[:, :1]
318+
skip_encoding = False
319+
if self.image is not None and self.inpaint_image is not None:
320+
if self.image.shape != self.inpaint_image.shape:
321+
skip_encoding = True
322+
323+
if skip_encoding:
324+
self.encoded_image = None
325+
else:
326+
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
327+
if self.image is None:
328+
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
325329
else:
326-
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
330+
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
331+
self.temp_data = None
332+
333+
def encode_latent_cond(self, control_image=None, inpaint_image=None):
334+
latent_image = None
335+
if control_image is not None:
336+
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
337+
338+
if self.is_inpaint:
327339
if inpaint_image is None:
328340
inpaint_image = torch.ones_like(control_image) * 0.5
329341

342+
if self.mask is not None:
343+
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
344+
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
345+
330346
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
331347

348+
if self.mask is None:
349+
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
350+
else:
351+
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
352+
353+
if latent_image is None:
354+
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
355+
332356
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
333357
else:
334358
return latent_image
@@ -344,13 +368,18 @@ def __call__(self, kwargs):
344368
block_type = kwargs.get("block_type", "")
345369
spacial_compression = self.vae.spacial_compression_encode()
346370
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
347-
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
371+
image_scaled = None
372+
if self.image is not None:
373+
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
374+
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
375+
348376
inpaint_scaled = None
349377
if self.inpaint_image is not None:
350378
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
379+
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
380+
351381
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
352-
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
353-
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
382+
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
354383
comfy.model_management.load_models_gpu(loaded_models)
355384

356385
cnet_blocks = self.model_patch.model.n_control_layers
@@ -391,7 +420,8 @@ def __call__(self, kwargs):
391420

392421
def to(self, device_or_dtype):
393422
if isinstance(device_or_dtype, torch.device):
394-
self.encoded_image = self.encoded_image.to(device_or_dtype)
423+
if self.encoded_image is not None:
424+
self.encoded_image = self.encoded_image.to(device_or_dtype)
395425
self.temp_data = None
396426
return self
397427

@@ -414,9 +444,12 @@ def INPUT_TYPES(s):
414444

415445
CATEGORY = "advanced/loaders/qwen"
416446

417-
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
447+
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
418448
model_patched = model.clone()
419-
image = image[:, :, :, :3]
449+
if image is not None:
450+
image = image[:, :, :, :3]
451+
if inpaint_image is not None:
452+
inpaint_image = inpaint_image[:, :, :, :3]
420453
if mask is not None:
421454
if mask.ndim == 3:
422455
mask = mask.unsqueeze(1)
@@ -425,13 +458,24 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
425458
mask = 1.0 - mask
426459

427460
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
428-
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
461+
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
429462
model_patched.set_model_noise_refiner_patch(patch)
430463
model_patched.set_model_double_block_patch(patch)
431464
else:
432465
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
433466
return (model_patched,)
434467

468+
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
469+
@classmethod
470+
def INPUT_TYPES(s):
471+
return {"required": { "model": ("MODEL",),
472+
"model_patch": ("MODEL_PATCH",),
473+
"vae": ("VAE",),
474+
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
475+
},
476+
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
477+
478+
CATEGORY = "advanced/loaders/zimage"
435479

436480
class UsoStyleProjectorPatch:
437481
def __init__(self, model_patch, encoded_image):
@@ -479,5 +523,6 @@ def apply_patch(self, model, model_patch, clip_vision_output):
479523
NODE_CLASS_MAPPINGS = {
480524
"ModelPatchLoader": ModelPatchLoader,
481525
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
526+
"ZImageFunControlnet": ZImageFunControlnet,
482527
"USOStyleReference": USOStyleReference,
483528
}

0 commit comments

Comments
 (0)