Skip to content

Commit 8faed51

Browse files
gobbleturkGoogle-ML-Automation
authored andcommitted
Ragged buffer size control
PiperOrigin-RevId: 885683593
1 parent 2a57a30 commit 8faed51

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ num_experts_per_tok: 1
191191
megablox: true
192192
sparse_matmul: true
193193
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
194+
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
195+
# By default (-1), this buffer will be worst case size to ensure no dropping.
196+
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
197+
# a size larger than this then tokens will be dropped.
198+
# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
194199
load_balance_loss_weight: 0.0 # weight for the load balance loss
195200
use_random_routing: false # whether to use random routing for debug/test purpose
196201
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ class MoEGeneral(BaseModel):
641641
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
642642
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
643643
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
644+
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
644645
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
645646
use_custom_sort_vjp: bool = Field(
646647
True,

src/maxtext/layers/moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
887887
)
888888
return input_offsets, send_sizes, output_offsets, recv_sizes
889889

890+
def get_ragged_buffer_size(self, local_expert_size, local_batch):
891+
if self.config.ragged_buffer_factor > 0.0:
892+
balanced_size = local_batch
893+
return int(balanced_size * self.config.ragged_buffer_factor)
894+
else:
895+
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
896+
return int(local_batch * max_local_experts_per_tok)
897+
890898
def transform_bias(self, experts_index, *biases):
891899
"""Selects bias values for a variable number of bias tensors based on chosen experts."""
892900
return tuple(bias[experts_index] for bias in biases)
@@ -1180,8 +1188,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11801188
# In the worst case, all of the global input data is assigned to each expert in the current shard.
11811189
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
11821190
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1183-
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
1184-
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1191+
buffer_size = self.get_ragged_buffer_size(local_expert_size, jnp.shape(x)[0])
11851192
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
11861193

11871194
x = jax.lax.ragged_all_to_all(

0 commit comments

Comments
 (0)