Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
Export Qwen 3.5 MoE to ExecuTorch .pte format (CUDA only).

Expand All @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE
from executorch.examples.models.qwen3_5_moe.model import FusedMoEExperts, Qwen35MoE


# ---------------------------------------------------------------------------
Expand All @@ -36,13 +36,93 @@
)

if args.qlinear or args.qembedding:
if args.qlinear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do this inside _quantize

_quantize_experts_int4(
model, config, args.qlinear_group_size, use_hqq=args.hqq
)
_quantize(model, config, args)
else:
Comment on lines 37 to 44
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_and_quantize() only calls _quantize_experts_int4() when args.qlinear is set, but the model’s MoE path expects packed experts.w1/w2 buffers. Running export with no quantization (or with --qembedding only) will fail once SparseMoE.forward is traced/executed. Either always quantize experts when using the fused kernel, add a fallback for unquantized experts, or explicitly error out when --qlinear is not provided.

Suggested change
if args.qlinear or args.qembedding:
if args.qlinear:
_quantize_experts_int4(model, config, args.qlinear_group_size)
_quantize(model, config, args)
else:
# Detect whether the model is using the fused MoE experts implementation.
uses_fused_experts = any(
isinstance(layer.mlp.experts, FusedMoEExperts) for layer in model.layers
)
if args.qlinear or args.qembedding:
# The fused MoE kernel expects packed INT4 expert weights. If fused
# experts are present, we require --qlinear so that
# _quantize_experts_int4() runs and prepares the expected buffers.
if uses_fused_experts and not args.qlinear:
raise RuntimeError(
"Exporting a model with FusedMoEExperts requires --qlinear to "
"quantize and pack expert weights for the fused MoE kernel. "
"Please rerun export with --qlinear (e.g. --qlinear 4w)."
)
if args.qlinear:
_quantize_experts_int4(model, config, args.qlinear_group_size)
_quantize(model, config, args)
else:
# No quantization flags were provided. If fused experts are used, we
# cannot safely run the fused MoE kernel without packed INT4 weights.
if uses_fused_experts:
raise RuntimeError(
"Exporting a model with FusedMoEExperts without quantization "
"is not supported. Please rerun export with --qlinear (e.g. "
"--qlinear 4w)."
)

Copilot uses AI. Check for mistakes.
model.to(dtype=torch.bfloat16)

return model, config


def _quantize_experts_int4(model, config, group_size=32, use_hqq=False):
"""Quantize expert weights to packed INT4 for the fused MoE kernel.

Two quantization methods:
--hqq: HQQ (Half-Quadratic Quantization) iteratively refines scales
via least-squares for better accuracy (slower).
default: Standard min/max symmetric quantization (faster).

Converts w1_weight [E, N, K] and w2_weight [E, N, K] to:
w1 [E, N, K//2] int8 packed, w1_scale [E, N, K//gs] bf16
w2 [E, N, K//2] int8 packed, w2_scale [E, N, K//gs] bf16
"""
if use_hqq:
from torchao.quantization.quant_primitives import (
_choose_qparams_and_quantize_scale_only_hqq,
)
else:
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
MappingType,
quantize_affine,
)

method = "HQQ" if use_hqq else "min/max"

for i, layer in enumerate(model.layers):
experts = layer.mlp.experts
if not isinstance(experts, FusedMoEExperts):
continue

experts.group_size = group_size
for name in ("w1_weight", "w2_weight"):
w = getattr(experts, name).data.float()
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w = getattr(experts, name).data.float() uses .data, which bypasses autograd safety checks and is generally unsafe. Use getattr(experts, name).detach() (optionally under torch.no_grad()) instead to avoid accidental in-place/autograd issues.

Suggested change
w = getattr(experts, name).data.float()
w = getattr(experts, name).detach().float()

Copilot uses AI. Check for mistakes.
E, N, K = w.shape

if use_hqq:
qdata, scale = _choose_qparams_and_quantize_scale_only_hqq(
w.view(E * N, K),
block_size=[1, group_size],
qmin=-8,
qmax=7,
)
int_data = qdata.to(torch.int8).view(E, N, K)
scale = scale.view(E, N, -1)
else:
block_size = (1, 1, group_size)
scale, zero_point = choose_qparams_affine(
w, MappingType.SYMMETRIC, block_size,
target_dtype=torch.int8, quant_min=-8, quant_max=7,
)
int_data = quantize_affine(
w, block_size, scale, zero_point,
output_dtype=torch.int8, quant_min=-8, quant_max=7,
)
scale = scale.reshape(E, N, -1)

# Pack two int4 values per byte: even K -> low nibble, odd K -> high nibble
uint4 = (int_data + 8).to(torch.int16) # shift to unsigned [0, 15]
low = uint4[:, :, 0::2]
high = uint4[:, :, 1::2]
packed = (low | (high << 4)).to(torch.int8) # [E, N, K//2]
Comment on lines +106 to +110
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INT4 packing assumes K is even (and that K is compatible with group_size for per-group scales). If K is odd, low/high shapes differ and packing will error; if K % group_size != 0, scale shapes won’t match the documented [E, N, K//gs]. Please add explicit validation (e.g., assert K % 2 == 0 and assert K % group_size == 0 with a clear error) before packing.

Copilot uses AI. Check for mistakes.

buf_name = name.replace("_weight", "")
experts.register_buffer(buf_name, packed)
experts.register_buffer(
f"{buf_name}_scale", scale.to(torch.bfloat16)
)
delattr(experts, name)

print(
f" Quantized experts (INT4 {method}) layer {i + 1}/{config.num_hidden_layers}",
end="\r",
)
print()


def _to_device_skip_meta(module, device, dtype=None):
"""Move submodules to device, skipping any that have meta-device buffers.

Expand Down Expand Up @@ -287,6 +367,10 @@
parser.add_argument(
"--qembedding", default=None, choices=["8w"], help="Quantize embedding layers."
)
parser.add_argument(
"--hqq", action="store_true",
help="Use HQQ scale-only optimization for expert quantization (slower, better accuracy).",
)
args = parser.parse_args()

# Register FLA Triton kernel
Expand Down
143 changes: 49 additions & 94 deletions examples/models/qwen3_5_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,72 +399,41 @@ def forward(self, x, input_pos):


# ---------------------------------------------------------------------------
# MoE: stacked expert weights + index by top-k
# MoE: expert weights for fused MoE Triton kernel

# 16 experts per group keeps each nn.Linear under ~32K output features,
# within tinygemm int4 packing limits while keeping the graph small
# (32 matmul nodes per layer instead of 768 with per-expert linears).
_EXPERTS_PER_GROUP = 16

class FusedMoEExperts(nn.Module):
"""Expert weights stored as stacked tensors for the fused MoE Triton kernel.

class ConditionalFeedForward(nn.Module):
"""Grouped expert weights as nn.Linear for quantization compatibility.
Before quantization: w1_weight [E, 2*inter, hidden] and w2_weight [E, hidden, inter]
are nn.Parameter tensors loaded from the checkpoint.

Experts are split into groups of _EXPERTS_PER_GROUP. Each group has:
gate_up_projs[g]: nn.Linear(hidden_size, G * intermediate_size * 2)
down_projs[g]: nn.Linear(intermediate_size, G * hidden_size)
This keeps each nn.Linear small enough for tinygemm int4 packing while
allowing quantize_model_() to handle them automatically.
After quantization (in export.py): replaced with packed INT4 buffers
w1 [E, 2*inter, hidden//2], w1_scale, w2 [E, hidden, inter//2], w2_scale.
"""

def __init__(self, hidden_size, intermediate_size, num_experts):
def __init__(self, config):
super().__init__()
self.num_experts = num_experts
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
G = _EXPERTS_PER_GROUP
assert num_experts % G == 0
num_groups = num_experts // G

self.gate_up_projs = nn.ModuleList(
[
nn.Linear(hidden_size, G * intermediate_size * 2, bias=False)
for _ in range(num_groups)
]
self.num_experts = config.num_experts
self.intermediate_size = config.moe_intermediate_size
self.hidden_size = config.hidden_size
self.group_size = 32

self.w1_weight = nn.Parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude what are w1 and w2 here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not working :(.

torch.empty(
config.num_experts,
2 * config.moe_intermediate_size,
config.hidden_size,
)
)
self.down_projs = nn.ModuleList(
[
nn.Linear(intermediate_size, G * hidden_size, bias=False)
for _ in range(num_groups)
]
self.w2_weight = nn.Parameter(
torch.empty(
config.num_experts,
config.hidden_size,
config.moe_intermediate_size,
)
)
Comment on lines +405 to 435
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FusedMoEExperts defines w1_weight/w2_weight, but SparseMoE.forward accesses self.experts.w1, w1_scale, w2, w2_scale. These attributes only get created after _quantize_experts_int4() runs, so the model will error (AttributeError) when running or exporting without --qlinear (or with --qembedding only). Consider providing a non-quantized path (or initializing w1/w2 aliases) or enforcing/validating that expert quantization is always performed before any forward/export.

Copilot uses AI. Check for mistakes.

def forward(self, x, expert_indices):
# x: (T, D), expert_indices: (T, top_k)
T = x.size(0)
top_k = expert_indices.size(1)
G = _EXPERTS_PER_GROUP
H = self.intermediate_size
D = self.hidden_size

# Gate + Up: compute per-group, cat, gather top-k
gate_up_parts = [proj(x).view(T, G, 2, H) for proj in self.gate_up_projs]
gate_up = torch.cat(gate_up_parts, dim=1) # (T, E, 2, H)

idx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, H)
gate_up_sel = gate_up.gather(1, idx) # (T, top_k, 2, H)
intermediate = F.silu(gate_up_sel[:, :, 0, :]) * gate_up_sel[:, :, 1, :]

# Down: compute per-group, cat, gather correct expert per slot
intermediate_flat = intermediate.reshape(T * top_k, H)
down_parts = [
proj(intermediate_flat).view(T, top_k, G, D) for proj in self.down_projs
]
all_down = torch.cat(down_parts, dim=2) # (T, top_k, E, D)

eidx = expert_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, D)
return all_down.gather(2, eidx).squeeze(2) # (T, top_k, D)


class SwiGLU(nn.Module):
"""SwiGLU MLP for shared expert."""
Expand All @@ -484,12 +453,9 @@ class SparseMoE(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.cond_ffn = ConditionalFeedForward(
config.hidden_size,
config.moe_intermediate_size,
config.num_experts,
)
self.experts = FusedMoEExperts(config)
self.shared_expert = SwiGLU(
config.hidden_size, config.shared_expert_intermediate_size
)
Expand All @@ -503,8 +469,18 @@ def forward(self, x):
expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1)
expert_weights = expert_weights.softmax(dim=-1)

expert_outs = self.cond_ffn(x_flat, expert_indices)
routed_out = torch.einsum("tai,ta->ti", expert_outs, expert_weights)
routed_out = torch.ops.triton.fused_moe(
x_flat,
self.experts.w1,
self.experts.w1_scale,
self.experts.w2,
Comment on lines +472 to +476
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseMoE.forward unconditionally calls torch.ops.triton.fused_moe, but this repo does not appear to register a triton::fused_moe op (no kernel wrapper / @triton_op found). This will raise at runtime and will also block torch.export(strict=True) unless a fake/abstract impl is registered. Please add the fused MoE Triton op registration (similar to triton::chunk_gated_delta_rule) and ensure it’s imported during export so the op is available.

Copilot uses AI. Check for mistakes.
self.experts.w2_scale,
expert_weights.float(),
expert_indices,
self.top_k,
self.num_experts,
self.experts.group_size,
)

shared_out = self.shared_expert(x_flat)
shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat))
Expand Down Expand Up @@ -641,9 +617,8 @@ def _load_and_remap_checkpoint(model_dir, config):
expert_weights,
)

# Stack per-expert weights, split into groups, reshape for nn.Linear
# Stack per-expert weights into [E, N, K] tensors for FusedMoEExperts
if expert_weights:
G = _EXPERTS_PER_GROUP
for layer_idx in range(config.num_hidden_layers):
gate_list = [
expert_weights.get((layer_idx, "gate", e))
Expand All @@ -661,21 +636,13 @@ def _load_and_remap_checkpoint(model_dir, config):
if gate_list[0] is not None:
w_gate = torch.stack(gate_list, dim=0) # (E, H, D)
w_up = torch.stack(up_list, dim=0)
fused = torch.cat([w_gate, w_up], dim=1) # (E, 2*H, D)
num_groups = config.num_experts // G
for g in range(num_groups):
chunk = fused[g * G : (g + 1) * G]
state_dict[
f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight"
] = chunk.reshape(-1, chunk.size(-1))
state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = torch.cat(
[w_gate, w_up], dim=1
) # (E, 2*H, D)
if down_list[0] is not None:
w_down = torch.stack(down_list, dim=0) # (E, D, H)
num_groups = config.num_experts // G
for g in range(num_groups):
chunk = w_down[g * G : (g + 1) * G]
state_dict[
f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"
] = chunk.reshape(-1, chunk.size(-1))
state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = torch.stack(
down_list, dim=0
) # (E, D, H)
del expert_weights

# Handle tied embeddings
Expand All @@ -697,27 +664,15 @@ def _process_checkpoint_key(ckpt_key, tensor, state_dict, expert_weights):
if norm_key.startswith(("model.visual.", "model.mtp_")):
return

# Fused expert weights: split into groups of _EXPERTS_PER_GROUP
# Fused expert weights: store directly as [E, N, K] for FusedMoEExperts
m = _FUSED_EXPERT_RE.match(norm_key)
if m:
layer_idx = int(m.group(1))
proj_name = m.group(2)
G = _EXPERTS_PER_GROUP
num_groups = tensor.size(0) // G
if proj_name == "gate_up_proj":
# (E, 2*H, D) → groups of (G, 2*H, D) → each (G*2*H, D)
for g in range(num_groups):
chunk = tensor[g * G : (g + 1) * G]
state_dict[
f"layers.{layer_idx}.mlp.cond_ffn.gate_up_projs.{g}.weight"
] = chunk.reshape(-1, chunk.size(-1)).contiguous()
state_dict[f"layers.{layer_idx}.mlp.experts.w1_weight"] = tensor
else:
# down_proj: (E, D, H) → groups of (G, D, H) → each (G*D, H)
for g in range(num_groups):
chunk = tensor[g * G : (g + 1) * G]
state_dict[f"layers.{layer_idx}.mlp.cond_ffn.down_projs.{g}.weight"] = (
chunk.reshape(-1, chunk.size(-1)).contiguous()
)
state_dict[f"layers.{layer_idx}.mlp.experts.w2_weight"] = tensor
return

# Per-expert weights
Expand Down
Loading