diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 1747be8c..40cebf5b 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -2,6 +2,8 @@ hardware: 'tpu' skip_jax_distributed_system: False attention: 'flash' +a2v_attention_kernel: 'flash' +v2a_attention_kernel: 'dot_product' attention_sharding_uniform: True precision: 'bf16' scan_layers: True @@ -68,6 +70,7 @@ flash_block_sizes: { block_kv_dkv_compute: 2048, use_fused_bwd_kernel: True, } +flash_min_seq_length: 4096 dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 2ccce748..7441a203 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple from flax import nnx +import jax import jax.numpy as jnp from ... import common_types from ..attention_flax import NNXAttentionOp @@ -347,6 +348,7 @@ def __init__( attention_kernel: str = "flash", rope_type: str = "interleaved", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, ): self.heads = heads self.rope_type = rope_type @@ -434,6 +436,7 @@ def __init__( axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV), axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV), flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) def __call__( @@ -447,46 +450,49 @@ def __call__( # Determine context (Self or Cross) context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # 1. Project - query = self.to_q(hidden_states) - key = self.to_k(context) - value = self.to_v(context) + # 1. Project and Norm + with jax.named_scope("QKV Projection"): + query = self.to_q(hidden_states) + key = self.to_k(context) + value = self.to_v(context) - # 2. Norm (Full Inner Dimension) - query = self.norm_q(query) - key = self.norm_k(key) + with jax.named_scope("QKV Norm"): + query = self.norm_q(query) + key = self.norm_k(key) # 3. Apply RoPE to tensors of shape [B, S, InnerDim] # Frequencies are shape [B, S, InnerDim] # 3. Apply RoPE - if rotary_emb is not None: - if hasattr(self, "rope_type") and self.rope_type == "split": - # Split RoPE: passing full freqs [B, H, S, D//2] - # apply_split_rotary_emb handles reshaping query/key - - query = apply_split_rotary_emb(query, rotary_emb) - - if k_rotary_emb is not None: - key = apply_split_rotary_emb(key, k_rotary_emb) - elif encoder_hidden_states is None: - key = apply_split_rotary_emb(key, rotary_emb) - - else: - # Interleaved (Default) - query = apply_rotary_emb(query, rotary_emb) - if k_rotary_emb is not None: - key = apply_rotary_emb(key, k_rotary_emb) - elif encoder_hidden_states is None: - key = apply_rotary_emb(key, rotary_emb) - - # 4. Attention - # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel - attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) - - # 7. Output Projection - hidden_states = self.to_out(attn_output) - - if self.dropout_layer is not None: - hidden_states = self.dropout_layer(hidden_states) + with jax.named_scope("Apply RoPE"): + if rotary_emb is not None: + if hasattr(self, "rope_type") and self.rope_type == "split": + # Split RoPE: passing full freqs [B, H, S, D//2] + # apply_split_rotary_emb handles reshaping query/key + + query = apply_split_rotary_emb(query, rotary_emb) + + if k_rotary_emb is not None: + key = apply_split_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_split_rotary_emb(key, rotary_emb) + + else: + # Interleaved (Default) + query = apply_rotary_emb(query, rotary_emb) + if k_rotary_emb is not None: + key = apply_rotary_emb(key, k_rotary_emb) + elif encoder_hidden_states is None: + key = apply_rotary_emb(key, rotary_emb) + + with jax.named_scope("Attention and Output Project"): + # 4. Attention + # NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel + attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask) + + # 7. Output Projection + hidden_states = self.to_out(attn_output) + + if self.dropout_layer is not None: + hidden_states = self.dropout_layer(hidden_states) return hidden_states diff --git a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py index 2da22b88..aff898f2 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py @@ -108,11 +108,12 @@ def __call__( Returns: (video_embeds, audio_embeds, new_attention_mask) """ - # 1. Shared Feature Extraction - features = self.feature_extractor(hidden_states, attention_mask) + with jax.named_scope("Text Encoder Forward"): + # 1. Shared Feature Extraction + features = self.feature_extractor(hidden_states, attention_mask) - # 2. Parallel Connection - video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask) - audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask) + # 2. Parallel Connection + video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask) + audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask) - return video_embeds, audio_embeds, new_attention_mask + return video_embeds, audio_embeds, new_attention_mask diff --git a/src/maxdiffusion/models/ltx2/transformer_ltx2.py b/src/maxdiffusion/models/ltx2/transformer_ltx2.py index 767e5823..a047e475 100644 --- a/src/maxdiffusion/models/ltx2/transformer_ltx2.py +++ b/src/maxdiffusion/models/ltx2/transformer_ltx2.py @@ -106,7 +106,10 @@ def __init__( names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], attention_kernel: str = "flash", + a2v_attention_kernel: str = "flash", + v2a_attention_kernel: str = "dot_product", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, ): self.dim = dim self.norm_eps = norm_eps @@ -137,6 +140,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) self.audio_norm1 = nnx.RMSNorm( @@ -162,6 +166,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 2. Prompt Cross-Attention @@ -215,6 +220,7 @@ def __init__( attention_kernel=self.attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention @@ -239,9 +245,10 @@ def __init__( eps=norm_eps, dtype=dtype, mesh=mesh, - attention_kernel=self.attention_kernel, + attention_kernel=a2v_attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=0, ) self.video_to_audio_norm = nnx.RMSNorm( @@ -265,9 +272,10 @@ def __init__( eps=norm_eps, dtype=dtype, mesh=mesh, - attention_kernel=self.attention_kernel, + attention_kernel=v2a_attention_kernel, rope_type=rope_type, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=flash_min_seq_length, ) # 4. Feed Forward @@ -350,7 +358,8 @@ def __call__( axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) - audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names) + axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed")) + audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names_audio) if encoder_hidden_states is not None: encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) @@ -378,11 +387,12 @@ def __call__( norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - attn_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=None, - rotary_emb=video_rotary_emb, - ) + with jax.named_scope("Video Self-Attention"): + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + rotary_emb=video_rotary_emb, + ) hidden_states = hidden_states + attn_hidden_states * gate_msa # Calculate Audio AdaLN values @@ -402,11 +412,12 @@ def __call__( norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa - attn_audio_hidden_states = self.audio_attn1( - hidden_states=norm_audio_hidden_states, - encoder_hidden_states=None, - rotary_emb=audio_rotary_emb, - ) + with jax.named_scope("Audio Self-Attention"): + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + rotary_emb=audio_rotary_emb, + ) audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa # 2. Video and Audio Cross-Attention with the text embeddings @@ -473,26 +484,28 @@ def __call__( mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift - a2v_attn_hidden_states = self.audio_to_video_attn( - mod_norm_hidden_states, - encoder_hidden_states=mod_norm_audio_hidden_states, - rotary_emb=ca_video_rotary_emb, - k_rotary_emb=ca_audio_rotary_emb, - attention_mask=a2v_cross_attention_mask, - ) + with jax.named_scope("Audio-to-Video Cross-Attention"): + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + rotary_emb=ca_video_rotary_emb, + k_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states # Video-to-Audio Cross Attention: Q: Audio; K,V: Video mod_norm_hidden_states_v2a = norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift mod_norm_audio_hidden_states_v2a = norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift - v2a_attn_hidden_states = self.video_to_audio_attn( - mod_norm_audio_hidden_states_v2a, - encoder_hidden_states=mod_norm_hidden_states_v2a, - rotary_emb=ca_audio_rotary_emb, - k_rotary_emb=ca_video_rotary_emb, - attention_mask=v2a_cross_attention_mask, - ) + with jax.named_scope("Video-to-Audio Cross-Attention"): + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states_v2a, + encoder_hidden_states=mod_norm_hidden_states_v2a, + rotary_emb=ca_audio_rotary_emb, + k_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states # 4. Feedforward @@ -560,8 +573,11 @@ def __init__( names_which_can_be_offloaded: list = [], scan_layers: bool = True, attention_kernel: str = "flash", + a2v_attention_kernel: str = "flash", + v2a_attention_kernel: str = "dot_product", qk_norm: str = "rms_norm_across_heads", flash_block_sizes: BlockSizes = None, + flash_min_seq_length: int = 4096, **kwargs, ): self.in_channels = in_channels @@ -608,6 +624,9 @@ def __init__( self.names_which_can_be_offloaded = names_which_can_be_offloaded self.scan_layers = scan_layers self.attention_kernel = attention_kernel + self.a2v_attention_kernel = a2v_attention_kernel + self.v2a_attention_kernel = v2a_attention_kernel + self.flash_min_seq_length = flash_min_seq_length _out_channels = self.out_channels or self.in_channels _audio_out_channels = self.audio_out_channels or self.audio_in_channels @@ -800,7 +819,10 @@ def init_block(rngs): names_which_can_be_saved=self.names_which_can_be_saved, names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, + a2v_attention_kernel=self.a2v_attention_kernel, + v2a_attention_kernel=self.v2a_attention_kernel, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=self.flash_min_seq_length, ) if self.scan_layers: @@ -832,7 +854,10 @@ def init_block(rngs): names_which_can_be_saved=self.names_which_can_be_saved, names_which_can_be_offloaded=self.names_which_can_be_offloaded, attention_kernel=self.attention_kernel, + a2v_attention_kernel=self.a2v_attention_kernel, + v2a_attention_kernel=self.v2a_attention_kernel, flash_block_sizes=flash_block_sizes, + flash_min_seq_length=self.flash_min_seq_length, ) blocks.append(block) self.transformer_blocks = nnx.List(blocks) @@ -900,113 +925,79 @@ def __call__( batch_size = hidden_states.shape[0] # 1. Prepare RoPE positional embeddings - if video_coords is None: - video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps) - if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames) + with jax.named_scope("RoPE Preparation"): + if video_coords is None: + video_coords = self.rope.prepare_video_coords(batch_size, num_frames, height, width, fps=fps) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords(batch_size, audio_num_frames) - video_rotary_emb = self.rope(video_coords) - audio_rotary_emb = self.audio_rope(audio_coords) + video_rotary_emb = self.rope(video_coords) + audio_rotary_emb = self.audio_rope(audio_coords) - video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :]) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :]) + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :]) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(audio_coords[:, 0:1, :]) # 2. Patchify input projections - hidden_states = self.proj_in(hidden_states) - audio_hidden_states = self.audio_proj_in(audio_hidden_states) + with jax.named_scope("Input Projection"): + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) # 3. Prepare timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier + with jax.named_scope("Timestep and Caption Projection"): + timestep_cross_attn_gate_scale_factor = self.cross_attn_timestep_scale_multiplier / self.timestep_scale_multiplier - temb, embedded_timestep = self.time_embed( - timestep.flatten(), - hidden_dtype=hidden_states.dtype, - ) - temb = temb.reshape(batch_size, -1, temb.shape[-1]) - embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1]) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.reshape(batch_size, -1, temb.shape[-1]) + embedded_timestep = embedded_timestep.reshape(batch_size, -1, embedded_timestep.shape[-1]) - temb_audio, audio_embedded_timestep = self.audio_time_embed( - audio_timestep.flatten(), - hidden_dtype=audio_hidden_states.dtype, - ) - temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1]) - audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1]) + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1]) + audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1]) - video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), - hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, - hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape( - batch_size, -1, video_cross_attn_scale_shift.shape[-1] - ) - video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) - audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), - hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, - hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape( - batch_size, -1, audio_cross_attn_scale_shift.shape[-1] - ) - audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) - # 4. Prepare prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1]) + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1]) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1]) # 5. Run transformer blocks def scan_fn(carry, block): hidden_states, audio_hidden_states, rngs_carry = carry - hidden_states_out, audio_hidden_states_out = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - ) - return ( - hidden_states_out.astype(hidden_states.dtype), - audio_hidden_states_out.astype(audio_hidden_states.dtype), - rngs_carry, - ), None - - if self.scan_layers: - rematted_scan_fn = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers - ) - carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block - (hidden_states, audio_hidden_states, _), _ = nnx.scan( - rematted_scan_fn, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - transform_metadata={nnx.PARTITION_NAME: "layers"}, - )(carry, self.transformer_blocks) - else: - for block in self.transformer_blocks: - hidden_states, audio_hidden_states = block( + with jax.named_scope("Transformer Layer"): + hidden_states_out, audio_hidden_states_out = block( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1024,25 +1015,65 @@ def scan_fn(carry, block): encoder_attention_mask=encoder_attention_mask, audio_encoder_attention_mask=audio_encoder_attention_mask, ) + return ( + hidden_states_out.astype(hidden_states.dtype), + audio_hidden_states_out.astype(audio_hidden_states.dtype), + rngs_carry, + ), None + + with jax.named_scope("Transformer Blocks"): + if self.scan_layers: + rematted_scan_fn = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + carry = (hidden_states, audio_hidden_states, nnx.Rngs(0)) # Placeholder RNGs for now if not used in block + (hidden_states, audio_hidden_states, _), _ = nnx.scan( + rematted_scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + transform_metadata={nnx.PARTITION_NAME: "layers"}, + )(carry, self.transformer_blocks) + else: + for block in self.transformer_blocks: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) # 6. Output layers - scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2) - shift = scale_shift_values[:, :, 0, :] - scale = scale_shift_values[:, :, 1, :] + with jax.named_scope("Output Projection & Norm"): + scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2) + shift = scale_shift_values[:, :, 0, :] + scale = scale_shift_values[:, :, 1, :] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - output = self.proj_out(hidden_states) + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) - audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims( - audio_embedded_timestep, axis=2 - ) - audio_shift = audio_scale_shift_values[:, :, 0, :] - audio_scale = audio_scale_shift_values[:, :, 1, :] + audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims( + audio_embedded_timestep, axis=2 + ) + audio_shift = audio_scale_shift_values[:, :, 0, :] + audio_scale = audio_scale_shift_values[:, :, 1, :] - audio_hidden_states = self.audio_norm_out(audio_hidden_states) - audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift - audio_output = self.audio_proj_out(audio_hidden_states) + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) if not return_dict: return (output, audio_output) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 0db6c398..5581f45b 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -125,6 +125,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["attention_kernel"] = config.attention ltx2_config["precision"] = get_precision(config) ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) + ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096) ltx2_config["remat_policy"] = config.remat_policy ltx2_config["names_which_can_be_saved"] = config.names_which_can_be_saved ltx2_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded @@ -1231,15 +1232,36 @@ def run_connectors(graphdef, state, hidden_states, attention_mask): connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) ) - for i, t in enumerate(timesteps): + video_embeds_sharded = video_embeds + audio_embeds_sharded = audio_embeds + + if not self.transformer.scan_layers: + activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + spec = NamedSharding(self.mesh, P(*activation_axes)) + video_embeds_sharded = jax.device_put(video_embeds, spec) + audio_embeds_sharded = jax.device_put(audio_embeds, spec) + + timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) + for i in range(len(timesteps_jax)): + t = timesteps_jax[i] + + # Isolate input sharding to scan_layers=False to avoid affecting the standard path + latents_jax_sharded = latents_jax + audio_latents_jax_sharded = audio_latents_jax + + if not self.transformer.scan_layers: + activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names) + audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names) + noise_pred, noise_pred_audio = transformer_forward_pass( graphdef, state, - latents_jax, - audio_latents_jax, + latents_jax_sharded, + audio_latents_jax_sharded, t, - video_embeds, - audio_embeds, + video_embeds_sharded, + audio_embeds_sharded, new_attention_mask, new_attention_mask, guidance_scale > 1.0, @@ -1317,6 +1339,21 @@ def run_connectors(graphdef, state, hidden_states, attention_mask): if output_type == "latent": return LTX2PipelineOutput(frames=latents, audio=audio_latents) + # Force latents and VAE weights to be fully replicated using with_sharding_constraint, this speeds up single video latency ~3x + try: + mesh = latents.sharding.mesh + replicated_sharding = NamedSharding(mesh, P()) + latents = jax.lax.with_sharding_constraint(latents, replicated_sharding) + + # Replicate VAE weights + graphdef, state = nnx.split(self.vae) + state = jax.tree_util.tree_map( + lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state + ) + self.vae = nnx.merge(graphdef, state) + except Exception as e: + max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") + if getattr(self.vae.config, "timestep_conditioning", False): noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) diff --git a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py index 9b69a9b6..754a6826 100644 --- a/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_flow_match_flax.py @@ -15,6 +15,7 @@ # DISCLAIMER: This is a JAX/Flax conversion of a PyTorch implementation. # The original PyTorch code was provided by the user. +from functools import partial from typing import Optional, Tuple, Union import flax @@ -243,6 +244,8 @@ def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarra diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None]) return jnp.argmin(diffs, axis=1) + # Arguments at indices 0 (self), 5 (to_final), and 6 (return_dict) are kept static for JIT compilation. + @partial(jax.jit, static_argnums=(0, 5, 6)) def step( self, state: FlowMatchSchedulerState, diff --git a/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py index 9acc147e..d29fee9a 100644 --- a/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_attention_ltx2.py @@ -378,6 +378,7 @@ def test_attention_mask_parity(self): jax_model.attention_op.attention_kernel = "flash" jax_model.attention_op.mesh = mesh + jax_model.attention_op.flash_min_seq_length = 0 mask_pattern_np = np.random.randint(0, 2, (self.B, S_flash)).astype(np.float32) pt_mask_additive = torch.from_numpy((1.0 - mask_pattern_np) * -1e9)[:, None, None, :]