Skip to content

Commit 758cb92

Browse files
committed
weight loading
1 parent b6fe087 commit 758cb92

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def load_vae_weights(
322322

323323
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
324324
flax_key = _tuple_str_to_int(flax_key)
325+
max_logging.log(f"Mapped VAE key: {pt_key} -> {flax_key}")
325326

326327
if resnet_index is not None:
327328
str_flax_key = tuple([str(x) for x in flax_key])
@@ -347,7 +348,7 @@ def load_vae_weights(
347348

348349
def rename_for_ltx2_vocoder(key):
349350
key = key.replace("ups.", "upsamplers.")
350-
key = key.replace("resblocks", "resnets")
351+
key = key.replace("resblocks.", "resblocks_")
351352
key = key.replace("conv_post", "conv_out")
352353
key = key.replace("conv_pre", "conv_in")
353354
key = key.replace("act_post", "act_out")
@@ -376,6 +377,10 @@ def load_vocoder_weights(
376377

377378
flax_key = _tuple_str_to_int(parts)
378379

380+
# Skip filter keys as they are derived in NNX model
381+
if "filter" in flax_key:
382+
continue
383+
379384
if flax_key[-1] == "kernel":
380385
if "upsamplers" in flax_key:
381386
tensor = tensor.transpose(2, 0, 1)[::-1, :, :]

0 commit comments

Comments
 (0)