@@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
887887 )
888888 return input_offsets , send_sizes , output_offsets , recv_sizes
889889
890+ def get_ragged_buffer_size (self , local_expert_size , local_batch ):
891+ if self .config .ragged_buffer_factor > 0.0 :
892+ balanced_size = local_batch
893+ return int (balanced_size * self .config .ragged_buffer_factor )
894+ else :
895+ max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
896+ return int (local_batch * max_local_experts_per_tok )
897+
890898 def transform_bias (self , experts_index , * biases ):
891899 """Selects bias values for a variable number of bias tensors based on chosen experts."""
892900 return tuple (bias [experts_index ] for bias in biases )
@@ -1180,8 +1188,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11801188 # In the worst case, all of the global input data is assigned to each expert in the current shard.
11811189 # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
11821190 # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
1183- max_local_experts_per_tok = min (local_expert_size , self .config .num_experts_per_tok )
1184- buffer_size = int (num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok )
1191+ buffer_size = self .get_ragged_buffer_size (local_expert_size , jnp .shape (x )[0 ])
11851192 output_shape = jax .lax .empty ((buffer_size , self .config .emb_dim ), dtype = x .dtype )
11861193
11871194 x = jax .lax .ragged_all_to_all (
0 commit comments