-
Notifications
You must be signed in to change notification settings - Fork 736
[Cherry-pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model #7228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
9e65d0b
c1dd403
ceec743
4de40ed
1a50583
fd2289c
101a9f8
39a6174
84f8d28
ae2ef92
a04b37d
ff9651c
5d19d7e
ee030a9
ab67c64
d534781
98e9f9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
|
|
||
| # from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.model_executor.utils import has_flashinfer | ||
| from fastdeploy.utils import get_logger | ||
|
|
||
| logger = get_logger("flashinfer", "flashinfer.log") | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
| # fd_config.parallel_config.tensor_parallel_size | ||
|
|
||
| if has_flashinfer(): | ||
| try: | ||
| paddle.compat.enable_torch_proxy(scope={"flashinfer"}) | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| if _flashinfer_comm is None: | ||
| logger.warning("FlashInfer comm not available, skipping workspace " "initialization") | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| if not has_flashinfer() or _flashinfer_comm is None: | ||
| return False | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config: FDConfig, | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 2048, | ||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
|
|
||
| Args: | ||
| input_tensor: Input tensor that needs allreduce | ||
| residual: Residual tensor | ||
| weight: RMS norm weight | ||
| eps: RMS norm epsilon | ||
| max_token_num: Maximum token number | ||
| use_oneshot: Whether to use oneshot mode | ||
| trigger_completion_at_end: Whether to trigger completion at end | ||
| fp32_acc: Whether to use fp32 precision | ||
|
|
||
| Returns: | ||
| Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output) | ||
| """ | ||
| if not has_flashinfer() or _flashinfer_comm is None: | ||
| logger.debug("FlashInfer not available, falling back to standard " "implementation") | ||
| return None, None | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| assert input_tensor.shape[0] <= max_token_num | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| fd_config=fd_config, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == paddle.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| # support empty tensor | ||
| if input_tensor.shape[0] == 0: | ||
| return norm_out, residual_out | ||
| _flashinfer_comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def fake_flashinfer_allreduce_residual_rmsnorm( | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 16384, | ||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -860,6 +860,9 @@ def __init__( | |
| skip_quant (bool): Whether to skip quantization. Defaults to False. | ||
| """ | ||
| self.fd_config = fd_config | ||
| self.enable_all_reduce_fusion = ( | ||
| fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 通过 建议使用明确的配置参数,例如: # 在 ParallelConfig 中添加
self.enable_o_proj_fusion: bool = False
# 在 linear.py 中
self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and fd_config.parallel_config.enable_o_proj_fusion |
||
| ) | ||
| self.ep_size = fd_config.parallel_config.expert_parallel_size | ||
| self.tp_size = fd_config.parallel_config.tensor_parallel_size | ||
| self.tp_group = fd_config.parallel_config.tp_group | ||
|
|
@@ -937,7 +940,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: | |
|
|
||
| out = self.quant_method.apply(self, x) | ||
|
|
||
| if self.reduce_results and self.tp_size > 1: | ||
| need_tp_all_reduce = ( | ||
| self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) | ||
| ) | ||
| if need_tp_all_reduce: | ||
| out = tensor_model_parallel_all_reduce(out, self.tp_group) | ||
|
|
||
| return out | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| is_batch_invariant_mode_enabled, | ||
| rms_norm_batch_invariant, | ||
| ) | ||
| from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm | ||
| from .utils import get_tensor, modules_to_convert | ||
|
|
||
|
|
||
|
|
@@ -122,6 +123,10 @@ def __init__( | |
| self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank | ||
| self.tp_group = self.fd_config.parallel_config.tp_group | ||
| is_input_norm = prefix.endswith(".input_layernorm") | ||
| self.enable_all_reduce_fusion = ( | ||
| fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix | ||
| ) | ||
|
|
||
| self.is_last_norm = prefix.endswith(".norm") | ||
| self.split_x = ( | ||
| self.fd_config.parallel_config.use_sequence_parallel_moe | ||
|
|
@@ -240,6 +245,11 @@ def forward( | |
| norm_out = rms_norm(x, self.weight, self.eps) | ||
| return norm_out.astype(x_dtype), residual_out | ||
| norm_out = self.norm_func(x, residual_input, self.weight, self.eps) | ||
| # enable trtllm all reduce fusion | ||
| elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 当 建议添加 fallback 逻辑检查返回值是否为 None: elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
fusion_result = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
if fusion_result[0] is not None:
norm_out = fusion_result
else:
# Fallback to standard implementation
norm_out = self.norm_func(
x, norm_weight=self.weight, norm_bias=None, epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis, bias=self.bias, residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
) |
||
| norm_out = flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps | ||
| ) | ||
| else: | ||
| if is_batch_invariant_mode_enabled(): | ||
| # M-invariant path: per-row Triton kernel, no cross-row reduction | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,7 +129,6 @@ def __init__( | |
| self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size | ||
| self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank | ||
| self.tp_group = fd_config.parallel_config.tp_group | ||
|
|
||
| self.use_ep = self.expert_parallel_size > 1 | ||
| self.use_tp = self.tensor_parallel_size > 1 | ||
|
|
||
|
|
@@ -224,7 +223,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None | |
|
|
||
| self.o_proj = RowParallelLinear( | ||
| fd_config, | ||
| prefix=f"{prefix}.o_proj", | ||
| prefix=f"{prefix}.enable_all_reduce.o_proj", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 这会导致模型权重无法正确加载,因为模型参数名变成了 建议:
|
||
| input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, | ||
| output_size=fd_config.model_config.hidden_size, | ||
| layer_id=layer_id, | ||
|
|
@@ -299,15 +298,15 @@ def __init__( | |
| fd_config, | ||
| hidden_size=fd_config.model_config.hidden_size, | ||
| eps=fd_config.model_config.rms_norm_eps, | ||
| prefix=f"{prefix}.input_layernorm", | ||
| prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 同上, |
||
| layer_id=layer_id, | ||
| ) | ||
|
|
||
| self.post_attention_layernorm = RMSNorm( | ||
| fd_config, | ||
| hidden_size=fd_config.model_config.hidden_size, | ||
| eps=fd_config.model_config.rms_norm_eps, | ||
| prefix=f"{prefix}.post_attention_layernorm", | ||
| prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 同上, |
||
| layer_id=layer_id, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议 cleanup 时使用
dist.get_group()获取默认 group,但初始化时(第 55 行)没有指定 group,这意味着可能使用了非默认的通信组。建议:
FlashInferWorkspaceManager中记录初始化时使用的 group