Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
BATCH_NO_EXP = "activation_batch_no_exp"

ATTN_LENGTH = "activation_attn_length"
ATTN_LENGTH_NO_EXP = "activation_attn_length_no_exp"

LENGTH = "activation_length"
LENGTH_NO_EXP = "activation_length_no_exp"
PREFILL_LENGTH = "prefill_activation_length"
Q_LENGTH = "activation_q_length"
Q_LENGTH_NO_EXP = "activation_q_length_no_exp"
Expand Down
17 changes: 6 additions & 11 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,22 +450,17 @@ mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'co
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence', 'context', 'expert']],
['activation_length', ['context', 'expert']],
['activation_attn_length', ['sequence', 'context', 'expert']],
['activation_attn_length', ['context', 'expert']],
['activation_attn_length_no_exp', ['sequence', 'context']],
['activation_attn_length_no_exp', ['context']],
['activation_length_no_exp', ['sequence', 'context']],
['activation_length_no_exp', ['context']],
['activation_length_no_exp_moe', ['sequence', 'context']],
['activation_length_no_exp_moe', ['context']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
['activation_attn_length', ['sequence', 'context']],
['activation_attn_length', ['context']],
['activation_length_moe', ['sequence', 'context']],
['activation_length_moe', ['context']],
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
['activation_q_length', ['context', 'expert']],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'expert']],
['activation_batch_moe', ['data', 'fsdp', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp']],
['activation_batch_no_exp_moe', ['data', 'fsdp']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
Expand Down
1 change: 0 additions & 1 deletion src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ mesh_axes: ['fsdp']
data_sharding: [['fsdp']]
logical_axis_rules: [
['activation_batch', ['fsdp']],
['activation_batch_no_exp', ['fsdp']],
['activation_batch_moe', ['fsdp']],
['activation_batch_no_exp_moe', ['fsdp']],
['activation_embed_and_logits_batch', ['fsdp']],
Expand Down
1 change: 0 additions & 1 deletion src/maxtext/configs/inference/inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ base_config: "base.yml"

logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
Expand Down
6 changes: 2 additions & 4 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,16 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['data']],
['activation_batch_moe', []],
['activation_batch_no_exp', []],
['activation_batch_no_exp_moe', []],
['activation_embed_and_logits_batch', ['data', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
['activation_heads', ['model', 'expert']],
['activation_kv_heads', ['model', 'expert']],
['activation_attn_length', ['expert']],
['activation_attn_length_no_exp', []],
['activation_length', ['data', 'expert']],
['activation_length', ['data']],
['activation_length_moe', ['data', 'expert']],
['activation_length_no_exp', 'data'],
['activation_length_no_exp_moe', 'data'],
['activation_length_moe', 'data'],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
Expand Down
1 change: 0 additions & 1 deletion src/maxtext/configs/post_train/rl_mt_jt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ logical_axis_rules: [
['prefill_activation_length', ['data']],
['prefill_activation_norm_length', ['data']],
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
Expand Down
58 changes: 13 additions & 45 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
AxisIdxes,
AxisNames,
BATCH,
BATCH_NO_EXP,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_SEQUENCE,
Expand All @@ -53,12 +52,10 @@
HEAD,
Q_LORA_UP_PROJ,
KV_BATCH,
KV_BATCH_NO_EXP,
KV_HEAD,
KV_HEAD_DIM,
KV_LORA_UP_PROJ,
LENGTH,
LENGTH_NO_EXP,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
PREFILL_KV_BATCH,
Expand Down Expand Up @@ -425,16 +422,11 @@ def mla_as_linen(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down Expand Up @@ -499,13 +491,8 @@ def mla_as_linen(
query_axis_names=query_axis_names,
key_axis_names=key_axis_names,
value_axis_names=value_axis_names,
ep_query_axis_names=ep_query_axis_names,
ep_key_axis_names=ep_key_axis_names,
ep_value_axis_names=ep_value_axis_names,
input_axis_names=input_axis_names,
ep_input_axis_names=ep_input_axis_names,
out_axis_names=out_axis_names,
ep_out_axis_names=ep_out_axis_names,
prefill_input_axis_names=prefill_input_axis_names,
decode_input_axis_names=decode_input_axis_names,
prefill_out_axis_names=prefill_out_axis_names,
Expand Down Expand Up @@ -573,16 +560,11 @@ def __init__(
prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM),
query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM),
ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED),
ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV),
ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV),
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM),
input_axis_names: AxisNames = (BATCH, LENGTH, EMBED),
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV),
prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED),
decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED),
prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV),
Expand Down Expand Up @@ -664,13 +646,8 @@ def __init__(
query_axis_names=query_axis_names,
key_axis_names=key_axis_names,
value_axis_names=value_axis_names,
ep_query_axis_names=ep_query_axis_names,
ep_key_axis_names=ep_key_axis_names,
ep_value_axis_names=ep_value_axis_names,
input_axis_names=input_axis_names,
ep_input_axis_names=ep_input_axis_names,
out_axis_names=out_axis_names,
ep_out_axis_names=ep_out_axis_names,
prefill_input_axis_names=prefill_input_axis_names,
decode_input_axis_names=decode_input_axis_names,
prefill_out_axis_names=prefill_out_axis_names,
Expand Down Expand Up @@ -882,12 +859,9 @@ def mla_query_projection(
if model_mode == MODEL_MODE_PREFILL:
query_logical_name = self.prefill_query_axis_names
wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
query_logical_name = self.ep_query_axis_names
wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ)
else:
query_logical_name = self.query_axis_names
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ)
query_sharding = create_sharding(self.mesh, query_logical_name)
wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name)
# Set softmax scaling.
Expand Down Expand Up @@ -1038,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
"""MLA key/value projection with integrated rotary embedding."""
if model_mode == MODEL_MODE_PREFILL:
wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
else:
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ)
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
Expand Down Expand Up @@ -1178,14 +1150,10 @@ def __call__(
inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names)
out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names)
out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV)
else:
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
out_logical_name = (BATCH, LENGTH, HEAD, D_KV)

if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)
Expand Down
Loading
Loading