From aca86918dff899ef529dfaf1749ae5e33dbd2693 Mon Sep 17 00:00:00 2001 From: Sagar Chapara Date: Wed, 8 Apr 2026 05:47:43 +0000 Subject: [PATCH 1/2] Fix WAN transformer partitioning for bias and kernel init --- src/maxdiffusion/models/attention_flax.py | 18 ++++++++++-------- .../models/wan/transformers/transformer_wan.py | 6 +++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 453c1650..9530dd09 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -979,7 +979,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -993,7 +993,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1007,7 +1007,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1021,7 +1021,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("heads",), + ("embed",), ), ) @@ -1332,11 +1332,13 @@ def setup(self): precision=self.precision, ) + proj_attn_kernel_axes = ("heads", "embed") + self.proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="i_proj", @@ -1345,9 +1347,9 @@ def setup(self): self.encoder_proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="e_proj", diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 11d7cad2..5e800598 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -193,11 +193,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", "embed", + "mlp", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -249,8 +249,8 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( + "mlp", "embed", - "mlp", ), ), ) From 039ceb3e68a98c10c84e01eb08b38a1c0e5f490b Mon Sep 17 00:00:00 2001 From: Sagar Chapara Date: Thu, 9 Apr 2026 11:59:20 +0000 Subject: [PATCH 2/2] fix lint issues --- src/maxdiffusion/models/wan/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 5e800598..e64d4de2 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -249,7 +249,7 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", + "mlp", "embed", ), ),