Skip to content

Commit aca8691

Browse files
committed
Fix WAN transformer partitioning for bias and kernel init
1 parent 6de9d57 commit aca8691

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def __init__(
979979
precision=precision,
980980
bias_init=nnx.with_partitioning(
981981
nnx.initializers.zeros,
982-
("embed",),
982+
("heads",),
983983
),
984984
)
985985

@@ -993,7 +993,7 @@ def __init__(
993993
precision=precision,
994994
bias_init=nnx.with_partitioning(
995995
nnx.initializers.zeros,
996-
("embed",),
996+
("heads",),
997997
),
998998
)
999999

@@ -1007,7 +1007,7 @@ def __init__(
10071007
precision=precision,
10081008
bias_init=nnx.with_partitioning(
10091009
nnx.initializers.zeros,
1010-
("embed",),
1010+
("heads",),
10111011
),
10121012
)
10131013

@@ -1021,7 +1021,7 @@ def __init__(
10211021
precision=precision,
10221022
bias_init=nnx.with_partitioning(
10231023
nnx.initializers.zeros,
1024-
("heads",),
1024+
("embed",),
10251025
),
10261026
)
10271027

@@ -1332,11 +1332,13 @@ def setup(self):
13321332
precision=self.precision,
13331333
)
13341334

1335+
proj_attn_kernel_axes = ("heads", "embed")
1336+
13351337
self.proj_attn = nn.Dense(
13361338
self.query_dim,
1337-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1339+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13381340
use_bias=True,
1339-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1341+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13401342
dtype=self.dtype,
13411343
param_dtype=self.weights_dtype,
13421344
name="i_proj",
@@ -1345,9 +1347,9 @@ def setup(self):
13451347

13461348
self.encoder_proj_attn = nn.Dense(
13471349
self.query_dim,
1348-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes),
1350+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes),
13491351
use_bias=True,
1350-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)),
1352+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
13511353
dtype=self.dtype,
13521354
param_dtype=self.weights_dtype,
13531355
name="e_proj",

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ def __init__(
193193
kernel_init=nnx.with_partitioning(
194194
nnx.initializers.xavier_uniform(),
195195
(
196-
"mlp",
197196
"embed",
197+
"mlp",
198198
),
199199
),
200-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
200+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)),
201201
)
202202

203203
def __call__(self, x: jax.Array) -> jax.Array:
@@ -249,8 +249,8 @@ def __init__(
249249
kernel_init=nnx.with_partitioning(
250250
nnx.initializers.xavier_uniform(),
251251
(
252+
"mlp",
252253
"embed",
253-
"mlp",
254254
),
255255
),
256256
)

0 commit comments

Comments
 (0)