Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
BATCH_NO_EXP = "activation_batch_no_exp"

ATTN_LENGTH = "activation_attn_length"
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"

LENGTH = "activation_length"
LENGTH_NO_EXP = "activation_length_no_exp"
PREFILL_LENGTH = "prefill_activation_length"
Q_LENGTH = "activation_q_length"
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
Q_LORA_UP_PROJ = "q_lora_up_proj"
KV_LENGTH = "activation_kv_length"
KV_LORA_UP_PROJ = "kv_lora_up_proj"
Expand All @@ -50,7 +46,6 @@
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
KV_BATCH = "activation_kv_batch"
KV_BATCH_NO_EXP = "activation_kv_batch_no_exp"
KV_HEAD = "activation_kv_heads"
KV_HEAD_DIM = "activation_kv_head_dim"
D_KV = "activation_kv"
Expand Down
43 changes: 14 additions & 29 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,38 +450,31 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence', 'context', 'expert']],
['activation_length', ['context', 'expert']],
['activation_attn_length', ['sequence', 'context', 'expert']],
['activation_attn_length', ['context', 'expert']],
['activation_attn_length_no_exp', ['sequence', 'context']],
['activation_attn_length_no_exp', ['context']],
['activation_length_no_exp', ['sequence', 'context']],
['activation_length_no_exp', ['context']],
['activation_length_no_exp_moe', ['sequence', 'context']],
['activation_length_no_exp_moe', ['context']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_attn_length', ['sequence', 'context']],
['activation_attn_length', ['context']],
['activation_length_moe', ['sequence', 'context']],
['activation_length_moe', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
['activation_q_length', ['context', 'expert']],
['activation_q_length_no_exp', ['context']],
['activation_q_length', ['context']],
['prefill_activation_length', ['sequence', 'context']],
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_kv_length', []],
['activation_attn_embed', ['tensor', 'tensor_transpose']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']],
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose']],
Expand All @@ -501,18 +494,10 @@ logical_axis_rules: [
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_moe', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp']],
['activation_batch_no_exp_moe', ['data', 'fsdp']],
['activation_batch_moe', ['data', 'fsdp']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
['activation_heads', ['tensor']],
Expand All @@ -46,9 +44,9 @@ logical_axis_rules: [
['activation_embed', ['tensor']],
['activation_embed_moe', ['tensor']],
['activation_mlp', ['tensor']],
['activation_mlp_moe', ['tensor']],
['activation_kv', ['tensor']],
['activation_kv_batch', ['data', 'fsdp', 'expert']],
['activation_kv_batch_no_exp', ['data', 'fsdp']],
['activation_kv_batch', ['data', 'fsdp']],
['activation_kv_head_dim', ['tensor']],
['activation_vocab', ['tensor']],
['activation_stage', 'stage'],
Expand All @@ -60,9 +58,7 @@ logical_axis_rules: [
['q_heads', ['tensor']],
['kv_heads', ['tensor']],
['embed', ['fsdp', 'expert']], # remove context from embed sharding
['embed_moe', ['fsdp', 'expert']],
['embed_no_exp', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['embed_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['norm', ['tensor']],
Expand Down
5 changes: 0 additions & 5 deletions src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,14 @@ mesh_axes: ['fsdp']
data_sharding: [['fsdp']]
logical_axis_rules: [
['activation_batch', ['fsdp']],
['activation_batch_no_exp', ['fsdp']],
['activation_batch_moe', ['fsdp']],
['activation_batch_no_exp_moe', ['fsdp']],
['activation_embed_and_logits_batch', ['fsdp']],
['activation_embed_and_logits_batch_sequence', ['fsdp']],
['activation_prefill_kv_batch', ['fsdp']],
['activation_kv_batch', ['fsdp']],
['activation_kv_batch_no_exp', ['fsdp']],
['decode_batch', ['fsdp']],
['embed', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_moe', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['exp_with_fsdp', 'fsdp'],
Expand Down
10 changes: 5 additions & 5 deletions src/maxtext/configs/inference/inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ base_config: "base.yml"

logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
Expand All @@ -13,6 +12,7 @@ logical_axis_rules: [
['activation_norm_length', ['tensor_sequence', 'sequence']],
['activation_embed', ['tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
Expand All @@ -34,10 +34,10 @@ logical_axis_rules: [
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['layers', 'stage'],
['kv', []],
Expand Down
20 changes: 7 additions & 13 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,23 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['data']],
['activation_batch_moe', []],
['activation_batch_no_exp', []],
['activation_batch_no_exp_moe', []],
['activation_embed_and_logits_batch', ['data', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_attn_length', ['expert']],
['activation_attn_length_no_exp', []],
['activation_length', ['data', 'expert']],
['activation_attn_length', []],
['activation_length', ['data']],
['activation_length_moe', ['data', 'expert']],
['activation_length_no_exp', 'data'],
['activation_length_no_exp_moe', 'data'],
['activation_length_moe', 'data'],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
['activation_embed_moe', ['model', 'attn_dp']],
['activation_mlp', ['model', 'attn_dp']],
['activation_mlp_moe', ['model', 'attn_dp']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch', ['data', 'expert', 'attn_dp_expert']],
['activation_kv_batch_no_exp', ['data']],
['activation_kv_batch', ['data']],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
Expand All @@ -70,11 +66,9 @@ logical_axis_rules: [
['kv', []],
['embed', ['expert', 'attn_dp_expert']],
['embed', ['attn_dp_expert']],
['embed_moe', ['expert', 'attn_dp_expert']],
['embed_moe', ['attn_dp_expert']],
['embed_moe', []],
['embed_moe', []],
['embed_tensor_transpose', ['attn_dp', 'model']],
['embed_no_exp', []],
['embed_no_exp_moe', []],
['q_lora', ['expert', 'attn_dp_expert']],
['kv_lora', ['expert', 'attn_dp_expert']],
['norm', []],
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ logical_axis_rules: [
['activation_stage', 'stage'],
['embed', ['fsdp']],
['embed_moe', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['layers', 'stage'],
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/configs/models/deepseek3-671b-batchsplit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ logical_axis_rules: [
['activation_stage', 'stage'],
['embed', ['fsdp']],
['embed_moe', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['layers', 'stage'],
Expand Down
9 changes: 4 additions & 5 deletions src/maxtext/configs/post_train/rl_mt_jt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ logical_axis_rules: [
['prefill_activation_length', ['data']],
['prefill_activation_norm_length', ['data']],
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
Expand Down Expand Up @@ -50,10 +49,10 @@ logical_axis_rules: [
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['layers', 'stage'],
['kv', []],
Expand Down
Loading
Loading