@@ -979,7 +979,7 @@ def __init__(
979979 precision = precision ,
980980 bias_init = nnx .with_partitioning (
981981 nnx .initializers .zeros ,
982- ("embed " ,),
982+ ("heads " ,),
983983 ),
984984 )
985985
@@ -993,7 +993,7 @@ def __init__(
993993 precision = precision ,
994994 bias_init = nnx .with_partitioning (
995995 nnx .initializers .zeros ,
996- ("embed " ,),
996+ ("heads " ,),
997997 ),
998998 )
999999
@@ -1007,7 +1007,7 @@ def __init__(
10071007 precision = precision ,
10081008 bias_init = nnx .with_partitioning (
10091009 nnx .initializers .zeros ,
1010- ("embed " ,),
1010+ ("heads " ,),
10111011 ),
10121012 )
10131013
@@ -1021,7 +1021,7 @@ def __init__(
10211021 precision = precision ,
10221022 bias_init = nnx .with_partitioning (
10231023 nnx .initializers .zeros ,
1024- ("heads " ,),
1024+ ("embed " ,),
10251025 ),
10261026 )
10271027
@@ -1332,11 +1332,13 @@ def setup(self):
13321332 precision = self .precision ,
13331333 )
13341334
1335+ proj_attn_kernel_axes = ("heads" , "embed" )
1336+
13351337 self .proj_attn = nn .Dense (
13361338 self .query_dim ,
1337- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
1339+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), proj_attn_kernel_axes ),
13381340 use_bias = True ,
1339- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads " ,)),
1341+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("embed " ,)),
13401342 dtype = self .dtype ,
13411343 param_dtype = self .weights_dtype ,
13421344 name = "i_proj" ,
@@ -1345,9 +1347,9 @@ def setup(self):
13451347
13461348 self .encoder_proj_attn = nn .Dense (
13471349 self .query_dim ,
1348- kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), kernel_axes ),
1350+ kernel_init = nn .with_logical_partitioning (nn .initializers .lecun_normal (), proj_attn_kernel_axes ),
13491351 use_bias = True ,
1350- bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("heads " ,)),
1352+ bias_init = nn .with_logical_partitioning (nn .initializers .zeros , ("embed " ,)),
13511353 dtype = self .dtype ,
13521354 param_dtype = self .weights_dtype ,
13531355 name = "e_proj" ,
0 commit comments