-
Notifications
You must be signed in to change notification settings - Fork 737
[Cherry-pick][Optimization] enable trtllm_all_reduce fusion kernel in glm model #7219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BingooYang
wants to merge
18
commits into
PaddlePaddle:release/2.5
Choose a base branch
from
BingooYang:2.5/trtllm_allreduce
base: release/2.5
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0009064
enable trtllm_all_reduce fusion kernel in glm model
BingooYang 6bfe1d9
update flashinfer paddle version
BingooYang e9dbd82
format update
BingooYang bb80bd8
fix a bug
BingooYang 756e6d8
modify test
BingooYang fe3df94
modify test
BingooYang a46338c
support empty tensor and modify test
BingooYang 1ae1e27
fix test_linear config issues
BingooYang 777edc6
modify test name
BingooYang a3979f7
add edge test case
BingooYang cbe082d
modify format
BingooYang 4e52802
fix conflict
BingooYang 2d18ba6
modify default max token num in trtllm_allreduce_fusion
BingooYang 9b01b07
add max token num branch for trtllm_allreduce_fusion
BingooYang 9e84a16
fix format
BingooYang f825e4b
fix rmsnorm config issue
BingooYang 6b1554f
modify 2025 to 2026
BingooYang 5b2d2f8
del redundent file
BingooYang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
212 changes: 212 additions & 0 deletions
212
fastdeploy/model_executor/layers/flashinfer_comm_fusion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import paddle | ||
| import paddle.distributed as dist | ||
|
|
||
| # from sglang.srt.distributed import get_tensor_model_parallel_world_size | ||
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.model_executor.utils import has_flashinfer | ||
| from fastdeploy.utils import get_logger | ||
|
|
||
| logger = get_logger("flashinfer", "flashinfer.log") | ||
|
|
||
| _flashinfer_comm = None | ||
| _workspace_manager = None | ||
|
|
||
| # fd_config.parallel_config.tensor_parallel_size | ||
|
|
||
| if has_flashinfer(): | ||
| try: | ||
| paddle.compat.enable_torch_proxy(scope={"flashinfer"}) | ||
| import flashinfer.comm as comm | ||
|
|
||
| _flashinfer_comm = comm | ||
| except ImportError: | ||
| logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") | ||
|
|
||
|
|
||
| class FlashInferWorkspaceManager: | ||
| def __init__(self): | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.world_size = None | ||
| self.rank = None | ||
| self.initialized = False | ||
|
|
||
| def initialize( | ||
| self, | ||
| world_size: int, | ||
| rank: int, | ||
| max_token_num: int, | ||
| hidden_dim: int, | ||
| group=None, | ||
| use_fp32_lamport: bool = False, | ||
| ): | ||
| """Initialize workspace""" | ||
| if self.initialized and self.world_size == world_size: | ||
| return | ||
|
|
||
| if _flashinfer_comm is None: | ||
| logger.warning("FlashInfer comm not available, skipping workspace " "initialization") | ||
| return | ||
|
|
||
| self.cleanup() | ||
|
|
||
| self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| rank, | ||
| world_size, | ||
| max_token_num, | ||
| hidden_dim, | ||
| group=group, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| self.world_size = world_size | ||
| self.rank = rank | ||
| self.initialized = True | ||
|
|
||
| logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") | ||
|
|
||
| def cleanup(self): | ||
| """Clean up workspace""" | ||
| if self.initialized and self.ipc_handles is not None: | ||
| try: | ||
| _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") | ||
| finally: | ||
| self.workspace_tensor = None | ||
| self.ipc_handles = None | ||
| self.initialized = False | ||
|
|
||
|
|
||
| _workspace_manager = FlashInferWorkspaceManager() | ||
|
|
||
|
|
||
| def ensure_workspace_initialized( | ||
| fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False | ||
| ): | ||
| """Ensure workspace is initialized""" | ||
| if not has_flashinfer() or _flashinfer_comm is None: | ||
| return False | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| return False | ||
|
|
||
| rank = dist.get_rank() | ||
|
|
||
| if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: | ||
| _workspace_manager.initialize( | ||
| world_size=world_size, | ||
| rank=rank, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=hidden_dim, | ||
| use_fp32_lamport=use_fp32_lamport, | ||
| ) | ||
|
|
||
| return _workspace_manager.initialized | ||
|
|
||
|
|
||
| def flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config: FDConfig, | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 2048, | ||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| """ | ||
| Use FlashInfer's fused allreduce + residual + RMS norm operation | ||
|
|
||
| Args: | ||
| input_tensor: Input tensor that needs allreduce | ||
| residual: Residual tensor | ||
| weight: RMS norm weight | ||
| eps: RMS norm epsilon | ||
| max_token_num: Maximum token number | ||
| use_oneshot: Whether to use oneshot mode | ||
| trigger_completion_at_end: Whether to trigger completion at end | ||
| fp32_acc: Whether to use fp32 precision | ||
|
|
||
| Returns: | ||
| Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output) | ||
| """ | ||
| if not has_flashinfer() or _flashinfer_comm is None: | ||
| logger.debug("FlashInfer not available, falling back to standard " "implementation") | ||
| return None, None | ||
|
|
||
| assert fd_config is not None | ||
| world_size = fd_config.parallel_config.tensor_parallel_size | ||
| if world_size <= 1: | ||
| logger.debug("Single GPU, no need for allreduce fusion") | ||
| return None, None | ||
|
|
||
| assert input_tensor.shape[0] <= max_token_num | ||
|
|
||
| if not ensure_workspace_initialized( | ||
| fd_config=fd_config, | ||
| max_token_num=max_token_num, | ||
| hidden_dim=input_tensor.shape[-1], | ||
| use_fp32_lamport=(input_tensor.dtype == paddle.float32), | ||
| ): | ||
| logger.debug("FlashInfer workspace not available") | ||
| return None, None | ||
|
|
||
| token_num, hidden_dim = input_tensor.shape | ||
|
|
||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| # support empty tensor | ||
| if input_tensor.shape[0] == 0: | ||
| return norm_out, residual_out | ||
| _flashinfer_comm.trtllm_allreduce_fusion( | ||
| allreduce_in=input_tensor, | ||
| world_size=world_size, | ||
| world_rank=dist.get_rank(), | ||
| token_num=token_num, | ||
| hidden_dim=hidden_dim, | ||
| workspace_ptrs=_workspace_manager.workspace_tensor, | ||
| launch_with_pdl=True, | ||
| use_oneshot=use_oneshot, | ||
| trigger_completion_at_end=trigger_completion_at_end, | ||
| fp32_acc=fp32_acc, | ||
| pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), | ||
| allreduce_out=None, | ||
| residual_in=residual, | ||
| residual_out=residual_out, | ||
| norm_out=norm_out, | ||
| quant_out=None, | ||
| scale_out=None, | ||
| rms_gamma=weight, | ||
| rms_eps=eps, | ||
| scale_factor=None, | ||
| layout_code=None, | ||
| ) | ||
|
|
||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def fake_flashinfer_allreduce_residual_rmsnorm( | ||
| input_tensor: paddle.Tensor, | ||
| residual: paddle.Tensor, | ||
| weight: paddle.Tensor, | ||
| eps: float = 1e-6, | ||
| max_token_num: int = 16384, | ||
| use_oneshot: Optional[bool] = None, | ||
| trigger_completion_at_end: bool = False, | ||
| fp32_acc: bool = False, | ||
| ) -> Tuple[paddle.Tensor, paddle.Tensor]: | ||
| residual_out = paddle.empty_like(residual) | ||
| norm_out = paddle.empty_like(input_tensor) | ||
| return norm_out, residual_out | ||
|
|
||
|
|
||
| def cleanup_flashinfer_workspace(): | ||
| global _workspace_manager | ||
| if _workspace_manager is not None: | ||
| _workspace_manager.cleanup() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,11 @@ | |
| from fastdeploy.config import FDConfig | ||
| from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused | ||
|
|
||
| from .batch_invariant_ops import ( | ||
| is_batch_invariant_mode_enabled, | ||
| rms_norm_batch_invariant, | ||
| ) | ||
| from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm | ||
| from .utils import get_tensor, modules_to_convert | ||
|
|
||
|
|
||
|
|
@@ -118,6 +123,10 @@ def __init__( | |
| self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank | ||
| self.tp_group = self.fd_config.parallel_config.tp_group | ||
| is_input_norm = prefix.endswith(".input_layernorm") | ||
| self.enable_all_reduce_fusion = ( | ||
| fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix | ||
| ) | ||
|
|
||
| self.is_last_norm = prefix.endswith(".norm") | ||
| self.split_x = ( | ||
| self.fd_config.parallel_config.use_sequence_parallel_moe | ||
|
|
@@ -236,6 +245,11 @@ def forward( | |
| norm_out = rms_norm(x, self.weight, self.eps) | ||
| return norm_out.astype(x_dtype), residual_out | ||
| norm_out = self.norm_func(x, residual_input, self.weight, self.eps) | ||
| # enable trtllm all reduce fusion | ||
| elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: | ||
| norm_out = flashinfer_allreduce_residual_rmsnorm( | ||
| fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Bug 当 建议修复方式: # enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
# Check if fusion succeeded, fallback to normal path if not
if norm_out[0] is None or norm_out[1] is None:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
) |
||
| ) | ||
| else: | ||
| norm_out = self.norm_func( | ||
| x, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
max_token_num在多处硬编码为 2048,限制了配置灵活性。建议从 FDConfig 中读取此参数。影响位置:
linear.py:941-out.shape[0] <= 2048normalization.py:249-x.shape[0] <= 2048flashinfer_comm_fusion.py:87-max_token_num: int = 2048(默认参数)flashinfer_comm_fusion.py:118-max_token_num: int = 2048(默认参数)建议在
FDConfig中添加flashinfer_allreduce_max_token_num字段,统一配置。