diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..0fea057876 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -32,7 +32,7 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding +from jax.sharding import Mesh from maxtext.common.common_types import ( Array, AttentionType, @@ -78,7 +78,7 @@ from maxtext.layers.initializers import variable_to_logically_partitioned from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils -from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name +from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec import numpy as np from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask @@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s): return attention_output, None - def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): - # decoder_segment_ids can be None - if pspec is None: - return None - sharding = NamedSharding(self.mesh, pspec) - return maybe_shard_with_name( - inputs, - sharding, - shard_mode=self.config.shard_mode, - debug_sharding=self.config.debug_sharding, - extra_stack_level=1, - ) - - query = _maybe_shard_with_pspec(query, axis_names_q) - key = _maybe_shard_with_pspec(key, axis_names_kv) - value = _maybe_shard_with_pspec(value, axis_names_kv) - decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) - decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) - sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) - indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) + query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding) + key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) + decoder_segment_ids_q = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding + ) + decoder_segment_ids_kv = maybe_shard_with_pspec( + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding + ) + sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding) + indexer_mask = maybe_shard_with_pspec( + indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding + ) ret = wrap_flash_attention( query, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..5669ac5fae 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -36,7 +36,7 @@ from maxtext.kernels import megablox as mblx from maxtext.utils import max_logging from maxtext.utils import max_utils -from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding +from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec from maxtext.utils.sharding import logical_to_mesh_axes import numpy as np import qwix.pallas as qpl @@ -1439,6 +1439,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes) + w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec) + w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec) + wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec) + if w0_bias is not None: + w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec) + if w1_bias is not None: + w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec) + if wo_bias is not None: + wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec) + return wrapper( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs ) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 74b22548b0..5b8468c749 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -115,6 +115,19 @@ def maybe_shard_with_name( return jax.lax.with_sharding_constraint(inputs, named_sharding) +def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False): + if pspec is None: + return None + sharding = NamedSharding(mesh, pspec) + return maybe_shard_with_name( + inputs, + sharding, + shard_mode=shard_mode, + debug_sharding=debug_sharding, + extra_stack_level=1, + ) + + def maybe_shard_with_logical( inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" ): diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 3ed0b40c14..8115c2a40b 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -14,9 +14,8 @@ """ Smoke test """ import os -import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory from maxtext.common.gcloud_stub import is_decoupled @@ -24,7 +23,7 @@ from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT -class Train(unittest.TestCase): +class Train(parameterized.TestCase): """Smoke test G3 only""" def setUp(self): @@ -46,7 +45,7 @@ def test_tiny_config(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -74,7 +73,7 @@ def test_tiny_config_no_scan(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -103,7 +102,7 @@ def test_tiny_config_explicit_shardmode(self): # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", - r"dataset_path={self.dataset_path}", + f"dataset_path={self.dataset_path}", "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", @@ -123,6 +122,93 @@ def test_tiny_config_explicit_shardmode(self): ] ) + def test_tiny_config_explicit_shardmode_deepseek(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path(), + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test_deepseek", + f"dataset_path={self.dataset_path}", + "model_name=deepseek3-test", + "base_emb_dim=32", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "head_dim=32", + "v_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "q_lora_rank=16", + "kv_lora_rank=16", + "per_device_batch_size=1", + "max_target_length=64", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "sparse_matmul=True", + "capacity_factor=-1", + "shard_mode=explicit", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", + ] + ) + + @parameterized.named_parameters( + ("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]), + ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), + ("fsdp", ["ici_fsdp_parallelism=-1"]), + ) + def test_parallelism_configs(self, parallelism_args): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + + base_args = [ + None, + get_test_config_path(), + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test_parallelism", + f"dataset_path={self.dataset_path}", + "model_name=deepseek3-test", + "base_emb_dim=32", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_decoder_layers=2", + "first_num_dense_layers=1", + "head_dim=32", + "v_head_dim=32", + "qk_nope_head_dim=32", + "qk_rope_head_dim=16", + "q_lora_rank=16", + "kv_lora_rank=16", + "per_device_batch_size=1", + "max_target_length=64", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "capacity_factor=-1", + "shard_mode=explicit", + "enable_goodput_recording=False", + "monitor_goodput=False", + "abort_on_nan_loss=False", + "abort_on_inf_loss=False", + ] + + full_args = base_args + parallelism_args + + train_main(full_args) + if __name__ == "__main__": absltest.main()