From 12b2e74bf65c81712837063ff9dad025078eb054 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Sun, 22 Mar 2026 10:58:17 -0700 Subject: [PATCH] Enable fused MoE kernel for Qwen 3.5 MoE model Replace the compute-all-gather approach (ConditionalFeedForward with grouped nn.Linear) with FusedMoEExperts that calls the fused MoE Triton kernel directly. Expert weights are quantized to packed INT4 using HQQ (Half-Quadratic Quantization) scale-only optimization from torchao, separate from the tinygemm path used for attention and shared expert linears. For decode (M=1), only 8 of 256 experts' weights are loaded from HBM per layer (128x less memory traffic vs the old approach). Depends on the fused MoE Triton kernel (triton::fused_moe). --- examples/models/qwen3_5_moe/export.py | 86 +++++++++++++++- examples/models/qwen3_5_moe/model.py | 143 +++++++++----------------- 2 files changed, 134 insertions(+), 95 deletions(-) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 3f6a874ec14..432b7c65f36 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE +from executorch.examples.models.qwen3_5_moe.model import FusedMoEExperts, Qwen35MoE # --------------------------------------------------------------------------- @@ -36,6 +36,10 @@ def load_and_quantize(args): ) if args.qlinear or args.qembedding: + if args.qlinear: + _quantize_experts_int4( + model, config, args.qlinear_group_size, use_hqq=args.hqq + ) _quantize(model, config, args) else: model.to(dtype=torch.bfloat16) @@ -43,6 +47,82 @@ def load_and_quantize(args): return model, config +def _quantize_experts_int4(model, config, group_size=32, use_hqq=False): + """Quantize expert weights to packed INT4 for the fused MoE kernel. + + Two quantization methods: + --hqq: HQQ (Half-Quadratic Quantization) iteratively refines scales + via least-squares for better accuracy (slower). + default: Standard min/max symmetric quantization (faster). + + Converts w1_weight [E, N, K] and w2_weight [E, N, K] to: + w1 [E, N, K//2] int8 packed, w1_scale [E, N, K//gs] bf16 + w2 [E, N, K//2] int8 packed, w2_scale [E, N, K//gs] bf16 + """ + if use_hqq: + from torchao.quantization.quant_primitives import ( + _choose_qparams_and_quantize_scale_only_hqq, + ) + else: + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + + method = "HQQ" if use_hqq else "min/max" + + for i, layer in enumerate(model.layers): + experts = layer.mlp.experts + if not isinstance(experts, FusedMoEExperts): + continue + + experts.group_size = group_size + for name in ("w1_weight", "w2_weight"): + w = getattr(experts, name).data.float() + E, N, K = w.shape + + if use_hqq: + qdata, scale = _choose_qparams_and_quantize_scale_only_hqq( + w.view(E * N, K), + block_size=[1, group_size], + qmin=-8, + qmax=7, + ) + int_data = qdata.to(torch.int8).view(E, N, K) + scale = scale.view(E, N, -1) + else: + block_size = (1, 1, group_size) + scale, zero_point = choose_qparams_affine( + w, MappingType.SYMMETRIC, block_size, + target_dtype=torch.int8, quant_min=-8, quant_max=7, + ) + int_data = quantize_affine( + w, block_size, scale, zero_point, + output_dtype=torch.int8, quant_min=-8, quant_max=7, + ) + scale = scale.reshape(E, N, -1) + + # Pack two int4 values per byte: even K -> low nibble, odd K -> high nibble + uint4 = (int_data + 8).to(torch.int16) # shift to unsigned [0, 15] + low = uint4[:, :, 0::2] + high = uint4[:, :, 1::2] + packed = (low | (high << 4)).to(torch.int8) # [E, N, K//2] + + buf_name = name.replace("_weight", "") + experts.register_buffer(buf_name, packed) + experts.register_buffer( + f"{buf_name}_scale", scale.to(torch.bfloat16) + ) + delattr(experts, name) + + print( + f" Quantized experts (INT4 {method}) layer {i + 1}/{config.num_hidden_layers}", + end="\r", + ) + print() + + def _to_device_skip_meta(module, device, dtype=None): """Move submodules to device, skipping any that have meta-device buffers. @@ -287,6 +367,10 @@ def main(): parser.add_argument( "--qembedding", default=None, choices=["8w"], help="Quantize embedding layers." ) + parser.add_argument( + "--hqq", action="store_true", + help="Use HQQ scale-only optimization for expert quantization (slower, better accuracy).", + ) args = parser.parse_args() # Register FLA Triton kernel diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 67c83d1a7f8..00c3deb20c9 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -399,72 +399,41 @@ def forward(self, x, input_pos): # --------------------------------------------------------------------------- -# MoE: stacked expert weights + index by top-k +# MoE: expert weights for fused MoE Triton kernel -# 16 experts per group keeps each nn.Linear under ~32K output features, -# within tinygemm int4 packing limits while keeping the graph small -# (32 matmul nodes per layer instead of 768 with per-expert linears). -_EXPERTS_PER_GROUP = 16 +class FusedMoEExperts(nn.Module): + """Expert weights stored as stacked tensors for the fused MoE Triton kernel. -class ConditionalFeedForward(nn.Module): - """Grouped expert weights as nn.Linear for quantization compatibility. + Before quantization: w1_weight [E, 2*inter, hidden] and w2_weight [E, hidden, inter] + are nn.Parameter tensors loaded from the checkpoint. - Experts are split into groups of _EXPERTS_PER_GROUP. Each group has: - gate_up_projs[g]: nn.Linear(hidden_size, G * intermediate_size * 2) - down_projs[g]: nn.Linear(intermediate_size, G * hidden_size) - This keeps each nn.Linear small enough for tinygemm int4 packing while - allowing quantize_model_() to handle them automatically. + After quantization (in export.py): replaced with packed INT4 buffers + w1 [E, 2*inter, hidden//2], w1_scale, w2 [E, hidden, inter//2], w2_scale. """ - def __init__(self, hidden_size, intermediate_size, num_experts): + def __init__(self, config): super().__init__() - self.num_experts = num_experts - self.intermediate_size = intermediate_size - self.hidden_size = hidden_size - G = _EXPERTS_PER_GROUP - assert num_experts % G == 0 - num_groups = num_experts // G - - self.gate_up_projs = nn.ModuleList( - [ - nn.Linear(hidden_size, G * intermediate_size * 2, bias=False) - for _ in range(num_groups) - ] + self.num_experts = config.num_experts + self.intermediate_size = config.moe_intermediate_size + self.hidden_size = config.hidden_size + self.group_size = 32 + + self.w1_weight = nn.Parameter( + torch.empty( + config.num_experts, + 2 * config.moe_intermediate_size, + config.hidden_size, + ) ) - self.down_projs = nn.ModuleList( - [ - nn.Linear(intermediate_size, G * hidden_size, bias=False) - for _ in range(num_groups) - ] + self.w2_weight = nn.Parameter( + torch.empty( + config.num_experts, + config.hidden_size, + config.moe_intermediate_size, + ) ) - def forward(self, x, expert_indices): - # x: (T, D), expert_indices: (T, top_k) - T = x.size(0) - top_k = expert_indices.size(1) - G = _EXPERTS_PER_GROUP - H = self.intermediate_size - D = self.hidden_size - - # Gate + Up: compute per-group, cat, gather top-k - gate_up_parts = [proj(x).view(T, G, 2, H) for proj in self.gate_up_projs] - gate_up = torch.cat(gate_up_parts, dim=1) # (T, E, 2, H) - - idx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, H) - gate_up_sel = gate_up.gather(1, idx) # (T, top_k, 2, H) - intermediate = F.silu(gate_up_sel[:, :, 0, :]) * gate_up_sel[:, :, 1, :] - - # Down: compute per-group, cat, gather correct expert per slot - intermediate_flat = intermediate.reshape(T * top_k, H) - down_parts = [ - proj(intermediate_flat).view(T, top_k, G, D) for proj in self.down_projs - ] - all_down = torch.cat(down_parts, dim=2) # (T, top_k, E, D) - - eidx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, D) - return all_down.gather(2, eidx).squeeze(2) # (T, top_k, D) - class SwiGLU(nn.Module): """SwiGLU MLP for shared expert.""" @@ -484,12 +453,9 @@ class SparseMoE(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForward( - config.hidden_size, - config.moe_intermediate_size, - config.num_experts, - ) + self.experts = FusedMoEExperts(config) self.shared_expert = SwiGLU( config.hidden_size, config.shared_expert_intermediate_size ) @@ -503,8 +469,18 @@ def forward(self, x): expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1) expert_weights = expert_weights.softmax(dim=-1) - expert_outs = self.cond_ffn(x_flat, expert_indices) - routed_out = torch.einsum("tai,ta->ti", expert_outs, expert_weights) + routed_out = torch.ops.triton.fused_moe( + x_flat, + self.experts.w1, + self.experts.w1_scale, + self.experts.w2, + self.experts.w2_scale, + expert_weights.float(), + expert_indices, + self.top_k, + self.num_experts, + self.experts.group_size, + ) shared_out = self.shared_expert(x_flat) shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat)) @@ -641,9 +617,8 @@ def _load_and_remap_checkpoint(model_dir, config): expert_weights, ) - # Stack per-expert weights, split into groups, reshape for nn.Linear + # Stack per-expert weights into [E, N, K] tensors for FusedMoEExperts if expert_weights: - G = _EXPERTS_PER_GROUP for layer_idx in range(config.num_hidden_layers): gate_list = [ expert_weights.get((layer_idx, "gate", e)) @@ -661,21 +636,13 @@ def _load_and_remap_checkpoint(model_dir, config): if gate_list[0] is not None: w_gate = torch.stack(gate_list, dim=0) # (E, H, D) w_up = torch.stack(up_list, dim=0) - fused = torch.cat([w_gate, w_up], dim=1) # (E, 2*H, D) - num_groups = config.num_experts // G - for g in range(num_groups): - chunk = fused[g * G : (g + 1) * G] - state_dict[ - f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight" - ] = chunk.reshape(-1, chunk.size(-1)) + state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = torch.cat( + [w_gate, w_up], dim=1 + ) # (E, 2*H, D) if down_list[0] is not None: - w_down = torch.stack(down_list, dim=0) # (E, D, H) - num_groups = config.num_experts // G - for g in range(num_groups): - chunk = w_down[g * G : (g + 1) * G] - state_dict[ - f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight" - ] = chunk.reshape(-1, chunk.size(-1)) + state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = torch.stack( + down_list, dim=0 + ) # (E, D, H) del expert_weights # Handle tied embeddings @@ -697,27 +664,15 @@ def _process_checkpoint_key(ckpt_key, tensor, state_dict, expert_weights): if norm_key.startswith(("model.visual.", "model.mtp_")): return - # Fused expert weights: split into groups of _EXPERTS_PER_GROUP + # Fused expert weights: store directly as [E, N, K] for FusedMoEExperts m = _FUSED_EXPERT_RE.match(norm_key) if m: layer_idx = int(m.group(1)) proj_name = m.group(2) - G = _EXPERTS_PER_GROUP - num_groups = tensor.size(0) // G if proj_name == "gate_up_proj": - # (E, 2*H, D) → groups of (G, 2*H, D) → each (G*2*H, D) - for g in range(num_groups): - chunk = tensor[g * G : (g + 1) * G] - state_dict[ - f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight" - ] = chunk.reshape(-1, chunk.size(-1)).contiguous() + state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = tensor else: - # down_proj: (E, D, H) → groups of (G, D, H) → each (G*D, H) - for g in range(num_groups): - chunk = tensor[g * G : (g + 1) * G] - state_dict[f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"] = ( - chunk.reshape(-1, chunk.size(-1)).contiguous() - ) + state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = tensor return # Per-expert weights