Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
# expert_shard_attention_option
EP_AS_CONTEXT = "context"
EP_AS_FSDP = "fsdp"
EP_AS_TP = "tp"

DECODING_ACTIVE_SEQUENCE_INDICATOR = 1

Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ norm_topk_prob: false # boolean to enable the top-k probability normalization. q
# how the expert axis is used to shard attention weights and activations
# "fsdp" (ep acts as fsdp parallelism)
# "context" (ep acts as context parallelism, training only)
# "tp" (ep acts as tp for attn/shared exp, useful for autoregressive inference or training/prefill with very low token count)
expert_shard_attention_option: "fsdp"


# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
moe_fsdp_use_two_stage_all_gather: false
# Shard the expert dimension of the MLP weights on the FSDP axis.
Expand Down
57 changes: 57 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/ep-tp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
mesh_axes: ['data', 'expert']
data_sharding: [['data', 'expert']]
# currently only EP support, will add data + EP
logical_axis_rules: [
['activation_batch', ['data']],
['activation_batch_moe', ['data']],
['activation_batch_no_exp', ['data']],
['activation_batch_no_exp_moe', ['data']],
['activation_embed_and_logits_batch', ['data','expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
['activation_heads', ['expert']],
['activation_kv_heads', ['expert']],
['activation_attn_length', ['expert']],
['activation_attn_length_no_exp', []],
['activation_length', ['expert']],
['activation_length_moe', ['expert']],
['activation_length_no_exp', []],
['activation_length_no_exp_moe', []],
['activation_q_length', ['expert']],
['activation_attn_embed', []],
['activation_embed', ['expert']], # this says that the expert physical axes should act like TP for attention (and shared expert) activations
['activation_embed_moe', ['expert']], # this is bad, it should be none, but we currently use this as an out_spec of wrapper when we should use activation_embed instead
['activation_mlp', []],
['activation_kv', []],
['activation_prefill_kv_batch', ['data', 'expert']],
['activation_kv_batch', ['data', 'expert']],
['activation_kv_batch_no_exp', []],
['activation_kv_head_dim', []],
['activation_vocab', []],
['activation_norm_length', []],
['activation_norm_length_moe', []],
['activation_exp', ['expert']],
['decode_batch', ['data', 'expert']],
['decode_length', []],
['mlp', []],
['mlp_no_fsdp', []],
['moe_mlp', []],
['vocab', []],
['heads', []],
['q_heads', ['expert']],
['kv_heads', ['expert']],
['kv_head_dim', []],
['kv', []],
['embed', ['expert']],
['embed', []],
['embed_moe', ['expert']],
['embed_moe', []],
['embed_tensor_transpose', []],
['embed_no_exp', []],
['embed_no_exp_moe', []],
['q_lora', ['expert']],
['kv_lora', ['expert']],
['norm', []],
['cache_heads', []],
['exp', ['expert']],
['paged_kv_heads', []],
]
2 changes: 1 addition & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ class MoEGeneral(BaseModel):
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
expert_shard_attention_option: Literal["fsdp", "context", "tp"] = Field(
"fsdp",
description="How the expert axis is used to shard attention weights and activations.",
)
Expand Down
29 changes: 22 additions & 7 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,8 @@ def gmm(
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
else:
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
# This can be improved!
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is terrible?

Copy link
Copy Markdown
Collaborator Author

@gobbleturk gobbleturk Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are a lot of things wrong with our shardings any _moe rule should only be used deep inside the moe layer after tokens have been routed. At this point in the model/code things should still be sharded like attention (not routed yet), so we should not use any _moe rule. Additionally the weights below use "activation" logical axis rules when "activation" should only be use for activations

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the resultant physical specs are probably what we want but we got there in a very poor way that is very hard to read / maintain

w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
Expand Down Expand Up @@ -1117,11 +1118,19 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
# The ring-of-experts strategy first duplicates the inputs to all
# expert shards, and then routes within each shard.

# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)
def ag_activations(x, logits, pre_bias_logits):
if self.config.expert_shard_attention_option == ctypes.EP_AS_FSDP:
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, axis=0, tiled=True)
for z in (x, logits, pre_bias_logits)
)
else:
# logits and pre_bias_logits are already replicated over EP, only AG x
x = jax.lax.all_gather(x, axis_name=self._expert_parallelism_name, axis=2, tiled=True)
return x, logits, pre_bias_logits

x, logits, pre_bias_logits = ag_activations(x, logits, pre_bias_logits)


# "Route" tokens within each shard.
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
Expand Down Expand Up @@ -1324,7 +1333,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):

# Sum up the partial outputs across the expert shards.
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
if self.config.expert_shard_attention_option == ctypes.EP_AS_FSDP:
scatter_dimension=0
elif self.config.expert_shard_attention_option == ctypes.EP_AS_CONTEXT:
scatter_dimension=1
elif self.config.expert_shard_attention_option == ctypes.EP_AS_TP:
scatter_dimension=2
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=scatter_dimension, tiled=True)

else:
if num_expert_parallelism > 1:
Expand Down
Loading