From 0f91addba77e8b6d2dfb373cc1c1fb59604401c3 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 3 Apr 2026 02:24:15 +0000 Subject: [PATCH] add ragged buffer --- src/maxtext/configs/base.yml | 5 +++++ .../custom_mesh_and_rule/pipeline-large-moe.yml | 4 ++-- src/maxtext/configs/types.py | 1 + src/maxtext/layers/moe.py | 11 +++++++++-- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index fb58aa79b4..e0bc5a6b84 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -191,6 +191,11 @@ num_experts_per_tok: 1 megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +ragged_buffer_factor: 1.0 # a factor to determine the size of the ragged buffer for routed MoE activations. +# By default (-1), this buffer will be worst case size to ensure no dropping. +# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates +# a size larger than this then tokens will be dropped. +# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor. load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 8209dece2d..7ec1edda14 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -59,8 +59,8 @@ logical_axis_rules: [ ['heads', ['tensor']], ['q_heads', ['tensor']], ['kv_heads', ['tensor']], - ['embed', ['fsdp', 'expert']], # remove context from embed sharding - ['embed_moe', ['fsdp', 'expert']], + ['embed', ['fsdp']], # remove context from embed sharding + ['embed_moe', ['fsdp']], ['embed_no_exp', ['fsdp']], ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 8f1f62922b..d3caaef6a0 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -640,6 +640,7 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") + ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5f6570bc88..a008f7816d 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded): ) return input_offsets, send_sizes, output_offsets, recv_sizes + def get_ragged_buffer_size(self, local_expert_size, local_batch): + if self.config.ragged_buffer_factor > 0.0: + balanced_size = local_batch + return int(balanced_size * self.config.ragged_buffer_factor) + else: + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + return int(local_batch * max_local_experts_per_tok) + def transform_bias(self, experts_index, *biases): """Selects bias values for a variable number of bias tensors based on chosen experts.""" return tuple(bias[experts_index] for bias in biases) @@ -1180,8 +1188,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # In the worst case, all of the global input data is assigned to each expert in the current shard. # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + buffer_size = self.get_ragged_buffer_size(local_expert_size, jnp.shape(x)[0]) output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype) x = jax.lax.ragged_all_to_all(