diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 12fb44a39d..8ab7182779 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -32,16 +32,12 @@ AxisIdxes = tuple[int, ...] BATCH = "activation_batch" -BATCH_NO_EXP = "activation_batch_no_exp" ATTN_LENGTH = "activation_attn_length" -ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp" LENGTH = "activation_length" -LENGTH_NO_EXP = "activation_length_no_exp" PREFILL_LENGTH = "prefill_activation_length" Q_LENGTH = "activation_q_length" -Q_LENGTH_NO_EXP = "activation_q_length_no_exp" Q_LORA_UP_PROJ = "q_lora_up_proj" KV_LENGTH = "activation_kv_length" KV_LORA_UP_PROJ = "kv_lora_up_proj" @@ -50,7 +46,6 @@ HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" KV_BATCH = "activation_kv_batch" -KV_BATCH_NO_EXP = "activation_kv_batch_no_exp" KV_HEAD = "activation_kv_heads" KV_HEAD_DIM = "activation_kv_head_dim" D_KV = "activation_kv" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 06063c5fcf..cdc5fbdb02 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -450,27 +450,20 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], - ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context', 'expert']], - ['activation_length', ['context', 'expert']], - ['activation_attn_length', ['sequence', 'context', 'expert']], - ['activation_attn_length', ['context', 'expert']], - ['activation_attn_length_no_exp', ['sequence', 'context']], - ['activation_attn_length_no_exp', ['context']], - ['activation_length_no_exp', ['sequence', 'context']], - ['activation_length_no_exp', ['context']], - ['activation_length_no_exp_moe', ['sequence', 'context']], - ['activation_length_no_exp_moe', ['context']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], + ['activation_attn_length', ['sequence', 'context']], + ['activation_attn_length', ['context']], + ['activation_length_moe', ['sequence', 'context']], + ['activation_length_moe', ['context']], ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], - ['activation_q_length', ['context', 'expert']], - ['activation_q_length_no_exp', ['context']], + ['activation_q_length', ['context']], ['prefill_activation_length', ['sequence', 'context']], ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_kv_length', []], @@ -478,10 +471,10 @@ logical_axis_rules: [ ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_vocab', ['tensor', 'tensor_transpose']], @@ -501,18 +494,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 8209dece2d..4bf27f68ff 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -30,9 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp']], - ['activation_batch_no_exp_moe', ['data', 'fsdp']], + ['activation_batch_moe', ['data', 'fsdp']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']], ['activation_heads', ['tensor']], @@ -46,9 +44,9 @@ logical_axis_rules: [ ['activation_embed', ['tensor']], ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], + ['activation_mlp_moe', ['tensor']], ['activation_kv', ['tensor']], - ['activation_kv_batch', ['data', 'fsdp', 'expert']], - ['activation_kv_batch_no_exp', ['data', 'fsdp']], + ['activation_kv_batch', ['data', 'fsdp']], ['activation_kv_head_dim', ['tensor']], ['activation_vocab', ['tensor']], ['activation_stage', 'stage'], @@ -60,9 +58,7 @@ logical_axis_rules: [ ['q_heads', ['tensor']], ['kv_heads', ['tensor']], ['embed', ['fsdp', 'expert']], # remove context from embed sharding - ['embed_moe', ['fsdp', 'expert']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index c8a28c5b24..8b35fadff2 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -18,19 +18,14 @@ mesh_axes: ['fsdp'] data_sharding: [['fsdp']] logical_axis_rules: [ ['activation_batch', ['fsdp']], - ['activation_batch_no_exp', ['fsdp']], ['activation_batch_moe', ['fsdp']], - ['activation_batch_no_exp_moe', ['fsdp']], ['activation_embed_and_logits_batch', ['fsdp']], ['activation_embed_and_logits_batch_sequence', ['fsdp']], ['activation_prefill_kv_batch', ['fsdp']], ['activation_kv_batch', ['fsdp']], - ['activation_kv_batch_no_exp', ['fsdp']], ['decode_batch', ['fsdp']], ['embed', ['fsdp']], - ['embed_no_exp', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 3b83094909..91847c19e7 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -2,7 +2,6 @@ base_config: "base.yml" logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], @@ -13,6 +12,7 @@ logical_axis_rules: [ ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], @@ -34,10 +34,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index cc8f40c7e1..2df79b24f3 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -31,27 +31,23 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ ['activation_batch', ['data']], ['activation_batch_moe', []], - ['activation_batch_no_exp', []], - ['activation_batch_no_exp_moe', []], ['activation_embed_and_logits_batch', ['data', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], ['activation_heads', ['model', 'expert']], ['activation_kv_heads', ['model', 'expert']], - ['activation_attn_length', ['expert']], - ['activation_attn_length_no_exp', []], - ['activation_length', ['data', 'expert']], + ['activation_attn_length', []], + ['activation_length', ['data']], ['activation_length_moe', ['data', 'expert']], - ['activation_length_no_exp', 'data'], - ['activation_length_no_exp_moe', 'data'], + ['activation_length_moe', 'data'], ['activation_q_length', ['expert', 'attn_dp_expert']], ['activation_attn_embed', 'model'], ['activation_embed', ['model', 'attn_dp']], ['activation_embed_moe', ['model', 'attn_dp']], ['activation_mlp', ['model', 'attn_dp']], + ['activation_mlp_moe', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], - ['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']], - ['activation_kv_batch_no_exp', ['data']], + ['activation_kv_batch', ['data']], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model', 'attn_dp']], ['activation_norm_length', []], @@ -70,11 +66,9 @@ logical_axis_rules: [ ['kv', []], ['embed', ['expert', 'attn_dp_expert']], ['embed', ['attn_dp_expert']], - ['embed_moe', ['expert', 'attn_dp_expert']], - ['embed_moe', ['attn_dp_expert']], + ['embed_moe', []], + ['embed_moe', []], ['embed_tensor_transpose', ['attn_dp', 'model']], - ['embed_no_exp', []], - ['embed_no_exp_moe', []], ['q_lora', ['expert', 'attn_dp_expert']], ['kv_lora', ['expert', 'attn_dp_expert']], ['norm', []], diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index 1face8d5a0..62b2b842d3 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -70,8 +70,6 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['layers', 'stage'], diff --git a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml index cf8fa0fc5f..44486934ec 100644 --- a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml +++ b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml @@ -70,8 +70,6 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['embed', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['layers', 'stage'], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index 38108dc82b..4383b1c4ac 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -18,7 +18,6 @@ logical_axis_rules: [ ['prefill_activation_length', ['data']], ['prefill_activation_norm_length', ['data']], ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], @@ -50,10 +49,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 2f67279426..13bec20f6c 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -37,7 +37,6 @@ AxisIdxes, AxisNames, BATCH, - BATCH_NO_EXP, CACHE_BATCH, CACHE_BATCH_PREFILL, CACHE_SEQUENCE, @@ -53,12 +52,10 @@ HEAD, Q_LORA_UP_PROJ, KV_BATCH, - KV_BATCH_NO_EXP, KV_HEAD, KV_HEAD_DIM, KV_LORA_UP_PROJ, LENGTH, - LENGTH_NO_EXP, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, PREFILL_KV_BATCH, @@ -425,16 +422,11 @@ def mla_as_linen( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -499,13 +491,8 @@ def mla_as_linen( query_axis_names=query_axis_names, key_axis_names=key_axis_names, value_axis_names=value_axis_names, - ep_query_axis_names=ep_query_axis_names, - ep_key_axis_names=ep_key_axis_names, - ep_value_axis_names=ep_value_axis_names, input_axis_names=input_axis_names, - ep_input_axis_names=ep_input_axis_names, out_axis_names=out_axis_names, - ep_out_axis_names=ep_out_axis_names, prefill_input_axis_names=prefill_input_axis_names, decode_input_axis_names=decode_input_axis_names, prefill_out_axis_names=prefill_out_axis_names, @@ -573,16 +560,11 @@ def __init__( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -664,13 +646,8 @@ def __init__( query_axis_names=query_axis_names, key_axis_names=key_axis_names, value_axis_names=value_axis_names, - ep_query_axis_names=ep_query_axis_names, - ep_key_axis_names=ep_key_axis_names, - ep_value_axis_names=ep_value_axis_names, input_axis_names=input_axis_names, - ep_input_axis_names=ep_input_axis_names, out_axis_names=out_axis_names, - ep_out_axis_names=ep_out_axis_names, prefill_input_axis_names=prefill_input_axis_names, decode_input_axis_names=decode_input_axis_names, prefill_out_axis_names=prefill_out_axis_names, @@ -882,12 +859,9 @@ def mla_query_projection( if model_mode == MODEL_MODE_PREFILL: query_logical_name = self.prefill_query_axis_names wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - query_logical_name = self.ep_query_axis_names - wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ) else: query_logical_name = self.query_axis_names - wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ) + wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ) query_sharding = create_sharding(self.mesh, query_logical_name) wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name) # Set softmax scaling. @@ -1038,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm """MLA key/value projection with integrated rotary embedding.""" if model_mode == MODEL_MODE_PREFILL: wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ) else: - wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) + wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ) wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) low_rank = checkpoint_name(low_rank, "kv_wa_proj") @@ -1178,14 +1150,10 @@ def __call__( inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names) inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names) out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names) - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names) - out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV) else: inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names) inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) - out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) + out_logical_name = (BATCH, LENGTH, HEAD, D_KV) if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None: decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..3bcd07e3e1 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -39,7 +39,6 @@ AxisIdxes, AxisNames, BATCH, - BATCH_NO_EXP, CACHE_BATCH, CACHE_BATCH_PREFILL, CACHE_HEADS, @@ -56,18 +55,15 @@ DEFAULT_MASK_VALUE, DType, D_KV, - EP_AS_CONTEXT, EP_AS_FSDP, HEAD, KV_LENGTH, LENGTH, - LENGTH_NO_EXP, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, PREFILL_LENGTH, Q_LENGTH, - Q_LENGTH_NO_EXP, ) from maxtext.inference import page_manager from maxtext.inference.kvcache import KVQuant, KVTensor @@ -302,12 +298,9 @@ def attention_op_as_linen( float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), - flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), prefill_cache_logical_axis_names: AxisNames = ( CACHE_BATCH_PREFILL, CACHE_SEQUENCE, @@ -364,11 +357,8 @@ def attention_op_as_linen( max_prefill_predict_length=max_prefill_predict_length, float32_logits=float32_logits, flash_axis_names_q=flash_axis_names_q, - flash_axis_names_q_ep=flash_axis_names_q_ep, flash_axis_names_kv=flash_axis_names_kv, - flash_axis_names_kv_ep=flash_axis_names_kv_ep, flash_axis_names_splash_kernel=flash_axis_names_splash_kernel, - flash_axis_names_splash_kernel_ep=flash_axis_names_splash_kernel_ep, prefill_cache_logical_axis_names=prefill_cache_logical_axis_names, cache_logical_axis_names=cache_logical_axis_names, cache_scale_logical_axis_names=cache_scale_logical_axis_names, @@ -405,12 +395,9 @@ def __init__( float32_qk_product: bool = False, max_prefill_predict_length: int = -1, float32_logits: bool = False, - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), - flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), prefill_cache_logical_axis_names: AxisNames = ( CACHE_BATCH_PREFILL, CACHE_SEQUENCE, @@ -492,11 +479,8 @@ def __init__( self.max_prefill_predict_length = max_prefill_predict_length self.float32_logits = float32_logits self.flash_axis_names_q = flash_axis_names_q - self.flash_axis_names_q_ep = flash_axis_names_q_ep self.flash_axis_names_kv = flash_axis_names_kv - self.flash_axis_names_kv_ep = flash_axis_names_kv_ep self.flash_axis_names_splash_kernel = flash_axis_names_splash_kernel - self.flash_axis_names_splash_kernel_ep = flash_axis_names_splash_kernel_ep self.prefill_cache_logical_axis_names = prefill_cache_logical_axis_names self.cache_logical_axis_names = cache_logical_axis_names self.cache_scale_logical_axis_names = cache_scale_logical_axis_names @@ -1150,23 +1134,13 @@ def tpu_flash_attention( segment_axis_names_kv = None sink_axis_names = self._logical_to_mesh_axes((HEAD,)) if decoder_segment_ids is not None: - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH)) - segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH)) - else: - segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) - segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) - - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) - axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) - axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) - indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) - else: - axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) - axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) - axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) - indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) + segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH)) + segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) + + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) + axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) + axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) + indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel @@ -1295,14 +1269,11 @@ def wrap_splash_kernel(single_head_mask): return splash_kernel splash_kernel = wrap_splash_kernel(single_head_mask) - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) - else: - segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP: if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) - segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) else: # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index e699b4ac4e..76fa6291f0 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -28,14 +28,12 @@ from maxtext.common.common_types import ( DecoderBlockType, BATCH, - BATCH_NO_EXP, HEAD, PREFILL_LENGTH, D_KV, AxisNames, AxisIdxes, ATTN_LENGTH, - ATTN_LENGTH_NO_EXP, DType, Config, Array, @@ -45,12 +43,10 @@ KV_HEAD, KV_HEAD_DIM, KV_BATCH, - KV_BATCH_NO_EXP, ATTN_EMBED, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, - EP_AS_CONTEXT, AttentionType, ) from maxtext.layers import nnx_wrappers @@ -142,16 +138,11 @@ def attention_as_linen( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -211,13 +202,8 @@ def attention_as_linen( query_axis_names=query_axis_names, key_axis_names=key_axis_names, value_axis_names=value_axis_names, - ep_query_axis_names=ep_query_axis_names, - ep_key_axis_names=ep_key_axis_names, - ep_value_axis_names=ep_value_axis_names, input_axis_names=input_axis_names, - ep_input_axis_names=ep_input_axis_names, out_axis_names=out_axis_names, - ep_out_axis_names=ep_out_axis_names, prefill_input_axis_names=prefill_input_axis_names, decode_input_axis_names=decode_input_axis_names, prefill_out_axis_names=prefill_out_axis_names, @@ -309,16 +295,11 @@ def __init__( prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, ATTN_EMBED), - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, ATTN_LENGTH, HEAD, D_KV), + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), @@ -420,13 +401,8 @@ def __init__( self.query_axis_names = query_axis_names self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names - self.ep_query_axis_names = ep_query_axis_names - self.ep_key_axis_names = ep_key_axis_names - self.ep_value_axis_names = ep_value_axis_names self.input_axis_names = input_axis_names - self.ep_input_axis_names = ep_input_axis_names self.out_axis_names = out_axis_names - self.ep_out_axis_names = ep_out_axis_names self.prefill_input_axis_names = prefill_input_axis_names self.decode_input_axis_names = decode_input_axis_names self.prefill_out_axis_names = prefill_out_axis_names @@ -1100,8 +1076,6 @@ def __call__( """ if model_mode == MODEL_MODE_PREFILL: input_axis_names = self.prefill_input_axis_names - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - input_axis_names = self.ep_input_axis_names elif model_mode == MODEL_MODE_TRAIN: input_axis_names = self.input_axis_names else: @@ -1172,10 +1146,6 @@ def __call__( query = self._maybe_shard_with_logical(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) key = self._maybe_shard_with_logical(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) value = self._maybe_shard_with_logical(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - query = self._maybe_shard_with_logical(query, self.ep_query_axis_names) - key = self._maybe_shard_with_logical(key, self.ep_key_axis_names) - value = self._maybe_shard_with_logical(value, self.ep_value_axis_names) else: query = self._maybe_shard_with_logical(query, self.query_axis_names) key = self._maybe_shard_with_logical(key, self.key_axis_names) @@ -1219,8 +1189,6 @@ def __call__( out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") if model_mode == MODEL_MODE_PREFILL: out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names) - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) elif model_mode == MODEL_MODE_TRAIN: out = self._maybe_shard_with_logical(out, self.out_axis_names) else: diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 6922bd0016..7932961182 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -27,7 +27,7 @@ from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh -from maxtext.common.common_types import Config, DecoderBlockType, EP_AS_CONTEXT, ShardMode +from maxtext.common.common_types import Config, DecoderBlockType, ShardMode from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN from maxtext.inference import page_manager from maxtext.layers import linears @@ -106,10 +106,8 @@ def __call__( if self.model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + logical_axis_names = ("activation_batch", "activation_length", "activation_embed") if model_mode == MODEL_MODE_PREFILL: inputs = _maybe_shard_with_logical(inputs, logical_axis_names) @@ -692,7 +690,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed")) else: norm_out_sharding = None @@ -710,7 +708,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") ) # [batch, length, emb_dim] -> [batch, length, vocab_size] diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 652718e0eb..466be00e5f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -165,7 +165,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: if model_mode == MODEL_MODE_PREFILL else ( "activation_embed_and_logits_batch", - "activation_length_no_exp", + "activation_length", "activation_embed", ) ) @@ -850,7 +850,7 @@ def __init__( self.attention_scaling = attention_scaling self.freqs_sharding = ( - create_sharding(mesh, ("activation_batch", "activation_length_no_exp", "q_heads")) + create_sharding(mesh, ("activation_batch", "activation_length", "q_heads")) if shard_mode == ShardMode.EXPLICIT else None ) @@ -976,7 +976,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim] # Apply the rotary transformation via complex multiplication. rotated_sharding = ( - create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", None, None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None, None)) if self.shard_mode == ShardMode.EXPLICIT else None ) diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 4af9c5c530..8d9f094c98 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -30,7 +30,7 @@ import flax.linen as nn from maxtext.common.common_types import DecoderBlockType, ShardMode, DType, Array, Config -from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, EP_AS_CONTEXT +from maxtext.common.common_types import MODEL_MODE_PREFILL from maxtext.layers import nnx_wrappers, quantizations from maxtext.layers import normalizations from maxtext.layers.initializers import NdInitializer, nd_dense_init, default_bias_init, variable_to_logically_partitioned @@ -404,10 +404,8 @@ def __init__( if self.model_mode == MODEL_MODE_PREFILL: self.intermediate_logical = ("activation_batch", "prefill_activation_length", "activation_mlp") - elif config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - self.intermediate_logical = ("activation_batch_no_exp", "activation_length", "activation_mlp") else: - self.intermediate_logical = ("activation_batch", "activation_length_no_exp", "activation_mlp") + self.intermediate_logical = ("activation_batch", "activation_length", "activation_mlp") if config.fused_mlp: self.wi = DenseGeneral( diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..5afddd1ac2 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -102,8 +102,8 @@ def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tupl def get_batchsplit_init_kernel_axes(): return ( - ("embed_no_exp", "fsdp_transpose_only", "expert_only"), - ("embed_no_exp", "fsdp_transpose_and_expert", None), + ("embed_moe", "fsdp_transpose_only", "expert_only"), + ("embed_moe", "fsdp_transpose_and_expert", None), ) @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. contract_ind = tuple(range(0, len(norm_axis))) output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.shard_mode == ShardMode.EXPLICIT else None ) @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", None, "mlp") + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None) - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", "mlp", None) + self.wo_kernel_axes = ("embed_moe", "mlp", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp_moe", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp_moe") + self.wi_kernel_axes = ("exp", "embed_moe", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_moe") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -437,7 +437,7 @@ def __init__( if self.config.mlp_bias: wi_bias_axes = ("exp", "activation_mlp") - wo_bias_axes = ("exp", "activation_embed_moe") + wo_bias_axes = ("exp", "activation_embed") wi_bias_shape = (self.num_experts, self.intermediate_dim) wo_bias_shape = (self.num_experts, self.config.emb_dim) self.wi_0_bias = nnx.Param( @@ -1020,48 +1020,24 @@ def gmm( output = output[: hs_shape[0]] return output - # Currently, we support data, tensor, and expert parallelism with Megablox. - # We all gather the input activations over tensor parallelism to follow - # https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. - - # Check if the batch should be sharded by expert and whether the batch_size - # supports this. For example, for interleaved inference, prefill always has - # batch_size=1 while decode can have batch_size > 1. - try: - is_batch_sharded_by_expert = ( - self._expert_parallelism_name - in tuple( - filter( - lambda tup: tup[0] == "activation_batch_moe", - self.config.logical_axis_rules, - ) - )[ - 0 - ][1] - ) - except: # pylint: disable=bare-except - is_batch_sharded_by_expert = False - if is_batch_sharded_by_expert and inputs.shape[0] > 1: - batch_logical_axis = "activation_batch_moe" - else: - batch_logical_axis = "activation_batch_no_exp_moe" + batch_logical_axis = "activation_batch" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + (batch_logical_axis, "activation_norm_length", "activation_embed") ) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) - gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -1113,7 +1089,7 @@ def gmm( P(), # Replicate the input key ), out_specs=( - self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")), + self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), @@ -1159,58 +1135,48 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: - batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" + batch_axis = "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes - if is_batch_sharded_by_expert: - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - all_shards_group_sizes, - expert_shard_id, - num_expert_parallelism, - ) - # TODO(ranran): For better performance, we could update output buffer to a smaller - # size to replace self.get_expert_parallelism_size() for efficiency, - # Or we could apply capacity_factor for excessive experts. - # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. - - # In the worst case, all of the global input data is assigned to each expert in the current shard. - # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if - # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) - - x = jax.lax.ragged_all_to_all( - x, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes, - local_expert_size, - shard_index=expert_shard_id, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) - else: - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes[None, :], - local_expert_size, - shard_index=expert_shard_id, - is_offset=True, - global_sorted_experts=selected_experts, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) + all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + all_shards_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) + + # TODO(ranran): For better performance, we could update output buffer to a smaller + # size to replace self.get_expert_parallelism_size() for efficiency, + # Or we could apply capacity_factor for excessive experts. + # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. + + # In the worst case, all of the global input data is assigned to each expert in the current shard. + # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if + # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) + + x = jax.lax.ragged_all_to_all( + x, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) + x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( + x, + global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) @@ -1352,47 +1318,27 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ), dtype=intermediate_output.dtype, ) - if is_batch_sharded_by_expert: - # locally unpermute back to the original order - local_output = _sort_activations( - intermediate_output, - jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable - self.config.use_custom_sort_vjp, - ) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - ) - intermediate_output = jax.lax.ragged_all_to_all( - local_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - else: - # If bach is replicated across EP shards then each shard should send - # 0..local_shard_size data to the other shards and receive the - # local_shard data from all of the other shards using - # ragged_all_to_all. - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - reshaped_group_sizes, # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - is_batch_sharded=False, - ) - intermediate_output = jax.lax.ragged_all_to_all( - intermediate_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) + + # locally unpermute back to the original order + local_output = _sort_activations( + intermediate_output, + jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable + self.config.use_custom_sort_vjp, + ) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + ) + intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) output = self.unpermute( intermediate_output, @@ -1425,13 +1371,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose")) if self.get_tensor_transpose_parallelism_size() > 1: - input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") else: - input_axes = (batch_logical_axis, "activation_norm_length_moe", None) + input_axes = (batch_logical_axis, "activation_norm_length", None) - gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) else: pre_bias_logits_axes = None @@ -1449,14 +1395,12 @@ def reshape_and_update_weights(self, weights, indices): # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( - self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None) - ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch", None, None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length", None)), indices, ) weight_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) @@ -1511,7 +1455,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1519,7 +1463,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None, None), + ("activation_batch", "activation_norm_length", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1607,7 +1551,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None), + ("activation_batch", "activation_norm_length", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1705,13 +1649,11 @@ def dense_matmul( ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" # gate_logits: batch, length, expert - gate_logits = self._maybe_shard_with_logical( - gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) - ) + gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_moe", None)) if self.config.model_name.startswith("deepseek3"): # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits = self._maybe_shard_with_logical( - pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) + pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) ) top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs) is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4 @@ -1754,13 +1696,13 @@ def dense_matmul( mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_mlp", ) @@ -1789,14 +1731,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1817,14 +1759,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1850,7 +1792,7 @@ def dense_matmul( dispatch, ( None, - "activation_batch_no_exp_moe", + "activation_batch_moe", "activation_norm_length_moe", None, "activation_embed_moe", @@ -1913,7 +1855,7 @@ def dense_matmul( intermediate_layer, ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ), diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..e35933c1bf 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -30,7 +30,6 @@ from jax.sharding import Mesh from maxtext.common.common_types import ( - EP_AS_CONTEXT, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, @@ -169,10 +168,8 @@ def __call__( if self.model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + logical_axis_names = ("activation_batch", "activation_length", "activation_embed") inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -739,7 +736,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed")) else: norm_out_sharding = None @@ -750,7 +747,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") ) # [batch, length, emb_dim] -> [batch, length, vocab_size] diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 1b130f1888..62ea52782b 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -28,7 +28,7 @@ from flax import linen as nn from flax.linen.spmd import LogicallyPartitioned -from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, ShardMode from maxtext.utils.sharding import ( maybe_shard_with_logical, maybe_shard_with_name, @@ -56,12 +56,8 @@ def setup(self): self.microbatches_per_stage = microbatches_per_stage self.use_circ_storage = self.need_circ_storage() - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - self.batch_axis_name = "activation_batch_no_exp" - self.seq_len_axis_name = "activation_length" - else: - self.batch_axis_name = "activation_batch" - self.seq_len_axis_name = "activation_length_no_exp" + self.batch_axis_name = "activation_batch" + self.seq_len_axis_name = "activation_length" self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None diff --git a/src/maxtext/layers/pipeline_deprecated.py b/src/maxtext/layers/pipeline_deprecated.py index 79f71c78ff..0a17711b4e 100644 --- a/src/maxtext/layers/pipeline_deprecated.py +++ b/src/maxtext/layers/pipeline_deprecated.py @@ -28,7 +28,7 @@ from flax import linen as nn from flax.linen.spmd import LogicallyPartitioned -from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, ShardMode from maxtext.utils.sharding import ( maybe_shard_with_logical, maybe_shard_with_name, @@ -67,12 +67,8 @@ def setup(self): # pylint: disable=missing-function-docstring self.microbatches_per_stage = microbatches_per_stage self.use_circ_storage = self.need_circ_storage() - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - self.batch_axis_name = "activation_batch_no_exp" - self.seq_len_axis_name = "activation_length" - else: - self.batch_axis_name = "activation_batch" - self.seq_len_axis_name = "activation_length_no_exp" + self.batch_axis_name = "activation_batch" + self.seq_len_axis_name = "activation_length" # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2. self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 252dadc768..9ae60abe25 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -184,9 +184,7 @@ def __call__( hidden_states = self._maybe_shard_with_logical(hidden_states, self.activation_axis_names) # MLP block. - mlp_intermediate_sharding = create_sharding( - self.mesh, ("activation_batch", "activation_length_no_exp", "activation_mlp") - ) + mlp_intermediate_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_mlp")) mlp_lnx = self.mlp( hidden_states, deterministic=deterministic, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index eb15747fc2..bc7d5fdfc1 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -29,7 +29,7 @@ from flax import linen as nn from flax import nnx -from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN +from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, EMBED, MODEL_MODE_TRAIN, LENGTH from maxtext.layers import attentions from maxtext.layers import initializers as max_initializers from maxtext.layers import moe @@ -723,7 +723,7 @@ def __init__( attention_kernel=cfg.attention, inputs_q_shape=dummy_inputs_shape, inputs_kv_shape=dummy_inputs_shape, - out_axis_names=(BATCH, LENGTH_NO_EXP, EMBED), + out_axis_names=(BATCH, LENGTH, EMBED), mesh=self.mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index ec68e9bc78..06184cff87 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -61,11 +61,11 @@ def vocab_tiling_linen_loss( param_spec = nn.get_partition_spec(params) hidden_spec = create_sharding( model.mesh, - ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_embed"), + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), ) label_spec = create_sharding( model.mesh, - ("activation_embed_and_logits_batch", "activation_length_no_exp"), + ("activation_embed_and_logits_batch", "activation_length"), ) reshaped_hidden_spec = create_sharding( model.mesh, diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 8287fb2602..4f4a5fed20 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -636,7 +636,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 4, "context_parallel_load_balance": False, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "auto", }, { @@ -644,7 +643,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 4, "context_parallel_load_balance": True, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "auto", }, { @@ -652,7 +650,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 2, "context_parallel_load_balance": False, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", "shard_mode": "auto", }, { @@ -660,23 +657,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 2, "context_parallel_load_balance": True, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", "shard_mode": "auto", }, { @@ -684,7 +664,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 4, "context_parallel_load_balance": False, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "explicit", }, { @@ -692,7 +671,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 4, "context_parallel_load_balance": True, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "explicit", }, { @@ -700,7 +678,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 2, "context_parallel_load_balance": False, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", "shard_mode": "explicit", }, { @@ -708,23 +685,6 @@ def test_share_kv_projections(self): "ici_context_parallelism": 2, "context_parallel_load_balance": True, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", "shard_mode": "explicit", }, ) @@ -735,7 +695,6 @@ def test_tpu_flash_attention_context_parallel( ici_context_parallelism, context_parallel_load_balance, ici_expert_parallelism, - expert_shard_attention_option, shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" @@ -759,7 +718,6 @@ def test_tpu_flash_attention_context_parallel( ici_context_parallelism=ici_context_parallelism, context_parallel_load_balance=context_parallel_load_balance, ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, shard_mode=shard_mode, ) devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) @@ -801,7 +759,7 @@ def test_tpu_flash_attention_context_parallel( jax.numpy.allclose(mha_generic_output, mha_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + f" ici_expert_parallelism={ici_expert_parallelism}.", ) @pytest.mark.tpu_only @@ -1460,7 +1418,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 4, "context_parallel_load_balance": False, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "auto", }, { @@ -1468,7 +1425,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 4, "context_parallel_load_balance": True, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "auto", }, { @@ -1476,7 +1432,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 2, "context_parallel_load_balance": False, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", "shard_mode": "auto", }, { @@ -1484,23 +1439,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 2, "context_parallel_load_balance": True, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", "shard_mode": "auto", }, { @@ -1508,7 +1446,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 4, "context_parallel_load_balance": False, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "explicit", }, { @@ -1516,7 +1453,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 4, "context_parallel_load_balance": True, "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", "shard_mode": "explicit", }, { @@ -1524,7 +1460,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 2, "context_parallel_load_balance": False, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", "shard_mode": "explicit", }, { @@ -1532,23 +1467,6 @@ def test_projection_initialization(self): "ici_context_parallelism": 2, "context_parallel_load_balance": True, "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", "shard_mode": "explicit", }, ) @@ -1559,7 +1477,6 @@ def test_tpu_flash_attention_context_parallel( ici_context_parallelism, context_parallel_load_balance, ici_expert_parallelism, - expert_shard_attention_option, shard_mode, ): """Test equivalence between dot_product and flash attention + context/expert parallelism""" @@ -1607,7 +1524,6 @@ def test_tpu_flash_attention_context_parallel( ici_context_parallelism=ici_context_parallelism, context_parallel_load_balance=context_parallel_load_balance, ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, ) devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto @@ -1653,7 +1569,7 @@ def test_tpu_flash_attention_context_parallel( jax.numpy.allclose(mla_generic_output, mla_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + f" ici_expert_parallelism={ici_expert_parallelism}.", ) def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): diff --git a/tests/unit/custom_mesh_and_rule_test.py b/tests/unit/custom_mesh_and_rule_test.py index f00cb0ee28..eb0f585c24 100644 --- a/tests/unit/custom_mesh_and_rule_test.py +++ b/tests/unit/custom_mesh_and_rule_test.py @@ -63,6 +63,7 @@ def test_ds3_large_pp(self): "base_emb_dim=256", "base_mlp_dim=256", "base_num_decoder_layers=4", + "use_tokamax_splash=true", "custom_mesh_and_rule=pipeline-large-moe", ) ) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index c3e83025f9..416c08a2f7 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -341,15 +341,13 @@ def __call__(self, inputs, deterministic: bool = False): weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype) mlp_lnx = jnp.zeros_like(inputs) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length_no_exp", "activation_embed")) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) for k in range(self.num_experts): weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) getattr(self, f"mlp_{k}") mlp_lnx_exp = getattr(self, f"mlp_{k}")(inputs, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ("activation_batch", "activation_length_no_exp", "activation_embed") - ) + mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp mlp_lnx += mlp_lnx_exp diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index 1284acbb17..4ce93e51be 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -23,7 +23,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from maxtext.configs import pyconfig from maxtext.common.gcloud_stub import is_decoupled -from maxtext.common.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, EP_AS_CONTEXT, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode from maxtext.layers.attention_mla import MLA from maxtext.utils import max_utils from maxtext.utils import maxtext_utils @@ -204,12 +204,8 @@ def forward_with_context_expert_parallelism( decoder_positions = reordered_batch["inputs_position"] # apply attention with sharding with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): - if cfg_cp.expert_shard_attention_option == EP_AS_CONTEXT: - batch_axis = "activation_batch_no_exp" - length_axis = "activation_length" - else: - batch_axis = "activation_batch" - length_axis = "activation_length_no_exp" + batch_axis = "activation_batch" + length_axis = "activation_length" lnx_spec = nn_partitioning.logical_to_mesh_axes( (batch_axis, length_axis, "activation_embed"), nn_partitioning.get_axis_rules(), diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json index cbee49b201..0ced56c01f 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/inputs_kv: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/q_nope: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/q_pe: bfloat16[192,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/query: bfloat16[192,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_nope: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_rope: bfloat16[192,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key: bfloat16[192,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/value: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[192,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[192,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json index a12030dbd9..950b64159c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/inputs_kv: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/q_nope: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/q_pe: bfloat16[768,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/query: bfloat16[768,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_nope: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_rope: bfloat16[768,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key: bfloat16[768,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/value: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[768,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[768,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json index 4172fc960f..f91f7f18a5 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[96,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/inputs_kv: bfloat16[96,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/q_nope: bfloat16[96,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/q_pe: bfloat16[96,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/query: bfloat16[96,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_nope: bfloat16[96,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_rope: bfloat16[96,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key: bfloat16[96,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/value: bfloat16[96,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[96,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[96,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[96,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json index 2789aa367e..4f601415fb 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[384,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/inputs_kv: bfloat16[384,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/q_nope: bfloat16[384,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/q_pe: bfloat16[384,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/query: bfloat16[384,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_nope: bfloat16[384,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_rope: bfloat16[384,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key: bfloat16[384,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/value: bfloat16[384,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[384,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[384,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[384,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json index cbee49b201..0ced56c01f 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/inputs_kv: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P('fsdp', None, None)" } }, { "attention_mla/q_nope: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/q_pe: bfloat16[192,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/query: bfloat16[192,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_nope: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key_rope: bfloat16[192,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/key: bfloat16[192,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, { "attention_mla/value: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[192,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P('fsdp', None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[192,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[192,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json index a12030dbd9..950b64159c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json @@ -14,55 +14,55 @@ }, { "attention_mla/inputs_q: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/inputs_kv: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_embed')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_embed')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "attention_mla/q_nope: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/q_pe: bfloat16[768,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/query: bfloat16[768,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_nope: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key_rope: bfloat16[768,2048,16,64]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/key: bfloat16[768,2048,16,192]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, { "attention_mla/value: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')", + "logic_axes": "('activation_kv_batch', 'activation_length', 'activation_kv_heads', 'activation_kv_head_dim')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -86,7 +86,7 @@ }, { "attention_mla/out: bfloat16[768,2048,16,128]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_heads', 'activation_kv')", "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, @@ -104,7 +104,7 @@ }, { "linears/x: bfloat16[768,2048,10944]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, @@ -134,7 +134,7 @@ }, { "linears/x: bfloat16[768,2048,2816]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json index 0d5b2d8c24..3c34c7f827 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[192,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json index 2146f74797..ffb2708907 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[768,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json index 4a5224cd6d..3ca8176bfc 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[96,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json index 6bb047297d..d9c70dde3f 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[384,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json index 0d5b2d8c24..3c34c7f827 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[192,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json index 2146f74797..ffb2708907 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json @@ -56,7 +56,7 @@ }, { "linears/x: bfloat16[768,2048,3072]": { - "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_mlp')", + "logic_axes": "('activation_batch', 'activation_length', 'activation_mlp')", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }