diff --git a/pr2.patch b/pr2.patch new file mode 100644 index 000000000..0d8cabe9f --- /dev/null +++ b/pr2.patch @@ -0,0 +1,284 @@ +diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml +index 2b716755..cf9d8438 100644 +--- a/src/maxdiffusion/configs/ltx2_video.yml ++++ b/src/maxdiffusion/configs/ltx2_video.yml +@@ -103,23 +103,3 @@ jit_initializers: True + enable_single_replica_ckpt_restoring: False + seed: 0 + audio_format: "s16" +- +-# LoRA parameters +-enable_lora: False +- +-# Distilled LoRA +-# lora_config: { +-# lora_model_name_or_path: ["Lightricks/LTX-2"], +-# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +-# adapter_name: ["distilled-lora-384"], +-# rank: [384] +-# } +- +-# Standard LoRA +-lora_config: { +- lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], +- weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], +- adapter_name: ["camera-control-dolly-in"], +- rank: [32] +-} +- +diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py +index 88260b5f..01dfae0a 100644 +--- a/src/maxdiffusion/generate_ltx2.py ++++ b/src/maxdiffusion/generate_ltx2.py +@@ -25,7 +25,6 @@ from google.cloud import storage + from google.api_core.exceptions import GoogleAPIError + import flax + from maxdiffusion.utils.export_utils import export_to_video_with_audio +-from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader + + + def upload_video_to_gcs(output_dir: str, video_path: str): +@@ -119,31 +118,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): + checkpoint_loader = LTX2Checkpointer(config=config) + pipeline, _, _ = checkpoint_loader.load_checkpoint() + +- # If LoRA is specified, inject layers and load weights. +- if ( +- getattr(config, "enable_lora", False) +- and hasattr(config, "lora_config") +- and config.lora_config +- and config.lora_config.get("lora_model_name_or_path") +- ): +- lora_loader = LTX2NNXLoraLoader() +- lora_config = config.lora_config +- paths = lora_config["lora_model_name_or_path"] +- weights = lora_config.get("weight_name", [None] * len(paths)) +- scales = lora_config.get("scale", [1.0] * len(paths)) +- ranks = lora_config.get("rank", [64] * len(paths)) +- +- for i in range(len(paths)): +- pipeline = lora_loader.load_lora_weights( +- pipeline, +- paths[i], +- transformer_weight_name=weights[i], +- rank=ranks[i], +- scale=scales[i], +- scan_layers=config.scan_layers, +- dtype=config.weights_dtype, +- ) +- + pipeline.enable_vae_slicing() + pipeline.enable_vae_tiling() + +diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py +index ca0371b7..96bdb0c8 100644 +--- a/src/maxdiffusion/loaders/lora_conversion_utils.py ++++ b/src/maxdiffusion/loaders/lora_conversion_utils.py +@@ -703,98 +703,3 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None +- +- +-def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): +- """ +- Translates LTX2 NNX path to Diffusers/LoRA keys. +- """ +- # --- 2. Map NNX Suffixes to LoRA Suffixes --- +- suffix_map = { +- # Self Attention (attn1) +- "attn1.to_q": "attn1.to_q", +- "attn1.to_k": "attn1.to_k", +- "attn1.to_v": "attn1.to_v", +- "attn1.to_out": "attn1.to_out.0", +- # Audio Self Attention (audio_attn1) +- "audio_attn1.to_q": "audio_attn1.to_q", +- "audio_attn1.to_k": "audio_attn1.to_k", +- "audio_attn1.to_v": "audio_attn1.to_v", +- "audio_attn1.to_out": "audio_attn1.to_out.0", +- # Audio Cross Attention (audio_attn2) +- "audio_attn2.to_q": "audio_attn2.to_q", +- "audio_attn2.to_k": "audio_attn2.to_k", +- "audio_attn2.to_v": "audio_attn2.to_v", +- "audio_attn2.to_out": "audio_attn2.to_out.0", +- # Cross Attention (attn2) +- "attn2.to_q": "attn2.to_q", +- "attn2.to_k": "attn2.to_k", +- "attn2.to_v": "attn2.to_v", +- "attn2.to_out": "attn2.to_out.0", +- # Audio to Video Cross Attention +- "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", +- "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", +- "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", +- "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", +- # Video to Audio Cross Attention +- "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", +- "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", +- "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", +- "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", +- # Feed Forward +- "ff.net_0": "ff.net.0.proj", +- "ff.net_2": "ff.net.2", +- # Audio Feed Forward +- "audio_ff.net_0": "audio_ff.net.0.proj", +- "audio_ff.net_2": "audio_ff.net.2", +- } +- +- # --- 3. Translation Logic --- +- global_map = { +- "proj_in": "diffusion_model.patchify_proj", +- "audio_proj_in": "diffusion_model.audio_patchify_proj", +- "proj_out": "diffusion_model.proj_out", +- "audio_proj_out": "diffusion_model.audio_proj_out", +- "time_embed.linear": "diffusion_model.adaln_single.linear", +- "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", +- "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", +- "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", +- "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", +- "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", +- # Nested conditioning layers +- "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", +- "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", +- "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", +- "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", +- "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", +- "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", +- "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", +- "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", +- "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", +- # Connectors +- "feature_extractor.linear": "text_embedding_projection.aggregate_embed", +- } +- +- if nnx_path_str in global_map: +- return global_map[nnx_path_str] +- +- if scan_layers: +- if nnx_path_str.startswith("transformer_blocks."): +- inner_suffix = nnx_path_str[len("transformer_blocks.") :] +- if inner_suffix in suffix_map: +- return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" +- else: +- m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) +- if m: +- idx, inner_suffix = m.group(1), m.group(2) +- if inner_suffix in suffix_map: +- return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" +- +- return None +diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +deleted file mode 100644 +index 247b3ba2..00000000 +--- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py ++++ /dev/null +@@ -1,75 +0,0 @@ +-# Copyright 2026 Google LLC +-# +-# Licensed under the Apache License, Version 2.0 (the "License"); +-# you may not use this file except in compliance with the License. +-# You may obtain a copy of the License at +-# +-# https://www.apache.org/licenses/LICENSE-2.0 +-# +-# Unless required by applicable law or agreed to in writing, software +-# distributed under the License is distributed on an "AS IS" BASIS, +-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-# See the License for the specific language governing permissions and +-# limitations under the License. +- +-"""NNX-based LoRA loader for LTX2 models.""" +- +-from flax import nnx +-from .lora_base import LoRABaseMixin +-from .lora_pipeline import StableDiffusionLoraLoaderMixin +-from ..models import lora_nnx +-from .. import max_logging +-from . import lora_conversion_utils +- +- +-class LTX2NNXLoraLoader(LoRABaseMixin): +- """ +- Handles loading LoRA weights into NNX-based LTX2 model. +- Assumes LTX2 pipeline contains 'transformer' +- attributes that are NNX Modules. +- """ +- +- def load_lora_weights( +- self, +- pipeline: nnx.Module, +- lora_model_path: str, +- transformer_weight_name: str, +- rank: int, +- scale: float = 1.0, +- scan_layers: bool = False, +- dtype: str = "float32", +- **kwargs, +- ): +- """ +- Merges LoRA weights into the pipeline from a checkpoint. +- """ +- lora_loader = StableDiffusionLoraLoaderMixin() +- +- merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora +- +- def translate_fn(nnx_path_str): +- return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) +- +- h_state_dict = None +- if hasattr(pipeline, "transformer") and transformer_weight_name: +- max_logging.log(f"Merging LoRA into transformer with rank={rank}") +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) +- # Filter state dict for transformer keys to avoid confusing warnings +- transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} +- merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) +- else: +- max_logging.log("transformer not found or no weight name provided for LoRA.") +- +- if hasattr(pipeline, "connectors"): +- max_logging.log(f"Merging LoRA into connectors with rank={rank}") +- if h_state_dict is None and transformer_weight_name: +- h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) +- +- if h_state_dict is not None: +- # Filter state dict for connector keys to avoid confusing warnings +- connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} +- merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) +- else: +- max_logging.log("Could not load LoRA state dict for connectors.") +- +- return pipeline +diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py +index 8500af61..7441a203 100644 +--- a/src/maxdiffusion/models/ltx2/attention_ltx2.py ++++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py +@@ -195,7 +195,7 @@ class LTX2RotaryPosEmbed(nnx.Module): + # pixel_coords[:, 0, ...] selects Frame dimension. + # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W) + frame_coords = pixel_coords[:, 0, ...] +- frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0) ++ frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0) + pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps) + + return pixel_coords +@@ -212,12 +212,12 @@ class LTX2RotaryPosEmbed(nnx.Module): + # 2. Start timestamps + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor +- grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0) ++ grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. End timestamps + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor +- grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0) ++ grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + # Stack [num_patches, 2] diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 1747be8c8..2b7167553 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -68,6 +68,7 @@ flash_block_sizes: { block_kv_dkv_compute: 2048, use_fused_bwd_kernel: True, } +flash_min_seq_length: 4096 dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 @@ -102,3 +103,23 @@ jit_initializers: True enable_single_replica_ckpt_restoring: False seed: 0 audio_format: "s16" + +# LoRA parameters +enable_lora: False + +# Distilled LoRA +# lora_config: { +# lora_model_name_or_path: ["Lightricks/LTX-2"], +# weight_name: ["ltx-2-19b-distilled-lora-384.safetensors"], +# adapter_name: ["distilled-lora-384"], +# rank: [384] +# } + +# Standard LoRA +lora_config: { + lora_model_name_or_path: ["Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In"], + weight_name: ["ltx-2-19b-lora-camera-control-dolly-in.safetensors"], + adapter_name: ["camera-control-dolly-in"], + rank: [32] +} + diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index 01dfae0a7..88260b5f1 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -25,6 +25,7 @@ from google.api_core.exceptions import GoogleAPIError import flax from maxdiffusion.utils.export_utils import export_to_video_with_audio +from maxdiffusion.loaders.ltx2_lora_nnx_loader import LTX2NNXLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -118,6 +119,31 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): checkpoint_loader = LTX2Checkpointer(config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() + # If LoRA is specified, inject layers and load weights. + if ( + getattr(config, "enable_lora", False) + and hasattr(config, "lora_config") + and config.lora_config + and config.lora_config.get("lora_model_name_or_path") + ): + lora_loader = LTX2NNXLoraLoader() + lora_config = config.lora_config + paths = lora_config["lora_model_name_or_path"] + weights = lora_config.get("weight_name", [None] * len(paths)) + scales = lora_config.get("scale", [1.0] * len(paths)) + ranks = lora_config.get("rank", [64] * len(paths)) + + for i in range(len(paths)): + pipeline = lora_loader.load_lora_weights( + pipeline, + paths[i], + transformer_weight_name=weights[i], + rank=ranks[i], + scale=scales[i], + scan_layers=config.scan_layers, + dtype=config.weights_dtype, + ) + pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 96bdb0c84..ca0371b76 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -703,3 +703,98 @@ def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" return None + + +def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates LTX2 NNX path to Diffusers/LoRA keys. + """ + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.to_q": "attn1.to_q", + "attn1.to_k": "attn1.to_k", + "attn1.to_v": "attn1.to_v", + "attn1.to_out": "attn1.to_out.0", + # Audio Self Attention (audio_attn1) + "audio_attn1.to_q": "audio_attn1.to_q", + "audio_attn1.to_k": "audio_attn1.to_k", + "audio_attn1.to_v": "audio_attn1.to_v", + "audio_attn1.to_out": "audio_attn1.to_out.0", + # Audio Cross Attention (audio_attn2) + "audio_attn2.to_q": "audio_attn2.to_q", + "audio_attn2.to_k": "audio_attn2.to_k", + "audio_attn2.to_v": "audio_attn2.to_v", + "audio_attn2.to_out": "audio_attn2.to_out.0", + # Cross Attention (attn2) + "attn2.to_q": "attn2.to_q", + "attn2.to_k": "attn2.to_k", + "attn2.to_v": "attn2.to_v", + "attn2.to_out": "attn2.to_out.0", + # Audio to Video Cross Attention + "audio_to_video_attn.to_q": "audio_to_video_attn.to_q", + "audio_to_video_attn.to_k": "audio_to_video_attn.to_k", + "audio_to_video_attn.to_v": "audio_to_video_attn.to_v", + "audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0", + # Video to Audio Cross Attention + "video_to_audio_attn.to_q": "video_to_audio_attn.to_q", + "video_to_audio_attn.to_k": "video_to_audio_attn.to_k", + "video_to_audio_attn.to_v": "video_to_audio_attn.to_v", + "video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0", + # Feed Forward + "ff.net_0": "ff.net.0.proj", + "ff.net_2": "ff.net.2", + # Audio Feed Forward + "audio_ff.net_0": "audio_ff.net.0.proj", + "audio_ff.net_2": "audio_ff.net.2", + } + + # --- 3. Translation Logic --- + global_map = { + "proj_in": "diffusion_model.patchify_proj", + "audio_proj_in": "diffusion_model.audio_patchify_proj", + "proj_out": "diffusion_model.proj_out", + "audio_proj_out": "diffusion_model.audio_proj_out", + "time_embed.linear": "diffusion_model.adaln_single.linear", + "audio_time_embed.linear": "diffusion_model.audio_adaln_single.linear", + "av_cross_attn_video_a2v_gate.linear": "diffusion_model.av_ca_a2v_gate_adaln_single.linear", + "av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear", + "av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear", + "av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear", + # Nested conditioning layers + "time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1", + "time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2", + "audio_time_embed.emb.timestep_embedder.linear_1": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1", + "audio_time_embed.emb.timestep_embedder.linear_2": "diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_scale_shift.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_video_a2v_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_1": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1", + "av_cross_attn_audio_v2a_gate.emb.timestep_embedder.linear_2": "diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2", + "caption_projection.linear_1": "diffusion_model.caption_projection.linear_1", + "caption_projection.linear_2": "diffusion_model.caption_projection.linear_2", + "audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1", + "audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2", + # Connectors + "feature_extractor.linear": "text_embedding_projection.aggregate_embed", + } + + if nnx_path_str in global_map: + return global_map[nnx_path_str] + + if scan_layers: + if nnx_path_str.startswith("transformer_blocks."): + inner_suffix = nnx_path_str[len("transformer_blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{{}}.{suffix_map[inner_suffix]}" + else: + m = re.match(r"^transformer_blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py new file mode 100644 index 000000000..247b3ba2e --- /dev/null +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for LTX2 models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + + +class LTX2NNXLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based LTX2 model. + Assumes LTX2 pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + dtype: str = "float32", + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + h_state_dict = None + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + # Filter state dict for transformer keys to avoid confusing warnings + transformer_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("diffusion_model")} + merge_fn(pipeline.transformer, transformer_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + if hasattr(pipeline, "connectors"): + max_logging.log(f"Merging LoRA into connectors with rank={rank}") + if h_state_dict is None and transformer_weight_name: + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + + if h_state_dict is not None: + # Filter state dict for connector keys to avoid confusing warnings + connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")} + merge_fn(pipeline.connectors, connector_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("Could not load LoRA state dict for connectors.") + + return pipeline diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 2ccce7488..7441a2038 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple from flax import nnx +import jax import jax.numpy as jnp from ... import common_types from ..attention_flax import NNXAttentionOp @@ -347,6 +348,7 @@ def __init__( attention_kernel: str = "flash", rope_type: str = "interleaved", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, ): self.heads = heads self.rope_type = rope_type @@ -434,6 +436,7 @@ def __init__( axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV), axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV), flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) def __call__( @@ -447,46 +450,49 @@ def __call__( # Determine context (Self or Cross) context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # 1. Project - query = self.to_q(hidden_states) - key = self.to_k(context) - value = self.to_v(context) + # 1. Project and Norm + with jax.named_scope("QKV Projection"): + query = self.to_q(hidden_states) + key = self.to_k(context) + value = self.to_v(context) - # 2. Norm (Full Inner Dimension) - query = self.norm_q(query) - key = self.norm_k(key) + with jax.named_scope("QKV Norm"): + query = self.norm_q(query) + key = self.norm_k(key) # 3. Apply RoPE to tensors of shape [B, S, InnerDim] # Frequencies are shape [B, S, InnerDim] # 3. Apply RoPE - if rotary_emb is not None: - if hasattr(self, "rope_type") and self.rope_type == "split": - # Split RoPE: passing full freqs [B, H, S, D//2] - # apply_split_rotary_emb handles reshaping query/key - - query = apply_split_rotary_emb(query, rotary_emb) - - if k_rotary_emb is not None: - key = apply_split_rotary_emb(key, k_rotary_emb) - elif encoder_hidden_states is None: - key = apply_split_rotary_emb(key, rotary_emb) - - else: - # Interleaved (Default) - query = apply_rotary_emb(query, rotary_emb) - if k_rotary_emb is not None: - key = apply_rotary_emb(key, k_rotary_emb) - elif encoder_hidden_states is None: - key = apply_rotary_emb(key, rotary_emb) - - # 4. Attention - # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel - attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) - - # 7. Output Projection - hidden_states = self.to_out(attn_output) - - if self.dropout_layer is not None: - hidden_states = self.dropout_layer(hidden_states) + with jax.named_scope("Apply RoPE"): + if rotary_emb is not None: + if hasattr(self, "rope_type") and self.rope_type == "split": + # Split RoPE: passing full freqs [B, H, S, D//2] + # apply_split_rotary_emb handles reshaping query/key + + query = apply_split_rotary_emb(query, rotary_emb) + + if k_rotary_emb is not None: + key = apply_split_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_split_rotary_emb(key, rotary_emb) + + else: + # Interleaved (Default) + query = apply_rotary_emb(query, rotary_emb) + if k_rotary_emb is not None: + key = apply_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_rotary_emb(key, rotary_emb) + + with jax.named_scope("Attention and Output Project"): + # 4. Attention + # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel + attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) + + # 7. Output Projection + hidden_states = self.to_out(attn_output) + + if self.dropout_layer is not None: + hidden_states = self.dropout_layer(hidden_states) return hidden_states diff --git a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py index 2da22b883..aff898f21 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py @@ -108,11 +108,12 @@ def __call__( Returns: (video_embeds, audio_embeds, new_attention_mask) """ - # 1. Shared Feature Extraction - features = self.feature_extractor(hidden_states, attention_mask) + with jax.named_scope("Text Encoder Forward"): + # 1. Shared Feature Extraction + features = self.feature_extractor(hidden_states, attention_mask) - # 2. Parallel Connection - video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask) - audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask) + # 2. Parallel Connection + video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask) + audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask) - return video_embeds, audio_embeds, new_attention_mask + return video_embeds, audio_embeds, new_attention_mask diff --git a/src/maxdiffusion/models/ltx2/transformer_ltx2.py b/src/maxdiffusion/models/ltx2/transformer_ltx2.py index 767e58235..7b6d714cc 100644 --- a/src/maxdiffusion/models/ltx2/transformer_ltx2.py +++ b/src/maxdiffusion/models/ltx2/transformer_ltx2.py @@ -107,6 +107,7 @@ def __init__( names_which_can_be_offloaded: list = [], attention_kernel: str = "flash", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, ): self.dim = dim self.norm_eps = norm_eps @@ -137,6 +138,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) self.audio_norm1 = nnx.RMSNorm( @@ -162,6 +164,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 2. Prompt Cross-Attention @@ -215,6 +218,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -239,9 +243,10 @@ def __init__( eps=norm_eps, dtype=dtype, mesh=mesh, - attention_kernel=self.attention_kernel, + attention_kernel="flash", rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=0, ) self.video_to_audio_norm = nnx.RMSNorm( @@ -268,6 +273,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 4. Feed Forward @@ -350,7 +356,8 @@ def __call__( axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) - audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names) + axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed")) + audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio) if encoder_hidden_states is not None: encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) @@ -378,11 +385,12 @@ def __call__( norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - rotary_emb=video_rotary_emb, - ) + with jax.named_scope("Video Self-Attention"): + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + rotary_emb=video_rotary_emb, + ) hidden_states = hidden_states + attn_hidden_states * gate_msa # Calculate Audio AdaLN values @@ -402,11 +410,12 @@ def __call__( norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - rotary_emb=audio_rotary_emb, - ) + with jax.named_scope("Audio Self-Attention"): + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + rotary_emb=audio_rotary_emb, + ) audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa # 2. Video and Audio Cross-Attention with the text embeddings @@ -473,26 +482,28 @@ def __call__( mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - rotary_emb=ca_video_rotary_emb, - k_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + with jax.named_scope("Audio-to-Video Cross-Attention"): + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + rotary_emb=ca_video_rotary_emb, + k_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states # Video-to-Audio Cross Attention: Q: Audio; K,V: Video mod_norm_hidden_states_v2a = norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift mod_norm_audio_hidden_states_v2a = norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states_v2a, - encoder_hidden_states=mod_norm_hidden_states_v2a, - rotary_emb=ca_audio_rotary_emb, - k_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + with jax.named_scope("Video-to-Audio Cross-Attention"): + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states_v2a, + encoder_hidden_states=mod_norm_hidden_states_v2a, + rotary_emb=ca_audio_rotary_emb, + k_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward @@ -562,6 +573,7 @@ def __init__( attention_kernel: str = "flash", qk_norm: str = "rms_norm_across_heads", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, **kwargs, ): self.in_channels = in_channels @@ -608,6 +620,7 @@ def __init__( self.names_which_can_be_offloaded = names_which_can_be_offloaded self.scan_layers = scan_layers self.attention_kernel = attention_kernel + self.flash_min_seq_length = flash_min_seq_length _out_channels = self.out_channels or self.in_channels _audio_out_channels = self.audio_out_channels or self.audio_in_channels @@ -801,6 +814,7 @@ def init_block(rngs): names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=self.flash_min_seq_length, ) if self.scan_layers: @@ -833,6 +847,7 @@ def init_block(rngs): names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=self.flash_min_seq_length, ) blocks.append(block) self.transformer_blocks = nnx.List(blocks) @@ -900,113 +915,79 @@ def __call__( batch_size = hidden_states.shape[0] # 1. Prepare RoPE positional embeddings - if video_coords is None: - video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps) - if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames) + with jax.named_scope("RoPE Preparation"): + if video_coords is None: + video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames) - video_rotary_emb = self.rope(video_coords) - audio_rotary_emb = self.audio_rope(audio_coords) + video_rotary_emb = self.rope(video_coords) + audio_rotary_emb = self.audio_rope(audio_coords) - video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :]) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :]) + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :]) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :]) # 2. Patchify input projections - hidden_states = self.proj_in(hidden_states) - audio_hidden_states = self.audio_proj_in(audio_hidden_states) + with jax.named_scope("Input Projection"): + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) # 3. Prepare timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier + with jax.named_scope("Timestep and Caption Projection"): + timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier - temb, embedded_timestep = self.time_embed( - timestep.flatten(), - hidden_dtype=hidden_states.dtype, - ) - temb = temb.reshape(batch_size, -1, temb.shape[-1]) - embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1]) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.reshape(batch_size, -1, temb.shape[-1]) + embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1]) - temb_audio, audio_embedded_timestep = self.audio_time_embed( - audio_timestep.flatten(), - hidden_dtype=audio_hidden_states.dtype, - ) - temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1]) - audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1]) + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1]) + audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1]) - video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), - hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, - hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape( - batch_size, -1, video_cross_attn_scale_shift.shape[-1] - ) - video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) - audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), - hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, - hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape( - batch_size, -1, audio_cross_attn_scale_shift.shape[-1] - ) - audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) - # 4. Prepare prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1]) + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1]) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) # 5. Run transformer blocks def scan_fn(carry, block): hidden_states, audio_hidden_states, rngs_carry = carry - hidden_states_out, audio_hidden_states_out = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - ) - return ( - hidden_states_out.astype(hidden_states.dtype), - audio_hidden_states_out.astype(audio_hidden_states.dtype), - rngs_carry, - ), None - - if self.scan_layers: - rematted_scan_fn = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers - ) - carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block - (hidden_states, audio_hidden_states, _), _ = nnx.scan( - rematted_scan_fn, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - transform_metadata={nnx.PARTITION_NAME: "layers"}, - )(carry, self.transformer_blocks) - else: - for block in self.transformer_blocks: - hidden_states, audio_hidden_states = block( + with jax.named_scope("Transformer Layer"): + hidden_states_out, audio_hidden_states_out = block( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1024,25 +1005,65 @@ def scan_fn(carry, block): encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, ) + return ( + hidden_states_out.astype(hidden_states.dtype), + audio_hidden_states_out.astype(audio_hidden_states.dtype), + rngs_carry, + ), None + + with jax.named_scope("Transformer Blocks"): + if self.scan_layers: + rematted_scan_fn = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block + (hidden_states, audio_hidden_states, _), _ = nnx.scan( + rematted_scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(carry, self.transformer_blocks) + else: + for block in self.transformer_blocks: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) # 6. Output layers - scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2) - shift = scale_shift_values[:, :, 0, :] - scale = scale_shift_values[:, :, 1, :] + with jax.named_scope("Output Projection & Norm"): + scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2) + shift = scale_shift_values[:, :, 0, :] + scale = scale_shift_values[:, :, 1, :] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - output = self.proj_out(hidden_states) + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) - audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims( - audio_embedded_timestep, axis=2 - ) - audio_shift = audio_scale_shift_values[:, :, 0, :] - audio_scale = audio_scale_shift_values[:, :, 1, :] + audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims( + audio_embedded_timestep, axis=2 + ) + audio_shift = audio_scale_shift_values[:, :, 0, :] + audio_scale = audio_scale_shift_values[:, :, 1, :] - audio_hidden_states = self.audio_norm_out(audio_hidden_states) - audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift - audio_output = self.audio_proj_out(audio_hidden_states) + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) if not return_dict: return (output, audio_output) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 0db6c398a..808c2ae03 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -125,6 +125,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["attention_kernel"] = config.attention ltx2_config["precision"] = get_precision(config) ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) + ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096) ltx2_config["remat_policy"] = config.remat_policy ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded @@ -1231,15 +1232,31 @@ def run_connectors(graphdef, state, hidden_states, attention_mask): connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) ) - for i, t in enumerate(timesteps): + timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) + for i, t_val in enumerate(timesteps): + t = timesteps_jax[i] + + # Isolate input sharding to scan_layers=False to avoid affecting the standard path + latents_jax_sharded = latents_jax + audio_latents_jax_sharded = audio_latents_jax + video_embeds_sharded = video_embeds + audio_embeds_sharded = audio_embeds + + if not self.transformer.scan_layers: + activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names) + audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names) + video_embeds_sharded = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names) + audio_embeds_sharded = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names) + noise_pred, noise_pred_audio = transformer_forward_pass( graphdef, state, - latents_jax, - audio_latents_jax, + latents_jax_sharded, + audio_latents_jax_sharded, t, - video_embeds, - audio_embeds, + video_embeds_sharded, + audio_embeds_sharded, new_attention_mask, new_attention_mask, guidance_scale > 1.0, @@ -1317,6 +1334,21 @@ def run_connectors(graphdef, state, hidden_states, attention_mask): if output_type == "latent": return LTX2PipelineOutput(frames=latents, audio=audio_latents) + # Force latents and VAE weights to be fully replicated using with_sharding_constraint, this speeds up single video latency ~3x + try: + mesh = latents.sharding.mesh + replicated_sharding = NamedSharding(mesh, P()) + latents = jax.lax.with_sharding_constraint(latents, replicated_sharding) + + # Replicate VAE weights + graphdef, state = nnx.split(self.vae) + state = jax.tree_util.tree_map( + lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state + ) + self.vae = nnx.merge(graphdef, state) + except Exception as e: + max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") + if getattr(self.vae.config, "timestep_conditioning", False): noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py index 9b69a9b6d..b9c980292 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -15,6 +15,7 @@ # DISCLAIMER: This is a JAX/Flax conversion of a PyTorch implementation. # The original PyTorch code was provided by the user. +from functools import partial from typing import Optional, Tuple, Union import flax @@ -243,6 +244,7 @@ def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarra diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) return jnp.argmin(diffs, axis=1) + @partial(jax.jit, static_argnums=(0, 5, 6)) def step( self, state: FlowMatchSchedulerState, diff --git a/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py index 9acc147e6..d29fee9a4 100644 --- a/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py @@ -378,6 +378,7 @@ def test_attention_mask_parity(self): jax_model.attention_op.attention_kernel = "flash" jax_model.attention_op.mesh = mesh + jax_model.attention_op.flash_min_seq_length = 0 mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32) pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]