Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 151 additions & 73 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)

from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
Expand All @@ -53,6 +54,12 @@
)


def deep_batch_gemm(x, y, expert_idx_per_token):
out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 paddlefleet_ops 可能为 None,缺少防御性检查

paddlefleet_ops 是通过 try_import(["paddlefleet.ops"]) 导入的,当 paddlefleet 模块不可用时会返回 None。如果用户在未安装 paddlefleet 的环境下启用 FD_USE_PHI_MOE_PERMUTE,此处会抛出 AttributeError

建议添加防御性检查:

def deep_batch_gemm(x, y, expert_idx_per_token):
    if paddlefleet_ops is None:
        raise RuntimeError(
            "paddlefleet.ops is required for FD_USE_PHI_MOE_PERMUTE=1. "
            "Please install paddlefleet or disable this feature."
        )
    out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype)
    paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token)
    return out

return out


class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
Expand Down Expand Up @@ -107,24 +114,31 @@ def compute_ffn(
used_in_ep_low_latency,
)
else:
ffn_out_without_down_proj_bias = fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
dequant_scale,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
max_tokens_per_expert,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
out = deep_batch_gemm(permute_input, getattr(layer, self.added_weight_attrs[0]), expert_idx_per_token)
out = paddle.nn.functional.swiglu(out)
ffn_out_without_down_proj_bias = deep_batch_gemm(
out, getattr(layer, self.added_weight_attrs[1]), expert_idx_per_token
)
else:
ffn_out_without_down_proj_bias = fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
permute_input,
token_nums_per_expert,
getattr(layer, self.added_weight_attrs[0]),
getattr(layer, self.added_weight_attrs[1]),
dequant_scale,
(layer.up_gate_proj_bias if hasattr(layer, "up_gate_proj_bias") else None),
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
expert_idx_per_token,
max_tokens_per_expert,
self.moe_quant_type,
used_in_ep_low_latency,
estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
layer.activation,
)

if layer.with_bias:
down_proj_bias_expand = paddle.index_select(layer.down_proj_bias, expert_idx_per_token, axis=0)
Expand Down Expand Up @@ -292,6 +306,9 @@ def apply_tp(
Paddle Cutlass compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_nums_per_expert = None
dequant_scale = None
max_tokens_per_expert = None
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
Expand Down Expand Up @@ -325,27 +342,48 @@ def apply_tp(
dequant_scale = None
max_tokens_per_expert = None
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
topk_idx = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
topk_weights,
permute_scale,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 permute_scale 变量未被使用

moe_permute 返回的 permute_scale 在后续代码中未被使用。如果确实不需要,建议用 _ 替代以明确表示该变量被有意忽略:

(
    permute_input,
    permute_indices_per_token,
    topk_weights,
    _,  # permute_scale not used in bf16 path
    expert_idx_per_token,
) = paddle.nn.functional.moe_permute(...)

注意:行 423 处也存在相同情况。

expert_idx_per_token,
) = paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
return_expert_indices=True,
override_buffer_size=override_buffer_size,
)
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
None, # Use layer.gate_correction_bias in get_moe_scores.
(
layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None
), # if set, permute_input will be int8_t
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=True,
)
else:
if current_platform.is_iluvatar():
(
Expand All @@ -368,38 +406,67 @@ def apply_tp(
dequant_scale = None
max_tokens_per_expert = None
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
topk_idx = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
topk_weights,
permute_scale,
expert_idx_per_token,
) = paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
return_expert_indices=True,
override_buffer_size=override_buffer_size,
)
else:
(
permute_input,
token_nums_per_expert,
permute_indices_per_token,
topk_weights,
topk_idx,
expert_idx_per_token,
dequant_scale,
max_tokens_per_expert,
) = moe_expert_dispatch(
x,
gate_out,
layer.gate_correction_bias,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
layer.top_k,
False,
self.moe_quant_type,
topk_only_mode=False,
)

if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
if not fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")

ffn_out = self.compute_ffn(
layer,
Expand All @@ -412,16 +479,27 @@ def apply_tp(
max_tokens_per_expert,
)

# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
fused_moe_out, out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_idx,
token_prob_unzipped=topk_weights,
total_zipped_tokens=x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
)
else:
# reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor
fused_moe_out = moe_expert_reduce(
ffn_out,
topk_weights,
permute_indices_per_token,
topk_idx,
None,
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)

return fused_moe_out

Expand Down
Loading