From 0009064e6f6ed105cfa5d39c4a52021251e275a2 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 4 Mar 2026 21:05:40 +0800 Subject: [PATCH 01/18] enable trtllm_all_reduce fusion kernel in glm model --- fastdeploy/config.py | 1 + fastdeploy/engine/args_utils.py | 11 + fastdeploy/engine/common_engine.py | 1 + fastdeploy/engine/engine.py | 1 + .../layers/flashinfer_comm_fusion.py | 209 +++++++++++++++ fastdeploy/model_executor/layers/linear.py | 5 +- .../model_executor/layers/normalization.py | 14 + .../layers/quantization/mxfp4.py | 9 +- fastdeploy/model_executor/models/glm4_moe.py | 8 +- fastdeploy/model_executor/utils.py | 4 + fastdeploy/worker/worker_process.py | 6 + tests/layers/test_rms_allreduce_fusion.py | 249 ++++++++++++++++++ 12 files changed, 506 insertions(+), 12 deletions(-) create mode 100644 fastdeploy/model_executor/layers/flashinfer_comm_fusion.py create mode 100644 tests/layers/test_rms_allreduce_fusion.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d4adafa9d6f..3d03f70636a 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -644,6 +644,7 @@ def __init__( self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.disable_custom_all_reduce: bool = False + self.enable_flashinfer_allreduce_fusion: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 1b9e21c2020..c502f1226df 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -273,6 +273,11 @@ class EngineArgs: Flag to disable the custom all-reduce kernel. """ + enable_flashinfer_allreduce_fusion: bool = False + """ + Flag to enable all reduce fusion kernel in flashinfer. + """ + use_internode_ll_two_stage: bool = False """ Flag to use the internode_ll_two_stage kernel. @@ -970,6 +975,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_custom_all_reduce, help="Flag to disable custom all-reduce.", ) + parallel_group.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + default=EngineArgs.enable_flashinfer_allreduce_fusion, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parallel_group.add_argument( "--use-internode-ll-two-stage", action="store_true", diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index b2c6cffce2f..61944c9bc97 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2372,6 +2372,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e0f27cb0509..ffd4ef3d6fa 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -622,6 +622,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py new file mode 100644 index 00000000000..4cdca3c89ce --- /dev/null +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -0,0 +1,209 @@ +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist + +# from sglang.srt.distributed import get_tensor_model_parallel_world_size +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.utils import has_flashinfer +from fastdeploy.utils import get_logger + +logger = get_logger("flashinfer", "flashinfer.log") + +_flashinfer_comm = None +_workspace_manager = None + +# fd_config.parallel_config.tensor_parallel_size + +if has_flashinfer(): + try: + paddle.compat.enable_torch_proxy(scope={"flashinfer"}) + import flashinfer.comm as comm + _flashinfer_comm = comm + except ImportError: + logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + if _flashinfer_comm is None: + logger.warning("FlashInfer comm not available, skipping workspace " "initialization") + return + + self.cleanup() + + self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + _flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized( + fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): + """Ensure workspace is initialized""" + if not has_flashinfer() or _flashinfer_comm is None: + return False + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + return False + + rank = dist.get_rank() + + if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport, + ) + + return _workspace_manager.initialized + + +def flashinfer_allreduce_residual_rmsnorm( + fd_config: FDConfig, + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 4096, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + + Args: + input_tensor: Input tensor that needs allreduce + residual: Residual tensor + weight: RMS norm weight + eps: RMS norm epsilon + max_token_num: Maximum token number + use_oneshot: Whether to use oneshot mode + trigger_completion_at_end: Whether to trigger completion at end + fp32_acc: Whether to use fp32 precision + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: (norm_output, residual_output) + """ + if not has_flashinfer() or _flashinfer_comm is None: + logger.debug("FlashInfer not available, falling back to standard " "implementation") + return None, None + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + assert input_tensor.shape[0] <= max_token_num + + if not ensure_workspace_initialized( + fd_config=fd_config, + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == paddle.float32), + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + + _flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=(_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def fake_flashinfer_allreduce_residual_rmsnorm( + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 16384, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index ebdfe088cc7..5f9e04d3fee 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -857,6 +857,9 @@ def __init__( skip_quant (bool): Whether to skip quantization. Defaults to False. """ self.fd_config = fd_config + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "enable_all_reduce" in prefix + ) self.ep_size = fd_config.parallel_config.expert_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group @@ -934,7 +937,7 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: out = self.quant_method.apply(self, x) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and self.tp_size > 1 and not self.enable_all_reduce_fusion: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index b9e3add5a46..73d2139335e 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -31,6 +31,11 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused +from .batch_invariant_ops import ( + is_batch_invariant_mode_enabled, + rms_norm_batch_invariant, +) +from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm from .utils import get_tensor, modules_to_convert @@ -118,6 +123,10 @@ def __init__( self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank self.tp_group = self.fd_config.parallel_config.tp_group is_input_norm = prefix.endswith(".input_layernorm") + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and "post_attention_layernorm" in prefix + ) + self.is_last_norm = prefix.endswith(".norm") self.split_x = ( self.fd_config.parallel_config.use_sequence_parallel_moe @@ -236,6 +245,11 @@ def forward( norm_out = rms_norm(x, self.weight, self.eps) return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) + # enable trtllm all reduce fusion + elif self.enable_all_reduce_fusion: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps + ) else: norm_out = self.norm_func( x, diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py index e64dc10b76b..ebe992364b1 100644 --- a/fastdeploy/model_executor/layers/quantization/mxfp4.py +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import importlib -import importlib.util import math from enum import Enum from typing import Callable, Optional @@ -25,11 +23,12 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch + from fastdeploy.utils import get_logger from ..moe import FusedMoE @@ -59,10 +58,6 @@ def check_device_capability(num): return False -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None - - def round_up(a, b): return ((a + b - 1) // b) * b diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 20d86fbaaf8..fd288f19fcc 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -127,7 +127,7 @@ def __init__( self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_group = fd_config.parallel_config.tp_group - + self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion self.use_ep = self.expert_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1 @@ -210,7 +210,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None self.o_proj = RowParallelLinear( fd_config, - prefix=f"{prefix}.o_proj", + prefix=f"{prefix}.enable_all_reduce.o_proj", input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, output_size=fd_config.model_config.hidden_size, layer_id=layer_id, @@ -285,7 +285,7 @@ def __init__( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.enable_all_reduce_fusion.input_layernorm", layer_id=layer_id, ) @@ -293,7 +293,7 @@ def __init__( fd_config, hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.post_attention_layernorm", + prefix=f"{prefix}.enable_all_reduce_fusion.post_attention_layernorm", layer_id=layer_id, ) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index a2d9bb3205e..a25213735c8 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -16,6 +16,8 @@ import os import re +import importlib +import importlib.util from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field @@ -614,6 +616,8 @@ def reconstruct_memory(model): paddle.device.cuda.empty_cache() _reload_model(model) +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None def need_memory_reconstruction(fd_config): _need_memory_reconstruction_archs = ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"] diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 5eccb9f92da..e39ba2278b3 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -831,6 +831,12 @@ def parse_args(): default=None, help="Configuration of SpeculativeConfig.", ) + parser.add_argument( + "--enable_flashinfer_allreduce_fusion", + action="store_true", + default=False, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parser.add_argument( "--max_num_batched_tokens", type=int, diff --git a/tests/layers/test_rms_allreduce_fusion.py b/tests/layers/test_rms_allreduce_fusion.py new file mode 100644 index 00000000000..2a5a3153c55 --- /dev/null +++ b/tests/layers/test_rms_allreduce_fusion.py @@ -0,0 +1,249 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import time +import unittest +from unittest.mock import Mock + +import numpy as np +import paddle +import paddle.distributed as dist + + +class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase): + """测试 FlashInfer AllReduce + Residual + RMSNorm 融合算子""" + + @classmethod + def setUpClass(cls): + """设置测试环境""" + if paddle.is_compiled_with_cuda(): + paddle.set_device("gpu") + else: + paddle.set_device("cpu") + dist.init_parallel_env() + + def setUp(self): + """每个测试用例的初始化""" + # 固定随机种子,确保可复现性 + paddle.seed(42) + np.random.seed(42) + + self.dtype = paddle.float32 + self.token_num = 128 + self.hidden_dim = 768 + self.eps = 1e-6 + self.epsilon = 1e-6 + self.max_token_num = 2048 + + # 创建 mock FDConfig + self.fd_config = Mock() + self.fd_config.parallel_config = Mock() + self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size() + self.begin_norm_axis = 1 + + # 性能测试参数 - 增加迭代次数提高稳定性 + self.warmup_iterations = 20 # 增加warmup + self.test_iterations = 200 # 增加测试迭代 + + def tearDown(self): + """清理资源""" + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + + def create_test_tensors(self): + """创建测试用的张量""" + input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + weight = paddle.randn([self.hidden_dim], dtype=self.dtype) + return input_tensor, residual, weight + + def compute_reference_output(self, input_tensor, residual, weight, eps): + """参考实现:手动计算 AllReduce + Residual + RMSNorm""" + # # Step 1: AllReduce (在单卡情况下就是原值) + # allreduce_out = input_tensor.clone() + # 添加 all reduce 算子 + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + # Step 2: Add residual + residual_out = input_tensor + residual + + # Step 3: RMSNorm + variance = residual_out.pow(2).mean(axis=-1, keepdim=True) + norm_out = residual_out * paddle.rsqrt(variance + eps) + norm_out = norm_out * weight + + # dist.all_reduce(residual_out, op=dist.ReduceOp.SUM) + return norm_out, residual_out + + def paddle_rms_fuse(self, input_tensor, residual, weight, eps): + from paddle.incubate.nn.functional import fused_rms_norm + + # 添加 all reduce 算子 + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + out_fused = fused_rms_norm( + input_tensor, + norm_weight=weight, + norm_bias=None, + epsilon=eps, + begin_norm_axis=self.begin_norm_axis, + bias=None, + residual=residual, + ) + + return out_fused[0], out_fused[1] + + def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): + """FlashInfer融合算子""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=eps, + max_token_num=self.max_token_num, + use_oneshot=False, + ) + return norm_out, residual_out + + def benchmark_function(self, func, *args, name="", **kwargs): + """ + 改进的性能基准测试 + - 增加GPU频率稳定等待 + - 使用中位数而非平均值(更稳定) + - 过滤异常值 + """ + # 强制GPU频率稳定 + if paddle.is_compiled_with_cuda(): + for _ in range(5): + paddle.device.cuda.synchronize() + time.sleep(0.01) + + # Warmup - 充分预热 + for _ in range(self.warmup_iterations): + result = func(*args, **kwargs) + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + # 额外等待,确保GPU稳定 + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + time.sleep(0.1) + + # 正式测试 + times = [] + for i in range(self.test_iterations): + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + start = time.perf_counter() + result = func(*args, **kwargs) + + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + end = time.perf_counter() + elapsed = (end - start) * 1000 # 转换为毫秒 + times.append(elapsed) + + times = np.array(times) + + # 使用IQR方法过滤异常值 + q1, q3 = np.percentile(times, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + filtered_times = times[(times >= lower_bound) & (times <= upper_bound)] + + # 如果过滤后数据太少,使用原始数据 + if len(filtered_times) < self.test_iterations * 0.5: + filtered_times = times + + # 统计信息 + avg_time = np.mean(filtered_times) + median_time = np.median(filtered_times) + std_time = np.std(filtered_times) + min_time = np.min(filtered_times) + max_time = np.max(filtered_times) + cv = (std_time / avg_time) * 100 # 变异系数 (%) + + print(f"\n{'='*70}") + print(f"Performance Benchmark: {name}") + print(f"{'='*70}") + print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)") + print(f"Median: {median_time:.4f} ms (most stable metric)") + print(f"Average: {avg_time:.4f} ms") + print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)") + print(f"Min: {min_time:.4f} ms") + print(f"Max: {max_time:.4f} ms") + print(f"{'='*70}\n") + + # 返回中位数(更稳定)和结果 + return median_time, result + + def test_accuracy_fused_vs_reference(self): + """测试融合算子与参考实现的准确性""" + input_tensor, residual, weight = self.create_test_tensors() + reference_output, ref_res = self.compute_reference_output( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + fused_output, paddle_res = self.paddle_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + # 验证结果 + np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_res.numpy(), flashinfer_res.numpy(), rtol=1e-5, atol=1e-5) + + +class TestFlashInferWorkspaceManager(unittest.TestCase): + """测试 FlashInferWorkspaceManager""" + + def setUp(self): + """初始化""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + self.manager = FlashInferWorkspaceManager() + + def test_initialization(self): + """测试初始化状态""" + self.assertIsNone(self.manager.workspace_tensor) + self.assertIsNone(self.manager.ipc_handles) + self.assertIsNone(self.manager.world_size) + self.assertIsNone(self.manager.rank) + self.assertFalse(self.manager.initialized) + + def test_cleanup(self): + """测试清理功能""" + self.manager.cleanup() + self.assertFalse(self.manager.initialized) + self.assertIsNone(self.manager.workspace_tensor) + + +if __name__ == "__main__": + unittest.main(verbosity=2) + + # 多卡运行示例: + # python -m paddle.distributed.launch --gpus=0,1 test_flashinfer.py From 6bfe1d97338d8e73c8ddd3a5fa3e266e3628d78d Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 11:22:55 +0800 Subject: [PATCH 02/18] update flashinfer paddle version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a6a7b6619c9..bed4165b191 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,7 +46,7 @@ setproctitle aistudio_sdk p2pstore py-cpuinfo -flashinfer-python-paddle +flashinfer-python-paddle @ @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl transformers>=4.55.1,<5.0.0 From e9dbd825644ea0324654fe1bd0ac225e6d5811ae Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 11:25:56 +0800 Subject: [PATCH 03/18] format update --- fastdeploy/model_executor/layers/flashinfer_comm_fusion.py | 1 + fastdeploy/model_executor/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py index 4cdca3c89ce..a922e262b4d 100644 --- a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -19,6 +19,7 @@ try: paddle.compat.enable_torch_proxy(scope={"flashinfer"}) import flashinfer.comm as comm + _flashinfer_comm = comm except ImportError: logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index a25213735c8..1a03b1bc71e 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -14,10 +14,10 @@ # limitations under the License. """ -import os -import re import importlib import importlib.util +import os +import re from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field From bb80bd8dbb8f315673e90b36e5f1dd52787da4cf Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 13:33:35 +0800 Subject: [PATCH 04/18] fix a bug --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bed4165b191..78f006dec53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,7 +46,7 @@ setproctitle aistudio_sdk p2pstore py-cpuinfo -flashinfer-python-paddle @ @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl +flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.2-py3-none-any.whl flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl transformers>=4.55.1,<5.0.0 From 756e6d85316e8ff7d444f64bc612a0802d543783 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 16:30:52 +0800 Subject: [PATCH 05/18] modify test --- tests/layers/test_rms_allreduce_fusion.py | 127 +++++++++++++++------- 1 file changed, 89 insertions(+), 38 deletions(-) diff --git a/tests/layers/test_rms_allreduce_fusion.py b/tests/layers/test_rms_allreduce_fusion.py index 2a5a3153c55..c6aed9729a2 100644 --- a/tests/layers/test_rms_allreduce_fusion.py +++ b/tests/layers/test_rms_allreduce_fusion.py @@ -24,11 +24,11 @@ class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase): - """测试 FlashInfer AllReduce + Residual + RMSNorm 融合算子""" + """Test FlashInfer AllReduce + Residual + RMSNorm fused operator""" @classmethod def setUpClass(cls): - """设置测试环境""" + """Set up test environment""" if paddle.is_compiled_with_cuda(): paddle.set_device("gpu") else: @@ -36,8 +36,8 @@ def setUpClass(cls): dist.init_parallel_env() def setUp(self): - """每个测试用例的初始化""" - # 固定随机种子,确保可复现性 + """Initialize each test case""" + # Fix random seed for reproducibility paddle.seed(42) np.random.seed(42) @@ -48,34 +48,34 @@ def setUp(self): self.epsilon = 1e-6 self.max_token_num = 2048 - # 创建 mock FDConfig + # Create mock FDConfig self.fd_config = Mock() self.fd_config.parallel_config = Mock() self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size() self.begin_norm_axis = 1 - # 性能测试参数 - 增加迭代次数提高稳定性 - self.warmup_iterations = 20 # 增加warmup - self.test_iterations = 200 # 增加测试迭代 + # Performance test params - increase iterations for stability + self.warmup_iterations = 20 # Increase warmup + self.test_iterations = 200 # Increase test iterations def tearDown(self): - """清理资源""" + """Clean up resources""" if paddle.is_compiled_with_cuda(): paddle.device.cuda.empty_cache() paddle.device.cuda.synchronize() def create_test_tensors(self): - """创建测试用的张量""" + """Create test tensors""" input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) weight = paddle.randn([self.hidden_dim], dtype=self.dtype) return input_tensor, residual, weight def compute_reference_output(self, input_tensor, residual, weight, eps): - """参考实现:手动计算 AllReduce + Residual + RMSNorm""" - # # Step 1: AllReduce (在单卡情况下就是原值) + """Reference implementation: manually compute AllReduce + Residual + RMSNorm""" + # # Step 1: AllReduce (identity on single device) # allreduce_out = input_tensor.clone() - # 添加 all reduce 算子 + # Apply all reduce operator dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) # Step 2: Add residual residual_out = input_tensor + residual @@ -91,7 +91,7 @@ def compute_reference_output(self, input_tensor, residual, weight, eps): def paddle_rms_fuse(self, input_tensor, residual, weight, eps): from paddle.incubate.nn.functional import fused_rms_norm - # 添加 all reduce 算子 + # Apply all reduce operator dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) out_fused = fused_rms_norm( input_tensor, @@ -106,7 +106,7 @@ def paddle_rms_fuse(self, input_tensor, residual, weight, eps): return out_fused[0], out_fused[1] def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): - """FlashInfer融合算子""" + """FlashInfer fused operator""" from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( flashinfer_allreduce_residual_rmsnorm, ) @@ -124,29 +124,29 @@ def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): def benchmark_function(self, func, *args, name="", **kwargs): """ - 改进的性能基准测试 - - 增加GPU频率稳定等待 - - 使用中位数而非平均值(更稳定) - - 过滤异常值 + Improved performance benchmark + - Wait for GPU frequency stabilization + - Use median instead of mean (more stable) + - Filter outliers """ - # 强制GPU频率稳定 + # Force GPU frequency stabilization if paddle.is_compiled_with_cuda(): for _ in range(5): paddle.device.cuda.synchronize() time.sleep(0.01) - # Warmup - 充分预热 + # Warmup - thorough warm-up for _ in range(self.warmup_iterations): result = func(*args, **kwargs) if paddle.is_compiled_with_cuda(): paddle.device.cuda.synchronize() - # 额外等待,确保GPU稳定 + # Extra wait to ensure GPU stability if paddle.is_compiled_with_cuda(): paddle.device.cuda.synchronize() time.sleep(0.1) - # 正式测试 + # Benchmark run times = [] for i in range(self.test_iterations): if paddle.is_compiled_with_cuda(): @@ -159,29 +159,29 @@ def benchmark_function(self, func, *args, name="", **kwargs): paddle.device.cuda.synchronize() end = time.perf_counter() - elapsed = (end - start) * 1000 # 转换为毫秒 + elapsed = (end - start) * 1000 # Convert to milliseconds times.append(elapsed) times = np.array(times) - # 使用IQR方法过滤异常值 + # Filter outliers using IQR method q1, q3 = np.percentile(times, [25, 75]) iqr = q3 - q1 lower_bound = q1 - 1.5 * iqr upper_bound = q3 + 1.5 * iqr filtered_times = times[(times >= lower_bound) & (times <= upper_bound)] - # 如果过滤后数据太少,使用原始数据 + # Fall back to raw data if too many samples filtered out if len(filtered_times) < self.test_iterations * 0.5: filtered_times = times - # 统计信息 + # Statistics avg_time = np.mean(filtered_times) median_time = np.median(filtered_times) std_time = np.std(filtered_times) min_time = np.min(filtered_times) max_time = np.max(filtered_times) - cv = (std_time / avg_time) * 100 # 变异系数 (%) + cv = (std_time / avg_time) * 100 # Coefficient of variation (%) print(f"\n{'='*70}") print(f"Performance Benchmark: {name}") @@ -194,11 +194,11 @@ def benchmark_function(self, func, *args, name="", **kwargs): print(f"Max: {max_time:.4f} ms") print(f"{'='*70}\n") - # 返回中位数(更稳定)和结果 + # Return median (more stable) and result return median_time, result def test_accuracy_fused_vs_reference(self): - """测试融合算子与参考实现的准确性""" + """Test accuracy of fused operator vs reference implementation""" input_tensor, residual, weight = self.create_test_tensors() reference_output, ref_res = self.compute_reference_output( input_tensor.clone(), residual.clone(), weight.clone(), self.eps @@ -209,7 +209,7 @@ def test_accuracy_fused_vs_reference(self): flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( input_tensor.clone(), residual.clone(), weight.clone(), self.eps ) - # 验证结果 + # Verify results np.testing.assert_allclose(fused_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(ref_res.numpy(), paddle_res.numpy(), rtol=1e-5, atol=1e-5) np.testing.assert_allclose(flashinfer_output.numpy(), reference_output.numpy(), rtol=1e-5, atol=1e-5) @@ -217,10 +217,10 @@ def test_accuracy_fused_vs_reference(self): class TestFlashInferWorkspaceManager(unittest.TestCase): - """测试 FlashInferWorkspaceManager""" + """Test FlashInferWorkspaceManager""" def setUp(self): - """初始化""" + """Initialize""" from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( FlashInferWorkspaceManager, ) @@ -228,7 +228,7 @@ def setUp(self): self.manager = FlashInferWorkspaceManager() def test_initialization(self): - """测试初始化状态""" + """Test initialization state""" self.assertIsNone(self.manager.workspace_tensor) self.assertIsNone(self.manager.ipc_handles) self.assertIsNone(self.manager.world_size) @@ -236,14 +236,65 @@ def test_initialization(self): self.assertFalse(self.manager.initialized) def test_cleanup(self): - """测试清理功能""" + """Test cleanup functionality""" self.manager.cleanup() self.assertFalse(self.manager.initialized) self.assertIsNone(self.manager.workspace_tensor) -if __name__ == "__main__": +def run_tests(): + """Run tests directly (called by subprocess after distributed launch)""" unittest.main(verbosity=2) - # 多卡运行示例: - # python -m paddle.distributed.launch --gpus=0,1 test_flashinfer.py + +def check_gpus(gpu_ids): + """Check whether the specified GPUs are available, raise error if not""" + import paddle + + if not paddle.is_compiled_with_cuda(): + raise RuntimeError("Paddle is not compiled with CUDA support, cannot use GPU") + + available_count = paddle.device.cuda.device_count() + if available_count == 0: + raise RuntimeError("No available GPU detected") + + missing = [gid for gid in gpu_ids if gid >= available_count] + if missing: + raise RuntimeError( + f"Required GPU {missing}, but only {available_count} GPU(s) detected (index 0~{available_count - 1})" + ) + + print(f"GPU check passed: required {gpu_ids}, {available_count} GPU(s) available") + + +def run_distributed(): + """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" + import os + import subprocess + import sys + + gpu_ids = [0, 1] + check_gpus(gpu_ids) + + gpus_str = ",".join(str(g) for g in gpu_ids) + script_path = os.path.abspath(__file__) + cmd = [ + sys.executable, "-m", "paddle.distributed.launch", + f"--gpus={gpus_str}", + script_path, "--run-tests", + ] + print(f"Launching distributed test: {' '.join(cmd)}") + result = subprocess.run(cmd, cwd=os.path.dirname(script_path)) + sys.exit(result.returncode) + + +if __name__ == "__main__": + import sys + + if "--run-tests" in sys.argv: + # Launched by paddle.distributed.launch, run tests directly + sys.argv.remove("--run-tests") + run_tests() + else: + # Default entry: launch distributed test as subprocess + run_distributed() From fe3df9480ff471ff078276889d8515226b5ff87d Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 16:32:04 +0800 Subject: [PATCH 06/18] modify test --- tests/layers/test_rms_allreduce_fusion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_rms_allreduce_fusion.py b/tests/layers/test_rms_allreduce_fusion.py index c6aed9729a2..a0cdc5e268b 100644 --- a/tests/layers/test_rms_allreduce_fusion.py +++ b/tests/layers/test_rms_allreduce_fusion.py @@ -279,9 +279,12 @@ def run_distributed(): gpus_str = ",".join(str(g) for g in gpu_ids) script_path = os.path.abspath(__file__) cmd = [ - sys.executable, "-m", "paddle.distributed.launch", + sys.executable, + "-m", + "paddle.distributed.launch", f"--gpus={gpus_str}", - script_path, "--run-tests", + script_path, + "--run-tests", ] print(f"Launching distributed test: {' '.join(cmd)}") result = subprocess.run(cmd, cwd=os.path.dirname(script_path)) From a46338ce3abdb81e8965cf1dee6c861cd6a08fd1 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 5 Mar 2026 22:03:10 +0800 Subject: [PATCH 07/18] support empty tensor and modify test --- .../layers/flashinfer_comm_fusion.py | 4 +- .../test_trtllm_allreduce_rms_fusion.py | 49 ++++++++++++++++ ...n.py => trtllm_allreduce_rms_fusion.py.py} | 58 +------------------ 3 files changed, 53 insertions(+), 58 deletions(-) create mode 100644 tests/layers/test_trtllm_allreduce_rms_fusion.py rename tests/layers/{test_rms_allreduce_fusion.py => trtllm_allreduce_rms_fusion.py.py} (84%) diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py index a922e262b4d..d6a8c6e1ba8 100644 --- a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -161,7 +161,9 @@ def flashinfer_allreduce_residual_rmsnorm( residual_out = paddle.empty_like(residual) norm_out = paddle.empty_like(input_tensor) - + # support empty tensor + if input_tensor.shape[0] == 0: + return norm_out, residual_out _flashinfer_comm.trtllm_allreduce_fusion( allreduce_in=input_tensor, world_size=world_size, diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..13b078f5f7d --- /dev/null +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -0,0 +1,49 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import subprocess +import sys + + +def test_run_distributed(): + """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" + + current_dir = os.path.dirname(os.path.abspath(__file__)) + run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py") + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + run_script, + ] + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=400) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + assert return_code in (0, 250), f"Process exited with code {return_code}" + + +test_run_distributed() diff --git a/tests/layers/test_rms_allreduce_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py.py similarity index 84% rename from tests/layers/test_rms_allreduce_fusion.py rename to tests/layers/trtllm_allreduce_rms_fusion.py.py index a0cdc5e268b..38ccbf07f28 100644 --- a/tests/layers/test_rms_allreduce_fusion.py +++ b/tests/layers/trtllm_allreduce_rms_fusion.py.py @@ -242,62 +242,6 @@ def test_cleanup(self): self.assertIsNone(self.manager.workspace_tensor) -def run_tests(): +if __name__ == "__main__": """Run tests directly (called by subprocess after distributed launch)""" unittest.main(verbosity=2) - - -def check_gpus(gpu_ids): - """Check whether the specified GPUs are available, raise error if not""" - import paddle - - if not paddle.is_compiled_with_cuda(): - raise RuntimeError("Paddle is not compiled with CUDA support, cannot use GPU") - - available_count = paddle.device.cuda.device_count() - if available_count == 0: - raise RuntimeError("No available GPU detected") - - missing = [gid for gid in gpu_ids if gid >= available_count] - if missing: - raise RuntimeError( - f"Required GPU {missing}, but only {available_count} GPU(s) detected (index 0~{available_count - 1})" - ) - - print(f"GPU check passed: required {gpu_ids}, {available_count} GPU(s) available") - - -def run_distributed(): - """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" - import os - import subprocess - import sys - - gpu_ids = [0, 1] - check_gpus(gpu_ids) - - gpus_str = ",".join(str(g) for g in gpu_ids) - script_path = os.path.abspath(__file__) - cmd = [ - sys.executable, - "-m", - "paddle.distributed.launch", - f"--gpus={gpus_str}", - script_path, - "--run-tests", - ] - print(f"Launching distributed test: {' '.join(cmd)}") - result = subprocess.run(cmd, cwd=os.path.dirname(script_path)) - sys.exit(result.returncode) - - -if __name__ == "__main__": - import sys - - if "--run-tests" in sys.argv: - # Launched by paddle.distributed.launch, run tests directly - sys.argv.remove("--run-tests") - run_tests() - else: - # Default entry: launch distributed test as subprocess - run_distributed() From 1ae1e2799a790c6723cdfe44ac5e32403a1f93a6 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Fri, 6 Mar 2026 10:16:26 +0800 Subject: [PATCH 08/18] fix test_linear config issues From 777edc6293e2498d138f408dea81d6b81c13804b Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Fri, 6 Mar 2026 14:09:34 +0800 Subject: [PATCH 09/18] modify test name --- ..._allreduce_rms_fusion.py.py => trtllm_allreduce_rms_fusion.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/layers/{trtllm_allreduce_rms_fusion.py.py => trtllm_allreduce_rms_fusion.py} (100%) diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py.py b/tests/layers/trtllm_allreduce_rms_fusion.py similarity index 100% rename from tests/layers/trtllm_allreduce_rms_fusion.py.py rename to tests/layers/trtllm_allreduce_rms_fusion.py From a3979f7d5d13e1c5dc2dcfed352c7654fd8ba621 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 10 Mar 2026 10:33:21 +0800 Subject: [PATCH 10/18] add edge test case --- .../test_trtllm_allreduce_rms_fusion.py | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py index 13b078f5f7d..73acb0a707c 100644 --- a/tests/layers/test_trtllm_allreduce_rms_fusion.py +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -17,6 +17,10 @@ import os import subprocess import sys +import unittest +from unittest.mock import Mock, patch + +import paddle def test_run_distributed(): @@ -47,3 +51,352 @@ def test_run_distributed(): test_run_distributed() + + +class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase): + """Test FlashInferWorkspaceManager edge cases and fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + # Patch before importing to test fallback paths + self.patcher_has_flashinfer = patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" + ) + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_initialization_early_return_when_already_initialized(self): + """Test line 47: early return when already initialized with same world_size""" + # Patch _flashinfer_comm to be available + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" + ) as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # First initialization + manager.initialized = True + manager.world_size = 2 + + # Mock the comm functions + mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock( + return_value=(Mock(), Mock()) + ) + + # Second initialization with same world_size - should return early + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + def test_initialization_warning_when_comm_none(self): + """Test lines 50-51: warning when _flashinfer_comm is None""" + # Patch to ensure _flashinfer_comm is None + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm", + None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # Should not raise, just log warning and return + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + # Verify not initialized + self.assertFalse(manager.initialized) + + def test_cleanup_with_exception(self): + """Test lines 73-80: cleanup with exception handling""" + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" + ) as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = True + manager.ipc_handles = Mock() + manager.workspace_tensor = Mock() + + # Mock the destroy function to raise exception + mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock( + side_effect=RuntimeError("Cleanup error") + ) + + # Should not raise, just log warning + manager.cleanup() + + # Verify cleanup happened + self.assertFalse(manager.initialized) + self.assertIsNone(manager.workspace_tensor) + self.assertIsNone(manager.ipc_handles) + + def test_cleanup_without_initialization(self): + """Test cleanup when not initialized""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = False + + # Should not raise + manager.cleanup() + + # Verify state + self.assertFalse(manager.initialized) + + +class TestEnsureWorkspaceInitialized(unittest.TestCase): + """Test ensure_workspace_initialized fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" + ) + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_ensure_workspace_when_flashinfer_not_available(self): + """Test line 91: early return when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False (not initialized) + self.assertFalse(result) + + def test_ensure_workspace_when_comm_none(self): + """Test ensure_workspace_initialized when _flashinfer_comm is None""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm", + None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False + self.assertFalse(result) + + def test_ensure_workspace_single_gpu(self): + """Test line 96: early return when world_size <= 1""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0): + result = ensure_workspace_initialized(fd_config) + + # Should return False for single GPU + self.assertFalse(result) + + +class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase): + """Test flashinfer_allreduce_residual_rmsnorm fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" + ) + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_flashinfer_not_available_fallback(self): + """Test lines 140-141: fallback when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None when flashinfer not available + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_single_gpu_fallback(self): + """Test lines 146-147: fallback for single GPU""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None for single GPU + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_empty_tensor_handling(self): + """Test line 166: empty tensor handling""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" + ) as mock_comm, patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", + return_value=True, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + # Empty tensor (0 tokens) + input_tensor = paddle.zeros([0, 768]) + residual = paddle.zeros([0, 768]) + weight = paddle.randn([768]) + + # Mock the trtllm_allreduce_fusion to not be called + mock_comm.trtllm_allreduce_fusion = Mock() + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return empty tensors, not call flashinfer + self.assertEqual(norm_out.shape[0], 0) + self.assertEqual(residual_out.shape[0], 0) + mock_comm.trtllm_allreduce_fusion.assert_not_called() + + +class TestFakeFlashInferFunction(unittest.TestCase): + """Test fake_flashinfer_allreduce_residual_rmsnorm function""" + + def test_fake_function_basic(self): + """Test lines 204-206: fake function basic functionality""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + fake_flashinfer_allreduce_residual_rmsnorm, + ) + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = fake_flashinfer_allreduce_residual_rmsnorm( + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=16384, + use_oneshot=None, + trigger_completion_at_end=False, + fp32_acc=False, + ) + + # Should return empty-like tensors + self.assertEqual(norm_out.shape, input_tensor.shape) + self.assertEqual(residual_out.shape, residual.shape) + + +class TestCleanupFlashInferWorkspace(unittest.TestCase): + """Test cleanup_flashinfer_workspace function""" + + def test_cleanup_workspace_function(self): + """Test lines 211-212: cleanup function""" + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager" + ) as mock_manager: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + cleanup_flashinfer_workspace, + ) + + mock_manager.cleanup = Mock() + + cleanup_flashinfer_workspace() + + mock_manager.cleanup.assert_called_once() From cbe082d4cf5baa60435a8dbf38cd1044df754d73 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 10 Mar 2026 11:13:55 +0800 Subject: [PATCH 11/18] modify format --- .../test_trtllm_allreduce_rms_fusion.py | 51 ++++++------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py index 73acb0a707c..edf6ca8e710 100644 --- a/tests/layers/test_trtllm_allreduce_rms_fusion.py +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -59,9 +59,7 @@ class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase): def setUp(self): """Initialize test fixtures""" # Patch before importing to test fallback paths - self.patcher_has_flashinfer = patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" - ) + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") self.mock_has_flashinfer = self.patcher_has_flashinfer.start() def tearDown(self): @@ -71,9 +69,7 @@ def tearDown(self): def test_initialization_early_return_when_already_initialized(self): """Test line 47: early return when already initialized with same world_size""" # Patch _flashinfer_comm to be available - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" - ) as mock_comm: + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( FlashInferWorkspaceManager, ) @@ -85,9 +81,7 @@ def test_initialization_early_return_when_already_initialized(self): manager.world_size = 2 # Mock the comm functions - mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock( - return_value=(Mock(), Mock()) - ) + mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock())) # Second initialization with same world_size - should return early manager.initialize( @@ -123,9 +117,7 @@ def test_initialization_warning_when_comm_none(self): def test_cleanup_with_exception(self): """Test lines 73-80: cleanup with exception handling""" - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" - ) as mock_comm: + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( FlashInferWorkspaceManager, ) @@ -136,9 +128,7 @@ def test_cleanup_with_exception(self): manager.workspace_tensor = Mock() # Mock the destroy function to raise exception - mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock( - side_effect=RuntimeError("Cleanup error") - ) + mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error")) # Should not raise, just log warning manager.cleanup() @@ -169,9 +159,7 @@ class TestEnsureWorkspaceInitialized(unittest.TestCase): def setUp(self): """Initialize test fixtures""" - self.patcher_has_flashinfer = patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" - ) + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") self.mock_has_flashinfer = self.patcher_has_flashinfer.start() def tearDown(self): @@ -220,9 +208,7 @@ def test_ensure_workspace_single_gpu(self): """Test line 96: early return when world_size <= 1""" self.mock_has_flashinfer.return_value = True - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" - ): + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( ensure_workspace_initialized, ) @@ -243,9 +229,7 @@ class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase): def setUp(self): """Initialize test fixtures""" - self.patcher_has_flashinfer = patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer" - ) + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") self.mock_has_flashinfer = self.patcher_has_flashinfer.start() def tearDown(self): @@ -285,9 +269,7 @@ def test_single_gpu_fallback(self): """Test lines 146-147: fallback for single GPU""" self.mock_has_flashinfer.return_value = True - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" - ): + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( flashinfer_allreduce_residual_rmsnorm, ) @@ -317,11 +299,12 @@ def test_empty_tensor_handling(self): """Test line 166: empty tensor handling""" self.mock_has_flashinfer.return_value = True - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm" - ) as mock_comm, patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", - return_value=True, + with ( + patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm, + patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", + return_value=True, + ), ): from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( flashinfer_allreduce_residual_rmsnorm, @@ -388,9 +371,7 @@ class TestCleanupFlashInferWorkspace(unittest.TestCase): def test_cleanup_workspace_function(self): """Test lines 211-212: cleanup function""" - with patch( - "fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager" - ) as mock_manager: + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager: from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( cleanup_flashinfer_workspace, ) From 4e52802b7c4730bc5a5e691968202c11ebbc3a60 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 10 Mar 2026 14:08:33 +0800 Subject: [PATCH 12/18] fix conflict --- fastdeploy/model_executor/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 1a03b1bc71e..8543699226b 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -555,6 +555,10 @@ def fn(loaded_weight_name, is_moe): return fn +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + @cache def get_sm_version(): if paddle.cuda.is_available(): @@ -616,8 +620,6 @@ def reconstruct_memory(model): paddle.device.cuda.empty_cache() _reload_model(model) -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None def need_memory_reconstruction(fd_config): _need_memory_reconstruction_archs = ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"] @@ -628,3 +630,6 @@ def need_memory_reconstruction(fd_config): return True else: return False + +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None From 2d18ba6fe889e67c19e0f70df4038d7e0bc51453 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 11 Mar 2026 20:43:34 +0800 Subject: [PATCH 13/18] modify default max token num in trtllm_allreduce_fusion --- fastdeploy/model_executor/layers/flashinfer_comm_fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py index d6a8c6e1ba8..b9b84938416 100644 --- a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -115,7 +115,7 @@ def flashinfer_allreduce_residual_rmsnorm( residual: paddle.Tensor, weight: paddle.Tensor, eps: float = 1e-6, - max_token_num: int = 4096, + max_token_num: int = 2048, use_oneshot: Optional[bool] = None, trigger_completion_at_end: bool = False, fp32_acc: bool = False, From 9b01b07c15425cb1e9b58abca5cbec9e69bde730 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Wed, 11 Mar 2026 21:38:31 +0800 Subject: [PATCH 14/18] add max token num branch for trtllm_allreduce_fusion --- fastdeploy/model_executor/layers/linear.py | 5 ++++- fastdeploy/model_executor/layers/normalization.py | 2 +- fastdeploy/model_executor/models/glm4_moe.py | 1 - 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 5f9e04d3fee..67888891511 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -937,7 +937,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: out = self.quant_method.apply(self, x) - if self.reduce_results and self.tp_size > 1 and not self.enable_all_reduce_fusion: + need_tp_all_reduce = ( + self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) + ) + if need_tp_all_reduce: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 73d2139335e..f9b4d42c7d5 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -246,7 +246,7 @@ def forward( return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) # enable trtllm all reduce fusion - elif self.enable_all_reduce_fusion: + elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: norm_out = flashinfer_allreduce_residual_rmsnorm( fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps ) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index fd288f19fcc..0995cde9407 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -127,7 +127,6 @@ def __init__( self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_group = fd_config.parallel_config.tp_group - self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion self.use_ep = self.expert_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1 From 9e84a16bba5253adda8f36906876236c7af87679 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 24 Mar 2026 11:26:25 +0800 Subject: [PATCH 15/18] fix format --- fastdeploy/model_executor/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 8543699226b..211467df57d 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -630,6 +630,3 @@ def need_memory_reconstruction(fd_config): return True else: return False - -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None From f825e4b059ef58be971887986e89d70f99088b7b Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Thu, 26 Mar 2026 15:35:58 +0800 Subject: [PATCH 16/18] fix rmsnorm config issue --- .../test_rmsnorm_layer_batch_invariant.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py diff --git a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py new file mode 100644 index 00000000000..54ce40ca5d4 --- /dev/null +++ b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py @@ -0,0 +1,102 @@ +"""Test RMSNorm layer's batch_invariant_mode forward path (normalization.py:244-248). + +This covers the integration between the RMSNorm *layer* and the Triton +rms_norm_batch_invariant kernel when batch_invariant_mode is enabled. +We bypass RMSNorm.__init__ (heavy FDConfig dependency) and set only +the attributes needed by forward(). +""" + +import unittest + +import paddle + +from fastdeploy.model_executor.layers.batch_invariant_ops import ( + rms_norm_batch_invariant, + set_batch_invariant_mode, +) +from fastdeploy.model_executor.layers.normalization import RMSNorm + + +def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"): + """Create a minimal RMSNorm without FDConfig by bypassing __init__.""" + layer = object.__new__(RMSNorm) + paddle.nn.Layer.__init__(layer) + # Attributes used by forward() + layer.weight = paddle.create_parameter( + shape=[hidden_size], + dtype=dtype, + default_initializer=paddle.nn.initializer.Constant(value=1.0), + ) + layer.eps = eps + layer.bias = None + layer.split_x = False + layer.allgather_out = False + layer.enable_all_reduce_fusion = False + return layer + + +class TestRMSNormBatchInvariantPath(unittest.TestCase): + """Test RMSNorm.forward with batch_invariant_mode enabled.""" + + def setUp(self): + paddle.set_device("gpu") + + def test_no_residual(self): + """batch_invariant path without residual_input.""" + D = 1024 + layer = _make_minimal_rmsnorm(D, dtype="float32") + paddle.seed(42) + x = paddle.randn([16, D], dtype="float32") + + with set_batch_invariant_mode(True): + out, residual_out = layer.forward(x, residual_input=None) + + # residual_out should be x itself (line 236: residual_out = x) + expected_norm = rms_norm_batch_invariant(x, layer.weight, layer.eps) + paddle.device.synchronize() + self.assertEqual(out.shape, [16, D]) + diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() + self.assertEqual(diff, 0.0, f"Output mismatch: diff={diff}") + + def test_with_residual(self): + """batch_invariant path with residual_input (covers lines 246-248).""" + D = 1024 + layer = _make_minimal_rmsnorm(D, dtype="float32") + paddle.seed(42) + x = paddle.randn([16, D], dtype="float32") + residual = paddle.randn([16, D], dtype="float32") + + with set_batch_invariant_mode(True): + out, residual_out = layer.forward(x, residual_input=residual) + + # Expected: x + residual -> rms_norm_batch_invariant, residual_out = x + residual + fused_x = x + residual + expected_norm = rms_norm_batch_invariant(fused_x, layer.weight, layer.eps) + paddle.device.synchronize() + + norm_diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() + res_diff = (residual_out.astype("float32") - fused_x.astype("float32")).abs().max().item() + self.assertEqual(norm_diff, 0.0, f"Norm output mismatch: diff={norm_diff}") + self.assertEqual(res_diff, 0.0, f"Residual output mismatch: diff={res_diff}") + + def test_bfloat16(self): + """batch_invariant path with bfloat16 input.""" + D = 3584 + layer = _make_minimal_rmsnorm(D, dtype="bfloat16") + paddle.seed(0) + x = paddle.randn([32, D], dtype="bfloat16") + residual = paddle.randn([32, D], dtype="bfloat16") + + with set_batch_invariant_mode(True): + out, residual_out = layer.forward(x, residual_input=residual) + + fused_x = x + residual + expected_norm = rms_norm_batch_invariant(fused_x, layer.weight, layer.eps) + paddle.device.synchronize() + + norm_diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() + self.assertEqual(norm_diff, 0.0, f"bf16 norm output mismatch: diff={norm_diff}") + + +if __name__ == "__main__": + unittest.main() From 6b1554ff263b9f5e378dcb9b83bb2c11854c6004 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Mon, 30 Mar 2026 00:51:53 +0800 Subject: [PATCH 17/18] modify 2025 to 2026 --- tests/layers/test_trtllm_allreduce_rms_fusion.py | 2 +- tests/layers/trtllm_allreduce_rms_fusion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py index edf6ca8e710..6699038d5fd 100644 --- a/tests/layers/test_trtllm_allreduce_rms_fusion.py +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -1,5 +1,5 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py index 38ccbf07f28..5770900df95 100644 --- a/tests/layers/trtllm_allreduce_rms_fusion.py +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -1,5 +1,5 @@ """ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. From 5b2d2f8eac3e801a27731eb9578d2ee485cb1921 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Tue, 7 Apr 2026 20:45:10 +0800 Subject: [PATCH 18/18] del redundent file --- .../test_rmsnorm_layer_batch_invariant.py | 102 ------------------ 1 file changed, 102 deletions(-) delete mode 100644 tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py diff --git a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py deleted file mode 100644 index 54ce40ca5d4..00000000000 --- a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Test RMSNorm layer's batch_invariant_mode forward path (normalization.py:244-248). - -This covers the integration between the RMSNorm *layer* and the Triton -rms_norm_batch_invariant kernel when batch_invariant_mode is enabled. -We bypass RMSNorm.__init__ (heavy FDConfig dependency) and set only -the attributes needed by forward(). -""" - -import unittest - -import paddle - -from fastdeploy.model_executor.layers.batch_invariant_ops import ( - rms_norm_batch_invariant, - set_batch_invariant_mode, -) -from fastdeploy.model_executor.layers.normalization import RMSNorm - - -def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"): - """Create a minimal RMSNorm without FDConfig by bypassing __init__.""" - layer = object.__new__(RMSNorm) - paddle.nn.Layer.__init__(layer) - # Attributes used by forward() - layer.weight = paddle.create_parameter( - shape=[hidden_size], - dtype=dtype, - default_initializer=paddle.nn.initializer.Constant(value=1.0), - ) - layer.eps = eps - layer.bias = None - layer.split_x = False - layer.allgather_out = False - layer.enable_all_reduce_fusion = False - return layer - - -class TestRMSNormBatchInvariantPath(unittest.TestCase): - """Test RMSNorm.forward with batch_invariant_mode enabled.""" - - def setUp(self): - paddle.set_device("gpu") - - def test_no_residual(self): - """batch_invariant path without residual_input.""" - D = 1024 - layer = _make_minimal_rmsnorm(D, dtype="float32") - paddle.seed(42) - x = paddle.randn([16, D], dtype="float32") - - with set_batch_invariant_mode(True): - out, residual_out = layer.forward(x, residual_input=None) - - # residual_out should be x itself (line 236: residual_out = x) - expected_norm = rms_norm_batch_invariant(x, layer.weight, layer.eps) - paddle.device.synchronize() - self.assertEqual(out.shape, [16, D]) - diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() - self.assertEqual(diff, 0.0, f"Output mismatch: diff={diff}") - - def test_with_residual(self): - """batch_invariant path with residual_input (covers lines 246-248).""" - D = 1024 - layer = _make_minimal_rmsnorm(D, dtype="float32") - paddle.seed(42) - x = paddle.randn([16, D], dtype="float32") - residual = paddle.randn([16, D], dtype="float32") - - with set_batch_invariant_mode(True): - out, residual_out = layer.forward(x, residual_input=residual) - - # Expected: x + residual -> rms_norm_batch_invariant, residual_out = x + residual - fused_x = x + residual - expected_norm = rms_norm_batch_invariant(fused_x, layer.weight, layer.eps) - paddle.device.synchronize() - - norm_diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() - res_diff = (residual_out.astype("float32") - fused_x.astype("float32")).abs().max().item() - self.assertEqual(norm_diff, 0.0, f"Norm output mismatch: diff={norm_diff}") - self.assertEqual(res_diff, 0.0, f"Residual output mismatch: diff={res_diff}") - - def test_bfloat16(self): - """batch_invariant path with bfloat16 input.""" - D = 3584 - layer = _make_minimal_rmsnorm(D, dtype="bfloat16") - paddle.seed(0) - x = paddle.randn([32, D], dtype="bfloat16") - residual = paddle.randn([32, D], dtype="bfloat16") - - with set_batch_invariant_mode(True): - out, residual_out = layer.forward(x, residual_input=residual) - - fused_x = x + residual - expected_norm = rms_norm_batch_invariant(fused_x, layer.weight, layer.eps) - paddle.device.synchronize() - - norm_diff = (out.astype("float32") - expected_norm.astype("float32")).abs().max().item() - self.assertEqual(norm_diff, 0.0, f"bf16 norm output mismatch: diff={norm_diff}") - - -if __name__ == "__main__": - unittest.main()