diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index fb58aa79b4..442e67d70c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -191,6 +191,11 @@ num_experts_per_tok: 1 megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations. +# By default (-1), this buffer will be worst case size to ensure no dropping. +# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates +# a size larger than this then tokens will be dropped. +# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor. load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 454a9f23f5..d705a3be5c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -641,6 +641,7 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") + ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( True, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..a66870fa4b 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded): ) return input_offsets, send_sizes, output_offsets, recv_sizes + def get_ragged_buffer_size(self, local_expert_size, local_batch): + if self.config.ragged_buffer_factor > 0.0: + balanced_size = local_batch + return int(balanced_size * self.config.ragged_buffer_factor) + else: + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + return int(local_batch * max_local_experts_per_tok) + def transform_bias(self, experts_index, *biases): """Selects bias values for a variable number of bias tensors based on chosen experts.""" return tuple(bias[experts_index] for bias in biases) @@ -1180,8 +1188,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # In the worst case, all of the global input data is assigned to each expert in the current shard. # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + buffer_size = self.get_ragged_buffer_size(local_expert_size, jnp.shape(x)[0]) output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) x = jax.lax.ragged_all_to_all(