@@ -313,22 +313,46 @@ def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=N
313313 self .inpaint_image = inpaint_image
314314 self .mask = mask
315315 self .strength = strength
316- self .encoded_image = self .encode_latent_cond (image )
317- self .encoded_image_size = (image .shape [1 ], image .shape [2 ])
318- self .temp_data = None
316+ self .is_inpaint = self .model_patch .model .additional_in_dim > 0
319317
320- def encode_latent_cond (self , control_image , inpaint_image = None ):
321- latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (control_image ))
322- if self .model_patch .model .additional_in_dim > 0 :
323- if self .mask is None :
324- mask_ = torch .zeros_like (latent_image )[:, :1 ]
318+ skip_encoding = False
319+ if self .image is not None and self .inpaint_image is not None :
320+ if self .image .shape != self .inpaint_image .shape :
321+ skip_encoding = True
322+
323+ if skip_encoding :
324+ self .encoded_image = None
325+ else :
326+ self .encoded_image = self .encode_latent_cond (self .image , self .inpaint_image )
327+ if self .image is None :
328+ self .encoded_image_size = (self .inpaint_image .shape [1 ], self .inpaint_image .shape [2 ])
325329 else :
326- mask_ = comfy .utils .common_upscale (self .mask .mean (dim = 1 , keepdim = True ), latent_image .shape [- 1 ], latent_image .shape [- 2 ], "bilinear" , "none" )
330+ self .encoded_image_size = (self .image .shape [1 ], self .image .shape [2 ])
331+ self .temp_data = None
332+
333+ def encode_latent_cond (self , control_image = None , inpaint_image = None ):
334+ latent_image = None
335+ if control_image is not None :
336+ latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (control_image ))
337+
338+ if self .is_inpaint :
327339 if inpaint_image is None :
328340 inpaint_image = torch .ones_like (control_image ) * 0.5
329341
342+ if self .mask is not None :
343+ mask_inpaint = comfy .utils .common_upscale (self .mask .view (self .mask .shape [0 ], - 1 , self .mask .shape [- 2 ], self .mask .shape [- 1 ]).mean (dim = 1 , keepdim = True ), inpaint_image .shape [- 2 ], inpaint_image .shape [- 3 ], "bilinear" , "center" )
344+ inpaint_image = ((inpaint_image - 0.5 ) * mask_inpaint .movedim (1 , - 1 ).round ()) + 0.5
345+
330346 inpaint_image_latent = comfy .latent_formats .Flux ().process_in (self .vae .encode (inpaint_image ))
331347
348+ if self .mask is None :
349+ mask_ = torch .zeros_like (inpaint_image_latent )[:, :1 ]
350+ else :
351+ mask_ = comfy .utils .common_upscale (self .mask .view (self .mask .shape [0 ], - 1 , self .mask .shape [- 2 ], self .mask .shape [- 1 ]).mean (dim = 1 , keepdim = True ), inpaint_image_latent .shape [- 1 ], inpaint_image_latent .shape [- 2 ], "nearest" , "center" )
352+
353+ if latent_image is None :
354+ latent_image = comfy .latent_formats .Flux ().process_in (self .vae .encode (torch .ones_like (inpaint_image ) * 0.5 ))
355+
332356 return torch .cat ([latent_image , mask_ , inpaint_image_latent ], dim = 1 )
333357 else :
334358 return latent_image
@@ -344,13 +368,18 @@ def __call__(self, kwargs):
344368 block_type = kwargs .get ("block_type" , "" )
345369 spacial_compression = self .vae .spacial_compression_encode ()
346370 if self .encoded_image is None or self .encoded_image_size != (x .shape [- 2 ] * spacial_compression , x .shape [- 1 ] * spacial_compression ):
347- image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" )
371+ image_scaled = None
372+ if self .image is not None :
373+ image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" ).movedim (1 , - 1 )
374+ self .encoded_image_size = (image_scaled .shape [- 3 ], image_scaled .shape [- 2 ])
375+
348376 inpaint_scaled = None
349377 if self .inpaint_image is not None :
350378 inpaint_scaled = comfy .utils .common_upscale (self .inpaint_image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" ).movedim (1 , - 1 )
379+ self .encoded_image_size = (inpaint_scaled .shape [- 3 ], inpaint_scaled .shape [- 2 ])
380+
351381 loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
352- self .encoded_image = self .encode_latent_cond (image_scaled .movedim (1 , - 1 ), inpaint_scaled )
353- self .encoded_image_size = (image_scaled .shape [- 2 ], image_scaled .shape [- 1 ])
382+ self .encoded_image = self .encode_latent_cond (image_scaled , inpaint_scaled )
354383 comfy .model_management .load_models_gpu (loaded_models )
355384
356385 cnet_blocks = self .model_patch .model .n_control_layers
@@ -391,7 +420,8 @@ def __call__(self, kwargs):
391420
392421 def to (self , device_or_dtype ):
393422 if isinstance (device_or_dtype , torch .device ):
394- self .encoded_image = self .encoded_image .to (device_or_dtype )
423+ if self .encoded_image is not None :
424+ self .encoded_image = self .encoded_image .to (device_or_dtype )
395425 self .temp_data = None
396426 return self
397427
@@ -414,9 +444,12 @@ def INPUT_TYPES(s):
414444
415445 CATEGORY = "advanced/loaders/qwen"
416446
417- def diffsynth_controlnet (self , model , model_patch , vae , image , strength , mask = None ):
447+ def diffsynth_controlnet (self , model , model_patch , vae , image = None , strength = 1.0 , inpaint_image = None , mask = None ):
418448 model_patched = model .clone ()
419- image = image [:, :, :, :3 ]
449+ if image is not None :
450+ image = image [:, :, :, :3 ]
451+ if inpaint_image is not None :
452+ inpaint_image = inpaint_image [:, :, :, :3 ]
420453 if mask is not None :
421454 if mask .ndim == 3 :
422455 mask = mask .unsqueeze (1 )
@@ -425,13 +458,24 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
425458 mask = 1.0 - mask
426459
427460 if isinstance (model_patch .model , comfy .ldm .lumina .controlnet .ZImage_Control ):
428- patch = ZImageControlPatch (model_patch , vae , image , strength , mask = mask )
461+ patch = ZImageControlPatch (model_patch , vae , image , strength , inpaint_image = inpaint_image , mask = mask )
429462 model_patched .set_model_noise_refiner_patch (patch )
430463 model_patched .set_model_double_block_patch (patch )
431464 else :
432465 model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
433466 return (model_patched ,)
434467
468+ class ZImageFunControlnet (QwenImageDiffsynthControlnet ):
469+ @classmethod
470+ def INPUT_TYPES (s ):
471+ return {"required" : { "model" : ("MODEL" ,),
472+ "model_patch" : ("MODEL_PATCH" ,),
473+ "vae" : ("VAE" ,),
474+ "strength" : ("FLOAT" , {"default" : 1.0 , "min" : - 10.0 , "max" : 10.0 , "step" : 0.01 }),
475+ },
476+ "optional" : {"image" : ("IMAGE" ,), "inpaint_image" : ("IMAGE" ,), "mask" : ("MASK" ,)}}
477+
478+ CATEGORY = "advanced/loaders/zimage"
435479
436480class UsoStyleProjectorPatch :
437481 def __init__ (self , model_patch , encoded_image ):
@@ -479,5 +523,6 @@ def apply_patch(self, model, model_patch, clip_vision_output):
479523NODE_CLASS_MAPPINGS = {
480524 "ModelPatchLoader" : ModelPatchLoader ,
481525 "QwenImageDiffsynthControlnet" : QwenImageDiffsynthControlnet ,
526+ "ZImageFunControlnet" : ZImageFunControlnet ,
482527 "USOStyleReference" : USOStyleReference ,
483528}
0 commit comments