@@ -322,6 +322,7 @@ def load_vae_weights(
322322
323323 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
324324 flax_key = _tuple_str_to_int (flax_key )
325+ max_logging .log (f"Mapped VAE key: { pt_key } -> { flax_key } " )
325326
326327 if resnet_index is not None :
327328 str_flax_key = tuple ([str (x ) for x in flax_key ])
@@ -347,7 +348,7 @@ def load_vae_weights(
347348
348349def rename_for_ltx2_vocoder (key ):
349350 key = key .replace ("ups." , "upsamplers." )
350- key = key .replace ("resblocks" , "resnets " )
351+ key = key .replace ("resblocks. " , "resblocks_ " )
351352 key = key .replace ("conv_post" , "conv_out" )
352353 key = key .replace ("conv_pre" , "conv_in" )
353354 key = key .replace ("act_post" , "act_out" )
@@ -376,6 +377,10 @@ def load_vocoder_weights(
376377
377378 flax_key = _tuple_str_to_int (parts )
378379
380+ # Skip filter keys as they are derived in NNX model
381+ if "filter" in flax_key :
382+ continue
383+
379384 if flax_key [- 1 ] == "kernel" :
380385 if "upsamplers" in flax_key :
381386 tensor = tensor .transpose (2 , 0 , 1 )[::- 1 , :, :]
0 commit comments