-
Notifications
You must be signed in to change notification settings - Fork 891
Enable fused MoE kernel for Qwen 3.5 MoE model #18388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,3 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Export Qwen 3.5 MoE to ExecuTorch .pte format (CUDA only). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -12,7 +12,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from executorch.examples.models.qwen3_5_moe.model import FusedMoEExperts, Qwen35MoE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # --------------------------------------------------------------------------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -36,13 +36,93 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.qlinear or args.qembedding: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.qlinear: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _quantize_experts_int4( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model, config, args.qlinear_group_size, use_hqq=args.hqq | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _quantize(model, config, args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
37
to
44
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.qlinear or args.qembedding: | |
| if args.qlinear: | |
| _quantize_experts_int4(model, config, args.qlinear_group_size) | |
| _quantize(model, config, args) | |
| else: | |
| # Detect whether the model is using the fused MoE experts implementation. | |
| uses_fused_experts = any( | |
| isinstance(layer.mlp.experts, FusedMoEExperts) for layer in model.layers | |
| ) | |
| if args.qlinear or args.qembedding: | |
| # The fused MoE kernel expects packed INT4 expert weights. If fused | |
| # experts are present, we require --qlinear so that | |
| # _quantize_experts_int4() runs and prepares the expected buffers. | |
| if uses_fused_experts and not args.qlinear: | |
| raise RuntimeError( | |
| "Exporting a model with FusedMoEExperts requires --qlinear to " | |
| "quantize and pack expert weights for the fused MoE kernel. " | |
| "Please rerun export with --qlinear (e.g. --qlinear 4w)." | |
| ) | |
| if args.qlinear: | |
| _quantize_experts_int4(model, config, args.qlinear_group_size) | |
| _quantize(model, config, args) | |
| else: | |
| # No quantization flags were provided. If fused experts are used, we | |
| # cannot safely run the fused MoE kernel without packed INT4 weights. | |
| if uses_fused_experts: | |
| raise RuntimeError( | |
| "Exporting a model with FusedMoEExperts without quantization " | |
| "is not supported. Please rerun export with --qlinear (e.g. " | |
| "--qlinear 4w)." | |
| ) |
Copilot
AI
Mar 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
w = getattr(experts, name).data.float() uses .data, which bypasses autograd safety checks and is generally unsafe. Use getattr(experts, name).detach() (optionally under torch.no_grad()) instead to avoid accidental in-place/autograd issues.
| w = getattr(experts, name).data.float() | |
| w = getattr(experts, name).detach().float() |
Copilot
AI
Mar 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
INT4 packing assumes K is even (and that K is compatible with group_size for per-group scales). If K is odd, low/high shapes differ and packing will error; if K % group_size != 0, scale shapes won’t match the documented [E, N, K//gs]. Please add explicit validation (e.g., assert K % 2 == 0 and assert K % group_size == 0 with a clear error) before packing.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,72 +399,41 @@ def forward(self, x, input_pos): | |
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # MoE: stacked expert weights + index by top-k | ||
| # MoE: expert weights for fused MoE Triton kernel | ||
|
|
||
| # 16 experts per group keeps each nn.Linear under ~32K output features, | ||
| # within tinygemm int4 packing limits while keeping the graph small | ||
| # (32 matmul nodes per layer instead of 768 with per-expert linears). | ||
| _EXPERTS_PER_GROUP = 16 | ||
|
|
||
| class FusedMoEExperts(nn.Module): | ||
| """Expert weights stored as stacked tensors for the fused MoE Triton kernel. | ||
|
|
||
| class ConditionalFeedForward(nn.Module): | ||
| """Grouped expert weights as nn.Linear for quantization compatibility. | ||
| Before quantization: w1_weight [E, 2*inter, hidden] and w2_weight [E, hidden, inter] | ||
| are nn.Parameter tensors loaded from the checkpoint. | ||
|
|
||
| Experts are split into groups of _EXPERTS_PER_GROUP. Each group has: | ||
| gate_up_projs[g]: nn.Linear(hidden_size, G * intermediate_size * 2) | ||
| down_projs[g]: nn.Linear(intermediate_size, G * hidden_size) | ||
| This keeps each nn.Linear small enough for tinygemm int4 packing while | ||
| allowing quantize_model_() to handle them automatically. | ||
| After quantization (in export.py): replaced with packed INT4 buffers | ||
| w1 [E, 2*inter, hidden//2], w1_scale, w2 [E, hidden, inter//2], w2_scale. | ||
| """ | ||
|
|
||
| def __init__(self, hidden_size, intermediate_size, num_experts): | ||
| def __init__(self, config): | ||
| super().__init__() | ||
| self.num_experts = num_experts | ||
| self.intermediate_size = intermediate_size | ||
| self.hidden_size = hidden_size | ||
| G = _EXPERTS_PER_GROUP | ||
| assert num_experts % G == 0 | ||
| num_groups = num_experts // G | ||
|
|
||
| self.gate_up_projs = nn.ModuleList( | ||
| [ | ||
| nn.Linear(hidden_size, G * intermediate_size * 2, bias=False) | ||
| for _ in range(num_groups) | ||
| ] | ||
| self.num_experts = config.num_experts | ||
| self.intermediate_size = config.moe_intermediate_size | ||
| self.hidden_size = config.hidden_size | ||
| self.group_size = 32 | ||
|
|
||
| self.w1_weight = nn.Parameter( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @claude what are w1 and w2 here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still not working :(. |
||
| torch.empty( | ||
| config.num_experts, | ||
| 2 * config.moe_intermediate_size, | ||
| config.hidden_size, | ||
| ) | ||
| ) | ||
| self.down_projs = nn.ModuleList( | ||
| [ | ||
| nn.Linear(intermediate_size, G * hidden_size, bias=False) | ||
| for _ in range(num_groups) | ||
| ] | ||
| self.w2_weight = nn.Parameter( | ||
| torch.empty( | ||
| config.num_experts, | ||
| config.hidden_size, | ||
| config.moe_intermediate_size, | ||
| ) | ||
| ) | ||
|
Comment on lines
+405
to
435
|
||
|
|
||
| def forward(self, x, expert_indices): | ||
| # x: (T, D), expert_indices: (T, top_k) | ||
| T = x.size(0) | ||
| top_k = expert_indices.size(1) | ||
| G = _EXPERTS_PER_GROUP | ||
| H = self.intermediate_size | ||
| D = self.hidden_size | ||
|
|
||
| # Gate + Up: compute per-group, cat, gather top-k | ||
| gate_up_parts = [proj(x).view(T, G, 2, H) for proj in self.gate_up_projs] | ||
| gate_up = torch.cat(gate_up_parts, dim=1) # (T, E, 2, H) | ||
|
|
||
| idx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, H) | ||
| gate_up_sel = gate_up.gather(1, idx) # (T, top_k, 2, H) | ||
| intermediate = F.silu(gate_up_sel[:, :, 0, :]) * gate_up_sel[:, :, 1, :] | ||
|
|
||
| # Down: compute per-group, cat, gather correct expert per slot | ||
| intermediate_flat = intermediate.reshape(T * top_k, H) | ||
| down_parts = [ | ||
| proj(intermediate_flat).view(T, top_k, G, D) for proj in self.down_projs | ||
| ] | ||
| all_down = torch.cat(down_parts, dim=2) # (T, top_k, E, D) | ||
|
|
||
| eidx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, D) | ||
| return all_down.gather(2, eidx).squeeze(2) # (T, top_k, D) | ||
|
|
||
|
|
||
| class SwiGLU(nn.Module): | ||
| """SwiGLU MLP for shared expert.""" | ||
|
|
@@ -484,12 +453,9 @@ class SparseMoE(nn.Module): | |
| def __init__(self, config): | ||
| super().__init__() | ||
| self.top_k = config.num_experts_per_tok | ||
| self.num_experts = config.num_experts | ||
| self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) | ||
| self.cond_ffn = ConditionalFeedForward( | ||
| config.hidden_size, | ||
| config.moe_intermediate_size, | ||
| config.num_experts, | ||
| ) | ||
| self.experts = FusedMoEExperts(config) | ||
| self.shared_expert = SwiGLU( | ||
| config.hidden_size, config.shared_expert_intermediate_size | ||
| ) | ||
|
|
@@ -503,8 +469,18 @@ def forward(self, x): | |
| expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1) | ||
| expert_weights = expert_weights.softmax(dim=-1) | ||
|
|
||
| expert_outs = self.cond_ffn(x_flat, expert_indices) | ||
| routed_out = torch.einsum("tai,ta->ti", expert_outs, expert_weights) | ||
| routed_out = torch.ops.triton.fused_moe( | ||
| x_flat, | ||
| self.experts.w1, | ||
| self.experts.w1_scale, | ||
| self.experts.w2, | ||
|
Comment on lines
+472
to
+476
|
||
| self.experts.w2_scale, | ||
| expert_weights.float(), | ||
| expert_indices, | ||
| self.top_k, | ||
| self.num_experts, | ||
| self.experts.group_size, | ||
| ) | ||
|
|
||
| shared_out = self.shared_expert(x_flat) | ||
| shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat)) | ||
|
|
@@ -641,9 +617,8 @@ def _load_and_remap_checkpoint(model_dir, config): | |
| expert_weights, | ||
| ) | ||
|
|
||
| # Stack per-expert weights, split into groups, reshape for nn.Linear | ||
| # Stack per-expert weights into [E, N, K] tensors for FusedMoEExperts | ||
| if expert_weights: | ||
| G = _EXPERTS_PER_GROUP | ||
| for layer_idx in range(config.num_hidden_layers): | ||
| gate_list = [ | ||
| expert_weights.get((layer_idx, "gate", e)) | ||
|
|
@@ -661,21 +636,13 @@ def _load_and_remap_checkpoint(model_dir, config): | |
| if gate_list[0] is not None: | ||
| w_gate = torch.stack(gate_list, dim=0) # (E, H, D) | ||
| w_up = torch.stack(up_list, dim=0) | ||
| fused = torch.cat([w_gate, w_up], dim=1) # (E, 2*H, D) | ||
| num_groups = config.num_experts // G | ||
| for g in range(num_groups): | ||
| chunk = fused[g * G : (g + 1) * G] | ||
| state_dict[ | ||
| f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight" | ||
| ] = chunk.reshape(-1, chunk.size(-1)) | ||
| state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = torch.cat( | ||
| [w_gate, w_up], dim=1 | ||
| ) # (E, 2*H, D) | ||
| if down_list[0] is not None: | ||
| w_down = torch.stack(down_list, dim=0) # (E, D, H) | ||
| num_groups = config.num_experts // G | ||
| for g in range(num_groups): | ||
| chunk = w_down[g * G : (g + 1) * G] | ||
| state_dict[ | ||
| f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight" | ||
| ] = chunk.reshape(-1, chunk.size(-1)) | ||
| state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = torch.stack( | ||
| down_list, dim=0 | ||
| ) # (E, D, H) | ||
| del expert_weights | ||
|
|
||
| # Handle tied embeddings | ||
|
|
@@ -697,27 +664,15 @@ def _process_checkpoint_key(ckpt_key, tensor, state_dict, expert_weights): | |
| if norm_key.startswith(("model.visual.", "model.mtp_")): | ||
| return | ||
|
|
||
| # Fused expert weights: split into groups of _EXPERTS_PER_GROUP | ||
| # Fused expert weights: store directly as [E, N, K] for FusedMoEExperts | ||
| m = _FUSED_EXPERT_RE.match(norm_key) | ||
| if m: | ||
| layer_idx = int(m.group(1)) | ||
| proj_name = m.group(2) | ||
| G = _EXPERTS_PER_GROUP | ||
| num_groups = tensor.size(0) // G | ||
| if proj_name == "gate_up_proj": | ||
| # (E, 2*H, D) → groups of (G, 2*H, D) → each (G*2*H, D) | ||
| for g in range(num_groups): | ||
| chunk = tensor[g * G : (g + 1) * G] | ||
| state_dict[ | ||
| f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight" | ||
| ] = chunk.reshape(-1, chunk.size(-1)).contiguous() | ||
| state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = tensor | ||
| else: | ||
| # down_proj: (E, D, H) → groups of (G, D, H) → each (G*D, H) | ||
| for g in range(num_groups): | ||
| chunk = tensor[g * G : (g + 1) * G] | ||
| state_dict[f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"] = ( | ||
| chunk.reshape(-1, chunk.size(-1)).contiguous() | ||
| ) | ||
| state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = tensor | ||
| return | ||
|
|
||
| # Per-expert weights | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do this inside _quantize