Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def __init__(
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.disable_custom_all_reduce: bool = False
self.enable_flashinfer_allreduce_fusion: bool = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ class EngineArgs:
Flag to disable the custom all-reduce kernel.
"""

enable_flashinfer_allreduce_fusion: bool = False
"""
Flag to enable all reduce fusion kernel in flashinfer.
"""

use_internode_ll_two_stage: bool = False
"""
Flag to use the internode_ll_two_stage kernel.
Expand Down Expand Up @@ -977,6 +982,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.disable_custom_all_reduce,
help="Flag to disable custom all-reduce.",
)
parallel_group.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
default=EngineArgs.enable_flashinfer_allreduce_fusion,
help="Flag to enable all reduce fusion kernel in flashinfer.",
)
parallel_group.add_argument(
"--use-internode-ll-two-stage",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,7 @@ def _start_worker_service(self):
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def _start_worker_service(self):
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
212 changes: 212 additions & 0 deletions fastdeploy/model_executor/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Optional, Tuple

import paddle
import paddle.distributed as dist

# from sglang.srt.distributed import get_tensor_model_parallel_world_size
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import has_flashinfer
from fastdeploy.utils import get_logger

logger = get_logger("flashinfer", "flashinfer.log")

_flashinfer_comm = None
_workspace_manager = None

# fd_config.parallel_config.tensor_parallel_size

if has_flashinfer():
try:
paddle.compat.enable_torch_proxy(scope={"flashinfer"})
import flashinfer.comm as comm

_flashinfer_comm = comm
except ImportError:
logger.warning("flashinfer.comm is not available, falling back to standard " "implementation")


class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False

def initialize(
self,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
group=None,
use_fp32_lamport: bool = False,
):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return

if _flashinfer_comm is None:
logger.warning("FlashInfer comm not available, skipping workspace " "initialization")
return

self.cleanup()

self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)

self.world_size = world_size
self.rank = rank
self.initialized = True

logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}")

def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 cleanup 时使用 dist.get_group() 获取默认 group,但初始化时(第 55 行)没有指定 group,这意味着可能使用了非默认的通信组。

建议:

  1. FlashInferWorkspaceManager 中记录初始化时使用的 group
  2. cleanup 时使用相同的 group
class FlashInferWorkspaceManager:
    def __init__(self):
        self.group = None  # 记录初始化时使用的 group
        ...

    def initialize(self, ..., group=None, ...):
        self.group = group  # 保存 group
        ...

    def cleanup(self):
        ...
        _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
            self.ipc_handles, group=self.group or dist.get_group()
        )

except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False


_workspace_manager = FlashInferWorkspaceManager()


def ensure_workspace_initialized(
fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not has_flashinfer() or _flashinfer_comm is None:
return False

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
return False

rank = dist.get_rank()

if not _workspace_manager.initialized or _workspace_manager.world_size != world_size:
_workspace_manager.initialize(
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport,
)

return _workspace_manager.initialized


def flashinfer_allreduce_residual_rmsnorm(
fd_config: FDConfig,
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 2048,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation

Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision

Returns:
Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output)
"""
if not has_flashinfer() or _flashinfer_comm is None:
logger.debug("FlashInfer not available, falling back to standard " "implementation")
return None, None

assert fd_config is not None
world_size = fd_config.parallel_config.tensor_parallel_size
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

assert input_tensor.shape[0] <= max_token_num

if not ensure_workspace_initialized(
fd_config=fd_config,
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == paddle.float32),
):
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
# support empty tensor
if input_tensor.shape[0] == 0:
return norm_out, residual_out
_flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)

return norm_out, residual_out


def fake_flashinfer_allreduce_residual_rmsnorm(
input_tensor: paddle.Tensor,
residual: paddle.Tensor,
weight: paddle.Tensor,
eps: float = 1e-6,
max_token_num: int = 16384,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
residual_out = paddle.empty_like(residual)
norm_out = paddle.empty_like(input_tensor)
return norm_out, residual_out


def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
8 changes: 7 additions & 1 deletion fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ def __init__(
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
self.fd_config = fd_config
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 通过 "enable_all_reduce" in prefix 字符串匹配来判断是否启用 fusion,这种方式不够直观且容易出错(如本次 PR 中的权重加载问题)。

建议使用明确的配置参数,例如:

# 在 ParallelConfig 中添加
self.enable_o_proj_fusion: bool = False

# 在 linear.py 中
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and fd_config.parallel_config.enable_o_proj_fusion

)
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
Expand Down Expand Up @@ -937,7 +940,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

out = self.quant_method.apply(self, x)

if self.reduce_results and self.tp_size > 1:
need_tp_all_reduce = (
self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048)
)
if need_tp_all_reduce:
out = tensor_model_parallel_all_reduce(out, self.tp_group)

return out
Expand Down
10 changes: 10 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_batch_invariant_mode_enabled,
rms_norm_batch_invariant,
)
from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm
from .utils import get_tensor, modules_to_convert


Expand Down Expand Up @@ -122,6 +123,10 @@ def __init__(
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
self.enable_all_reduce_fusion = (
fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix
)

self.is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.fd_config.parallel_config.use_sequence_parallel_moe
Expand Down Expand Up @@ -240,6 +245,11 @@ def forward(
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugflashinfer_allreduce_residual_rmsnorm 返回 (None, None) 时(flashinfer 不可用或未初始化),后续代码访问 norm_out 会抛出 TypeError

建议添加 fallback 逻辑检查返回值是否为 None:

elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
    fusion_result = flashinfer_allreduce_residual_rmsnorm(
        fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
    )
    if fusion_result[0] is not None:
        norm_out = fusion_result
    else:
        # Fallback to standard implementation
        norm_out = self.norm_func(
            x, norm_weight=self.weight, norm_bias=None, epsilon=self.eps,
            begin_norm_axis=self.begin_norm_axis, bias=self.bias, residual=residual_input,
            quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
            quant_round_type=self.quant_round_type,
            quant_max_bound=self.quant_max_bound,
            quant_min_bound=self.quant_min_bound,
        )

norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
else:
if is_batch_invariant_mode_enabled():
# M-invariant path: per-row Triton kernel, no cross-row reduction
Expand Down
9 changes: 2 additions & 7 deletions fastdeploy/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
"""

import importlib
import importlib.util
import math
from enum import Enum
from typing import Callable, Optional
Expand All @@ -25,11 +23,12 @@

from fastdeploy import envs
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs
from fastdeploy.platforms import current_platform

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch

from fastdeploy.utils import get_logger

from ..moe import FusedMoE
Expand Down Expand Up @@ -59,10 +58,6 @@ def check_device_capability(num):
return False


def has_flashinfer():
return importlib.util.find_spec("flashinfer") is not None


def round_up(a, b):
return ((a + b - 1) // b) * b

Expand Down
7 changes: 3 additions & 4 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __init__(
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group

self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1

Expand Down Expand Up @@ -224,7 +223,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
prefix=f"{prefix}.enable_all_reduce.o_proj",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug o_proj 的 prefix 从 f"{prefix}.o_proj" 修改为 f"{prefix}.enable_all_reduce.o_proj",但 _get_tensor_parallel_mappings 中的权重映射仍使用 layers.0.self_attn.o_proj.weight

这会导致模型权重无法正确加载,因为模型参数名变成了 layers.{i}.self_attn.enable_all_reduce.o_proj.weight,但权重文件中的 key 是 layers.{i}.self_attn.o_proj.weight

建议:

  1. 不修改 prefix,使用单独的配置参数来控制 fusion 行为
  2. 或者同时更新权重映射逻辑以匹配新的 prefix

input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
Expand Down Expand Up @@ -299,15 +298,15 @@ def __init__(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同上,input_layernorm 的 prefix 修改为包含 enable_all_reduce_fusion,但权重映射逻辑没有相应更新,会导致权重加载失败。

layer_id=layer_id,
)

self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 同上,post_attention_layernorm 的 prefix 修改为包含 enable_all_reduce_fusion,但权重映射逻辑没有相应更新。

layer_id=layer_id,
)

Expand Down
Loading
Loading