From 91b7c290bf3c57c79fde9055666fe3ea9c821c9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Wed, 25 Mar 2026 12:54:33 +0800 Subject: [PATCH 1/4] feat: fp8 dsa support --- .../common/basemodel/attention/__init__.py | 1 + .../basemodel/attention/create_utils.py | 4 + .../basemodel/attention/nsa/__init__.py | 8 + .../common/basemodel/attention/nsa/fp8.py | 187 +++++++++++ .../common/kv_cache_mem_manager/__init__.py | 2 + .../deepseek3_2_dsa_fp8_mem_manager.py | 300 ++++++++++++++++++ .../deepseek3_2mem_manager.py | 3 + .../common/kv_cache_mem_manager/mem_utils.py | 14 +- .../layer_infer/transformer_layer_infer.py | 11 +- .../destindex_copy_kv_flashmla_fp8.py | 163 ++++++++++ lightllm/server/api_cli.py | 4 +- lightllm/server/core/objs/start_args_type.py | 4 +- lightllm/utils/flashmla_utils.py | 33 ++ test/acc/test_deepseekv32_fp8kv_dsa.sh | 13 + .../kernel/benchmark_deepseekv32_fp8kv_dsa.py | 105 ++++++ ...k_deepseekv32_sparse_decode_fp8_vs_bf16.py | 160 ++++++++++ .../test_flashmla_fp8_sparse_decode.py | 80 +++++ .../test_destindex_copy_kv_flashmla_fp8.py | 53 ++++ 18 files changed, 1136 insertions(+), 9 deletions(-) create mode 100644 lightllm/common/basemodel/attention/nsa/fp8.py create mode 100644 lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py create mode 100644 lightllm/utils/flashmla_utils.py create mode 100644 test/acc/test_deepseekv32_fp8kv_dsa.sh create mode 100644 test/kernel/benchmark_deepseekv32_fp8kv_dsa.py create mode 100644 test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py create mode 100644 unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py create mode 100644 unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 0eea52cc89..9ddac8b18d 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -12,6 +12,7 @@ # NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8 import NsaFlashMlaFp8AttBackend from .create_utils import ( get_prefill_att_backend_class, diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 3ba16e2189..aeb7b5dced 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -15,6 +15,7 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8 import NsaFlashMlaFp8AttBackend logger = init_logger(__name__) @@ -56,6 +57,9 @@ "flashmla_sparse": NsaFlashMlaSparseAttBackend, # Future backends: "fa3", "tilelang", "aiter" }, + "fp8kv_dsa": { + "flashmla_sparse": NsaFlashMlaFp8AttBackend, + }, } diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py index 11a1ebfdcd..6c154e5544 100644 --- a/lightllm/common/basemodel/attention/nsa/__init__.py +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -5,9 +5,17 @@ NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState, ) +from .fp8 import ( + NsaFlashMlaFp8AttBackend, + NsaFlashMlaFp8PrefillAttState, + NsaFlashMlaFp8DecodeAttState, +) __all__ = [ "NsaFlashMlaSparseAttBackend", "NsaFlashMlaSparsePrefillAttState", "NsaFlashMlaSparseDecodeAttState", + "NsaFlashMlaFp8AttBackend", + "NsaFlashMlaFp8PrefillAttState", + "NsaFlashMlaFp8DecodeAttState", ] diff --git a/lightllm/common/basemodel/attention/nsa/fp8.py b/lightllm/common/basemodel/attention/nsa/fp8.py new file mode 100644 index 0000000000..cac7e7eb14 --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/fp8.py @@ -0,0 +1,187 @@ +import dataclasses +from typing import TYPE_CHECKING, Tuple + +import torch + +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.flashmla_utils import import_flash_mla + +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaFp8AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8PrefillAttState": + return NsaFlashMlaFp8PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8DecodeAttState": + return NsaFlashMlaFp8DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaFp8PrefillAttState(BasePrefillAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaFp8AttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + return self._nsa_prefill_att(q=q, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + flash_mla = import_flash_mla() + + nsa_dict = att_control.nsa_prefill_dict + layer_index = nsa_dict["layer_index"] + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index) + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32) + if topk_length.ndim == 2 and topk_length.shape[1] == 1: + topk_length = topk_length[:, 0].contiguous() + + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + q=q.contiguous(), + kv=kv.contiguous(), + indices=topk_indices.contiguous(), + sm_scale=softmax_scale, + d_v=kv_lora_rank, + topk_length=topk_length, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaFp8DecodeAttState(BaseDecodeAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + flashmla_sched_meta: object = None + + def init_state(self): + self.backend: NsaFlashMlaFp8AttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + flash_mla = import_flash_mla() + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + return + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + flash_mla = import_flash_mla() + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + + q_nope, q_rope = q + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + + o_tensor, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv.contiguous(), + block_table=None, + cache_seqlens=None, + head_dim_v=kv_lora_rank, + tile_scheduler_metadata=self.flashmla_sched_meta, + num_splits=None, + softmax_scale=softmax_scale, + causal=False, + is_fp8_kvcache=True, + indices=topk_indices.contiguous(), + ) + return o_tensor[:, 0, :, :] diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index bfccc8b482..7421b5c73e 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -3,6 +3,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager +from .deepseek3_2_dsa_fp8_mem_manager import Deepseek3_2DSAFP8MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -13,6 +14,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", + "Deepseek3_2DSAFP8MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py new file mode 100644 index 0000000000..aa6e9bcd89 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py @@ -0,0 +1,300 @@ +import torch +from typing import Any, List + +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.common.kv_trans_kernel.kv_trans import kv_trans +from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io + +from .deepseek2_mem_manager import Deepseek2MemoryManager + + +class Deepseek3_2DSAFP8MemoryManager(Deepseek2MemoryManager): + flashmla_bytes_per_token = 656 + indexer_bytes_per_token = 132 + kv_head_dim = 576 + + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" + self.prefill_dtype = dtype + super().__init__( + size, torch.uint8, head_num, self.flashmla_bytes_per_token, layer_num, always_copy, mem_fraction + ) + + def get_cell_size(self): + prefill_bytes = self.kv_head_dim * torch._utils._element_size(self.prefill_dtype) + return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token + prefill_bytes) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, size + 1, head_num, self.flashmla_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.indexer_k_buffer = torch.empty( + (layer_num, size + 1, head_num, self.indexer_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.prefill_kv_buffer = torch.empty( + (layer_num, size + 1, head_num, self.kv_head_dim), dtype=self.prefill_dtype, device="cuda" + ) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( + destindex_copy_kv_flashmla_fp8, + ) + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" + + o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) + o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) + o_rope = self.kv_buffer[layer_index][:, :, 528:].view(torch.bfloat16) + destindex_copy_kv_flashmla_fp8( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + o_nope, + o_scale, + o_rope, + ) + self.prefill_kv_buffer[layer_index, mem_index, :, :] = kv + + def get_att_input_params(self, layer_index: int) -> Any: + return self.get_flashmla_kv_cache(layer_index) + + def get_flashmla_kv_cache(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index].view(-1, 1, 1, self.flashmla_bytes_per_token) + + def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: + return self.prefill_kv_buffer[layer_index] + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + return self.indexer_k_buffer[layer_index] + + def alloc_kv_move_buffer(self, max_req_total_len): + self.kv_move_buffer = torch.empty( + (1, max_req_total_len + 8, self.head_num, self.flashmla_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.indexer_k_move_buffer = torch.empty( + (1, max_req_total_len + 8, self.head_num, self.indexer_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") + self.token_dim_size = self.flashmla_bytes_per_token + self.indexer_token_dim_size = self.indexer_bytes_per_token + return + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.flashmla_bytes_per_token), + dtype=torch.uint8, + device="cuda", + ) + self.indexer_k_paged_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.indexer_bytes_per_token), + dtype=torch.uint8, + device="cuda", + ) + self._buffer_mem_indexes_tensors = [ + torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) + ] + return self.kv_move_buffer + + def write_mem_to_page_kv_move_buffer( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_world_size: int, + ): + cur_page = self.kv_move_buffer[page_index] + cur_indexer_page = self.indexer_k_paged_move_buffer[page_index] + pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] + pin_mem_indexes.numpy()[:] = mem_indexes + mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + mla_page_io(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, kv_buffer=dp_mems[0].kv_buffer, mode="write") + mla_page_io( + mem_indexes=mem_indexes_gpu, + page_tensor=cur_indexer_page, + kv_buffer=dp_mems[0].indexer_k_buffer, + mode="write", + ) + + def read_page_kv_move_buffer_to_mem( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_world_size: int, + ): + cur_page = self.kv_move_buffer[page_index] + cur_indexer_page = self.indexer_k_paged_move_buffer[page_index] + pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] + pin_mem_indexes.numpy()[:] = mem_indexes + mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + for mem in dp_mems: + mla_page_io(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, kv_buffer=mem.kv_buffer, mode="read") + mla_page_io( + mem_indexes=mem_indexes_gpu, + page_tensor=cur_indexer_page, + kv_buffer=mem.indexer_k_buffer, + mode="read", + ) + + def send_to_decode_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) + cur_device_index = self.kv_buffer.get_device() + cur_mem = mem_managers[cur_device_index] + for layer_index in range(cur_mem.layer_num): + nccl_comm.send(self._get_main_move_data(move_token_indexes, layer_index), dst=1) + nccl_comm.send(self._get_indexer_move_data(move_token_indexes, layer_index), dst=1) + + def receive_from_prefill_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) + + cur_device_index = self.kv_buffer.get_device() + token_num = len(move_token_indexes) + main_buffer = self.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view( + 1, token_num, self.head_num, self.flashmla_bytes_per_token + ) + indexer_buffer = self.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + 1, token_num, self.head_num, self.indexer_bytes_per_token + ) + for i, mem in enumerate(mem_managers): + for layer_index in range(mem.layer_num): + nccl_comm.recv(main_buffer, src=0) + nccl_comm.recv(indexer_buffer, src=0) + if i == cur_device_index: + mem._write_main_move_data(move_token_indexes, main_buffer, layer_index) + mem._write_indexer_move_data(move_token_indexes, indexer_buffer, layer_index) + else: + new_main = mem.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view(main_buffer.shape) + new_indexer = mem.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + indexer_buffer.shape + ) + from torch.cuda import comm + + comm.broadcast(main_buffer, out=[new_main]) + comm.broadcast(indexer_buffer, out=[new_indexer]) + mem._write_main_move_data(move_token_indexes, new_main, layer_index) + mem._write_indexer_move_data(move_token_indexes, new_indexer, layer_index) + + def send_to_decode_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + for layer_index in range(self.layer_num): + nccl_comm.send(self._get_main_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer), dst=1) + nccl_comm.send( + self._get_indexer_move_data_p2p(move_token_indexes, layer_index, self.indexer_k_move_buffer), dst=1 + ) + + def receive_from_prefill_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + token_num = len(move_token_indexes) + main_buffer = self.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view( + token_num, self.head_num, self.flashmla_bytes_per_token + ) + indexer_buffer = self.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + token_num, self.head_num, self.indexer_bytes_per_token + ) + for mem in mem_managers: + for layer_index in range(mem.layer_num): + nccl_comm.recv(main_buffer, src=0) + nccl_comm.recv(indexer_buffer, src=0) + mem._write_main_move_data_p2p(move_token_indexes, main_buffer, layer_index) + mem._write_indexer_move_data_p2p(move_token_indexes, indexer_buffer, layer_index) + + def _get_main_move_data(self, token_indexes: List[int], layer_index: int): + move_size = self.token_dim_size * len(token_indexes) + move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( + 1, len(token_indexes), self.head_num, self.flashmla_bytes_per_token + ) + move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] + return move_buffer + + def _get_indexer_move_data(self, token_indexes: List[int], layer_index: int): + move_size = self.indexer_token_dim_size * len(token_indexes) + move_buffer = self.indexer_k_move_buffer.view(-1)[0:move_size].view( + 1, len(token_indexes), self.head_num, self.indexer_bytes_per_token + ) + move_buffer[:, :, :, :] = self.indexer_k_buffer[layer_index, token_indexes, :, :] + return move_buffer + + def _write_main_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor + + def _write_indexer_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + self.indexer_k_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor + + def _get_main_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): + move_token_num = len(token_indexes) + move_size = self.token_dim_size * move_token_num + move_buffer = kv_move_buffer.view(-1)[0:move_size].view( + move_token_num, self.head_num, self.flashmla_bytes_per_token + ) + kv_trans(self.kv_buffer[layer_index], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]) + return move_buffer + + def _get_indexer_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): + move_token_num = len(token_indexes) + move_size = self.indexer_token_dim_size * move_token_num + move_buffer = kv_move_buffer.view(-1)[0:move_size].view( + move_token_num, self.head_num, self.indexer_bytes_per_token + ) + kv_trans( + self.indexer_k_buffer[layer_index], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] + ) + return move_buffer + + def _write_main_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + move_token_num = len(token_indexes) + kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) + + def _write_indexer_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + move_token_num = len(token_indexes) + kv_trans( + buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.indexer_k_buffer[layer_index], token_indexes + ) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index fbf9f88c84..66f37a16f1 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -34,3 +34,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: def get_att_input_params(self, layer_index: int) -> Any: kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] return kv + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index].view(dtype=torch.uint8)[:, :, -132:] diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 36ca8646a2..d64c645ffb 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -24,7 +24,12 @@ def select_mem_manager_class(): from lightllm.models import Deepseek3_2TpPartModel if issubclass(model_class, Deepseek3_2TpPartModel): - mem_class = Deepseek3_2MemoryManager + if get_env_start_args().llm_kv_type == "fp8kv_dsa": + from . import Deepseek3_2DSAFP8MemoryManager + + mem_class = Deepseek3_2DSAFP8MemoryManager + else: + mem_class = Deepseek3_2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class @@ -55,4 +60,9 @@ def select_mem_manager_class(): @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, FP8StaticPerHeadQuantMemManager, FP8StaticPerTensorQuantMemManager] + return mem_class in [ + PPLINT8KVMemoryManager, + PPLINT4KVMemoryManager, + FP8StaticPerHeadQuantMemManager, + FP8StaticPerTensorQuantMemManager, + ] diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b006120179..06631a4286 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,12 +1,11 @@ import torch -from typing import Union +from typing import Any from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl -from lightllm.common.basemodel.attention.nsa import NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks @@ -89,6 +88,7 @@ def _context_attention_kernel( att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ + "layer_index": self.layer_num_, "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, @@ -129,6 +129,7 @@ def _token_attention_kernel( att_control = AttControl( nsa_decode=True, nsa_decode_dict={ + "layer_index": self.layer_num_, "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, @@ -168,7 +169,7 @@ def get_indices( hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek2InferStateInfo, - att_state: Union[NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState], + att_state: Any, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: @@ -195,7 +196,7 @@ def get_indices( K_fp8=k_fp8, K_scale=k_scale, DestLoc=infer_state.mem_index, - O_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + O_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -211,7 +212,7 @@ def get_indices( mtp_step = get_env_start_args().mtp_step # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - I_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + I_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), b_seq_len=infer_state.b_seq_len, b_req_idx=infer_state.b_req_idx, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py new file mode 100644 index 0000000000..9e90050e11 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py @@ -0,0 +1,163 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _quant_scale(max_nope, fp8_max): + return tl.exp2(tl.ceil(tl.log2(tl.maximum(max_nope / fp8_max, 1e-4)))) + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_flashmla_fp8( + KV_nope, + KV_rope, + Dest_loc, + O_nope, + O_scale, + O_rope, + stride_kv_nope_bs, + stride_kv_nope_h, + stride_kv_nope_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_o_nope_bs, + stride_o_nope_h, + stride_o_nope_d, + stride_o_scale_bs, + stride_o_scale_h, + stride_o_scale_d, + stride_o_rope_bs, + stride_o_rope_h, + stride_o_rope_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + BLOCK_DMODEL_NOPE: tl.constexpr, + BLOCK_DMODEL_ROPE: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + offs_rope = tl.arange(0, BLOCK_DMODEL_ROPE) + + # This kernel is only used by the DeepSeek-V3.2 DSA FP8 path, which + # stores a single MQA-style KV head per token. Keep all accesses 1-D so + # Triton treats per-tile scales as scalars instead of 1-element blocks. + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_rope + + kv_rope = tl.load(kv_rope_ptrs) + + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_rope + tl.store(o_rope_ptrs, kv_rope) + + num_tiles = BLOCK_DMODEL_NOPE // GROUP_SIZE + for tile_idx in range(0, num_tiles): + offs_tile = tile_idx * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + kv_nope_tile_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_tile + kv_nope_tile = tl.load(kv_nope_tile_ptrs) + max_nope = tl.max(tl.abs(kv_nope_tile), axis=0) + kv_scale = _quant_scale(max_nope, FP8_MAX) + kv_nope_fp8 = tl.clamp(kv_nope_tile / kv_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + o_nope_ptrs = ( + O_nope + + dest_index * stride_o_nope_bs + + (tile_idx * GROUP_SIZE) * stride_o_nope_d + + tl.arange(0, GROUP_SIZE) * stride_o_nope_d + ) + tl.store(o_nope_ptrs, kv_nope_fp8) + + o_scale_ptrs = O_scale + dest_index * stride_o_scale_bs + tile_idx * stride_o_scale_d + tl.store(o_scale_ptrs, kv_scale.to(tl.float32)) + return + + +@torch.no_grad() +def destindex_copy_kv_flashmla_fp8( + KV_nope: torch.Tensor, + KV_rope: torch.Tensor, + DestLoc: torch.Tensor, + O_nope: torch.Tensor, + O_scale: torch.Tensor, + O_rope: torch.Tensor, +): + seq_len = DestLoc.shape[0] + kv_nope_head_dim = KV_nope.shape[2] + kv_rope_head_dim = KV_rope.shape[2] + + assert kv_nope_head_dim == 512, f"Expected kv_nope_head_dim=512, got {kv_nope_head_dim}" + assert kv_rope_head_dim == 64, f"Expected kv_rope_head_dim=64, got {kv_rope_head_dim}" + assert O_nope.shape[2] == 512 + assert O_scale.shape[2] == 4 + assert O_rope.shape[2] == 64 + + _fwd_kernel_destindex_copy_kv_flashmla_fp8[(seq_len,)]( + KV_nope, + KV_rope, + DestLoc, + O_nope, + O_scale, + O_rope, + KV_nope.stride(0), + KV_nope.stride(1), + KV_nope.stride(2), + KV_rope.stride(0), + KV_rope.stride(1), + KV_rope.stride(2), + O_nope.stride(0), + O_nope.stride(1), + O_nope.stride(2), + O_scale.stride(0), + O_scale.stride(1), + O_scale.stride(2), + O_rope.stride(0), + O_rope.stride(1), + O_rope.stride(2), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + BLOCK_DMODEL_NOPE=512, + BLOCK_DMODEL_ROPE=64, + GROUP_SIZE=128, + num_warps=4, + num_stages=1, + ) + return + + +def pack_kv_reference(kv: torch.Tensor) -> torch.Tensor: + assert kv.shape[-1] == 576 + kv = kv.view(-1, 1, 576) + kv_nope = kv[:, :, :512].contiguous().view(-1, 512) + kv_rope = kv[:, :, 512:].contiguous().view(-1, 64) + out = torch.empty((kv.shape[0], 1, 656), dtype=torch.uint8, device=kv.device) + out_nope = out[:, :, :512].view(torch.float8_e4m3fn).view(-1, 1, 512) + out_scale = out[:, :, 512:528].view(torch.float32).view(-1, 1, 4) + out_rope = out[:, :, 528:].view(torch.bfloat16).view(-1, 1, 64) + out_rope.copy_(kv_rope.view(-1, 1, 64)) + for tile_idx in range(4): + start = tile_idx * 128 + end = start + 128 + tile = kv_nope[:, start:end] + scale = torch.pow(2, torch.clamp_min(tile.abs().amax(dim=-1).float() / 448.0, 1e-4).log2().ceil()) + out_scale[:, 0, tile_idx] = scale + out_nope[:, 0, start:end] = (tile.float() / scale[:, None]).to(torch.float8_e4m3fn) + return out + + +def dequantize_kv_reference(packed: torch.Tensor) -> torch.Tensor: + packed = packed.view(-1, 1, 656) + out = torch.empty((packed.shape[0], 1, 576), dtype=torch.bfloat16, device=packed.device) + packed_nope = packed[:, :, :512].view(torch.float8_e4m3fn).view(-1, 1, 512) + packed_scale = packed[:, :, 512:528].view(torch.float32).view(-1, 1, 4) + packed_rope = packed[:, :, 528:].view(torch.bfloat16).view(-1, 1, 64) + out[:, :, 512:].copy_(packed_rope) + for tile_idx in range(4): + start = tile_idx * 128 + end = start + 128 + out[:, :, start:end] = ( + packed_nope[:, :, start:end].to(torch.float32) * packed_scale[:, :, tile_idx : tile_idx + 1] + ) + return out diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d32da8097c..ad1826757f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -363,13 +363,15 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--llm_kv_type", type=str, - choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"], + choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"], default="None", help="""kv type used in llm, None for dtype that llm used in config.json. fp8kv_sph: use float8_e4m3fn to store kv cache for inference, quant way is static per head kv quant. fp8kv_spt: use float8_e4m3fn to store kv cache for inference, quant way is static per tensor kv quant. + fp8kv_dsa: use DeepSeek-V3.2 DSA-specific FlashMLA FP8 sparse KV cache, + intended for the deepseek_v32 model path. fp8kv_sph and fp8kv_spt requires --kv_quant_calibration_config_path to load pre-computed FP8 scales. Note: fp8kv_spt requires flashinfer-python>=0.6.5 (default is 0.6.3, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 37c022f3a3..4daf0bbe33 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -125,7 +125,9 @@ class StartArgs: vit_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) - llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]}) + llm_kv_type: str = field( + default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} + ) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( diff --git a/lightllm/utils/flashmla_utils.py b/lightllm/utils/flashmla_utils.py new file mode 100644 index 0000000000..8e1ba0e30f --- /dev/null +++ b/lightllm/utils/flashmla_utils.py @@ -0,0 +1,33 @@ +import sys +from importlib import import_module +from pathlib import Path + + +def _candidate_roots() -> list[Path]: + repo_root = Path(__file__).resolve().parents[2] + return [ + repo_root / "FlashMLA", + repo_root.parent / "FlashMLA", + ] + + +def import_flash_mla(): + try: + return import_module("flash_mla") + except ModuleNotFoundError: + pass + + for root in _candidate_roots(): + if root.exists(): + root_str = str(root) + if root_str not in sys.path: + sys.path.insert(0, root_str) + try: + return import_module("flash_mla") + except ModuleNotFoundError: + continue + + raise ModuleNotFoundError( + "flash_mla is not installed and no local FlashMLA checkout was found. " + "Install FlashMLA or place the repository at ./FlashMLA." + ) diff --git a/test/acc/test_deepseekv32_fp8kv_dsa.sh b/test/acc/test_deepseekv32_fp8kv_dsa.sh new file mode 100644 index 0000000000..008ef43b3b --- /dev/null +++ b/test/acc/test_deepseekv32_fp8kv_dsa.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +export PYTHONPATH="${ROOT_DIR}:${ROOT_DIR}/FlashMLA:${ROOT_DIR}/sglang/sgl-kernel/python:${PYTHONPATH:-}" + +pytest \ + "${ROOT_DIR}/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py" \ + "${ROOT_DIR}/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py" \ + -s + +python "${ROOT_DIR}/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py" --tokens-list 10000 100000 1000000 --page-size-list 1 64 128 256 --iters 100 --warmup 20 +python "${ROOT_DIR}/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py" --batch 64 --heads 128 --cache-tokens-list 10000 100000 1000000 --page-size-list 1 64 128 256 --topk 2048 --iters 100 --warmup 20 --check-correctness diff --git a/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py b/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py new file mode 100644 index 0000000000..3f4f0ef80b --- /dev/null +++ b/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py @@ -0,0 +1,105 @@ +import argparse +import statistics +import sys +from pathlib import Path + +import torch + +CUR_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str((CUR_DIR / "../../lightllm/common/basemodel/triton_kernel/kv_copy").resolve())) +sys.path.insert(0, str((CUR_DIR / "../../lightllm/models/deepseek3_2/triton_kernel").resolve())) + +from mla_copy_kv import destindex_copy_kv +from destindex_copy_kv_flashmla_fp8 import destindex_copy_kv_flashmla_fp8 + + +def _time_cuda(fn, warmup: int, iters: int) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + samples_ms = [] + for _ in range(iters): + start.record() + fn() + end.record() + end.synchronize() + samples_ms.append(start.elapsed_time(end)) + return statistics.mean(samples_ms) + + +def _gbps(token_num: int, bytes_per_token: int, elapsed_ms: float) -> float: + return (token_num * bytes_per_token) / (elapsed_ms / 1e3) / 1e9 + + +def _build_random_page_mapping(token_num: int, page_size: int, device: str): + num_pages = (token_num + page_size - 1) // page_size + padded_token_num = num_pages * page_size + physical_page_ids = torch.randperm(num_pages, dtype=torch.int64, device=device) + logical_tokens = torch.arange(token_num, dtype=torch.int64, device=device) + dest_loc = physical_page_ids[logical_tokens // page_size] * page_size + (logical_tokens % page_size) + return dest_loc, padded_token_num + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark DeepSeek V3.2 bf16 KV store vs fp8kv_dsa KV store") + parser.add_argument("--tokens", type=int, default=65536) + parser.add_argument("--tokens-list", type=int, nargs="*", default=None) + parser.add_argument("--page-size", type=int, default=None) + parser.add_argument("--page-size-list", type=int, nargs="*", default=[1, 64, 128, 256]) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--warmup", type=int, default=20) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + device = "cuda" + dtype = torch.bfloat16 + token_list = args.tokens_list if args.tokens_list else [args.tokens] + + bf16_total_bytes = 576 * 2 + 576 * 2 + fp8_total_bytes = 576 * 2 + 656 + + page_sizes = [args.page_size] if args.page_size else args.page_size_list + + print(f"iters={args.iters} warmup={args.warmup}") + for token_num in token_list: + kv = torch.randn((token_num, 1, 576), dtype=dtype, device=device) + + for page_size in page_sizes: + dest_loc, page_token_num = _build_random_page_mapping(token_num, page_size, device) + + bf16_nope = torch.empty((page_token_num, 1, 512), dtype=dtype, device=device) + bf16_rope = torch.empty((page_token_num, 1, 64), dtype=dtype, device=device) + fp8_packed = torch.empty((page_token_num, 1, 656), dtype=torch.uint8, device=device) + fp8_nope = fp8_packed[:, :, :512].view(torch.float8_e4m3fn) + fp8_scale = fp8_packed[:, :, 512:528].view(torch.float32) + fp8_rope = fp8_packed[:, :, 528:].view(torch.bfloat16) + + def run_bf16(): + destindex_copy_kv(kv[:, :, :512], kv[:, :, 512:], dest_loc, bf16_nope, bf16_rope) + + def run_fp8(): + destindex_copy_kv_flashmla_fp8( + kv[:, :, :512], + kv[:, :, 512:], + dest_loc, + fp8_nope, + fp8_scale, + fp8_rope, + ) + + bf16_ms = _time_cuda(run_bf16, warmup=args.warmup, iters=args.iters) + fp8_ms = _time_cuda(run_fp8, warmup=args.warmup, iters=args.iters) + + print(f"page_size={page_size} seqlen={token_num}") + print(f"bf16_kv: avg_ms={bf16_ms:.4f} total_traffic_gbps={_gbps(token_num, bf16_total_bytes, bf16_ms):.2f}") + print(f"fp8kv_dsa: avg_ms={fp8_ms:.4f} total_traffic_gbps={_gbps(token_num, fp8_total_bytes, fp8_ms):.2f}") + print(f"speedup={bf16_ms / fp8_ms:.3f}x compression={(576 * 2) / 656:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py b/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py new file mode 100644 index 0000000000..5ddbee1140 --- /dev/null +++ b/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py @@ -0,0 +1,160 @@ +import argparse +import statistics +import sys +from pathlib import Path + +import torch + +CUR_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str((CUR_DIR / "../../lightllm/models/deepseek3_2/triton_kernel").resolve())) +sys.path.insert(0, str((CUR_DIR / "../../lightllm/utils").resolve())) + +from destindex_copy_kv_flashmla_fp8 import pack_kv_reference +from flashmla_utils import import_flash_mla + + +def _time_cuda(fn, warmup: int, iters: int) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + samples_ms = [] + for _ in range(iters): + start.record() + fn() + end.record() + end.synchronize() + samples_ms.append(start.elapsed_time(end)) + return statistics.mean(samples_ms) + + +def _gbps(batch: int, topk: int, bytes_per_token: int, elapsed_ms: float) -> float: + return (batch * topk * bytes_per_token) / (elapsed_ms / 1e3) / 1e9 + + +def _build_random_page_layout( + cache_tokens: int, + page_size: int, + device: str, + dtype: torch.dtype, +): + num_pages = (cache_tokens + page_size - 1) // page_size + padded_cache_tokens = num_pages * page_size + physical_page_ids = torch.randperm(num_pages, dtype=torch.int64, device=device) + + logical_kv = torch.randn((cache_tokens, 1, 576), dtype=dtype, device=device) + physical_kv = torch.zeros((padded_cache_tokens, 1, 576), dtype=dtype, device=device) + + logical_tokens = torch.arange(cache_tokens, dtype=torch.int64, device=device) + physical_token_locs = physical_page_ids[logical_tokens // page_size] * page_size + (logical_tokens % page_size) + physical_kv[physical_token_locs] = logical_kv + + return physical_kv, physical_page_ids, physical_token_locs, padded_cache_tokens + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark DeepSeek V3.2 decode: bf16 sparse-selected vs fp8 DSA") + parser.add_argument("--batch", type=int, default=64) + parser.add_argument("--heads", type=int, default=128) + parser.add_argument("--cache-tokens", type=int, default=131072) + parser.add_argument("--cache-tokens-list", type=int, nargs="*", default=None) + parser.add_argument("--page-size", type=int, default=None) + parser.add_argument("--page-size-list", type=int, nargs="*", default=[1, 64, 128, 256]) + parser.add_argument("--topk", type=int, default=2048) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--check-correctness", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + flash_mla = import_flash_mla() + device = "cuda" + dtype = torch.bfloat16 + sm_scale = 576 ** (-0.5) + cache_token_list = args.cache_tokens_list if args.cache_tokens_list else [args.cache_tokens] + page_sizes = [args.page_size] if args.page_size else args.page_size_list + + q_nope = torch.randn((args.batch, args.heads, 512), dtype=dtype, device=device) + q_rope = torch.randn((args.batch, args.heads, 64), dtype=dtype, device=device) + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + print(f"batch={args.batch} heads={args.heads} topk={args.topk} iters={args.iters} warmup={args.warmup}") + for idx, cache_tokens in enumerate(cache_token_list): + for page_idx, page_size in enumerate(page_sizes): + physical_kv, physical_page_ids, physical_token_locs, padded_cache_tokens = _build_random_page_layout( + cache_tokens, page_size, device, dtype + ) + num_pages = padded_cache_tokens // page_size + + k_rope = physical_kv[:, :, 512:].view(num_pages, page_size, 1, 64).contiguous() + kv_nope = physical_kv[:, :, :512].view(num_pages, page_size, 1, 512).contiguous() + kv_fp8 = pack_kv_reference(physical_kv).view(num_pages, page_size, 1, 656).contiguous() + + selected_pages = (args.topk + page_size - 1) // page_size + page_table = physical_page_ids[:selected_pages].to(torch.int32).repeat(args.batch, 1) + cache_seqlens = torch.full((args.batch,), args.topk, dtype=torch.int32, device=device) + cu_seqlens_q = torch.arange(0, args.batch + 1, dtype=torch.int32, device=device) + cu_seqlens_k_new = torch.arange( + 0, (args.batch + 1) * args.topk, args.topk, dtype=torch.int32, device=device + ) + fp8_indices = ( + physical_token_locs[: args.topk].to(torch.int32).view(1, 1, args.topk).repeat(args.batch, 1, 1) + ) + sched_meta, _ = flash_mla.get_mla_metadata() + + def run_bf16(): + return flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=page_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=1, + softmax_scale=sm_scale, + causal=True, + ) + + def run_fp8(): + out, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv_fp8, + block_table=None, + cache_seqlens=None, + head_dim_v=512, + tile_scheduler_metadata=sched_meta, + num_splits=None, + softmax_scale=sm_scale, + causal=False, + is_fp8_kvcache=True, + indices=fp8_indices, + ) + return out[:, 0] + + if args.check_correctness and idx == 0 and page_idx == 0: + bf16_out = run_bf16() + fp8_out = run_fp8() + max_diff = (bf16_out - fp8_out).abs().max().item() + mean_diff = (bf16_out - fp8_out).abs().mean().item() + print(f"correctness: max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}") + + bf16_ms = _time_cuda(run_bf16, warmup=args.warmup, iters=args.iters) + fp8_ms = _time_cuda(run_fp8, warmup=args.warmup, iters=args.iters) + + print(f"page_size={page_size} seqlen={cache_tokens}") + print( + f"bf16_decode: avg_ms={bf16_ms:.4f} kv_read_gbps={_gbps(args.batch, args.topk, 576 * 2, bf16_ms):.2f}" + ) + print(f"fp8_dsa_decode: avg_ms={fp8_ms:.4f} kv_read_gbps={_gbps(args.batch, args.topk, 656, fp8_ms):.2f}") + print(f"speedup={bf16_ms / fp8_ms:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py b/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py new file mode 100644 index 0000000000..1582c0d08a --- /dev/null +++ b/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py @@ -0,0 +1,80 @@ +import sys +from pathlib import Path + +import pytest +import torch + +CUR_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str((CUR_DIR / "../../../lightllm/models/deepseek3_2/triton_kernel").resolve())) +sys.path.insert(0, str((CUR_DIR / "../../../lightllm/utils").resolve())) + +from destindex_copy_kv_flashmla_fp8 import dequantize_kv_reference, pack_kv_reference +from flashmla_utils import import_flash_mla + + +def _manual_sparse_decode( + q: torch.Tensor, dense_kv: torch.Tensor, indices: torch.Tensor, sm_scale: float +) -> torch.Tensor: + batch, _, heads, _ = q.shape + topk = indices.shape[-1] + out = torch.zeros((batch, heads, 512), dtype=torch.float32, device=q.device) + + for b in range(batch): + valid = indices[b, 0] >= 0 + cur_idx = indices[b, 0, valid] + assert cur_idx.numel() > 0 + cur_k = dense_kv[cur_idx, 0, :] + cur_v = cur_k[:, :512] + logits = torch.einsum("hd,td->ht", q[b, 0].float(), cur_k.float()) * sm_scale + probs = torch.softmax(logits, dim=-1) + out[b] = torch.einsum("ht,td->hd", probs, cur_v.float()) + + if cur_idx.numel() < topk: + assert torch.all(indices[b, 0, cur_idx.numel() :] == -1) + + return out.to(torch.bfloat16) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_flashmla_fp8_sparse_decode_matches_manual_reference(): + try: + flash_mla = import_flash_mla() + except ModuleNotFoundError as exc: + pytest.skip(str(exc)) + + batch = 2 + seq_q = 1 + heads = 64 + token_num = 128 + topk = 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn((batch, seq_q, heads, 576), dtype=dtype, device=device) + kv = torch.randn((token_num, 1, 576), dtype=dtype, device=device) + packed = pack_kv_reference(kv).view(token_num, 1, 1, 656) + + indices = torch.randint(0, token_num, (batch, seq_q, topk), dtype=torch.int32, device=device) + indices[0, 0, -3:] = -1 + indices[1, 0, -5:] = -1 + sm_scale = 576 ** (-0.5) + + sched_meta, _ = flash_mla.get_mla_metadata() + out, _ = flash_mla.flash_mla_with_kvcache( + q=q, + k_cache=packed, + block_table=None, + cache_seqlens=None, + head_dim_v=512, + tile_scheduler_metadata=sched_meta, + num_splits=None, + softmax_scale=sm_scale, + causal=False, + is_fp8_kvcache=True, + indices=indices, + ) + torch.cuda.synchronize() + + dense_kv = dequantize_kv_reference(packed) + ref = _manual_sparse_decode(q, dense_kv, indices, sm_scale) + assert torch.allclose(out[:, 0], ref, rtol=7e-2, atol=7e-2) diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py b/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py new file mode 100644 index 0000000000..840b23159f --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py @@ -0,0 +1,53 @@ +import sys +from pathlib import Path + +import pytest +import torch + +CUR_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str((CUR_DIR / "../../../../lightllm/models/deepseek3_2/triton_kernel").resolve())) + +from destindex_copy_kv_flashmla_fp8 import ( + dequantize_kv_reference, + destindex_copy_kv_flashmla_fp8, + pack_kv_reference, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_destindex_copy_kv_flashmla_fp8_matches_reference(): + token_num = 257 + kv = torch.randn((token_num, 1, 576), dtype=torch.bfloat16, device="cuda") + dest_loc = torch.randperm(token_num, device="cuda", dtype=torch.int64) + + out = torch.empty((token_num, 1, 656), dtype=torch.uint8, device="cuda") + out_nope = out[:, :, :512].view(torch.float8_e4m3fn) + out_scale = out[:, :, 512:528].view(torch.float32) + out_rope = out[:, :, 528:].view(torch.bfloat16) + + destindex_copy_kv_flashmla_fp8( + kv[:, :, :512], + kv[:, :, 512:], + dest_loc, + out_nope, + out_scale, + out_rope, + ) + torch.cuda.synchronize() + + ref = pack_kv_reference(kv) + assert torch.equal(out[dest_loc], ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_destindex_copy_kv_flashmla_fp8_roundtrip(): + token_num = 257 + kv = torch.randn((token_num, 1, 576), dtype=torch.bfloat16, device="cuda") + packed = pack_kv_reference(kv) + dequant = dequantize_kv_reference(packed) + + rope_err = (dequant[:, :, 512:] - kv[:, :, 512:]).abs().max().item() + nope_err = (dequant[:, :, :512] - kv[:, :, :512]).abs().max().item() + + assert rope_err == 0.0 + assert nope_err < 4e-1 From 14f0df0fb5a0f3b58b18879526c8f70865614fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Thu, 26 Mar 2026 16:50:59 +0800 Subject: [PATCH 2/4] refine --- docker/Dockerfile | 6 ++++ .../common/basemodel/attention/nsa/fp8.py | 8 ++--- lightllm/utils/flashmla_utils.py | 33 ------------------- ...k_deepseekv32_sparse_decode_fp8_vs_bf16.py | 5 ++- .../test_flashmla_fp8_sparse_decode.py | 7 +--- 5 files changed, 13 insertions(+), 46 deletions(-) delete mode 100644 lightllm/utils/flashmla_utils.py diff --git a/docker/Dockerfile b/docker/Dockerfile index e766107ae7..439ecddb34 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 ARG VLLM_VERSION=0.16.0 +ARG FLASH_MLA_REF=47c35a7 ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 @@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm==${VLLM_VERSION} +RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ + cd /root/FlashMLA && \ + git checkout ${FLASH_MLA_REF} && \ + git submodule update --init --recursive && \ + FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* diff --git a/lightllm/common/basemodel/attention/nsa/fp8.py b/lightllm/common/basemodel/attention/nsa/fp8.py index cac7e7eb14..509f315b0c 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8.py +++ b/lightllm/common/basemodel/attention/nsa/fp8.py @@ -4,7 +4,6 @@ import torch from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.flashmla_utils import import_flash_mla from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState @@ -72,7 +71,7 @@ def _nsa_prefill_att( q: torch.Tensor, att_control: AttControl, ) -> torch.Tensor: - flash_mla = import_flash_mla() + import flash_mla nsa_dict = att_control.nsa_prefill_dict layer_index = nsa_dict["layer_index"] @@ -135,7 +134,8 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) - flash_mla = import_flash_mla() + import flash_mla + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() return @@ -157,7 +157,7 @@ def _nsa_decode_att( kv: torch.Tensor, att_control: AttControl, ) -> torch.Tensor: - flash_mla = import_flash_mla() + import flash_mla nsa_dict = att_control.nsa_decode_dict topk_indices = nsa_dict["topk_indices"] diff --git a/lightllm/utils/flashmla_utils.py b/lightllm/utils/flashmla_utils.py deleted file mode 100644 index 8e1ba0e30f..0000000000 --- a/lightllm/utils/flashmla_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -import sys -from importlib import import_module -from pathlib import Path - - -def _candidate_roots() -> list[Path]: - repo_root = Path(__file__).resolve().parents[2] - return [ - repo_root / "FlashMLA", - repo_root.parent / "FlashMLA", - ] - - -def import_flash_mla(): - try: - return import_module("flash_mla") - except ModuleNotFoundError: - pass - - for root in _candidate_roots(): - if root.exists(): - root_str = str(root) - if root_str not in sys.path: - sys.path.insert(0, root_str) - try: - return import_module("flash_mla") - except ModuleNotFoundError: - continue - - raise ModuleNotFoundError( - "flash_mla is not installed and no local FlashMLA checkout was found. " - "Install FlashMLA or place the repository at ./FlashMLA." - ) diff --git a/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py b/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py index 5ddbee1140..e33441a56d 100644 --- a/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py +++ b/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py @@ -7,10 +7,8 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.insert(0, str((CUR_DIR / "../../lightllm/models/deepseek3_2/triton_kernel").resolve())) -sys.path.insert(0, str((CUR_DIR / "../../lightllm/utils").resolve())) from destindex_copy_kv_flashmla_fp8 import pack_kv_reference -from flashmla_utils import import_flash_mla def _time_cuda(fn, warmup: int, iters: int) -> float: @@ -73,7 +71,8 @@ def main(): from sgl_kernel.flash_attn import flash_attn_with_kvcache - flash_mla = import_flash_mla() + import flash_mla + device = "cuda" dtype = torch.bfloat16 sm_scale = 576 ** (-0.5) diff --git a/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py b/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py index 1582c0d08a..80d8ab9e7b 100644 --- a/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py +++ b/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py @@ -6,10 +6,8 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.insert(0, str((CUR_DIR / "../../../lightllm/models/deepseek3_2/triton_kernel").resolve())) -sys.path.insert(0, str((CUR_DIR / "../../../lightllm/utils").resolve())) from destindex_copy_kv_flashmla_fp8 import dequantize_kv_reference, pack_kv_reference -from flashmla_utils import import_flash_mla def _manual_sparse_decode( @@ -37,10 +35,7 @@ def _manual_sparse_decode( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") def test_flashmla_fp8_sparse_decode_matches_manual_reference(): - try: - flash_mla = import_flash_mla() - except ModuleNotFoundError as exc: - pytest.skip(str(exc)) + import flash_mla batch = 2 seq_q = 1 From d150428a06ed62e96a7e1c502348e2179a27dbf7 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 26 Mar 2026 16:54:12 +0800 Subject: [PATCH 3/4] remove test files --- .../common/basemodel/attention/__init__.py | 2 +- .../basemodel/attention/create_utils.py | 4 +- .../basemodel/attention/nsa/__init__.py | 14 +- .../common/basemodel/attention/nsa/fp8.py | 18 +- .../attention/nsa/fp8_flashmla_sparse.py | 187 ++++++++++++++++++ .../common/kv_cache_mem_manager/__init__.py | 4 +- ...ken_group_quant_deepseek3_2mem_manager.py} | 43 ++-- .../common/kv_cache_mem_manager/mem_utils.py | 4 +- test/acc/test_deepseekv32_fp8kv_dsa.sh | 13 -- .../kernel/benchmark_deepseekv32_fp8kv_dsa.py | 105 ---------- ...k_deepseekv32_sparse_decode_fp8_vs_bf16.py | 159 --------------- .../test_flashmla_fp8_sparse_decode.py | 75 ------- .../test_destindex_copy_kv_flashmla_fp8.py | 53 ----- 13 files changed, 239 insertions(+), 442 deletions(-) create mode 100644 lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py rename lightllm/common/kv_cache_mem_manager/{deepseek3_2_dsa_fp8_mem_manager.py => fp8_per_token_group_quant_deepseek3_2mem_manager.py} (89%) delete mode 100644 test/acc/test_deepseekv32_fp8kv_dsa.sh delete mode 100644 test/kernel/benchmark_deepseekv32_fp8kv_dsa.py delete mode 100644 test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py delete mode 100644 unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py delete mode 100644 unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 9ddac8b18d..10cd3b0864 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -12,7 +12,7 @@ # NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend -from .nsa.fp8 import NsaFlashMlaFp8AttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend from .create_utils import ( get_prefill_att_backend_class, diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index aeb7b5dced..2c4a34d325 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -15,7 +15,7 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend -from .nsa.fp8 import NsaFlashMlaFp8AttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend logger = init_logger(__name__) @@ -58,7 +58,7 @@ # Future backends: "fa3", "tilelang", "aiter" }, "fp8kv_dsa": { - "flashmla_sparse": NsaFlashMlaFp8AttBackend, + "flashmla_sparse": NsaFlashMlaFp8SparseAttBackend, }, } diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py index 6c154e5544..f9db52dc2b 100644 --- a/lightllm/common/basemodel/attention/nsa/__init__.py +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -5,17 +5,17 @@ NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState, ) -from .fp8 import ( - NsaFlashMlaFp8AttBackend, - NsaFlashMlaFp8PrefillAttState, - NsaFlashMlaFp8DecodeAttState, +from .fp8_flashmla_sparse import ( + NsaFlashMlaFp8SparseAttBackend, + NsaFlashMlaFp8SparsePrefillAttState, + NsaFlashMlaFp8SparseDecodeAttState, ) __all__ = [ "NsaFlashMlaSparseAttBackend", "NsaFlashMlaSparsePrefillAttState", "NsaFlashMlaSparseDecodeAttState", - "NsaFlashMlaFp8AttBackend", - "NsaFlashMlaFp8PrefillAttState", - "NsaFlashMlaFp8DecodeAttState", + "NsaFlashMlaFp8SparseAttBackend", + "NsaFlashMlaFp8SparsePrefillAttState", + "NsaFlashMlaFp8SparseDecodeAttState", ] diff --git a/lightllm/common/basemodel/attention/nsa/fp8.py b/lightllm/common/basemodel/attention/nsa/fp8.py index 509f315b0c..4ddccf391e 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8.py +++ b/lightllm/common/basemodel/attention/nsa/fp8.py @@ -11,7 +11,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo -class NsaFlashMlaFp8AttBackend(BaseAttBackend): +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) device = get_current_device_id() @@ -20,22 +20,22 @@ def __init__(self, model): for _ in range(2) ] - def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8PrefillAttState": - return NsaFlashMlaFp8PrefillAttState(backend=self, infer_state=infer_state) + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) - def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8DecodeAttState": - return NsaFlashMlaFp8DecodeAttState(backend=self, infer_state=infer_state) + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) @dataclasses.dataclass -class NsaFlashMlaFp8PrefillAttState(BasePrefillAttState): +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): ks: torch.Tensor = None ke: torch.Tensor = None lengths: torch.Tensor = None ragged_mem_index: torch.Tensor = None def init_state(self): - self.backend: NsaFlashMlaFp8AttBackend = self.backend + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend self.ragged_mem_index = torch.empty( self.infer_state.total_token_num, dtype=torch.int32, @@ -99,7 +99,7 @@ def _nsa_prefill_att( @dataclasses.dataclass -class NsaFlashMlaFp8DecodeAttState(BaseDecodeAttState): +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): ks: torch.Tensor = None ke: torch.Tensor = None lengths: torch.Tensor = None @@ -107,7 +107,7 @@ class NsaFlashMlaFp8DecodeAttState(BaseDecodeAttState): flashmla_sched_meta: object = None def init_state(self): - self.backend: NsaFlashMlaFp8AttBackend = self.backend + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend model = self.backend.model use_cuda_graph = ( self.infer_state.batch_size <= model.graph_max_batch_size diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py new file mode 100644 index 0000000000..4ddccf391e --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -0,0 +1,187 @@ +import dataclasses +from typing import TYPE_CHECKING, Tuple + +import torch + +from lightllm.utils.dist_utils import get_current_device_id + +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + return self._nsa_prefill_att(q=q, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_prefill_dict + layer_index = nsa_dict["layer_index"] + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index) + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32) + if topk_length.ndim == 2 and topk_length.shape[1] == 1: + topk_length = topk_length[:, 0].contiguous() + + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + q=q.contiguous(), + kv=kv.contiguous(), + indices=topk_indices.contiguous(), + sm_scale=softmax_scale, + d_v=kv_lora_rank, + topk_length=topk_length, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + flashmla_sched_meta: object = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + import flash_mla + + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + return + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + + q_nope, q_rope = q + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + + o_tensor, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv.contiguous(), + block_table=None, + cache_seqlens=None, + head_dim_v=kv_lora_rank, + tile_scheduler_metadata=self.flashmla_sched_meta, + num_splits=None, + softmax_scale=softmax_scale, + causal=False, + is_fp8_kvcache=True, + indices=topk_indices.contiguous(), + ) + return o_tensor[:, 0, :, :] diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 7421b5c73e..79e75b3485 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -3,7 +3,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager -from .deepseek3_2_dsa_fp8_mem_manager import Deepseek3_2DSAFP8MemoryManager +from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -14,7 +14,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", - "Deepseek3_2DSAFP8MemoryManager", + "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py similarity index 89% rename from lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py rename to lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py index aa6e9bcd89..fb0eeb6ecd 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2_dsa_fp8_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -9,10 +9,14 @@ from .deepseek2_mem_manager import Deepseek2MemoryManager -class Deepseek3_2DSAFP8MemoryManager(Deepseek2MemoryManager): +class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): flashmla_bytes_per_token = 656 indexer_bytes_per_token = 132 kv_head_dim = 576 + kv_nope_dim = 512 + kv_rope_dim = 64 + quant_group_size = 128 + quant_group_num = kv_nope_dim // quant_group_size def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" @@ -22,8 +26,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False ) def get_cell_size(self): - prefill_bytes = self.kv_head_dim * torch._utils._element_size(self.prefill_dtype) - return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token + prefill_bytes) + return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty( @@ -32,9 +35,6 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.indexer_k_buffer = torch.empty( (layer_num, size + 1, head_num, self.indexer_bytes_per_token), dtype=torch.uint8, device="cuda" ) - self.prefill_kv_buffer = torch.empty( - (layer_num, size + 1, head_num, self.kv_head_dim), dtype=self.prefill_dtype, device="cuda" - ) def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( @@ -56,7 +56,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: o_scale, o_rope, ) - self.prefill_kv_buffer[layer_index, mem_index, :, :] = kv def get_att_input_params(self, layer_index: int) -> Any: return self.get_flashmla_kv_cache(layer_index) @@ -65,7 +64,23 @@ def get_flashmla_kv_cache(self, layer_index: int) -> torch.Tensor: return self.kv_buffer[layer_index].view(-1, 1, 1, self.flashmla_bytes_per_token) def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: - return self.prefill_kv_buffer[layer_index] + packed_kv = self.kv_buffer[layer_index] + kv_nope = packed_kv[:, :, : self.kv_nope_dim].view(torch.float8_e4m3fn) + kv_scale = packed_kv[:, :, self.kv_nope_dim : self.kv_nope_dim + self.quant_group_num * 4].view(torch.float32) + kv_rope = packed_kv[:, :, self.kv_nope_dim + self.quant_group_num * 4 :].view(torch.bfloat16) + + kv_nope = kv_nope.view(-1, 1, self.quant_group_num, self.quant_group_size).to(self.prefill_dtype) + kv_scale = kv_scale.to(self.prefill_dtype).unsqueeze(-1) + kv_nope = (kv_nope * kv_scale).view(-1, 1, self.kv_nope_dim) + + kv = torch.empty( + (packed_kv.shape[0], packed_kv.shape[1], self.kv_head_dim), + dtype=self.prefill_dtype, + device=packed_kv.device, + ) + kv[:, :, : self.kv_nope_dim] = kv_nope + kv[:, :, self.kv_nope_dim :] = kv_rope.to(self.prefill_dtype) + return kv def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: return self.indexer_k_buffer[layer_index] @@ -103,7 +118,7 @@ def write_mem_to_page_kv_move_buffer( mem_indexes: List[int], page_index: int, dp_index: int, - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_world_size: int, ): cur_page = self.kv_move_buffer[page_index] @@ -125,7 +140,7 @@ def read_page_kv_move_buffer_to_mem( mem_indexes: List[int], page_index: int, dp_index: int, - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_world_size: int, ): cur_page = self.kv_move_buffer[page_index] @@ -146,7 +161,7 @@ def read_page_kv_move_buffer_to_mem( def send_to_decode_node( self, move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): @@ -164,7 +179,7 @@ def send_to_decode_node( def receive_from_prefill_node( self, move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): @@ -204,7 +219,7 @@ def receive_from_prefill_node( def send_to_decode_node_p2p( self, move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): @@ -223,7 +238,7 @@ def send_to_decode_node_p2p( def receive_from_prefill_node_p2p( self, move_tasks: List[KVMoveTask], - mem_managers: List["Deepseek3_2DSAFP8MemoryManager"], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index d64c645ffb..79ea448794 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -25,9 +25,9 @@ def select_mem_manager_class(): if issubclass(model_class, Deepseek3_2TpPartModel): if get_env_start_args().llm_kv_type == "fp8kv_dsa": - from . import Deepseek3_2DSAFP8MemoryManager + from . import FP8PerTokenGroupQuantDeepseek3_2MemoryManager - mem_class = Deepseek3_2DSAFP8MemoryManager + mem_class = FP8PerTokenGroupQuantDeepseek3_2MemoryManager else: mem_class = Deepseek3_2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") diff --git a/test/acc/test_deepseekv32_fp8kv_dsa.sh b/test/acc/test_deepseekv32_fp8kv_dsa.sh deleted file mode 100644 index 008ef43b3b..0000000000 --- a/test/acc/test_deepseekv32_fp8kv_dsa.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" -export PYTHONPATH="${ROOT_DIR}:${ROOT_DIR}/FlashMLA:${ROOT_DIR}/sglang/sgl-kernel/python:${PYTHONPATH:-}" - -pytest \ - "${ROOT_DIR}/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py" \ - "${ROOT_DIR}/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py" \ - -s - -python "${ROOT_DIR}/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py" --tokens-list 10000 100000 1000000 --page-size-list 1 64 128 256 --iters 100 --warmup 20 -python "${ROOT_DIR}/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py" --batch 64 --heads 128 --cache-tokens-list 10000 100000 1000000 --page-size-list 1 64 128 256 --topk 2048 --iters 100 --warmup 20 --check-correctness diff --git a/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py b/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py deleted file mode 100644 index 3f4f0ef80b..0000000000 --- a/test/kernel/benchmark_deepseekv32_fp8kv_dsa.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import statistics -import sys -from pathlib import Path - -import torch - -CUR_DIR = Path(__file__).resolve().parent -sys.path.insert(0, str((CUR_DIR / "../../lightllm/common/basemodel/triton_kernel/kv_copy").resolve())) -sys.path.insert(0, str((CUR_DIR / "../../lightllm/models/deepseek3_2/triton_kernel").resolve())) - -from mla_copy_kv import destindex_copy_kv -from destindex_copy_kv_flashmla_fp8 import destindex_copy_kv_flashmla_fp8 - - -def _time_cuda(fn, warmup: int, iters: int) -> float: - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - samples_ms = [] - for _ in range(iters): - start.record() - fn() - end.record() - end.synchronize() - samples_ms.append(start.elapsed_time(end)) - return statistics.mean(samples_ms) - - -def _gbps(token_num: int, bytes_per_token: int, elapsed_ms: float) -> float: - return (token_num * bytes_per_token) / (elapsed_ms / 1e3) / 1e9 - - -def _build_random_page_mapping(token_num: int, page_size: int, device: str): - num_pages = (token_num + page_size - 1) // page_size - padded_token_num = num_pages * page_size - physical_page_ids = torch.randperm(num_pages, dtype=torch.int64, device=device) - logical_tokens = torch.arange(token_num, dtype=torch.int64, device=device) - dest_loc = physical_page_ids[logical_tokens // page_size] * page_size + (logical_tokens % page_size) - return dest_loc, padded_token_num - - -def main(): - parser = argparse.ArgumentParser(description="Benchmark DeepSeek V3.2 bf16 KV store vs fp8kv_dsa KV store") - parser.add_argument("--tokens", type=int, default=65536) - parser.add_argument("--tokens-list", type=int, nargs="*", default=None) - parser.add_argument("--page-size", type=int, default=None) - parser.add_argument("--page-size-list", type=int, nargs="*", default=[1, 64, 128, 256]) - parser.add_argument("--iters", type=int, default=100) - parser.add_argument("--warmup", type=int, default=20) - args = parser.parse_args() - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required for this benchmark") - - device = "cuda" - dtype = torch.bfloat16 - token_list = args.tokens_list if args.tokens_list else [args.tokens] - - bf16_total_bytes = 576 * 2 + 576 * 2 - fp8_total_bytes = 576 * 2 + 656 - - page_sizes = [args.page_size] if args.page_size else args.page_size_list - - print(f"iters={args.iters} warmup={args.warmup}") - for token_num in token_list: - kv = torch.randn((token_num, 1, 576), dtype=dtype, device=device) - - for page_size in page_sizes: - dest_loc, page_token_num = _build_random_page_mapping(token_num, page_size, device) - - bf16_nope = torch.empty((page_token_num, 1, 512), dtype=dtype, device=device) - bf16_rope = torch.empty((page_token_num, 1, 64), dtype=dtype, device=device) - fp8_packed = torch.empty((page_token_num, 1, 656), dtype=torch.uint8, device=device) - fp8_nope = fp8_packed[:, :, :512].view(torch.float8_e4m3fn) - fp8_scale = fp8_packed[:, :, 512:528].view(torch.float32) - fp8_rope = fp8_packed[:, :, 528:].view(torch.bfloat16) - - def run_bf16(): - destindex_copy_kv(kv[:, :, :512], kv[:, :, 512:], dest_loc, bf16_nope, bf16_rope) - - def run_fp8(): - destindex_copy_kv_flashmla_fp8( - kv[:, :, :512], - kv[:, :, 512:], - dest_loc, - fp8_nope, - fp8_scale, - fp8_rope, - ) - - bf16_ms = _time_cuda(run_bf16, warmup=args.warmup, iters=args.iters) - fp8_ms = _time_cuda(run_fp8, warmup=args.warmup, iters=args.iters) - - print(f"page_size={page_size} seqlen={token_num}") - print(f"bf16_kv: avg_ms={bf16_ms:.4f} total_traffic_gbps={_gbps(token_num, bf16_total_bytes, bf16_ms):.2f}") - print(f"fp8kv_dsa: avg_ms={fp8_ms:.4f} total_traffic_gbps={_gbps(token_num, fp8_total_bytes, fp8_ms):.2f}") - print(f"speedup={bf16_ms / fp8_ms:.3f}x compression={(576 * 2) / 656:.3f}x") - - -if __name__ == "__main__": - main() diff --git a/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py b/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py deleted file mode 100644 index e33441a56d..0000000000 --- a/test/kernel/benchmark_deepseekv32_sparse_decode_fp8_vs_bf16.py +++ /dev/null @@ -1,159 +0,0 @@ -import argparse -import statistics -import sys -from pathlib import Path - -import torch - -CUR_DIR = Path(__file__).resolve().parent -sys.path.insert(0, str((CUR_DIR / "../../lightllm/models/deepseek3_2/triton_kernel").resolve())) - -from destindex_copy_kv_flashmla_fp8 import pack_kv_reference - - -def _time_cuda(fn, warmup: int, iters: int) -> float: - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - samples_ms = [] - for _ in range(iters): - start.record() - fn() - end.record() - end.synchronize() - samples_ms.append(start.elapsed_time(end)) - return statistics.mean(samples_ms) - - -def _gbps(batch: int, topk: int, bytes_per_token: int, elapsed_ms: float) -> float: - return (batch * topk * bytes_per_token) / (elapsed_ms / 1e3) / 1e9 - - -def _build_random_page_layout( - cache_tokens: int, - page_size: int, - device: str, - dtype: torch.dtype, -): - num_pages = (cache_tokens + page_size - 1) // page_size - padded_cache_tokens = num_pages * page_size - physical_page_ids = torch.randperm(num_pages, dtype=torch.int64, device=device) - - logical_kv = torch.randn((cache_tokens, 1, 576), dtype=dtype, device=device) - physical_kv = torch.zeros((padded_cache_tokens, 1, 576), dtype=dtype, device=device) - - logical_tokens = torch.arange(cache_tokens, dtype=torch.int64, device=device) - physical_token_locs = physical_page_ids[logical_tokens // page_size] * page_size + (logical_tokens % page_size) - physical_kv[physical_token_locs] = logical_kv - - return physical_kv, physical_page_ids, physical_token_locs, padded_cache_tokens - - -def main(): - parser = argparse.ArgumentParser(description="Benchmark DeepSeek V3.2 decode: bf16 sparse-selected vs fp8 DSA") - parser.add_argument("--batch", type=int, default=64) - parser.add_argument("--heads", type=int, default=128) - parser.add_argument("--cache-tokens", type=int, default=131072) - parser.add_argument("--cache-tokens-list", type=int, nargs="*", default=None) - parser.add_argument("--page-size", type=int, default=None) - parser.add_argument("--page-size-list", type=int, nargs="*", default=[1, 64, 128, 256]) - parser.add_argument("--topk", type=int, default=2048) - parser.add_argument("--iters", type=int, default=100) - parser.add_argument("--warmup", type=int, default=20) - parser.add_argument("--check-correctness", action="store_true") - args = parser.parse_args() - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required for this benchmark") - - from sgl_kernel.flash_attn import flash_attn_with_kvcache - - import flash_mla - - device = "cuda" - dtype = torch.bfloat16 - sm_scale = 576 ** (-0.5) - cache_token_list = args.cache_tokens_list if args.cache_tokens_list else [args.cache_tokens] - page_sizes = [args.page_size] if args.page_size else args.page_size_list - - q_nope = torch.randn((args.batch, args.heads, 512), dtype=dtype, device=device) - q_rope = torch.randn((args.batch, args.heads, 64), dtype=dtype, device=device) - q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() - print(f"batch={args.batch} heads={args.heads} topk={args.topk} iters={args.iters} warmup={args.warmup}") - for idx, cache_tokens in enumerate(cache_token_list): - for page_idx, page_size in enumerate(page_sizes): - physical_kv, physical_page_ids, physical_token_locs, padded_cache_tokens = _build_random_page_layout( - cache_tokens, page_size, device, dtype - ) - num_pages = padded_cache_tokens // page_size - - k_rope = physical_kv[:, :, 512:].view(num_pages, page_size, 1, 64).contiguous() - kv_nope = physical_kv[:, :, :512].view(num_pages, page_size, 1, 512).contiguous() - kv_fp8 = pack_kv_reference(physical_kv).view(num_pages, page_size, 1, 656).contiguous() - - selected_pages = (args.topk + page_size - 1) // page_size - page_table = physical_page_ids[:selected_pages].to(torch.int32).repeat(args.batch, 1) - cache_seqlens = torch.full((args.batch,), args.topk, dtype=torch.int32, device=device) - cu_seqlens_q = torch.arange(0, args.batch + 1, dtype=torch.int32, device=device) - cu_seqlens_k_new = torch.arange( - 0, (args.batch + 1) * args.topk, args.topk, dtype=torch.int32, device=device - ) - fp8_indices = ( - physical_token_locs[: args.topk].to(torch.int32).view(1, 1, args.topk).repeat(args.batch, 1, 1) - ) - sched_meta, _ = flash_mla.get_mla_metadata() - - def run_bf16(): - return flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=page_table, - cache_seqlens=cache_seqlens, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - ) - - def run_fp8(): - out, _ = flash_mla.flash_mla_with_kvcache( - q=q_all, - k_cache=kv_fp8, - block_table=None, - cache_seqlens=None, - head_dim_v=512, - tile_scheduler_metadata=sched_meta, - num_splits=None, - softmax_scale=sm_scale, - causal=False, - is_fp8_kvcache=True, - indices=fp8_indices, - ) - return out[:, 0] - - if args.check_correctness and idx == 0 and page_idx == 0: - bf16_out = run_bf16() - fp8_out = run_fp8() - max_diff = (bf16_out - fp8_out).abs().max().item() - mean_diff = (bf16_out - fp8_out).abs().mean().item() - print(f"correctness: max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}") - - bf16_ms = _time_cuda(run_bf16, warmup=args.warmup, iters=args.iters) - fp8_ms = _time_cuda(run_fp8, warmup=args.warmup, iters=args.iters) - - print(f"page_size={page_size} seqlen={cache_tokens}") - print( - f"bf16_decode: avg_ms={bf16_ms:.4f} kv_read_gbps={_gbps(args.batch, args.topk, 576 * 2, bf16_ms):.2f}" - ) - print(f"fp8_dsa_decode: avg_ms={fp8_ms:.4f} kv_read_gbps={_gbps(args.batch, args.topk, 656, fp8_ms):.2f}") - print(f"speedup={bf16_ms / fp8_ms:.3f}x") - - -if __name__ == "__main__": - main() diff --git a/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py b/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py deleted file mode 100644 index 80d8ab9e7b..0000000000 --- a/unit_tests/models/deepseek3_2/test_flashmla_fp8_sparse_decode.py +++ /dev/null @@ -1,75 +0,0 @@ -import sys -from pathlib import Path - -import pytest -import torch - -CUR_DIR = Path(__file__).resolve().parent -sys.path.insert(0, str((CUR_DIR / "../../../lightllm/models/deepseek3_2/triton_kernel").resolve())) - -from destindex_copy_kv_flashmla_fp8 import dequantize_kv_reference, pack_kv_reference - - -def _manual_sparse_decode( - q: torch.Tensor, dense_kv: torch.Tensor, indices: torch.Tensor, sm_scale: float -) -> torch.Tensor: - batch, _, heads, _ = q.shape - topk = indices.shape[-1] - out = torch.zeros((batch, heads, 512), dtype=torch.float32, device=q.device) - - for b in range(batch): - valid = indices[b, 0] >= 0 - cur_idx = indices[b, 0, valid] - assert cur_idx.numel() > 0 - cur_k = dense_kv[cur_idx, 0, :] - cur_v = cur_k[:, :512] - logits = torch.einsum("hd,td->ht", q[b, 0].float(), cur_k.float()) * sm_scale - probs = torch.softmax(logits, dim=-1) - out[b] = torch.einsum("ht,td->hd", probs, cur_v.float()) - - if cur_idx.numel() < topk: - assert torch.all(indices[b, 0, cur_idx.numel() :] == -1) - - return out.to(torch.bfloat16) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -def test_flashmla_fp8_sparse_decode_matches_manual_reference(): - import flash_mla - - batch = 2 - seq_q = 1 - heads = 64 - token_num = 128 - topk = 64 - dtype = torch.bfloat16 - device = "cuda" - - q = torch.randn((batch, seq_q, heads, 576), dtype=dtype, device=device) - kv = torch.randn((token_num, 1, 576), dtype=dtype, device=device) - packed = pack_kv_reference(kv).view(token_num, 1, 1, 656) - - indices = torch.randint(0, token_num, (batch, seq_q, topk), dtype=torch.int32, device=device) - indices[0, 0, -3:] = -1 - indices[1, 0, -5:] = -1 - sm_scale = 576 ** (-0.5) - - sched_meta, _ = flash_mla.get_mla_metadata() - out, _ = flash_mla.flash_mla_with_kvcache( - q=q, - k_cache=packed, - block_table=None, - cache_seqlens=None, - head_dim_v=512, - tile_scheduler_metadata=sched_meta, - num_splits=None, - softmax_scale=sm_scale, - causal=False, - is_fp8_kvcache=True, - indices=indices, - ) - torch.cuda.synchronize() - - dense_kv = dequantize_kv_reference(packed) - ref = _manual_sparse_decode(q, dense_kv, indices, sm_scale) - assert torch.allclose(out[:, 0], ref, rtol=7e-2, atol=7e-2) diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py b/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py deleted file mode 100644 index 840b23159f..0000000000 --- a/unit_tests/models/deepseek3_2/triton_kernel/test_destindex_copy_kv_flashmla_fp8.py +++ /dev/null @@ -1,53 +0,0 @@ -import sys -from pathlib import Path - -import pytest -import torch - -CUR_DIR = Path(__file__).resolve().parent -sys.path.insert(0, str((CUR_DIR / "../../../../lightllm/models/deepseek3_2/triton_kernel").resolve())) - -from destindex_copy_kv_flashmla_fp8 import ( - dequantize_kv_reference, - destindex_copy_kv_flashmla_fp8, - pack_kv_reference, -) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -def test_destindex_copy_kv_flashmla_fp8_matches_reference(): - token_num = 257 - kv = torch.randn((token_num, 1, 576), dtype=torch.bfloat16, device="cuda") - dest_loc = torch.randperm(token_num, device="cuda", dtype=torch.int64) - - out = torch.empty((token_num, 1, 656), dtype=torch.uint8, device="cuda") - out_nope = out[:, :, :512].view(torch.float8_e4m3fn) - out_scale = out[:, :, 512:528].view(torch.float32) - out_rope = out[:, :, 528:].view(torch.bfloat16) - - destindex_copy_kv_flashmla_fp8( - kv[:, :, :512], - kv[:, :, 512:], - dest_loc, - out_nope, - out_scale, - out_rope, - ) - torch.cuda.synchronize() - - ref = pack_kv_reference(kv) - assert torch.equal(out[dest_loc], ref) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") -def test_destindex_copy_kv_flashmla_fp8_roundtrip(): - token_num = 257 - kv = torch.randn((token_num, 1, 576), dtype=torch.bfloat16, device="cuda") - packed = pack_kv_reference(kv) - dequant = dequantize_kv_reference(packed) - - rope_err = (dequant[:, :, 512:] - kv[:, :, 512:]).abs().max().item() - nope_err = (dequant[:, :, :512] - kv[:, :, :512]).abs().max().item() - - assert rope_err == 0.0 - assert nope_err < 4e-1 From a505f8ca7c07905df9ea7bccc81aed133eec3198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Fri, 27 Mar 2026 20:07:29 +0800 Subject: [PATCH 4/4] perf --- .../attention/nsa/fp8_flashmla_sparse.py | 12 ++++++- ...oken_group_quant_deepseek3_2mem_manager.py | 36 +++++++++++++++++-- .../layer_infer/transformer_layer_infer.py | 15 +++++--- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py index 4ddccf391e..135aa92db6 100644 --- a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -76,10 +76,20 @@ def _nsa_prefill_att( nsa_dict = att_control.nsa_prefill_dict layer_index = nsa_dict["layer_index"] topk_indices = nsa_dict["topk_indices"] + topk_indices_local = nsa_dict["topk_indices_local"] + prefill_cache_kv = nsa_dict["prefill_cache_kv"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] - kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index) + if self.infer_state.prefix_total_token_num > 0: + kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices( + layer_index=layer_index, + topk_indices=topk_indices, + ) + else: + kv = prefill_cache_kv + topk_indices = topk_indices_local + if topk_indices.ndim == 2: topk_indices = topk_indices.unsqueeze(1) diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py index fb0eeb6ecd..3016982879 100644 --- a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -63,8 +63,7 @@ def get_att_input_params(self, layer_index: int) -> Any: def get_flashmla_kv_cache(self, layer_index: int) -> torch.Tensor: return self.kv_buffer[layer_index].view(-1, 1, 1, self.flashmla_bytes_per_token) - def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: - packed_kv = self.kv_buffer[layer_index] + def _dequantize_packed_kv(self, packed_kv: torch.Tensor) -> torch.Tensor: kv_nope = packed_kv[:, :, : self.kv_nope_dim].view(torch.float8_e4m3fn) kv_scale = packed_kv[:, :, self.kv_nope_dim : self.kv_nope_dim + self.quant_group_num * 4].view(torch.float32) kv_rope = packed_kv[:, :, self.kv_nope_dim + self.quant_group_num * 4 :].view(torch.bfloat16) @@ -82,6 +81,39 @@ def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: kv[:, :, self.kv_nope_dim :] = kv_rope.to(self.prefill_dtype) return kv + def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: + return self._dequantize_packed_kv(self.kv_buffer[layer_index]) + + def get_prefill_kv_cache_and_remap_indices(self, layer_index: int, topk_indices: torch.Tensor): + squeeze_h_kv = topk_indices.ndim == 2 + if squeeze_h_kv: + topk_indices = topk_indices.unsqueeze(1) + + valid_mask = topk_indices != -1 + valid_indices = topk_indices[valid_mask] + + if valid_indices.numel() == 0: + empty_kv = torch.empty( + (0, 1, self.kv_head_dim), + dtype=self.prefill_dtype, + device=topk_indices.device, + ) + remapped = topk_indices.clone() + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return empty_kv, remapped + + unique_mem_index, inverse = torch.unique(valid_indices, sorted=False, return_inverse=True) + packed_kv = self.kv_buffer[layer_index].index_select(0, unique_mem_index.to(torch.int64)) + compact_kv = self._dequantize_packed_kv(packed_kv) + + remapped = torch.full_like(topk_indices, -1) + remapped[valid_mask] = inverse.to(remapped.dtype) + + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return compact_kv, remapped + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: return self.indexer_k_buffer[layer_index] diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 06631a4286..c3fa32e25f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -75,12 +75,13 @@ def _context_attention_kernel( # 计算 topk_indices att_state = infer_state.prefill_att_state - topk_indices = self.indexer.get_indices( + topk_indices_local, topk_indices = self.indexer.get_indices( hidden_states=infer_state.get_topk_indices_params["hidden_states"], q_lora=infer_state.get_topk_indices_params["q_lora"], infer_state=infer_state, att_state=att_state, layer_weight=layer_weight, + return_local_index=True, ) del infer_state.get_topk_indices_params @@ -90,6 +91,8 @@ def _context_attention_kernel( nsa_prefill_dict={ "layer_index": self.layer_num_, "topk_indices": topk_indices, + "topk_indices_local": topk_indices_local, + "prefill_cache_kv": kv, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, }, @@ -171,6 +174,7 @@ def get_indices( infer_state: Deepseek2InferStateInfo, att_state: Any, layer_weight: Deepseek3_2TransformerLayerWeight, + return_local_index: bool = False, ) -> torch.Tensor: q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) @@ -234,15 +238,18 @@ def get_indices( row_starts=ks, ) b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1) + local_topk_index = b_topk_index # 将 topk index 转化为 mem index from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index - b_topk_index = trans_topk_index_to_mem_index( - topk_index=b_topk_index, + b_topk_mem_index = trans_topk_index_to_mem_index( + topk_index=local_topk_index, ragged_mem_index=att_state.ragged_mem_index, ) - return b_topk_index + if return_local_index: + return local_topk_index, b_topk_mem_index + return b_topk_mem_index @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: