From 287dce8f53b701e9016914527352dd091f1f2e13 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Tue, 7 Apr 2026 22:48:50 +0000 Subject: [PATCH 1/7] update --- src/maxtext/layers/moe.py | 21 ++++- tests/integration/smoke/train_smoke_test.py | 88 ++++++++++++++++++++- 2 files changed, 105 insertions(+), 4 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..16896c7c88 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1439,6 +1439,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", None, None)) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", None, None)) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", None, None)) + if w0_bias is not None: + w0_bias = self._maybe_shard_with_logical(w0_bias, ("exp", None)) + if w1_bias is not None: + w1_bias = self._maybe_shard_with_logical(w1_bias, ("exp", None)) + if wo_bias is not None: + wo_bias = self._maybe_shard_with_logical(wo_bias, ("exp", None)) + return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs ) @@ -1712,7 +1722,16 @@ def dense_matmul( # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) - ) + ) + w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", None, None)) + w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", None, None)) + wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", None, None)) + if w0_bias is not None: + w0_bias = self._maybe_shard_with_logical(w0_bias, ("exp", None)) + if w1_bias is not None: + w1_bias = self._maybe_shard_with_logical(w1_bias, ("exp", None)) + if wo_bias is not None: + wo_bias = self._maybe_shard_with_logical(wo_bias, ("exp", None)) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 3ed0b40c14..12d0fb0414 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -46,7 +46,7 @@ def test_tiny_config(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -74,7 +74,7 @@ def test_tiny_config_no_scan(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -103,7 +103,7 @@ def test_tiny_config_explicit_shardmode(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -124,5 +124,87 @@ def test_tiny_config_explicit_shardmode(self): ) + def test_tiny_config_explicit_shardmode_deepseek(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + # Tests the Dense Matmul codepath + train_main( + [ + None, + get_test_config_path(), + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test_deepseek", + f"dataset_path={self.dataset_path}", + "model_name=deepseek3-test", + "base_emb_dim=32", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "head_dim=32", + "v_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "q_lora_rank=16", + "kv_lora_rank=16", + "per_device_batch_size=1", + "max_target_length=64", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "sparse_matmul=False", + "capacity_factor=-1", + "shard_mode=explicit", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", + ] + ) + # Tests the Sparse Matmul codepath + train_main( + [ + None, + get_test_config_path(), + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test_deepseek", + f"dataset_path={self.dataset_path}", + "model_name=deepseek3-test", + "base_emb_dim=32", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "head_dim=32", + "v_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "q_lora_rank=16", + "kv_lora_rank=16", + "per_device_batch_size=1", + "max_target_length=64", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "sparse_matmul=True", + "capacity_factor=-1", + "shard_mode=explicit", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", + ] + ) + + if __name__ == "__main__": absltest.main() From c491bf3a28230db02a88c2b9a70183d316bfcc9a Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 03:19:12 +0000 Subject: [PATCH 2/7] Update --- src/maxtext/layers/moe.py | 23 ++++++--- .../trainers/pre_train/train_compile.py | 2 +- tests/integration/smoke/train_smoke_test.py | 51 ++++++++++++++++++- 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 16896c7c88..2d7bda3e57 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -36,7 +36,7 @@ from maxtext.kernels import megablox as mblx from maxtext.utils import max_logging from maxtext.utils import max_utils -from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_name from maxtext.utils.sharding import logical_to_mesh_axes import numpy as np import qwix.pallas as qpl @@ -1439,15 +1439,21 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) - w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", None, None)) - w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", None, None)) - wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", None, None)) + w0_ns = jax.sharding.NamedSharding(self.mesh, w0_pspec) + w1_ns = jax.sharding.NamedSharding(self.mesh, w1_pspec) + wo_ns = jax.sharding.NamedSharding(self.mesh, wo_pspec) + w0_bias_ns = jax.sharding.NamedSharding(self.mesh, w0_bias_pspec) + w1_bias_ns = jax.sharding.NamedSharding(self.mesh, w1_bias_pspec) + wo_bias_ns = jax.sharding.NamedSharding(self.mesh, wo_bias_pspec) + w0_kernel = maybe_shard_with_name(w0_kernel, w0_ns, self.config.shard_mode) + w1_kernel = maybe_shard_with_name(w1_kernel, w1_ns, self.config.shard_mode) + wo_kernel = maybe_shard_with_name(wo_kernel, wo_ns, self.config.shard_mode) if w0_bias is not None: - w0_bias = self._maybe_shard_with_logical(w0_bias, ("exp", None)) + w0_bias = maybe_shard_with_name(w0_bias, w0_bias_ns, self.config.shard_mode) if w1_bias is not None: - w1_bias = self._maybe_shard_with_logical(w1_bias, ("exp", None)) + w1_bias = maybe_shard_with_name(w1_bias, w1_bias_ns, self.config.shard_mode) if wo_bias is not None: - wo_bias = self._maybe_shard_with_logical(wo_bias, ("exp", None)) + wo_bias = maybe_shard_with_name(wo_bias, wo_bias_ns, self.config.shard_mode) return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs @@ -2044,8 +2050,9 @@ def __call__( gate_dtype = jnp.float32 if cfg.float32_gate_logits else cfg.dtype routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) gate_logits, pre_bias_logits = self.gate(routing_inputs) - + w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) + print("shuwen w0 kernel init:", jax.typeof(w0_kernel)) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) wo_kernel = jnp.asarray(self.wo[...], self.dtype) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 78392a388a..efbea75b56 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -219,7 +219,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: - jax.config.update("jax_default_prng_impl", "unsafe_rbg") + jax.config.update("jax_default_prng_impl", "threefry2x32") os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 12d0fb0414..6c8e50bd3d 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -16,7 +16,7 @@ import os import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory from maxtext.common.gcloud_stub import is_decoupled @@ -24,7 +24,7 @@ from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT -class Train(unittest.TestCase): +class Train(parameterized.TestCase): """Smoke test G3 only""" def setUp(self): @@ -205,6 +205,53 @@ def test_tiny_config_explicit_shardmode_deepseek(self): ] ) + @parameterized.named_parameters( + ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), + ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), + ("fsdp_tensor", ["ici_fsdp_parallelism=-1", "ici_tensor_parallelism=2"]), + ("fsdp", ["ici_fsdp_parallelism=-1"]), + ("fsdp_transpose", ["ici_fsdp_parallelism=-1", "ici_tensor_transpose_parallelism=2"]), + ) + def test_parallelism_configs(self, parallelism_args): + base_args = [ + None, + get_test_config_path(), + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test_parallelism", + f"dataset_path={self.dataset_path}", + "model_name=deepseek3-test", + "base_emb_dim=32", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "head_dim=32", + "v_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "q_lora_rank=16", + "kv_lora_rank=16", + "per_device_batch_size=1", + "max_target_length=64", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "capacity_factor=-1", + "shard_mode=auto", + "enable_goodput_recording=False", + "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", + ] + + full_args = base_args + parallelism_args + + train_main(full_args) + + if __name__ == "__main__": absltest.main() From 18f2e98a57aca45f05442dbfc81d000b6406f9d8 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 03:40:03 +0000 Subject: [PATCH 3/7] remove dense matmul changes --- src/maxtext/layers/moe.py | 9 --------- tests/integration/smoke/train_smoke_test.py | 1 + 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 2d7bda3e57..6a14524e5e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1729,15 +1729,6 @@ def dense_matmul( pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) ) - w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp", None, None)) - w1_kernel = self._maybe_shard_with_logical(w1_kernel, ("exp", None, None)) - wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp", None, None)) - if w0_bias is not None: - w0_bias = self._maybe_shard_with_logical(w0_bias, ("exp", None)) - if w1_bias is not None: - w1_bias = self._maybe_shard_with_logical(w1_bias, ("exp", None)) - if wo_bias is not None: - wo_bias = self._maybe_shard_with_logical(wo_bias, ("exp", None)) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 6c8e50bd3d..9b52fd0d9c 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -208,6 +208,7 @@ def test_tiny_config_explicit_shardmode_deepseek(self): @parameterized.named_parameters( ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), + ("fsdp_expert_dense_matmul", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "sparse_matmul=False"]), ("fsdp_tensor", ["ici_fsdp_parallelism=-1", "ici_tensor_parallelism=2"]), ("fsdp", ["ici_fsdp_parallelism=-1"]), ("fsdp_transpose", ["ici_fsdp_parallelism=-1", "ici_tensor_transpose_parallelism=2"]), From 0a8693cec0d41a1740b8904eb80485e72e273ed3 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 16:34:54 +0000 Subject: [PATCH 4/7] update --- src/maxtext/trainers/pre_train/train_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index efbea75b56..fcc8d5ecd3 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -219,7 +219,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: - jax.config.update("jax_default_prng_impl", "threefry2x32") + jax.config.update("jax_default_prng_impl", "unsafe_rbg") #threefry2x32 os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) From 644a941368ef580fc7f94c686cd5a85a390beba7 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 16:38:36 +0000 Subject: [PATCH 5/7] formatting --- src/maxtext/layers/moe.py | 4 ++-- src/maxtext/trainers/pre_train/train_compile.py | 2 +- tests/integration/smoke/train_smoke_test.py | 14 ++++++-------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 6a14524e5e..19e2600650 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1728,7 +1728,7 @@ def dense_matmul( # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits = self._maybe_shard_with_logical( pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) - ) + ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 if is_llama4_decoder_layer: @@ -2041,7 +2041,7 @@ def __call__( gate_dtype = jnp.float32 if cfg.float32_gate_logits else cfg.dtype routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) gate_logits, pre_bias_logits = self.gate(routing_inputs) - + w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) print("shuwen w0 kernel init:", jax.typeof(w0_kernel)) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index fcc8d5ecd3..26cddd2816 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -219,7 +219,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: - jax.config.update("jax_default_prng_impl", "unsafe_rbg") #threefry2x32 + jax.config.update("jax_default_prng_impl", "unsafe_rbg") # threefry2x32 os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 9b52fd0d9c..130e58be1d 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -123,7 +123,6 @@ def test_tiny_config_explicit_shardmode(self): ] ) - def test_tiny_config_explicit_shardmode_deepseek(self): test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable # Tests the Dense Matmul codepath @@ -206,12 +205,12 @@ def test_tiny_config_explicit_shardmode_deepseek(self): ) @parameterized.named_parameters( - ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), - ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), - ("fsdp_expert_dense_matmul", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "sparse_matmul=False"]), - ("fsdp_tensor", ["ici_fsdp_parallelism=-1", "ici_tensor_parallelism=2"]), - ("fsdp", ["ici_fsdp_parallelism=-1"]), - ("fsdp_transpose", ["ici_fsdp_parallelism=-1", "ici_tensor_transpose_parallelism=2"]), + ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), + ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), + ("fsdp_expert_dense_matmul", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "sparse_matmul=False"]), + ("fsdp_tensor", ["ici_fsdp_parallelism=-1", "ici_tensor_parallelism=2"]), + ("fsdp", ["ici_fsdp_parallelism=-1"]), + ("fsdp_transpose", ["ici_fsdp_parallelism=-1", "ici_tensor_transpose_parallelism=2"]), ) def test_parallelism_configs(self, parallelism_args): base_args = [ @@ -253,6 +252,5 @@ def test_parallelism_configs(self, parallelism_args): train_main(full_args) - if __name__ == "__main__": absltest.main() From c49646ffcb49d75fd8efe31f1c653d65423f4243 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 17:34:46 +0000 Subject: [PATCH 6/7] update --- src/maxtext/layers/attention_op.py | 35 ++++++-------- src/maxtext/layers/moe.py | 27 +++++------ .../trainers/pre_train/train_compile.py | 2 +- src/maxtext/utils/sharding.py | 13 +++++ tests/integration/smoke/train_smoke_test.py | 48 ++----------------- 5 files changed, 44 insertions(+), 81 deletions(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..8cca826d70 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -78,7 +78,7 @@ from maxtext.layers.initializers import variable_to_logically_partitioned from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils -from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name +from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec import numpy as np from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask @@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s): return attention_output, None - def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): - # decoder_segment_ids can be None - if pspec is None: - return None - sharding = NamedSharding(self.mesh, pspec) - return maybe_shard_with_name( - inputs, - sharding, - shard_mode=self.config.shard_mode, - debug_sharding=self.config.debug_sharding, - extra_stack_level=1, - ) - - query = _maybe_shard_with_pspec(query, axis_names_q) - key = _maybe_shard_with_pspec(key, axis_names_kv) - value = _maybe_shard_with_pspec(value, axis_names_kv) - decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) - decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) - sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) - indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) + query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding) + key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + decoder_segment_ids_q = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding + ) + decoder_segment_ids_kv = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding + ) + sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding) + indexer_mask = maybe_shard_with_pspec( + indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding + ) ret = wrap_flash_attention( query, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 19e2600650..65a17358a1 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -36,7 +36,7 @@ from maxtext.kernels import megablox as mblx from maxtext.utils import max_logging from maxtext.utils import max_utils -from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_name +from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec from maxtext.utils.sharding import logical_to_mesh_axes import numpy as np import qwix.pallas as qpl @@ -1439,21 +1439,21 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) - w0_ns = jax.sharding.NamedSharding(self.mesh, w0_pspec) - w1_ns = jax.sharding.NamedSharding(self.mesh, w1_pspec) - wo_ns = jax.sharding.NamedSharding(self.mesh, wo_pspec) - w0_bias_ns = jax.sharding.NamedSharding(self.mesh, w0_bias_pspec) - w1_bias_ns = jax.sharding.NamedSharding(self.mesh, w1_bias_pspec) - wo_bias_ns = jax.sharding.NamedSharding(self.mesh, wo_bias_pspec) - w0_kernel = maybe_shard_with_name(w0_kernel, w0_ns, self.config.shard_mode) - w1_kernel = maybe_shard_with_name(w1_kernel, w1_ns, self.config.shard_mode) - wo_kernel = maybe_shard_with_name(wo_kernel, wo_ns, self.config.shard_mode) + w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec, self.config.debug_sharding) + w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec, self.config.debug_sharding) + wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec, self.config.debug_sharding) if w0_bias is not None: - w0_bias = maybe_shard_with_name(w0_bias, w0_bias_ns, self.config.shard_mode) + w0_bias = maybe_shard_with_pspec( + w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec, self.config.debug_sharding + ) if w1_bias is not None: - w1_bias = maybe_shard_with_name(w1_bias, w1_bias_ns, self.config.shard_mode) + w1_bias = maybe_shard_with_pspec( + w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec, self.config.debug_sharding + ) if wo_bias is not None: - wo_bias = maybe_shard_with_name(wo_bias, wo_bias_ns, self.config.shard_mode) + wo_bias = maybe_shard_with_pspec( + wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec, self.config.debug_sharding + ) return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs @@ -2043,7 +2043,6 @@ def __call__( gate_logits, pre_bias_logits = self.gate(routing_inputs) w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) - print("shuwen w0 kernel init:", jax.typeof(w0_kernel)) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) wo_kernel = jnp.asarray(self.wo[...], self.dtype) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 26cddd2816..78392a388a 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -219,7 +219,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: - jax.config.update("jax_default_prng_impl", "unsafe_rbg") # threefry2x32 + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["LIBTPU_INIT_ARGS"] = ( os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 74b22548b0..5b8468c749 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -115,6 +115,19 @@ def maybe_shard_with_name( return jax.lax.with_sharding_constraint(inputs, named_sharding) +def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False): + if pspec is None: + return None + sharding = NamedSharding(mesh, pspec) + return maybe_shard_with_name( + inputs, + sharding, + shard_mode=shard_mode, + debug_sharding=debug_sharding, + extra_stack_level=1, + ) + + def maybe_shard_with_logical( inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" ): diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 130e58be1d..8115c2a40b 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -14,7 +14,6 @@ """ Smoke test """ import os -import unittest from absl.testing import absltest, parameterized from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory @@ -125,46 +124,6 @@ def test_tiny_config_explicit_shardmode(self): def test_tiny_config_explicit_shardmode_deepseek(self): test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable - # Tests the Dense Matmul codepath - train_main( - [ - None, - get_test_config_path(), - # pylint: disable=f-string-without-interpolation - f"base_output_directory={self.base_output_directory}", - "run_name=runner_test_deepseek", - f"dataset_path={self.dataset_path}", - "model_name=deepseek3-test", - "base_emb_dim=32", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=64", - "base_moe_mlp_dim=64", - "base_num_decoder_layers=2", - "first_num_dense_layers=1", - "head_dim=32", - "v_head_dim=32", - "qk_nope_head_dim=32", - "qk_rope_head_dim=16", - "q_lora_rank=16", - "kv_lora_rank=16", - "per_device_batch_size=1", - "max_target_length=64", - "dataset_type=synthetic", - "steps=2", - "enable_checkpointing=False", - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", - "sparse_matmul=False", - "capacity_factor=-1", - "shard_mode=explicit", - "enable_goodput_recording=False", - "enable_checkpoint_cloud_logger=False", - "monitor_goodput=False", - "abort_on_nan_loss=False", - "abort_on_inf_loss=False", - ] - ) - # Tests the Sparse Matmul codepath train_main( [ None, @@ -207,12 +166,11 @@ def test_tiny_config_explicit_shardmode_deepseek(self): @parameterized.named_parameters( ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), - ("fsdp_expert_dense_matmul", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "sparse_matmul=False"]), - ("fsdp_tensor", ["ici_fsdp_parallelism=-1", "ici_tensor_parallelism=2"]), ("fsdp", ["ici_fsdp_parallelism=-1"]), - ("fsdp_transpose", ["ici_fsdp_parallelism=-1", "ici_tensor_transpose_parallelism=2"]), ) def test_parallelism_configs(self, parallelism_args): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + base_args = [ None, get_test_config_path(), @@ -240,7 +198,7 @@ def test_parallelism_configs(self, parallelism_args): "enable_checkpointing=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", "capacity_factor=-1", - "shard_mode=auto", + "shard_mode=explicit", "enable_goodput_recording=False", "monitor_goodput=False", "abort_on_nan_loss=False", From de6fa70268410a3dca6eeb0eb61fbb70bf5d75d5 Mon Sep 17 00:00:00 2001 From: Shuwen Fang Date: Wed, 8 Apr 2026 18:06:19 +0000 Subject: [PATCH 7/7] fixes --- src/maxtext/layers/attention_op.py | 2 +- src/maxtext/layers/moe.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 8cca826d70..0fea057876 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -32,7 +32,7 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding +from jax.sharding import Mesh from maxtext.common.common_types import ( Array, AttentionType, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 65a17358a1..5669ac5fae 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1439,21 +1439,15 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) - w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec, self.config.debug_sharding) - w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec, self.config.debug_sharding) - wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec, self.config.debug_sharding) + w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec) + w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec) + wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec) if w0_bias is not None: - w0_bias = maybe_shard_with_pspec( - w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec, self.config.debug_sharding - ) + w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec) if w1_bias is not None: - w1_bias = maybe_shard_with_pspec( - w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec, self.config.debug_sharding - ) + w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec) if wo_bias is not None: - wo_bias = maybe_shard_with_pspec( - wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec, self.config.debug_sharding - ) + wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec) return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs