diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index ec1b6b4fe2..7b4d1fb341 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -73,6 +73,7 @@ # expert_shard_attention_option EP_AS_CONTEXT = "context" EP_AS_FSDP = "fsdp" +EP_AS_TP = "tp" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 477ee223dc..d0d6459f85 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -231,8 +231,10 @@ norm_topk_prob: false # boolean to enable the top-k probability normalization. q # how the expert axis is used to shard attention weights and activations # "fsdp" (ep acts as fsdp parallelism) # "context" (ep acts as context parallelism, training only) +# "tp" (ep acts as tp for attn/shared exp, useful for autoregressive inference or training/prefill with very low token count) expert_shard_attention_option: "fsdp" + # when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls moe_fsdp_use_two_stage_all_gather: false # Shard the expert dimension of the MLP weights on the FSDP axis. diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-tp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-tp.yml new file mode 100644 index 0000000000..cd57045edc --- /dev/null +++ b/src/maxtext/configs/custom_mesh_and_rule/ep-tp.yml @@ -0,0 +1,57 @@ +mesh_axes: ['data', 'expert'] +data_sharding: [['data', 'expert']] +# currently only EP support, will add data + EP +logical_axis_rules: [ + ['activation_batch', ['data']], + ['activation_batch_moe', ['data']], + ['activation_batch_no_exp', ['data']], + ['activation_batch_no_exp_moe', ['data']], + ['activation_embed_and_logits_batch', ['data','expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], + ['activation_heads', ['expert']], + ['activation_kv_heads', ['expert']], + ['activation_attn_length', ['expert']], + ['activation_attn_length_no_exp', []], + ['activation_length', ['expert']], + ['activation_length_moe', ['expert']], + ['activation_length_no_exp', []], + ['activation_length_no_exp_moe', []], + ['activation_q_length', ['expert']], + ['activation_attn_embed', []], + ['activation_embed', ['expert']], # this says that the expert physical axes should act like TP for attention (and shared expert) activations + ['activation_embed_moe', ['expert']], # this is bad, it should be none, but we currently use this as an out_spec of wrapper when we should use activation_embed instead + ['activation_mlp', []], + ['activation_kv', []], + ['activation_prefill_kv_batch', ['data', 'expert']], + ['activation_kv_batch', ['data', 'expert']], + ['activation_kv_batch_no_exp', []], + ['activation_kv_head_dim', []], + ['activation_vocab', []], + ['activation_norm_length', []], + ['activation_norm_length_moe', []], + ['activation_exp', ['expert']], + ['decode_batch', ['data', 'expert']], + ['decode_length', []], + ['mlp', []], + ['mlp_no_fsdp', []], + ['moe_mlp', []], + ['vocab', []], + ['heads', []], + ['q_heads', ['expert']], + ['kv_heads', ['expert']], + ['kv_head_dim', []], + ['kv', []], + ['embed', ['expert']], + ['embed', []], + ['embed_moe', ['expert']], + ['embed_moe', []], + ['embed_tensor_transpose', []], + ['embed_no_exp', []], + ['embed_no_exp_moe', []], + ['q_lora', ['expert']], + ['kv_lora', ['expert']], + ['norm', []], + ['cache_heads', []], + ['exp', ['expert']], + ['paged_kv_heads', []], + ] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index ed4836571b..b9717d28c5 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -636,7 +636,7 @@ class MoEGeneral(BaseModel): ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") - expert_shard_attention_option: Literal["fsdp", "context"] = Field( + expert_shard_attention_option: Literal["fsdp", "context", "tp"] = Field( "fsdp", description="How the expert axis is used to shard attention weights and activations.", ) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 9d53d1149d..39999d8a6e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1040,7 +1040,8 @@ def gmm( w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + # This can be improved! + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed")) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) @@ -1117,11 +1118,19 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # The ring-of-experts strategy first duplicates the inputs to all # expert shards, and then routes within each shard. - # Duplicate inputs to all expert shards. - x, logits, pre_bias_logits = tuple( - jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) - for z in (x, logits, pre_bias_logits) - ) + def ag_activations(x, logits, pre_bias_logits): + if self.config.expert_shard_attention_option == ctypes.EP_AS_FSDP: + x, logits, pre_bias_logits = tuple( + jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, axis=0, tiled=True) + for z in (x, logits, pre_bias_logits) + ) + else: + # logits and pre_bias_logits are already replicated over EP, only AG x + x = jax.lax.all_gather(x, axis_name=self._expert_parallelism_name, axis=2, tiled=True) + return x, logits, pre_bias_logits + + x, logits, pre_bias_logits = ag_activations(x, logits, pre_bias_logits) + # "Route" tokens within each shard. num_experts_per_shard = self.config.num_experts // num_expert_parallelism @@ -1324,7 +1333,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): # Sum up the partial outputs across the expert shards. output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size())) - output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) + if self.config.expert_shard_attention_option == ctypes.EP_AS_FSDP: + scatter_dimension=0 + elif self.config.expert_shard_attention_option == ctypes.EP_AS_CONTEXT: + scatter_dimension=1 + elif self.config.expert_shard_attention_option == ctypes.EP_AS_TP: + scatter_dimension=2 + output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=scatter_dimension, tiled=True) else: if num_expert_parallelism > 1: