Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
hardware: 'tpu'
skip_jax_distributed_system: False
attention: 'flash'
a2v_attention_kernel: 'flash'
v2a_attention_kernel: 'dot_product'
attention_sharding_uniform: True
precision: 'bf16'
scan_layers: True
Expand Down Expand Up @@ -68,6 +70,7 @@ flash_block_sizes: {
block_kv_dkv_compute: 2048,
use_fused_bwd_kernel: True,
}
flash_min_seq_length: 4096
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
Expand Down
78 changes: 42 additions & 36 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Optional, Tuple
from flax import nnx
import jax
import jax.numpy as jnp
from ... import common_types
from ..attention_flax import NNXAttentionOp
Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__(
attention_kernel: str = "flash",
rope_type: str = "interleaved",
flash_block_sizes: BlockSizes = None,
flash_min_seq_length: int = 4096,
):
self.heads = heads
self.rope_type = rope_type
Expand Down Expand Up @@ -434,6 +436,7 @@ def __init__(
axis_names_q=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV),
axis_names_kv=(common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV),
flash_block_sizes=flash_block_sizes,
flash_min_seq_length=flash_min_seq_length,
)

def __call__(
Expand All @@ -447,46 +450,49 @@ def __call__(
# Determine context (Self or Cross)
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states

# 1. Project
query = self.to_q(hidden_states)
key = self.to_k(context)
value = self.to_v(context)
# 1. Project and Norm
with jax.named_scope("QKV Projection"):
query = self.to_q(hidden_states)
key = self.to_k(context)
value = self.to_v(context)

# 2. Norm (Full Inner Dimension)
query = self.norm_q(query)
key = self.norm_k(key)
with jax.named_scope("QKV Norm"):
query = self.norm_q(query)
key = self.norm_k(key)

# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
# Frequencies are shape [B, S, InnerDim]
# 3. Apply RoPE
if rotary_emb is not None:
if hasattr(self, "rope_type") and self.rope_type == "split":
# Split RoPE: passing full freqs [B, H, S, D//2]
# apply_split_rotary_emb handles reshaping query/key

query = apply_split_rotary_emb(query, rotary_emb)

if k_rotary_emb is not None:
key = apply_split_rotary_emb(key, k_rotary_emb)
elif encoder_hidden_states is None:
key = apply_split_rotary_emb(key, rotary_emb)

else:
# Interleaved (Default)
query = apply_rotary_emb(query, rotary_emb)
if k_rotary_emb is not None:
key = apply_rotary_emb(key, k_rotary_emb)
elif encoder_hidden_states is None:
key = apply_rotary_emb(key, rotary_emb)

# 4. Attention
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)

# 7. Output Projection
hidden_states = self.to_out(attn_output)

if self.dropout_layer is not None:
hidden_states = self.dropout_layer(hidden_states)
with jax.named_scope("Apply RoPE"):
if rotary_emb is not None:
if hasattr(self, "rope_type") and self.rope_type == "split":
# Split RoPE: passing full freqs [B, H, S, D//2]
# apply_split_rotary_emb handles reshaping query/key

query = apply_split_rotary_emb(query, rotary_emb)

if k_rotary_emb is not None:
key = apply_split_rotary_emb(key, k_rotary_emb)
elif encoder_hidden_states is None:
key = apply_split_rotary_emb(key, rotary_emb)

else:
# Interleaved (Default)
query = apply_rotary_emb(query, rotary_emb)
if k_rotary_emb is not None:
key = apply_rotary_emb(key, k_rotary_emb)
elif encoder_hidden_states is None:
key = apply_rotary_emb(key, rotary_emb)

with jax.named_scope("Attention and Output Project"):
# 4. Attention
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)

# 7. Output Projection
hidden_states = self.to_out(attn_output)

if self.dropout_layer is not None:
hidden_states = self.dropout_layer(hidden_states)

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,12 @@ def __call__(
Returns:
(video_embeds, audio_embeds, new_attention_mask)
"""
# 1. Shared Feature Extraction
features = self.feature_extractor(hidden_states, attention_mask)
with jax.named_scope("Text Encoder Forward"):
# 1. Shared Feature Extraction
features = self.feature_extractor(hidden_states, attention_mask)

# 2. Parallel Connection
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
# 2. Parallel Connection
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)

return video_embeds, audio_embeds, new_attention_mask
return video_embeds, audio_embeds, new_attention_mask
Loading
Loading