Skip to content

Commit 4ae0dff

Browse files
committed
Fixes
1 parent 7b86470 commit 4ae0dff

4 files changed

Lines changed: 25 additions & 7 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
44
attention: 'flash'
5+
a2v_attention_kernel: 'flash'
6+
v2a_attention_kernel: 'dot_product'
57
attention_sharding_uniform: True
68
precision: 'bf16'
79
scan_layers: True

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(
106106
names_which_can_be_saved: list = [],
107107
names_which_can_be_offloaded: list = [],
108108
attention_kernel: str = "flash",
109+
a2v_attention_kernel: str = "flash",
110+
v2a_attention_kernel: str = "dot_product",
109111
flash_block_sizes: BlockSizes = None,
110112
flash_min_seq_length: int = 4096,
111113
):
@@ -243,7 +245,7 @@ def __init__(
243245
eps=norm_eps,
244246
dtype=dtype,
245247
mesh=mesh,
246-
attention_kernel="flash",
248+
attention_kernel=a2v_attention_kernel,
247249
rope_type=rope_type,
248250
flash_block_sizes=flash_block_sizes,
249251
flash_min_seq_length=0,
@@ -270,7 +272,7 @@ def __init__(
270272
eps=norm_eps,
271273
dtype=dtype,
272274
mesh=mesh,
273-
attention_kernel=self.attention_kernel,
275+
attention_kernel=v2a_attention_kernel,
274276
rope_type=rope_type,
275277
flash_block_sizes=flash_block_sizes,
276278
flash_min_seq_length=flash_min_seq_length,
@@ -571,6 +573,8 @@ def __init__(
571573
names_which_can_be_offloaded: list = [],
572574
scan_layers: bool = True,
573575
attention_kernel: str = "flash",
576+
a2v_attention_kernel: str = "flash",
577+
v2a_attention_kernel: str = "dot_product",
574578
qk_norm: str = "rms_norm_across_heads",
575579
flash_block_sizes: BlockSizes = None,
576580
flash_min_seq_length: int = 4096,
@@ -620,6 +624,8 @@ def __init__(
620624
self.names_which_can_be_offloaded = names_which_can_be_offloaded
621625
self.scan_layers = scan_layers
622626
self.attention_kernel = attention_kernel
627+
self.a2v_attention_kernel = a2v_attention_kernel
628+
self.v2a_attention_kernel = v2a_attention_kernel
623629
self.flash_min_seq_length = flash_min_seq_length
624630

625631
_out_channels = self.out_channels or self.in_channels
@@ -813,6 +819,8 @@ def init_block(rngs):
813819
names_which_can_be_saved=self.names_which_can_be_saved,
814820
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
815821
attention_kernel=self.attention_kernel,
822+
a2v_attention_kernel=self.a2v_attention_kernel,
823+
v2a_attention_kernel=self.v2a_attention_kernel,
816824
flash_block_sizes=flash_block_sizes,
817825
flash_min_seq_length=self.flash_min_seq_length,
818826
)
@@ -846,6 +854,8 @@ def init_block(rngs):
846854
names_which_can_be_saved=self.names_which_can_be_saved,
847855
names_which_can_be_offloaded=self.names_which_can_be_offloaded,
848856
attention_kernel=self.attention_kernel,
857+
a2v_attention_kernel=self.a2v_attention_kernel,
858+
v2a_attention_kernel=self.v2a_attention_kernel,
849859
flash_block_sizes=flash_block_sizes,
850860
flash_min_seq_length=self.flash_min_seq_length,
851861
)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,22 +1232,27 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12321232
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
12331233
)
12341234

1235+
video_embeds_sharded = video_embeds
1236+
audio_embeds_sharded = audio_embeds
1237+
1238+
if not self.transformer.scan_layers:
1239+
activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
1240+
spec = NamedSharding(self.mesh, P(*activation_axes))
1241+
video_embeds_sharded = jax.device_put(video_embeds, spec)
1242+
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
1243+
12351244
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1236-
for i, t_val in enumerate(timesteps):
1245+
for i in range(len(timesteps_jax)):
12371246
t = timesteps_jax[i]
12381247

12391248
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
12401249
latents_jax_sharded = latents_jax
12411250
audio_latents_jax_sharded = audio_latents_jax
1242-
video_embeds_sharded = video_embeds
1243-
audio_embeds_sharded = audio_embeds
12441251

12451252
if not self.transformer.scan_layers:
12461253
activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
12471254
latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names)
12481255
audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names)
1249-
video_embeds_sharded = jax.lax.with_sharding_constraint(video_embeds, activation_axis_names)
1250-
audio_embeds_sharded = jax.lax.with_sharding_constraint(audio_embeds, activation_axis_names)
12511256

12521257
noise_pred, noise_pred_audio = transformer_forward_pass(
12531258
graphdef,

src/maxdiffusion/schedulers/scheduling_flow_match_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarra
244244
diffs = jnp.abs(state.timesteps[None, :] - timestep[:, None])
245245
return jnp.argmin(diffs, axis=1)
246246

247+
# Arguments at indices 0 (self), 5 (to_final), and 6 (return_dict) are kept static for JIT compilation.
247248
@partial(jax.jit, static_argnums=(0, 5, 6))
248249
def step(
249250
self,

0 commit comments

Comments
 (0)