Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
BATCH_NO_EXP = "activation_batch_no_exp"

ATTN_LENGTH = "activation_attn_length"
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"

LENGTH = "activation_length"
LENGTH_NO_EXP = "activation_length_no_exp"
PREFILL_LENGTH = "prefill_activation_length"
Q_LENGTH = "activation_q_length"
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
Q_LORA_UP_PROJ = "q_lora_up_proj"
KV_LENGTH = "activation_kv_length"
KV_LORA_UP_PROJ = "kv_lora_up_proj"
Expand All @@ -50,7 +46,6 @@
HEAD = "activation_heads"
PREFILL_KV_BATCH = "activation_prefill_kv_batch"
KV_BATCH = "activation_kv_batch"
KV_BATCH_NO_EXP = "activation_kv_batch_no_exp"
KV_HEAD = "activation_kv_heads"
KV_HEAD_DIM = "activation_kv_head_dim"
D_KV = "activation_kv"
Expand Down
116 changes: 56 additions & 60 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,70 +450,45 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
# Vocab activation
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence', 'context', 'expert']],
['activation_length', ['context', 'expert']],
['activation_attn_length', ['sequence', 'context', 'expert']],
['activation_attn_length', ['context', 'expert']],
['activation_attn_length_no_exp', ['sequence', 'context']],
['activation_attn_length_no_exp', ['context']],
['activation_length_no_exp', ['sequence', 'context']],
['activation_length_no_exp', ['context']],
['activation_length_no_exp_moe', ['sequence', 'context']],
['activation_length_no_exp_moe', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
['activation_q_length', ['context', 'expert']],
['activation_q_length_no_exp', ['context']],
['prefill_activation_length', ['sequence', 'context']],
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_kv_length', []],
['activation_attn_embed', ['tensor', 'tensor_transpose']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose']],
['activation_vocab', 'tensor_sequence'],
['activation_vocab', ['sequence','context']],
['activation_stage', 'stage'],
['activation_exp', ['expert']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_length', ['sequence']],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
# Vocab weight
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
# MoE activation
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_length_moe', ['sequence', 'context']],
['activation_length_moe', ['context']],
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_exp', ['expert']],
# MoE weight
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], # should be deprecated
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_moe', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']], # should be deprecated
['exp_with_fsdp', 'fsdp'], # should be deprecated
# Attn activation
['activation_attn_length', ['sequence', 'context']],
['activation_attn_length', ['context']],
['activation_q_length', ['context']],
['activation_kv_length', []],
['activation_attn_embed', ['tensor', 'tensor_transpose']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']],
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
# Attn weight
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Expand All @@ -524,29 +499,50 @@ logical_axis_rules: [
['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['kv_lora', ['fsdp', 'sequence', 'context', 'expert']],
["kv_lora_up_proj",[]],
# Other activation
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_stage', 'stage'],
# Other weight
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['embed', ['fsdp', 'sequence', 'context', 'expert']],
['norm', ['tensor', 'tensor_transpose']],
['layers', 'stage'],
# Others (inference etc.)
['prefill_activation_length', ['sequence', 'context']],
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['decode_length', ['sequence']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
['paged_kv_heads', ['tensor']],
['diloco', 'diloco'],
['engram_dim', ['tensor']],
# Should remove following names as they duplicate shardings
['qkv', []],
['kv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
['cache_kv', []],
['cache_sequence', []],
['exp', 'expert'],
['exp_with_fsdp', 'fsdp'],
['paged_kv_heads', ['tensor']],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['engram_dim', ['tensor']],
['mhc', []],
['diloco', 'diloco'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp']],
['activation_batch_no_exp_moe', ['data', 'fsdp']],
['activation_batch_moe', ['data', 'fsdp']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
['activation_heads', ['tensor']],
Expand All @@ -46,9 +44,9 @@ logical_axis_rules: [
['activation_embed', ['tensor']],
['activation_embed_moe', ['tensor']],
['activation_mlp', ['tensor']],
['activation_mlp_moe', ['tensor']],
['activation_kv', ['tensor']],
['activation_kv_batch', ['data', 'fsdp', 'expert']],
['activation_kv_batch_no_exp', ['data', 'fsdp']],
['activation_kv_batch', ['data', 'fsdp']],
['activation_kv_head_dim', ['tensor']],
['activation_vocab', ['tensor']],
['activation_stage', 'stage'],
Expand All @@ -60,9 +58,7 @@ logical_axis_rules: [
['q_heads', ['tensor']],
['kv_heads', ['tensor']],
['embed', ['fsdp', 'expert']], # remove context from embed sharding
['embed_moe', ['fsdp', 'expert']],
['embed_no_exp', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['embed_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['norm', ['tensor']],
Expand Down
5 changes: 0 additions & 5 deletions src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,14 @@ mesh_axes: ['fsdp']
data_sharding: [['fsdp']]
logical_axis_rules: [
['activation_batch', ['fsdp']],
['activation_batch_no_exp', ['fsdp']],
['activation_batch_moe', ['fsdp']],
['activation_batch_no_exp_moe', ['fsdp']],
['activation_embed_and_logits_batch', ['fsdp']],
['activation_embed_and_logits_batch_sequence', ['fsdp']],
['activation_prefill_kv_batch', ['fsdp']],
['activation_kv_batch', ['fsdp']],
['activation_kv_batch_no_exp', ['fsdp']],
['decode_batch', ['fsdp']],
['embed', ['fsdp']],
['embed_no_exp', ['fsdp']],
['embed_moe', ['fsdp']],
['embed_no_exp_moe', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['exp_with_fsdp', 'fsdp'],
Expand Down
11 changes: 6 additions & 5 deletions src/maxtext/configs/inference/inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ base_config: "base.yml"

logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
Expand All @@ -13,6 +12,7 @@ logical_axis_rules: [
['activation_norm_length', ['tensor_sequence', 'sequence']],
['activation_embed', ['tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
Expand All @@ -27,17 +27,18 @@ logical_axis_rules: [
['decode_length', []],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['layers', 'stage'],
['kv', []],
Expand Down
Loading
Loading