From e4bd151fdcfc1711b3a1bc1ba4b5788e795e1fd6 Mon Sep 17 00:00:00 2001 From: Ran Ran Date: Fri, 3 Apr 2026 22:58:10 +0000 Subject: [PATCH] Internal test about EP PiperOrigin-RevId: 894290655 --- src/maxtext/configs/base.yml | 2 + src/maxtext/configs/types.py | 8 + src/maxtext/layers/moe.py | 705 +++++++++++++++++++++++++++++------ tests/unit/moe_test.py | 45 +++ 4 files changed, 641 insertions(+), 119 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index fb58aa79b4..82879fd95e 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -195,6 +195,8 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism +use_iterative_moe: false # whether to use iterative routing for sparse matmul to save memory +ra2a_num_chunks: 1 # number of chunks to split tokens into for iterative MoE # tunable tiling dimensions used for mlp gmm # megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`) # tokamax ragged dot - supports all 18 configs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 454a9f23f5..8e828caff9 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -650,6 +650,14 @@ class MoEGeneral(BaseModel): False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) + use_iterative_moe: bool = Field( + False, + description="Whether to use iterative MoE routing to save memory.", + ) + ra2a_num_chunks: int = Field( + 1, + description="Number of chunks for iterative MoE routing.", + ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") expert_shard_attention_option: Literal["fsdp", "context"] = Field( diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..f8a8a966f5 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -891,134 +891,158 @@ def transform_bias(self, experts_index, *biases): """Selects bias values for a variable number of bias tensors based on chosen experts.""" return tuple(bias[experts_index] for bias in biases) - def sparse_matmul( + def _gmm( self, inputs, - gate_logits, - pre_bias_logits, - w0_kernel, - w1_kernel, - wo_kernel, - w0_bias, - w1_bias, - wo_bias, + kernel, + tiling, + group_sizes, + expert_assignments, + weight_gather_axes, + input_buffer_count, + combine_scopes, ): - """Perform sparse matrix multiplication of inputs and Experts.""" - - def gmm( - inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes + # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm + if ( + self.config.using_pipeline_parallelism + and self.config.pipeline_fsdp_ag_per_repeat ): - # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm - if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: - tokamax_group_sizes = group_sizes - elif self.config.attention == "vllm_rpa": - tokamax_group_sizes = group_sizes + tokamax_group_sizes = group_sizes + elif self.config.attention == "vllm_rpa": + tokamax_group_sizes = group_sizes + else: + tokamax_group_sizes = tokamax.RaggedDotGroupSizes( + group_sizes, + max_utils.generate_representative_group_sizes( + inputs.shape[0], kernel.shape[0] + ), + ) + pad_length = self.config.wi_tile_fwd_batch_seq + hs_shape = inputs.shape + # pad length is the 1st dimension of tiling size in gmm call + if inputs.shape[0] != expert_assignments.shape[0]: + raise ValueError( + "The number of input tokens must match the number of expert" + " assignments!" + ) + padding_amount = 0 + if hs_shape[0] % pad_length: + padding_amount = pad_length - hs_shape[0] % pad_length + inputs = jax.lax.pad( + inputs, + jnp.array(0.0, dtype=inputs.dtype), + [(0, padding_amount, 0), (0, 0, 0)], + ) + + inputs = inputs.astype(self.dtype) + kernel = kernel.astype(self.dtype) + + lhs_quantize_dtype, rhs_quantize_dtype = None, None + if self.quant is not None: + quant_dg = self.quant.quant_dg + lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype() + rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() + m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2] + if not self.config.megablox and not self.config.use_tokamax_gmm: + tiling = ( + min(tiling[0], m), + min(tiling[1], k), + min(tiling[2], n), + ) + if self.config.use_tokamax_gmm: + if self.config.quantization: + output = mblx.gmm( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + preferred_element_type=self.dtype, + tiling=tiling, + lhs_quantize_dtype=lhs_quantize_dtype, + rhs_quantize_dtype=rhs_quantize_dtype, + use_qwix_quantization=self.config.use_qwix_quantization, + use_tokamax_backend=self.config.use_tokamax_gmm, + weight_gather_axes=weight_gather_axes, + input_buffer_count=input_buffer_count, + combine_scopes=combine_scopes, + ) else: - tokamax_group_sizes = tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation="mosaic", ) - pad_length = self.config.wi_tile_fwd_batch_seq - hs_shape = inputs.shape - # pad length is the 1st dimension of tiling size in gmm call - if inputs.shape[0] != expert_assignments.shape[0]: - raise ValueError("The number of input tokens must match the number of expert" " assignments!") - padding_amount = 0 - if hs_shape[0] % pad_length: - padding_amount = pad_length - hs_shape[0] % pad_length - inputs = jax.lax.pad(inputs, jnp.array(0.0, dtype=inputs.dtype), [(0, padding_amount, 0), (0, 0, 0)]) - - inputs = inputs.astype(self.dtype) - kernel = kernel.astype(self.dtype) - - lhs_quantize_dtype, rhs_quantize_dtype = None, None - if self.quant is not None: - quant_dg = self.quant.quant_dg - lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype() - rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() - m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2] - if not self.config.megablox and not self.config.use_tokamax_gmm: - tiling = ( - min(tiling[0], m), - min(tiling[1], k), - min(tiling[2], n), + else: + if self.config.megablox: + output = mblx.gmm( + lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + preferred_element_type=self.dtype, + tiling=tiling, + lhs_quantize_dtype=lhs_quantize_dtype, + rhs_quantize_dtype=rhs_quantize_dtype, + use_qwix_quantization=self.config.use_qwix_quantization, + use_tokamax_backend=self.config.use_tokamax_gmm, + weight_gather_axes=weight_gather_axes, ) - if self.config.use_tokamax_gmm: - if self.config.quantization: - output = mblx.gmm( - lhs=inputs, - rhs=kernel, - group_sizes=group_sizes, - preferred_element_type=self.dtype, - tiling=tiling, - lhs_quantize_dtype=lhs_quantize_dtype, - rhs_quantize_dtype=rhs_quantize_dtype, - use_qwix_quantization=self.config.use_qwix_quantization, - use_tokamax_backend=self.config.use_tokamax_gmm, - weight_gather_axes=weight_gather_axes, - input_buffer_count=input_buffer_count, - combine_scopes=combine_scopes, - ) - else: - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax_group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=self.dtype, - implementation="mosaic", - ) else: - if self.config.megablox: - output = mblx.gmm( + rhs_inputs = kernel + if isinstance(kernel, aqt.QTensor): + if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1: + raise ValueError( + "Unsupported usecase for ragged_dot with quantized kernel." + ) + rhs_inputs = kernel.qvalue + if self.config.use_qwix_quantization: + # Use full contraction for QWIX quantization to allow quantization + # fusion (max reduce over contracting dimension). + tiling = (tiling[0], k, tiling[2]) + + is_tpu = self.mesh.devices.flat[0] == "tpu" + # TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync + mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0" + with set_xla_metadata( + ragged_dot_tiling=",".join([str(t) for t in tiling]), + mosaic_fusion_group=mosaic_group_id, + ): + output = jax.lax.ragged_dot( lhs=inputs, - rhs=kernel, + rhs=rhs_inputs, group_sizes=group_sizes, preferred_element_type=self.dtype, - tiling=tiling, - lhs_quantize_dtype=lhs_quantize_dtype, - rhs_quantize_dtype=rhs_quantize_dtype, - use_qwix_quantization=self.config.use_qwix_quantization, - use_tokamax_backend=self.config.use_tokamax_gmm, - weight_gather_axes=weight_gather_axes, ) - else: - rhs_inputs = kernel - if isinstance(kernel, aqt.QTensor): - if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1: - raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.") - rhs_inputs = kernel.qvalue - if self.config.use_qwix_quantization: - # Use full contraction for QWIX quantization to allow quantization - # fusion (max reduce over contracting dimension). - tiling = (tiling[0], k, tiling[2]) - - is_tpu = self.mesh.devices.flat[0] == "tpu" - # TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync - mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0" - with set_xla_metadata( - ragged_dot_tiling=",".join([str(t) for t in tiling]), - mosaic_fusion_group=mosaic_group_id, - ): - output = jax.lax.ragged_dot( - lhs=inputs, - rhs=rhs_inputs, - group_sizes=group_sizes, - preferred_element_type=self.dtype, + if isinstance(kernel, aqt.QTensor): + # Multiply outputs by the kernely scale + scales = jnp.take( + kernel.scale[0].squeeze(), indices=expert_assignments, axis=0 + ) + if padding_amount > 0: + scales = jax.lax.pad( + scales, + jnp.array(0.0, dtype=scales.dtype), + [(0, padding_amount, 0), (0, 0, 0)], ) - if isinstance(kernel, aqt.QTensor): - # Multiply outputs by the kernely scale - scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0) - if padding_amount > 0: - scales = jax.lax.pad( - scales, - jnp.array(0.0, dtype=scales.dtype), - [(0, padding_amount, 0), (0, 0, 0)], - ) - output *= scales - if padding_amount > 0: - output = output[: hs_shape[0]] - return output + output *= scales + if padding_amount > 0: + output = output[: hs_shape[0]] + return output + + def sparse_matmul( + self, + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + ): + """Perform sparse matrix multiplication of inputs and Experts.""" # Currently, we support data, tensor, and expert parallelism with Megablox. # We all gather the input activations over tensor parallelism to follow @@ -1237,7 +1261,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[0], 0)) wo_gather_axes.extend(get_active_sharding_axes(wo_pspec[1], 1)) gmm_fn = functools.partial( - gmm, + self._gmm, group_sizes=group_sizes, expert_assignments=selected_experts, ) @@ -1405,6 +1429,428 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): return output, lb_loss, bias_updates + if self.config.moe_fsdp_use_two_stage_all_gather: + # Unshard on fsdp axis + w0_kernel = self._maybe_shard_with_logical( + w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp") + ) + w1_kernel = self._maybe_shard_with_logical( + w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp") + ) + + # Unshard on fsdp_transpose axis + wo_kernel = self._maybe_shard_with_logical( + wo_kernel, ("exp_with_fsdp", "mlp", "embed_tensor_transpose") + ) + + # Make sure XLA does not optimize by combining above All-Gather to unshard + # on FSDP axis and the subsequent unshard on fsdp_transpose axis + w0_kernel = jax.lax.optimization_barrier(w0_kernel) + w1_kernel = jax.lax.optimization_barrier(w1_kernel) + wo_kernel = jax.lax.optimization_barrier(wo_kernel) + + # Unshard on both fsdp and fsdp_transpose transpose + w0_kernel = self._maybe_shard_with_logical( + w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp") + ) + w1_kernel = self._maybe_shard_with_logical( + w1_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp_no_fsdp") + ) + wo_kernel = self._maybe_shard_with_logical( + wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose") + ) + + if self.get_tensor_transpose_parallelism_size() > 1: + input_axes = ( + batch_logical_axis, + "activation_norm_length_moe", + "activation_embed_moe", + ) + else: + input_axes = (batch_logical_axis, "activation_norm_length_moe", None) + + gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + if self.config.model_name.startswith("deepseek3"): + pre_bias_logits_axes = ( + batch_logical_axis, + "activation_norm_length_moe", + None, + ) + else: + pre_bias_logits_axes = None + + inputs = self._maybe_shard_with_logical(inputs, input_axes) + gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes) + pre_bias_logits = self._maybe_shard_with_logical( + pre_bias_logits, pre_bias_logits_axes + ) + + return wrapper( + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + self.rngs, + ) + + def sparse_matmul_iterative( + self, + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + ): + """Perform sparse matrix multiplication of inputs and Experts iteratively.""" + # Currently, we support data, tensor, and expert parallelism with Megablox. + # We all gather the input activations over tensor parallelism to follow + # https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. + + # Check if the batch should be sharded by expert and whether the batch_size + # supports this. For example, for interleaved inference, prefill always has + # batch_size=1 while decode can have batch_size > 1. + try: + is_batch_sharded_by_expert = ( + self._expert_parallelism_name + in tuple( + filter( + lambda tup: tup[0] == "activation_batch_moe", + self.config.logical_axis_rules, + ) + )[0][1] + ) + except: # pylint: disable=bare-except + is_batch_sharded_by_expert = False + if is_batch_sharded_by_expert and inputs.shape[0] > 1: + batch_logical_axis = "activation_batch_moe" + else: + batch_logical_axis = "activation_batch_no_exp_moe" + + if self.get_tensor_transpose_parallelism_size() > 1: + input_partition_pspec = self._logical_to_mesh_axes(( + batch_logical_axis, + "activation_norm_length_moe", + "activation_embed_moe", + )) + w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) + w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) + wo_bias_pspec = self._logical_to_mesh_axes( + ("exp", "activation_embed_moe") + ) + else: + input_partition_pspec = self._logical_to_mesh_axes( + (batch_logical_axis, "activation_norm_length_moe", None) + ) + w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) + w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) + wo_bias_pspec = self._logical_to_mesh_axes( + ("exp", "activation_embed_moe") + ) + + gate_logits_pspec = self._logical_to_mesh_axes( + (batch_logical_axis, "activation_norm_length_moe", None) + ) + if self.config.model_name.startswith("deepseek3"): + pre_bias_logits_pspec = self._logical_to_mesh_axes( + (batch_logical_axis, "activation_norm_length_moe", None) + ) + else: + # pre_bias_logits is None for non-DeepSeek v3 models + pre_bias_logits_pspec = None + + # w0, w1, wo needs to be un sharded on fsdp / fsdp_transpose axis, so use + # mlp_no_fsdp axis + weight_gather = False + if self.config.shard_exp_on_fsdp: + quantization_rule = qpl.get_current_rule("gmm") + if ( + quantization_rule + and quantization_rule.weight_calibration_method.startswith("fixed") + ): + # special sharding when using static scaling for weights in quantization with shard_exp_on_fsdp + w0_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes) + w1_pspec = self._logical_to_mesh_axes(self.wi_kernel_axes) + wo_pspec = self._logical_to_mesh_axes(self.wo_kernel_axes) + weight_gather = True + else: + # special sharding for dsv3 to remove overhead between gmm/AG + w0_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", None, "mlp_no_fsdp") + ) + w1_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", None, "mlp_no_fsdp") + ) + wo_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", "mlp_no_fsdp", None) + ) + elif self.config.use_2d_fsdp_sharding: + w0_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", "mlp_no_fsdp", None) + ) + w1_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", "mlp_no_fsdp", None) + ) + wo_pspec = self._logical_to_mesh_axes( + ("embed_tensor_transpose", "mlp_no_fsdp", None) + ) + else: + w0_pspec = self._logical_to_mesh_axes( + ("exp", "embed_tensor_transpose", "mlp_no_fsdp") + ) + w1_pspec = self._logical_to_mesh_axes( + ("exp", "embed_tensor_transpose", "mlp_no_fsdp") + ) + wo_pspec = self._logical_to_mesh_axes( + ("exp", "mlp_no_fsdp", "embed_tensor_transpose") + ) + if isinstance(w0_kernel, aqt.QTensor): + w0_pspec = aqt.partition_spec( + w0_pspec, (1,), w0_kernel.dtype, use_bias=False + ) + if isinstance(w1_kernel, aqt.QTensor): + w1_pspec = aqt.partition_spec( + w1_pspec, (1,), w1_kernel.dtype, use_bias=False + ) + if isinstance(wo_kernel, aqt.QTensor): + wo_pspec = aqt.partition_spec( + wo_pspec, (1,), wo_kernel.dtype, use_bias=False + ) + + @functools.partial( + jax.shard_map, + mesh=self.mesh, + in_specs=( + input_partition_pspec, + gate_logits_pspec, + pre_bias_logits_pspec, + w0_pspec, + w1_pspec, + wo_pspec, + w0_bias_pspec, + w1_bias_pspec, + wo_bias_pspec, + P(), # Replicate the input key + ), + out_specs=( + self._logical_to_mesh_axes(( + batch_logical_axis, + "activation_norm_length_moe", + "activation_embed_moe", + )), + P(), # Handle None or replicate the output + P(), # Handle None or replicate the output + ), + check_vma=False, + ) + def wrapper( + x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs + ): + batch_size, sequence_length, _ = x.shape + num_expert_parallelism = self.get_expert_parallelism_size() + + expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) + + # Permute tokens + ( + x_permuted, + sorted_selected_experts, + weights, + _, + selected_experts, + lb_loss, + bias_updates, + ) = self.permute( + x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs + ) + + num_chunks = self.config.ra2a_num_chunks + total_tokens = x_permuted.shape[0] + chunk_size = total_tokens // num_chunks + + output_accum = jnp.zeros( + ( + total_tokens, + self.config.emb_dim // self.get_tensor_parallelism_size(), + ), + dtype=x.dtype, + ) + + def loop_body(i, state): + output_accum = state + start_idx = i * chunk_size + + x_chunk = jax.lax.dynamic_slice_in_dim( + x_permuted, start_idx, chunk_size, axis=0 + ) + expert_assignments_chunk = jax.lax.dynamic_slice_in_dim( + selected_experts, start_idx, chunk_size, axis=0 + ) + + # Compute group sizes for this chunk + chunk_group_sizes = jnp.bincount( + expert_assignments_chunk, length=self.config.num_experts + ) + + # reshaped_group_sizes for this chunk + local_expert_size = ( + self.config.num_experts // self.get_expert_parallelism_size() + ) + reshaped_chunk_group_sizes = jnp.sum( + chunk_group_sizes.reshape(-1, local_expert_size), axis=1 + ) + + # all_gather to get full matrix for this chunk + all_shards_chunk_group_sizes = jax.lax.all_gather( + reshaped_chunk_group_sizes, + axis_name=self._expert_parallelism_name, + ) + + # get all_to_all params for this chunk + ( + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + ) = RoutedMoE.get_all_to_all_params( + all_shards_chunk_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) + + # Receiver buffer size for this chunk + max_recv_size = num_expert_parallelism * chunk_size + output_shape = jnp.zeros( + (max_recv_size, self.config.emb_dim), dtype=x.dtype + ) + + # All-to-all chunk + x_recv = jax.lax.ragged_all_to_all( + x_chunk, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + + # Local permute + global_chunk_group_sizes = jax.lax.all_gather( + chunk_group_sizes, axis_name=self._expert_parallelism_name + ) + + ( + x_local, + local_sorted_indices, + group_sizes_local, + selected_experts_local, + ) = RoutedMoE.local_permute( + x_recv, + global_chunk_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) + + # Call GMM + tiling0 = ( + self.config.wi_tile_fwd_batch_seq, + self.config.emb_dim, + self.config.moe_mlp_dim, + ) + tiling1 = ( + self.config.wo_tile_fwd_batch_seq, + self.config.moe_mlp_dim, + self.config.emb_dim, + ) + + # Rematerialize the GMM computation to save memory in backward pass + def gmm_compute(x_loc, gs_loc, se_loc): + out = self._gmm( + inputs=x_loc, + kernel=w0, + tiling=tiling0, + group_sizes=gs_loc, + expert_assignments=se_loc, + weight_gather_axes=[], + input_buffer_count=0, + combine_scopes=False, + ) + out = adc.checkpoint_name(out, "moe_mlpwi_0") + + out2 = self._gmm( + inputs=x_loc, + kernel=w1, + tiling=tiling0, + group_sizes=gs_loc, + expert_assignments=se_loc, + weight_gather_axes=[], + input_buffer_count=0, + combine_scopes=False, + ) + out2 = adc.checkpoint_name(out2, "moe_mlpwi_1") + + out = self.apply_ffn_activation(out, out2) + + out = self._gmm( + inputs=out, + kernel=wo, + tiling=tiling1, + group_sizes=gs_loc, + expert_assignments=se_loc, + weight_gather_axes=[], + input_buffer_count=0, + combine_scopes=False, + ) + out = adc.checkpoint_name(out, "moe_mlpwo") + return out + + output_chunk = jax.checkpoint(gmm_compute)( + x_local, group_sizes_local, selected_experts_local + ) + + # Send back + output_chunk_shape = jnp.zeros( + (chunk_size, self.config.emb_dim), dtype=x.dtype + ) + output_chunk_sent_back = jax.lax.ragged_all_to_all( + output_chunk, + output_chunk_shape, + output_offsets, + recv_sizes, + input_offsets, + send_sizes, + axis_name=self._expert_parallelism_name, + ) + + output_accum = jax.lax.dynamic_update_slice( + output_accum, output_chunk_sent_back, (start_idx, 0) + ) + return output_accum + + output_accum = jax.lax.fori_loop(0, num_chunks, loop_body, output_accum) + + # Unpermute + output = self.unpermute( + output_accum, + sorted_selected_experts, + weights, + batch_size=batch_size, + sequence_length=sequence_length, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) + + return output, lb_loss, bias_updates + if self.config.moe_fsdp_use_two_stage_all_gather: # Unshard on fsdp axis w0_kernel = self._maybe_shard_with_logical(w0_kernel, ("exp_with_fsdp", "embed_tensor_transpose", "mlp")) @@ -2053,9 +2499,30 @@ def __call__( w1_bias, wo_bias, ) - output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias - ) + if self.config.use_iterative_moe: + output, lb_loss, bias_updates = self.sparse_matmul_iterative( + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + ) + else: + output, lb_loss, bias_updates = self.sparse_matmul( + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + ) else: output, lb_loss, bias_updates = self.dense_matmul( inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index c3e83025f9..c268f19f11 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -509,6 +509,51 @@ def test_ragged_dot(self): actual_output, _, _ = self.get_moe_output(variables, hidden_states, cfg, mesh) self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False)) + @pytest.mark.tpu_only + def test_iterative_moe(self): + cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="moe_block_iterative_test", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + megablox=False, + sparse_matmul=True, + use_iterative_moe=True, + ra2a_num_chunks=2, + per_device_batch_size=1, + max_target_length=128, + ) + + rng = jax.random.PRNGKey(1234) + rng_model, rng_hidden_states = jax.random.split(rng) + device_count = jax.device_count() + hidden_states = jax.random.uniform( + rng_hidden_states, + ( + int(cfg.per_device_batch_size) * device_count, + cfg.max_target_length, + cfg.base_emb_dim, + ), + dtype=cfg.dtype, + ) + + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + variables, expected_output = self.get_expected_output( + rng_model, hidden_states, cfg, mesh + ) + actual_output = self.get_moe_output(variables, hidden_states, cfg, mesh) + self.assertTrue( + jax.numpy.allclose( + expected_output, + actual_output, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) + @pytest.mark.tpu_only def test_dense(self): cfg = pyconfig.initialize(