Skip to content

Commit 828ec0d

Browse files
committed
Add LoRA Inference Support for LTX2 Model
1 parent 6de9d57 commit 828ec0d

9 files changed

Lines changed: 460 additions & 181 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ flash_block_sizes: {
6868
block_kv_dkv_compute: 2048,
6969
use_fused_bwd_kernel: True,
7070
}
71+
flash_min_seq_length: 4096
7172
dcn_context_parallelism: 1
7273
dcn_tensor_parallelism: 1
7374
ici_data_parallelism: 1
@@ -102,3 +103,23 @@ jit_initializers: True
102103
enable_single_replica_ckpt_restoring: False
103104
seed: 0
104105
audio_format: "s16"
106+
107+
# LoRA parameters
108+
enable_lora: False
109+
110+
# Distilled LoRA
111+
# lora_config: {
112+
# lora_model_name_or_path: ["Lightricks/LTX-2"],
113+
# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"],
114+
# adapter_name: ["distilled-lora-384"],
115+
# rank: [384]
116+
# }
117+
118+
# Standard LoRA
119+
lora_config: {
120+
lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"],
121+
weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"],
122+
adapter_name: ["camera-control-dolly-in"],
123+
rank: [32]
124+
}
125+

src/maxdiffusion/generate_ltx2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.api_core.exceptions import GoogleAPIError
2626
import flax
2727
from maxdiffusion.utils.export_utils import export_to_video_with_audio
28+
from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader
2829

2930

3031
def upload_video_to_gcs(output_dir: str, video_path: str):
@@ -118,6 +119,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
118119
checkpoint_loader = LTX2Checkpointer(config=config)
119120
pipeline, _, _ = checkpoint_loader.load_checkpoint()
120121

122+
# If LoRA is specified, inject layers and load weights.
123+
if (
124+
getattr(config, "enable_lora", False)
125+
and hasattr(config, "lora_config")
126+
and config.lora_config
127+
and config.lora_config.get("lora_model_name_or_path")
128+
):
129+
lora_loader = LTX2NNXLoraLoader()
130+
lora_config = config.lora_config
131+
paths = lora_config["lora_model_name_or_path"]
132+
weights = lora_config.get("weight_name", [None] * len(paths))
133+
scales = lora_config.get("scale", [1.0] * len(paths))
134+
ranks = lora_config.get("rank", [64] * len(paths))
135+
136+
for i in range(len(paths)):
137+
pipeline = lora_loader.load_lora_weights(
138+
pipeline,
139+
paths[i],
140+
transformer_weight_name=weights[i],
141+
rank=ranks[i],
142+
scale=scales[i],
143+
scan_layers=config.scan_layers,
144+
dtype=config.weights_dtype,
145+
)
146+
121147
pipeline.enable_vae_slicing()
122148
pipeline.enable_vae_tiling()
123149

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,98 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
703703
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"
704704

705705
return None
706+
707+
708+
def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
709+
"""
710+
Translates LTX2 NNX path to Diffusers/LoRA keys.
711+
"""
712+
# --- 2. Map NNX Suffixes to LoRA Suffixes ---
713+
suffix_map = {
714+
# Self Attention (attn1)
715+
"attn1.to_q": "attn1.to_q",
716+
"attn1.to_k": "attn1.to_k",
717+
"attn1.to_v": "attn1.to_v",
718+
"attn1.to_out": "attn1.to_out.0",
719+
# Audio Self Attention (audio_attn1)
720+
"audio_attn1.to_q": "audio_attn1.to_q",
721+
"audio_attn1.to_k": "audio_attn1.to_k",
722+
"audio_attn1.to_v": "audio_attn1.to_v",
723+
"audio_attn1.to_out": "audio_attn1.to_out.0",
724+
# Audio Cross Attention (audio_attn2)
725+
"audio_attn2.to_q": "audio_attn2.to_q",
726+
"audio_attn2.to_k": "audio_attn2.to_k",
727+
"audio_attn2.to_v": "audio_attn2.to_v",
728+
"audio_attn2.to_out": "audio_attn2.to_out.0",
729+
# Cross Attention (attn2)
730+
"attn2.to_q": "attn2.to_q",
731+
"attn2.to_k": "attn2.to_k",
732+
"attn2.to_v": "attn2.to_v",
733+
"attn2.to_out": "attn2.to_out.0",
734+
# Audio to Video Cross Attention
735+
"audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
736+
"audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
737+
"audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
738+
"audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
739+
# Video to Audio Cross Attention
740+
"video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
741+
"video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
742+
"video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
743+
"video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
744+
# Feed Forward
745+
"ff.net_0": "ff.net.0.proj",
746+
"ff.net_2": "ff.net.2",
747+
# Audio Feed Forward
748+
"audio_ff.net_0": "audio_ff.net.0.proj",
749+
"audio_ff.net_2": "audio_ff.net.2",
750+
}
751+
752+
# --- 3. Translation Logic ---
753+
global_map = {
754+
"proj_in": "diffusion_model.patchify_proj",
755+
"audio_proj_in": "diffusion_model.audio_patchify_proj",
756+
"proj_out": "diffusion_model.proj_out",
757+
"audio_proj_out": "diffusion_model.audio_proj_out",
758+
"time_embed.linear": "diffusion_model.adaln_single.linear",
759+
"audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear",
760+
"av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear",
761+
"av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear",
762+
"av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear",
763+
"av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear",
764+
# Nested conditioning layers
765+
"time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1",
766+
"time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2",
767+
"audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1",
768+
"audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2",
769+
"av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
770+
"av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
771+
"av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1",
772+
"av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2",
773+
"av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1",
774+
"av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2",
775+
"av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1",
776+
"av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2",
777+
"caption_projection.linear_1": "diffusion_model.caption_projection.linear_1",
778+
"caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
779+
"audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
780+
"audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
781+
# Connectors
782+
"feature_extractor.linear": "text_embedding_projection.aggregate_embed",
783+
}
784+
785+
if nnx_path_str in global_map:
786+
return global_map[nnx_path_str]
787+
788+
if scan_layers:
789+
if nnx_path_str.startswith("transformer_blocks."):
790+
inner_suffix = nnx_path_str[len("transformer_blocks.") :]
791+
if inner_suffix in suffix_map:
792+
return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}"
793+
else:
794+
m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str)
795+
if m:
796+
idx, inner_suffix = m.group(1), m.group(2)
797+
if inner_suffix in suffix_map:
798+
return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}"
799+
800+
return None
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""NNX-based LoRA loader for LTX2 models."""
16+
17+
from flax import nnx
18+
from .lora_base import LoRABaseMixin
19+
from .lora_pipeline import StableDiffusionLoraLoaderMixin
20+
from ..models import lora_nnx
21+
from .. import max_logging
22+
from . import lora_conversion_utils
23+
24+
25+
class LTX2NNXLoraLoader(LoRABaseMixin):
26+
"""
27+
Handles loading LoRA weights into NNX-based LTX2 model.
28+
Assumes LTX2 pipeline contains 'transformer'
29+
attributes that are NNX Modules.
30+
"""
31+
32+
def load_lora_weights(
33+
self,
34+
pipeline: nnx.Module,
35+
lora_model_path: str,
36+
transformer_weight_name: str,
37+
rank: int,
38+
scale: float = 1.0,
39+
scan_layers: bool = False,
40+
dtype: str = "float32",
41+
**kwargs,
42+
):
43+
"""
44+
Merges LoRA weights into the pipeline from a checkpoint.
45+
"""
46+
lora_loader = StableDiffusionLoraLoaderMixin()
47+
48+
merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
49+
50+
def translate_fn(nnx_path_str):
51+
return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)
52+
53+
h_state_dict = None
54+
if hasattr(pipeline, "transformer") and transformer_weight_name:
55+
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
56+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
57+
# Filter state dict for transformer keys to avoid confusing warnings
58+
transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")}
59+
merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype)
60+
else:
61+
max_logging.log("transformer not found or no weight name provided for LoRA.")
62+
63+
if hasattr(pipeline, "connectors"):
64+
max_logging.log(f"Merging LoRA into connectors with rank={rank}")
65+
if h_state_dict is None and transformer_weight_name:
66+
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
67+
68+
if h_state_dict is not None:
69+
# Filter state dict for connector keys to avoid confusing warnings
70+
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}
71+
merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype)
72+
else:
73+
max_logging.log("Could not load LoRA state dict for connectors.")
74+
75+
return pipeline

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19+
import jax
1920
import jax.numpy as jnp
2021
from ... import common_types
2122
from ..attention_flax import NNXAttentionOp
@@ -194,7 +195,7 @@ def prepare_video_coords(
194195
# pixel_coords[:, 0, ...] selects Frame dimension.
195196
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
196197
frame_coords = pixel_coords[:, 0, ...]
197-
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
198+
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0)
198199
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
199200

200201
return pixel_coords
@@ -211,12 +212,12 @@ def prepare_audio_coords(
211212
# 2. Start timestamps
212213
audio_scale_factor = self.scale_factors[0]
213214
grid_start_mel = grid_f * audio_scale_factor
214-
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
215+
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0)
215216
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
216217

217218
# 3. End timestamps
218219
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
219-
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
220+
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0)
220221
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
221222

222223
# Stack [num_patches, 2]
@@ -347,6 +348,7 @@ def __init__(
347348
attention_kernel: str = "flash",
348349
rope_type: str = "interleaved",
349350
flash_block_sizes: BlockSizes = None,
351+
flash_min_seq_length: int = 4096,
350352
):
351353
self.heads = heads
352354
self.rope_type = rope_type
@@ -434,6 +436,7 @@ def __init__(
434436
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
435437
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
436438
flash_block_sizes=flash_block_sizes,
439+
flash_min_seq_length=flash_min_seq_length,
437440
)
438441

439442
def __call__(
@@ -447,46 +450,49 @@ def __call__(
447450
# Determine context (Self or Cross)
448451
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
449452

450-
# 1. Project
451-
query = self.to_q(hidden_states)
452-
key = self.to_k(context)
453-
value = self.to_v(context)
453+
# 1. Project and Norm
454+
with jax.named_scope("QKV Projection"):
455+
query = self.to_q(hidden_states)
456+
key = self.to_k(context)
457+
value = self.to_v(context)
454458

455-
# 2. Norm (Full Inner Dimension)
456-
query = self.norm_q(query)
457-
key = self.norm_k(key)
459+
with jax.named_scope("QKV Norm"):
460+
query = self.norm_q(query)
461+
key = self.norm_k(key)
458462

459463
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
460464
# Frequencies are shape [B, S, InnerDim]
461465
# 3. Apply RoPE
462-
if rotary_emb is not None:
463-
if hasattr(self, "rope_type") and self.rope_type == "split":
464-
# Split RoPE: passing full freqs [B, H, S, D//2]
465-
# apply_split_rotary_emb handles reshaping query/key
466-
467-
query = apply_split_rotary_emb(query, rotary_emb)
468-
469-
if k_rotary_emb is not None:
470-
key = apply_split_rotary_emb(key, k_rotary_emb)
471-
elif encoder_hidden_states is None:
472-
key = apply_split_rotary_emb(key, rotary_emb)
473-
474-
else:
475-
# Interleaved (Default)
476-
query = apply_rotary_emb(query, rotary_emb)
477-
if k_rotary_emb is not None:
478-
key = apply_rotary_emb(key, k_rotary_emb)
479-
elif encoder_hidden_states is None:
480-
key = apply_rotary_emb(key, rotary_emb)
481-
482-
# 4. Attention
483-
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
484-
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
485-
486-
# 7. Output Projection
487-
hidden_states = self.to_out(attn_output)
488-
489-
if self.dropout_layer is not None:
490-
hidden_states = self.dropout_layer(hidden_states)
466+
with jax.named_scope("Apply RoPE"):
467+
if rotary_emb is not None:
468+
if hasattr(self, "rope_type") and self.rope_type == "split":
469+
# Split RoPE: passing full freqs [B, H, S, D//2]
470+
# apply_split_rotary_emb handles reshaping query/key
471+
472+
query = apply_split_rotary_emb(query, rotary_emb)
473+
474+
if k_rotary_emb is not None:
475+
key = apply_split_rotary_emb(key, k_rotary_emb)
476+
elif encoder_hidden_states is None:
477+
key = apply_split_rotary_emb(key, rotary_emb)
478+
479+
else:
480+
# Interleaved (Default)
481+
query = apply_rotary_emb(query, rotary_emb)
482+
if k_rotary_emb is not None:
483+
key = apply_rotary_emb(key, k_rotary_emb)
484+
elif encoder_hidden_states is None:
485+
key = apply_rotary_emb(key, rotary_emb)
486+
487+
with jax.named_scope("Attention and Output Project"):
488+
# 4. Attention
489+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
490+
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
491+
492+
# 7. Output Projection
493+
hidden_states = self.to_out(attn_output)
494+
495+
if self.dropout_layer is not None:
496+
hidden_states = self.dropout_layer(hidden_states)
491497

492498
return hidden_states

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def __call__(
108108
Returns:
109109
(video_embeds, audio_embeds, new_attention_mask)
110110
"""
111-
# 1. Shared Feature Extraction
112-
features = self.feature_extractor(hidden_states, attention_mask)
111+
with jax.named_scope("Text Encoder Forward"):
112+
# 1. Shared Feature Extraction
113+
features = self.feature_extractor(hidden_states, attention_mask)
113114

114-
# 2. Parallel Connection
115-
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
116-
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
115+
# 2. Parallel Connection
116+
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
117+
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
117118

118-
return video_embeds, audio_embeds, new_attention_mask
119+
return video_embeds, audio_embeds, new_attention_mask

0 commit comments

Comments
 (0)