From 35efff13cb894195205775ccbff94161487ef003 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 8 Apr 2026 18:01:02 +0000 Subject: [PATCH 1/3] remoe no_exp from attentions --- src/maxtext/common/common_types.py | 3 -- src/maxtext/configs/base.yml | 6 +-- .../pipeline-large-moe.yml | 3 +- .../custom_mesh_and_rule/pure-fsdp.yml | 1 - src/maxtext/configs/inference/vllm.yml | 6 +-- src/maxtext/layers/attention_op.py | 10 ++--- src/maxtext/layers/attentions.py | 39 +++++-------------- 7 files changed, 20 insertions(+), 48 deletions(-) diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 4f5b825c00..8ab7182779 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -34,12 +34,10 @@ BATCH = "activation_batch" ATTN_LENGTH = "activation_attn_length" -ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp" LENGTH = "activation_length" PREFILL_LENGTH = "prefill_activation_length" Q_LENGTH = "activation_q_length" -Q_LENGTH_NO_EXP = "activation_q_length_no_exp" Q_LORA_UP_PROJ = "q_lora_up_proj" KV_LENGTH = "activation_kv_length" KV_LORA_UP_PROJ = "kv_lora_up_proj" @@ -48,7 +46,6 @@ HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" KV_BATCH = "activation_kv_batch" -KV_BATCH_NO_EXP = "activation_kv_batch_no_exp" KV_HEAD = "activation_kv_heads" KV_HEAD_DIM = "activation_kv_head_dim" D_KV = "activation_kv" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0b335608c4..a135071d83 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -468,8 +468,7 @@ logical_axis_rules: [ ['activation_length_moe', ['context']], ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], - ['activation_q_length', ['context', 'expert']], - ['activation_q_length_no_exp', ['context']], + ['activation_q_length', ['context']], ['prefill_activation_length', ['sequence', 'context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_kv_length', []], @@ -479,8 +478,7 @@ logical_axis_rules: [ ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 57f181bb85..7da8e4b4a9 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -46,8 +46,7 @@ logical_axis_rules: [ ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], ['activation_kv', ['tensor']], - ['activation_kv_batch', ['data', 'fsdp', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp']], + ['activation_kv_batch', ['data', 'fsdp']], ['activation_kv_head_dim', ['tensor']], ['activation_vocab', ['tensor']], ['activation_stage', 'stage'], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index caf6a70e7f..5d6939bad6 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -24,7 +24,6 @@ logical_axis_rules: [ ['activation_embed_and_logits_batch_sequence', ['fsdp']], ['activation_prefill_kv_batch', ['fsdp']], ['activation_kv_batch', ['fsdp']], - ['activation_kv_batch_no_exp', ['fsdp']], ['decode_batch', ['fsdp']], ['embed', ['fsdp']], ['embed_no_exp', ['fsdp']], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 4e9315dbb4..3664e5ce1a 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -36,8 +36,7 @@ logical_axis_rules: [ ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], ['activation_heads', ['model', 'expert']], ['activation_kv_heads', ['model', 'expert']], - ['activation_attn_length', ['expert']], - ['activation_attn_length_no_exp', []], + ['activation_attn_length', []], ['activation_length', ['data']], ['activation_length_moe', ['data', 'expert']], ['activation_length_moe', 'data'], @@ -48,8 +47,7 @@ logical_axis_rules: [ ['activation_mlp', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], - ['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']], - ['activation_kv_batch_no_exp', ['data']], + ['activation_kv_batch', ['data']], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model', 'attn_dp']], ['activation_norm_length', []], diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 8a795ad5da..3bcd07e3e1 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -63,7 +63,7 @@ MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, PREFILL_LENGTH, - Q_LENGTH_NO_EXP, + Q_LENGTH, ) from maxtext.inference import page_manager from maxtext.inference.kvcache import KVQuant, KVTensor @@ -1134,13 +1134,13 @@ def tpu_flash_attention( segment_axis_names_kv = None sink_axis_names = self._logical_to_mesh_axes((HEAD,)) if decoder_segment_ids is not None: - segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) + segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH)) segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) - indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP, KV_LENGTH)) + indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel @@ -1269,11 +1269,11 @@ def wrap_splash_kernel(single_head_mask): return splash_kernel splash_kernel = wrap_splash_kernel(single_head_mask) - segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP: if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) - segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) else: # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 50472b0646..76fa6291f0 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -34,7 +34,6 @@ AxisNames, AxisIdxes, ATTN_LENGTH, - ATTN_LENGTH_NO_EXP, DType, Config, Array, @@ -44,12 +43,10 @@ KV_HEAD, KV_HEAD_DIM, KV_BATCH, - KV_BATCH_NO_EXP, ATTN_EMBED, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, - EP_AS_CONTEXT, AttentionType, ) from maxtext.layers import nnx_wrappers @@ -141,14 +138,11 @@ def attention_as_linen( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -208,9 +202,6 @@ def attention_as_linen( query_axis_names=query_axis_names, key_axis_names=key_axis_names, value_axis_names=value_axis_names, - ep_query_axis_names=ep_query_axis_names, - ep_key_axis_names=ep_key_axis_names, - ep_value_axis_names=ep_value_axis_names, input_axis_names=input_axis_names, out_axis_names=out_axis_names, prefill_input_axis_names=prefill_input_axis_names, @@ -304,14 +295,11 @@ def __init__( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -413,9 +401,6 @@ def __init__( self.query_axis_names = query_axis_names self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names - self.ep_query_axis_names = ep_query_axis_names - self.ep_key_axis_names = ep_key_axis_names - self.ep_value_axis_names = ep_value_axis_names self.input_axis_names = input_axis_names self.out_axis_names = out_axis_names self.prefill_input_axis_names = prefill_input_axis_names @@ -1161,10 +1146,6 @@ def __call__( query = self._maybe_shard_with_logical(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) key = self._maybe_shard_with_logical(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) value = self._maybe_shard_with_logical(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - query = self._maybe_shard_with_logical(query, self.ep_query_axis_names) - key = self._maybe_shard_with_logical(key, self.ep_key_axis_names) - value = self._maybe_shard_with_logical(value, self.ep_value_axis_names) else: query = self._maybe_shard_with_logical(query, self.query_axis_names) key = self._maybe_shard_with_logical(key, self.key_axis_names) From a3b3ccef1cc2d9bf17d13e3dca1c00d33559b60d Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 8 Apr 2026 18:40:27 +0000 Subject: [PATCH 2/3] remove no exp in moe --- src/maxtext/configs/base.yml | 20 +- .../pipeline-large-moe.yml | 8 +- .../custom_mesh_and_rule/pure-fsdp.yml | 3 - src/maxtext/configs/inference/inference.yml | 9 +- src/maxtext/configs/inference/vllm.yml | 8 +- .../configs/models/deepseek3-671b-2dfsdp.yml | 2 - .../models/deepseek3-671b-batchsplit.yml | 2 - src/maxtext/configs/post_train/rl_mt_jt.yml | 8 +- src/maxtext/layers/moe.py | 244 +++++++----------- 9 files changed, 115 insertions(+), 189 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index a135071d83..10a778cb1e 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -454,8 +454,7 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], @@ -476,6 +475,7 @@ logical_axis_rules: [ ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], @@ -498,18 +498,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 7da8e4b4a9..4bf27f68ff 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -30,8 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'expert']], - ['activation_batch_no_exp_moe', ['data', 'fsdp']], + ['activation_batch_moe', ['data', 'fsdp']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']], ['activation_heads', ['tensor']], @@ -45,6 +44,7 @@ logical_axis_rules: [ ['activation_embed', ['tensor']], ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], + ['activation_mlp_moe', ['tensor']], ['activation_kv', ['tensor']], ['activation_kv_batch', ['data', 'fsdp']], ['activation_kv_head_dim', ['tensor']], @@ -58,9 +58,7 @@ logical_axis_rules: [ ['q_heads', ['tensor']], ['kv_heads', ['tensor']], ['embed', ['fsdp', 'expert']], # remove context from embed sharding - ['embed_moe', ['fsdp', 'expert']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index 5d6939bad6..8b35fadff2 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -19,16 +19,13 @@ data_sharding: [['fsdp']] logical_axis_rules: [ ['activation_batch', ['fsdp']], ['activation_batch_moe', ['fsdp']], - ['activation_batch_no_exp_moe', ['fsdp']], ['activation_embed_and_logits_batch', ['fsdp']], ['activation_embed_and_logits_batch_sequence', ['fsdp']], ['activation_prefill_kv_batch', ['fsdp']], ['activation_kv_batch', ['fsdp']], ['decode_batch', ['fsdp']], ['embed', ['fsdp']], - ['embed_no_exp', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 40bede0b50..91847c19e7 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -12,6 +12,7 @@ logical_axis_rules: [ ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], @@ -33,10 +34,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 3664e5ce1a..2df79b24f3 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -31,7 +31,6 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ ['activation_batch', ['data']], ['activation_batch_moe', []], - ['activation_batch_no_exp_moe', []], ['activation_embed_and_logits_batch', ['data', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], ['activation_heads', ['model', 'expert']], @@ -45,6 +44,7 @@ logical_axis_rules: [ ['activation_embed', ['model', 'attn_dp']], ['activation_embed_moe', ['model', 'attn_dp']], ['activation_mlp', ['model', 'attn_dp']], + ['activation_mlp_moe', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], ['activation_kv_batch', ['data']], @@ -66,11 +66,9 @@ logical_axis_rules: [ ['kv', []], ['embed', ['expert', 'attn_dp_expert']], ['embed', ['attn_dp_expert']], - ['embed_moe', ['expert', 'attn_dp_expert']], - ['embed_moe', ['attn_dp_expert']], + ['embed_moe', []], + ['embed_moe', []], ['embed_tensor_transpose', ['attn_dp', 'model']], - ['embed_no_exp', []], - ['embed_no_exp_moe', []], ['q_lora', ['expert', 'attn_dp_expert']], ['kv_lora', ['expert', 'attn_dp_expert']], ['norm', []], diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index 1face8d5a0..62b2b842d3 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -70,8 +70,6 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['layers', 'stage'], diff --git a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml index cf8fa0fc5f..44486934ec 100644 --- a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml +++ b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml @@ -70,8 +70,6 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['layers', 'stage'], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index c83addc01c..4383b1c4ac 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -49,10 +49,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7cc2227722..5afddd1ac2 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -102,8 +102,8 @@ def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tupl def get_batchsplit_init_kernel_axes(): return ( - ("embed_no_exp", "fsdp_transpose_only", "expert_only"), - ("embed_no_exp", "fsdp_transpose_and_expert", None), + ("embed_moe", "fsdp_transpose_only", "expert_only"), + ("embed_moe", "fsdp_transpose_and_expert", None), ) @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. contract_ind = tuple(range(0, len(norm_axis))) output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.shard_mode == ShardMode.EXPLICIT else None ) @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", None, "mlp") + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None) - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", "mlp", None) + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp_moe", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp_moe") + self.wi_kernel_axes = ("exp", "embed_moe", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_moe") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -437,7 +437,7 @@ def __init__( if self.config.mlp_bias: wi_bias_axes = ("exp", "activation_mlp") - wo_bias_axes = ("exp", "activation_embed_moe") + wo_bias_axes = ("exp", "activation_embed") wi_bias_shape = (self.num_experts, self.intermediate_dim) wo_bias_shape = (self.num_experts, self.config.emb_dim) self.wi_0_bias = nnx.Param( @@ -1020,48 +1020,24 @@ def gmm( output = output[: hs_shape[0]] return output - # Currently, we support data, tensor, and expert parallelism with Megablox. - # We all gather the input activations over tensor parallelism to follow - # https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. - - # Check if the batch should be sharded by expert and whether the batch_size - # supports this. For example, for interleaved inference, prefill always has - # batch_size=1 while decode can have batch_size > 1. - try: - is_batch_sharded_by_expert = ( - self._expert_parallelism_name - in tuple( - filter( - lambda tup: tup[0] == "activation_batch_moe", - self.config.logical_axis_rules, - ) - )[ - 0 - ][1] - ) - except: # pylint: disable=bare-except - is_batch_sharded_by_expert = False - if is_batch_sharded_by_expert and inputs.shape[0] > 1: - batch_logical_axis = "activation_batch_moe" - else: - batch_logical_axis = "activation_batch_no_exp_moe" + batch_logical_axis = "activation_batch" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + (batch_logical_axis, "activation_norm_length", "activation_embed") ) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) - gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -1113,7 +1089,7 @@ def gmm( P(), # Replicate the input key ), out_specs=( - self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")), + self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), @@ -1159,58 +1135,48 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: - batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" + batch_axis = "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes - if is_batch_sharded_by_expert: - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - all_shards_group_sizes, - expert_shard_id, - num_expert_parallelism, - ) - # TODO(ranran): For better performance, we could update output buffer to a smaller - # size to replace self.get_expert_parallelism_size() for efficiency, - # Or we could apply capacity_factor for excessive experts. - # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. - - # In the worst case, all of the global input data is assigned to each expert in the current shard. - # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if - # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) - - x = jax.lax.ragged_all_to_all( - x, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes, - local_expert_size, - shard_index=expert_shard_id, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) - else: - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes[None, :], - local_expert_size, - shard_index=expert_shard_id, - is_offset=True, - global_sorted_experts=selected_experts, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) + all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + all_shards_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) + + # TODO(ranran): For better performance, we could update output buffer to a smaller + # size to replace self.get_expert_parallelism_size() for efficiency, + # Or we could apply capacity_factor for excessive experts. + # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. + + # In the worst case, all of the global input data is assigned to each expert in the current shard. + # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if + # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) + + x = jax.lax.ragged_all_to_all( + x, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) + x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( + x, + global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) @@ -1352,47 +1318,27 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ), dtype=intermediate_output.dtype, ) - if is_batch_sharded_by_expert: - # locally unpermute back to the original order - local_output = _sort_activations( - intermediate_output, - jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable - self.config.use_custom_sort_vjp, - ) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - ) - intermediate_output = jax.lax.ragged_all_to_all( - local_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - else: - # If bach is replicated across EP shards then each shard should send - # 0..local_shard_size data to the other shards and receive the - # local_shard data from all of the other shards using - # ragged_all_to_all. - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - reshaped_group_sizes, # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - is_batch_sharded=False, - ) - intermediate_output = jax.lax.ragged_all_to_all( - intermediate_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) + + # locally unpermute back to the original order + local_output = _sort_activations( + intermediate_output, + jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable + self.config.use_custom_sort_vjp, + ) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + ) + intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) output = self.unpermute( intermediate_output, @@ -1425,13 +1371,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose")) if self.get_tensor_transpose_parallelism_size() > 1: - input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") else: - input_axes = (batch_logical_axis, "activation_norm_length_moe", None) + input_axes = (batch_logical_axis, "activation_norm_length", None) - gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) else: pre_bias_logits_axes = None @@ -1449,14 +1395,12 @@ def reshape_and_update_weights(self, weights, indices): # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( - self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None) - ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_moe", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch", None, None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length", None)), indices, ) weight_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) @@ -1511,7 +1455,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1519,7 +1463,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None, None), + ("activation_batch", "activation_norm_length", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1607,7 +1551,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None), + ("activation_batch", "activation_norm_length", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1752,13 +1696,13 @@ def dense_matmul( mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_mlp", ) @@ -1787,14 +1731,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1815,14 +1759,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1848,7 +1792,7 @@ def dense_matmul( dispatch, ( None, - "activation_batch_no_exp_moe", + "activation_batch_moe", "activation_norm_length_moe", None, "activation_embed_moe", @@ -1911,7 +1855,7 @@ def dense_matmul( intermediate_layer, ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ), From 77add5ced505f6659271ce3c7357e6b504368927 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 8 Apr 2026 19:44:42 +0000 Subject: [PATCH 3/3] refactor rule order and add vocab embed --- src/maxtext/configs/base.yml | 94 +++++++++++-------- .../pipeline-large-moe.yml | 1 + src/maxtext/configs/inference/inference.yml | 2 + src/maxtext/configs/inference/vllm.yml | 1 + src/maxtext/configs/post_train/rl_mt_jt.yml | 1 + src/maxtext/layers/decoders.py | 2 +- src/maxtext/layers/embeddings.py | 2 +- src/maxtext/layers/moe.py | 3 +- src/maxtext/layers/nnx_decoders.py | 2 +- 9 files changed, 62 insertions(+), 46 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 10a778cb1e..5e20b5441e 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -454,55 +454,46 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], + # Vocab activation ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], - ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context']], - ['activation_length', ['context']], - ['activation_attn_length', ['sequence', 'context']], - ['activation_attn_length', ['context']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose']], + ['activation_vocab', 'tensor_sequence'], + ['activation_vocab', ['sequence','context']], + # Vocab weight + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + # MoE activation + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['sequence', 'context']], ['activation_length_moe', ['context']], - ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], + ['activation_embed_moe', ['tensor', 'tensor_transpose']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_exp', ['expert']], + # MoE weight + ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], + ['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated + ['exp_with_fsdp', 'fsdp'], # should be deprecated + # Attn activation + ['activation_attn_length', ['sequence', 'context']], + ['activation_attn_length', ['context']], ['activation_q_length', ['context']], - ['prefill_activation_length', ['sequence', 'context']], - ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_kv_length', []], ['activation_attn_embed', ['tensor', 'tensor_transpose']], - ['activation_embed', ['tensor', 'tensor_transpose']], - ['activation_embed_moe', ['tensor', 'tensor_transpose']], - ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose']], - ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', ['sequence','context']], - ['activation_stage', 'stage'], - ['activation_exp', ['expert']], - ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['decode_length', ['sequence']], - ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + # Attn weight ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], @@ -513,29 +504,50 @@ logical_axis_rules: [ ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], ["kv_lora_up_proj",[]], + # Other activation + ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], + ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], + ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], + ['activation_embed', ['tensor', 'tensor_transpose']], + ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_stage', 'stage'], + # Other weight + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], + # Others (inference etc.) + ['prefill_activation_length', ['sequence', 'context']], + ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], + ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['decode_length', ['sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], + ['paged_kv_heads', ['tensor']], + ['diloco', 'diloco'], + ['engram_dim', ['tensor']], + # Should remove following names as they duplicate shardings ['qkv', []], ['kv', []], ['kv_head_dim', []], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', []], ['exp', 'expert'], - ['exp_with_fsdp', 'fsdp'], - ['paged_kv_heads', ['tensor']], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], ['dense_layers', []], ['moe_layers', []], - ['engram_dim', ['tensor']], ['mhc', []], - ['diloco', 'diloco'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 4bf27f68ff..c83b212412 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -52,6 +52,7 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['activation_exp', ['expert']], ['mlp', ['tensor']], + ['mlp_moe', ['tensor']], ['mlp_no_fsdp', ['tensor']], ['vocab', ['tensor']], ['heads', ['tensor']], diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 91847c19e7..fa972a0343 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -26,7 +26,9 @@ logical_axis_rules: [ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 2df79b24f3..fd818f773e 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -56,6 +56,7 @@ logical_axis_rules: [ ['decode_batch', ['expert', 'attn_dp_expert']], ['decode_length', []], ['mlp', ['model', 'attn_dp']], + ['mlp_moe', ['model', 'attn_dp']], ['mlp_no_fsdp', ['model', 'attn_dp']], ['moe_mlp', ['model', 'attn_dp']], ['vocab', ['model', 'attn_dp']], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index 4383b1c4ac..715f02962d 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -42,6 +42,7 @@ logical_axis_rules: [ ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 7932961182..a2870a0c1f 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -735,7 +735,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi out_features_shape=cfg.vocab_size, weight_dtype=cfg.weight_dtype, dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), + kernel_axes=("embed_vocab", "vocab"), shard_mode=cfg.shard_mode, name="logits_dense", matmul_precision=self.config.matmul_precision, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 466be00e5f..c3f703699e 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -132,7 +132,7 @@ def __init__( (self.num_embeddings, self.num_features), self.config.weight_dtype, ), - sharding=("vocab", "embed"), + sharding=("vocab", "embed_vocab"), ) def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5afddd1ac2..56ca595c7a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1135,13 +1135,12 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: - batch_axis = "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) + all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=self._expert_parallelism_name) input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( all_shards_group_sizes, expert_shard_id, diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index e35933c1bf..6619580499 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -287,7 +287,7 @@ def __init__( out_features_shape=config.vocab_size, weight_dtype=config.weight_dtype, dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, - kernel_axes=("embed", "vocab"), + kernel_axes=("embed_vocab", "vocab"), shard_mode=config.shard_mode, matmul_precision=self.config.matmul_precision, parameter_memory_host_offload=config.parameter_memory_host_offload,