diff --git a/docker/Dockerfile b/docker/Dockerfile index e766107ae..439ecddb3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 ARG VLLM_VERSION=0.16.0 +ARG FLASH_MLA_REF=47c35a7 ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 @@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm==${VLLM_VERSION} +RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ + cd /root/FlashMLA && \ + git checkout ${FLASH_MLA_REF} && \ + git submodule update --init --recursive && \ + FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 0eea52cc8..10cd3b086 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -12,6 +12,7 @@ # NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend from .create_utils import ( get_prefill_att_backend_class, diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 3ba16e218..2c4a34d32 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -15,6 +15,7 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend logger = init_logger(__name__) @@ -56,6 +57,9 @@ "flashmla_sparse": NsaFlashMlaSparseAttBackend, # Future backends: "fa3", "tilelang", "aiter" }, + "fp8kv_dsa": { + "flashmla_sparse": NsaFlashMlaFp8SparseAttBackend, + }, } diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py index 11a1ebfdc..f9db52dc2 100644 --- a/lightllm/common/basemodel/attention/nsa/__init__.py +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -5,9 +5,17 @@ NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState, ) +from .fp8_flashmla_sparse import ( + NsaFlashMlaFp8SparseAttBackend, + NsaFlashMlaFp8SparsePrefillAttState, + NsaFlashMlaFp8SparseDecodeAttState, +) __all__ = [ "NsaFlashMlaSparseAttBackend", "NsaFlashMlaSparsePrefillAttState", "NsaFlashMlaSparseDecodeAttState", + "NsaFlashMlaFp8SparseAttBackend", + "NsaFlashMlaFp8SparsePrefillAttState", + "NsaFlashMlaFp8SparseDecodeAttState", ] diff --git a/lightllm/common/basemodel/attention/nsa/fp8.py b/lightllm/common/basemodel/attention/nsa/fp8.py new file mode 100644 index 000000000..4ddccf391 --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/fp8.py @@ -0,0 +1,187 @@ +import dataclasses +from typing import TYPE_CHECKING, Tuple + +import torch + +from lightllm.utils.dist_utils import get_current_device_id + +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + return self._nsa_prefill_att(q=q, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_prefill_dict + layer_index = nsa_dict["layer_index"] + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index) + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32) + if topk_length.ndim == 2 and topk_length.shape[1] == 1: + topk_length = topk_length[:, 0].contiguous() + + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + q=q.contiguous(), + kv=kv.contiguous(), + indices=topk_indices.contiguous(), + sm_scale=softmax_scale, + d_v=kv_lora_rank, + topk_length=topk_length, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + flashmla_sched_meta: object = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + import flash_mla + + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + return + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + + q_nope, q_rope = q + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + + o_tensor, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv.contiguous(), + block_table=None, + cache_seqlens=None, + head_dim_v=kv_lora_rank, + tile_scheduler_metadata=self.flashmla_sched_meta, + num_splits=None, + softmax_scale=softmax_scale, + causal=False, + is_fp8_kvcache=True, + indices=topk_indices.contiguous(), + ) + return o_tensor[:, 0, :, :] diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py new file mode 100644 index 000000000..135aa92db --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -0,0 +1,197 @@ +import dataclasses +from typing import TYPE_CHECKING, Tuple + +import torch + +from lightllm.utils.dist_utils import get_current_device_id + +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + return self._nsa_prefill_att(q=q, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_prefill_dict + layer_index = nsa_dict["layer_index"] + topk_indices = nsa_dict["topk_indices"] + topk_indices_local = nsa_dict["topk_indices_local"] + prefill_cache_kv = nsa_dict["prefill_cache_kv"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if self.infer_state.prefix_total_token_num > 0: + kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices( + layer_index=layer_index, + topk_indices=topk_indices, + ) + else: + kv = prefill_cache_kv + topk_indices = topk_indices_local + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32) + if topk_length.ndim == 2 and topk_length.shape[1] == 1: + topk_length = topk_length[:, 0].contiguous() + + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + q=q.contiguous(), + kv=kv.contiguous(), + indices=topk_indices.contiguous(), + sm_scale=softmax_scale, + d_v=kv_lora_rank, + topk_length=topk_length, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + flashmla_sched_meta: object = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + import flash_mla + + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + return + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + + q_nope, q_rope = q + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + + o_tensor, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv.contiguous(), + block_table=None, + cache_seqlens=None, + head_dim_v=kv_lora_rank, + tile_scheduler_metadata=self.flashmla_sched_meta, + num_splits=None, + softmax_scale=softmax_scale, + causal=False, + is_fp8_kvcache=True, + indices=topk_indices.contiguous(), + ) + return o_tensor[:, 0, :, :] diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index bfccc8b48..79e75b348 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -3,6 +3,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager +from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -13,6 +14,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", + "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index fbf9f88c8..66f37a16f 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -34,3 +34,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: def get_att_input_params(self, layer_index: int) -> Any: kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] return kv + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index].view(dtype=torch.uint8)[:, :, -132:] diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py new file mode 100644 index 000000000..301698287 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -0,0 +1,347 @@ +import torch +from typing import Any, List + +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.common.kv_trans_kernel.kv_trans import kv_trans +from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io + +from .deepseek2_mem_manager import Deepseek2MemoryManager + + +class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): + flashmla_bytes_per_token = 656 + indexer_bytes_per_token = 132 + kv_head_dim = 576 + kv_nope_dim = 512 + kv_rope_dim = 64 + quant_group_size = 128 + quant_group_num = kv_nope_dim // quant_group_size + + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" + self.prefill_dtype = dtype + super().__init__( + size, torch.uint8, head_num, self.flashmla_bytes_per_token, layer_num, always_copy, mem_fraction + ) + + def get_cell_size(self): + return self.layer_num * (self.flashmla_bytes_per_token + self.indexer_bytes_per_token) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, size + 1, head_num, self.flashmla_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.indexer_k_buffer = torch.empty( + (layer_num, size + 1, head_num, self.indexer_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( + destindex_copy_kv_flashmla_fp8, + ) + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" + + o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) + o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) + o_rope = self.kv_buffer[layer_index][:, :, 528:].view(torch.bfloat16) + destindex_copy_kv_flashmla_fp8( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + o_nope, + o_scale, + o_rope, + ) + + def get_att_input_params(self, layer_index: int) -> Any: + return self.get_flashmla_kv_cache(layer_index) + + def get_flashmla_kv_cache(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index].view(-1, 1, 1, self.flashmla_bytes_per_token) + + def _dequantize_packed_kv(self, packed_kv: torch.Tensor) -> torch.Tensor: + kv_nope = packed_kv[:, :, : self.kv_nope_dim].view(torch.float8_e4m3fn) + kv_scale = packed_kv[:, :, self.kv_nope_dim : self.kv_nope_dim + self.quant_group_num * 4].view(torch.float32) + kv_rope = packed_kv[:, :, self.kv_nope_dim + self.quant_group_num * 4 :].view(torch.bfloat16) + + kv_nope = kv_nope.view(-1, 1, self.quant_group_num, self.quant_group_size).to(self.prefill_dtype) + kv_scale = kv_scale.to(self.prefill_dtype).unsqueeze(-1) + kv_nope = (kv_nope * kv_scale).view(-1, 1, self.kv_nope_dim) + + kv = torch.empty( + (packed_kv.shape[0], packed_kv.shape[1], self.kv_head_dim), + dtype=self.prefill_dtype, + device=packed_kv.device, + ) + kv[:, :, : self.kv_nope_dim] = kv_nope + kv[:, :, self.kv_nope_dim :] = kv_rope.to(self.prefill_dtype) + return kv + + def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor: + return self._dequantize_packed_kv(self.kv_buffer[layer_index]) + + def get_prefill_kv_cache_and_remap_indices(self, layer_index: int, topk_indices: torch.Tensor): + squeeze_h_kv = topk_indices.ndim == 2 + if squeeze_h_kv: + topk_indices = topk_indices.unsqueeze(1) + + valid_mask = topk_indices != -1 + valid_indices = topk_indices[valid_mask] + + if valid_indices.numel() == 0: + empty_kv = torch.empty( + (0, 1, self.kv_head_dim), + dtype=self.prefill_dtype, + device=topk_indices.device, + ) + remapped = topk_indices.clone() + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return empty_kv, remapped + + unique_mem_index, inverse = torch.unique(valid_indices, sorted=False, return_inverse=True) + packed_kv = self.kv_buffer[layer_index].index_select(0, unique_mem_index.to(torch.int64)) + compact_kv = self._dequantize_packed_kv(packed_kv) + + remapped = torch.full_like(topk_indices, -1) + remapped[valid_mask] = inverse.to(remapped.dtype) + + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return compact_kv, remapped + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + return self.indexer_k_buffer[layer_index] + + def alloc_kv_move_buffer(self, max_req_total_len): + self.kv_move_buffer = torch.empty( + (1, max_req_total_len + 8, self.head_num, self.flashmla_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.indexer_k_move_buffer = torch.empty( + (1, max_req_total_len + 8, self.head_num, self.indexer_bytes_per_token), dtype=torch.uint8, device="cuda" + ) + self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") + self.token_dim_size = self.flashmla_bytes_per_token + self.indexer_token_dim_size = self.indexer_bytes_per_token + return + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.flashmla_bytes_per_token), + dtype=torch.uint8, + device="cuda", + ) + self.indexer_k_paged_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.indexer_bytes_per_token), + dtype=torch.uint8, + device="cuda", + ) + self._buffer_mem_indexes_tensors = [ + torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) + ] + return self.kv_move_buffer + + def write_mem_to_page_kv_move_buffer( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_world_size: int, + ): + cur_page = self.kv_move_buffer[page_index] + cur_indexer_page = self.indexer_k_paged_move_buffer[page_index] + pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] + pin_mem_indexes.numpy()[:] = mem_indexes + mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + mla_page_io(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, kv_buffer=dp_mems[0].kv_buffer, mode="write") + mla_page_io( + mem_indexes=mem_indexes_gpu, + page_tensor=cur_indexer_page, + kv_buffer=dp_mems[0].indexer_k_buffer, + mode="write", + ) + + def read_page_kv_move_buffer_to_mem( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_world_size: int, + ): + cur_page = self.kv_move_buffer[page_index] + cur_indexer_page = self.indexer_k_paged_move_buffer[page_index] + pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] + pin_mem_indexes.numpy()[:] = mem_indexes + mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + for mem in dp_mems: + mla_page_io(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, kv_buffer=mem.kv_buffer, mode="read") + mla_page_io( + mem_indexes=mem_indexes_gpu, + page_tensor=cur_indexer_page, + kv_buffer=mem.indexer_k_buffer, + mode="read", + ) + + def send_to_decode_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) + cur_device_index = self.kv_buffer.get_device() + cur_mem = mem_managers[cur_device_index] + for layer_index in range(cur_mem.layer_num): + nccl_comm.send(self._get_main_move_data(move_token_indexes, layer_index), dst=1) + nccl_comm.send(self._get_indexer_move_data(move_token_indexes, layer_index), dst=1) + + def receive_from_prefill_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) + + cur_device_index = self.kv_buffer.get_device() + token_num = len(move_token_indexes) + main_buffer = self.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view( + 1, token_num, self.head_num, self.flashmla_bytes_per_token + ) + indexer_buffer = self.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + 1, token_num, self.head_num, self.indexer_bytes_per_token + ) + for i, mem in enumerate(mem_managers): + for layer_index in range(mem.layer_num): + nccl_comm.recv(main_buffer, src=0) + nccl_comm.recv(indexer_buffer, src=0) + if i == cur_device_index: + mem._write_main_move_data(move_token_indexes, main_buffer, layer_index) + mem._write_indexer_move_data(move_token_indexes, indexer_buffer, layer_index) + else: + new_main = mem.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view(main_buffer.shape) + new_indexer = mem.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + indexer_buffer.shape + ) + from torch.cuda import comm + + comm.broadcast(main_buffer, out=[new_main]) + comm.broadcast(indexer_buffer, out=[new_indexer]) + mem._write_main_move_data(move_token_indexes, new_main, layer_index) + mem._write_indexer_move_data(move_token_indexes, new_indexer, layer_index) + + def send_to_decode_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + for layer_index in range(self.layer_num): + nccl_comm.send(self._get_main_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer), dst=1) + nccl_comm.send( + self._get_indexer_move_data_p2p(move_token_indexes, layer_index, self.indexer_k_move_buffer), dst=1 + ) + + def receive_from_prefill_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["FP8PerTokenGroupQuantDeepseek3_2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, + ): + assert dp_size_in_node == 1 + move_token_indexes = [] + for task in move_tasks: + if task.move_kv_len != 0: + move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + token_num = len(move_token_indexes) + main_buffer = self.kv_move_buffer.view(-1)[0 : self.token_dim_size * token_num].view( + token_num, self.head_num, self.flashmla_bytes_per_token + ) + indexer_buffer = self.indexer_k_move_buffer.view(-1)[0 : self.indexer_token_dim_size * token_num].view( + token_num, self.head_num, self.indexer_bytes_per_token + ) + for mem in mem_managers: + for layer_index in range(mem.layer_num): + nccl_comm.recv(main_buffer, src=0) + nccl_comm.recv(indexer_buffer, src=0) + mem._write_main_move_data_p2p(move_token_indexes, main_buffer, layer_index) + mem._write_indexer_move_data_p2p(move_token_indexes, indexer_buffer, layer_index) + + def _get_main_move_data(self, token_indexes: List[int], layer_index: int): + move_size = self.token_dim_size * len(token_indexes) + move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( + 1, len(token_indexes), self.head_num, self.flashmla_bytes_per_token + ) + move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] + return move_buffer + + def _get_indexer_move_data(self, token_indexes: List[int], layer_index: int): + move_size = self.indexer_token_dim_size * len(token_indexes) + move_buffer = self.indexer_k_move_buffer.view(-1)[0:move_size].view( + 1, len(token_indexes), self.head_num, self.indexer_bytes_per_token + ) + move_buffer[:, :, :, :] = self.indexer_k_buffer[layer_index, token_indexes, :, :] + return move_buffer + + def _write_main_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor + + def _write_indexer_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + self.indexer_k_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor + + def _get_main_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): + move_token_num = len(token_indexes) + move_size = self.token_dim_size * move_token_num + move_buffer = kv_move_buffer.view(-1)[0:move_size].view( + move_token_num, self.head_num, self.flashmla_bytes_per_token + ) + kv_trans(self.kv_buffer[layer_index], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]) + return move_buffer + + def _get_indexer_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): + move_token_num = len(token_indexes) + move_size = self.indexer_token_dim_size * move_token_num + move_buffer = kv_move_buffer.view(-1)[0:move_size].view( + move_token_num, self.head_num, self.indexer_bytes_per_token + ) + kv_trans( + self.indexer_k_buffer[layer_index], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] + ) + return move_buffer + + def _write_main_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + move_token_num = len(token_indexes) + kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) + + def _write_indexer_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index: int): + move_token_num = len(token_indexes) + kv_trans( + buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.indexer_k_buffer[layer_index], token_indexes + ) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 36ca8646a..79ea44879 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -24,7 +24,12 @@ def select_mem_manager_class(): from lightllm.models import Deepseek3_2TpPartModel if issubclass(model_class, Deepseek3_2TpPartModel): - mem_class = Deepseek3_2MemoryManager + if get_env_start_args().llm_kv_type == "fp8kv_dsa": + from . import FP8PerTokenGroupQuantDeepseek3_2MemoryManager + + mem_class = FP8PerTokenGroupQuantDeepseek3_2MemoryManager + else: + mem_class = Deepseek3_2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class @@ -55,4 +60,9 @@ def select_mem_manager_class(): @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, FP8StaticPerHeadQuantMemManager, FP8StaticPerTensorQuantMemManager] + return mem_class in [ + PPLINT8KVMemoryManager, + PPLINT4KVMemoryManager, + FP8StaticPerHeadQuantMemManager, + FP8StaticPerTensorQuantMemManager, + ] diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b00612017..c3fa32e25 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,12 +1,11 @@ import torch -from typing import Union +from typing import Any from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl -from lightllm.common.basemodel.attention.nsa import NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks @@ -76,12 +75,13 @@ def _context_attention_kernel( # 计算 topk_indices att_state = infer_state.prefill_att_state - topk_indices = self.indexer.get_indices( + topk_indices_local, topk_indices = self.indexer.get_indices( hidden_states=infer_state.get_topk_indices_params["hidden_states"], q_lora=infer_state.get_topk_indices_params["q_lora"], infer_state=infer_state, att_state=att_state, layer_weight=layer_weight, + return_local_index=True, ) del infer_state.get_topk_indices_params @@ -89,7 +89,10 @@ def _context_attention_kernel( att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ + "layer_index": self.layer_num_, "topk_indices": topk_indices, + "topk_indices_local": topk_indices_local, + "prefill_cache_kv": kv, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, }, @@ -129,6 +132,7 @@ def _token_attention_kernel( att_control = AttControl( nsa_decode=True, nsa_decode_dict={ + "layer_index": self.layer_num_, "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, @@ -168,8 +172,9 @@ def get_indices( hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek2InferStateInfo, - att_state: Union[NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState], + att_state: Any, layer_weight: Deepseek3_2TransformerLayerWeight, + return_local_index: bool = False, ) -> torch.Tensor: q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) @@ -195,7 +200,7 @@ def get_indices( K_fp8=k_fp8, K_scale=k_scale, DestLoc=infer_state.mem_index, - O_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + O_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -211,7 +216,7 @@ def get_indices( mtp_step = get_env_start_args().mtp_step # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - I_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + I_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), b_seq_len=infer_state.b_seq_len, b_req_idx=infer_state.b_req_idx, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, @@ -233,15 +238,18 @@ def get_indices( row_starts=ks, ) b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1) + local_topk_index = b_topk_index # 将 topk index 转化为 mem index from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index - b_topk_index = trans_topk_index_to_mem_index( - topk_index=b_topk_index, + b_topk_mem_index = trans_topk_index_to_mem_index( + topk_index=local_topk_index, ragged_mem_index=att_state.ragged_mem_index, ) - return b_topk_index + if return_local_index: + return local_topk_index, b_topk_mem_index + return b_topk_mem_index @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py new file mode 100644 index 000000000..9e90050e1 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py @@ -0,0 +1,163 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _quant_scale(max_nope, fp8_max): + return tl.exp2(tl.ceil(tl.log2(tl.maximum(max_nope / fp8_max, 1e-4)))) + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_flashmla_fp8( + KV_nope, + KV_rope, + Dest_loc, + O_nope, + O_scale, + O_rope, + stride_kv_nope_bs, + stride_kv_nope_h, + stride_kv_nope_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_o_nope_bs, + stride_o_nope_h, + stride_o_nope_d, + stride_o_scale_bs, + stride_o_scale_h, + stride_o_scale_d, + stride_o_rope_bs, + stride_o_rope_h, + stride_o_rope_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + BLOCK_DMODEL_NOPE: tl.constexpr, + BLOCK_DMODEL_ROPE: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + offs_rope = tl.arange(0, BLOCK_DMODEL_ROPE) + + # This kernel is only used by the DeepSeek-V3.2 DSA FP8 path, which + # stores a single MQA-style KV head per token. Keep all accesses 1-D so + # Triton treats per-tile scales as scalars instead of 1-element blocks. + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_rope + + kv_rope = tl.load(kv_rope_ptrs) + + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_rope + tl.store(o_rope_ptrs, kv_rope) + + num_tiles = BLOCK_DMODEL_NOPE // GROUP_SIZE + for tile_idx in range(0, num_tiles): + offs_tile = tile_idx * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + kv_nope_tile_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_tile + kv_nope_tile = tl.load(kv_nope_tile_ptrs) + max_nope = tl.max(tl.abs(kv_nope_tile), axis=0) + kv_scale = _quant_scale(max_nope, FP8_MAX) + kv_nope_fp8 = tl.clamp(kv_nope_tile / kv_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + o_nope_ptrs = ( + O_nope + + dest_index * stride_o_nope_bs + + (tile_idx * GROUP_SIZE) * stride_o_nope_d + + tl.arange(0, GROUP_SIZE) * stride_o_nope_d + ) + tl.store(o_nope_ptrs, kv_nope_fp8) + + o_scale_ptrs = O_scale + dest_index * stride_o_scale_bs + tile_idx * stride_o_scale_d + tl.store(o_scale_ptrs, kv_scale.to(tl.float32)) + return + + +@torch.no_grad() +def destindex_copy_kv_flashmla_fp8( + KV_nope: torch.Tensor, + KV_rope: torch.Tensor, + DestLoc: torch.Tensor, + O_nope: torch.Tensor, + O_scale: torch.Tensor, + O_rope: torch.Tensor, +): + seq_len = DestLoc.shape[0] + kv_nope_head_dim = KV_nope.shape[2] + kv_rope_head_dim = KV_rope.shape[2] + + assert kv_nope_head_dim == 512, f"Expected kv_nope_head_dim=512, got {kv_nope_head_dim}" + assert kv_rope_head_dim == 64, f"Expected kv_rope_head_dim=64, got {kv_rope_head_dim}" + assert O_nope.shape[2] == 512 + assert O_scale.shape[2] == 4 + assert O_rope.shape[2] == 64 + + _fwd_kernel_destindex_copy_kv_flashmla_fp8[(seq_len,)]( + KV_nope, + KV_rope, + DestLoc, + O_nope, + O_scale, + O_rope, + KV_nope.stride(0), + KV_nope.stride(1), + KV_nope.stride(2), + KV_rope.stride(0), + KV_rope.stride(1), + KV_rope.stride(2), + O_nope.stride(0), + O_nope.stride(1), + O_nope.stride(2), + O_scale.stride(0), + O_scale.stride(1), + O_scale.stride(2), + O_rope.stride(0), + O_rope.stride(1), + O_rope.stride(2), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + BLOCK_DMODEL_NOPE=512, + BLOCK_DMODEL_ROPE=64, + GROUP_SIZE=128, + num_warps=4, + num_stages=1, + ) + return + + +def pack_kv_reference(kv: torch.Tensor) -> torch.Tensor: + assert kv.shape[-1] == 576 + kv = kv.view(-1, 1, 576) + kv_nope = kv[:, :, :512].contiguous().view(-1, 512) + kv_rope = kv[:, :, 512:].contiguous().view(-1, 64) + out = torch.empty((kv.shape[0], 1, 656), dtype=torch.uint8, device=kv.device) + out_nope = out[:, :, :512].view(torch.float8_e4m3fn).view(-1, 1, 512) + out_scale = out[:, :, 512:528].view(torch.float32).view(-1, 1, 4) + out_rope = out[:, :, 528:].view(torch.bfloat16).view(-1, 1, 64) + out_rope.copy_(kv_rope.view(-1, 1, 64)) + for tile_idx in range(4): + start = tile_idx * 128 + end = start + 128 + tile = kv_nope[:, start:end] + scale = torch.pow(2, torch.clamp_min(tile.abs().amax(dim=-1).float() / 448.0, 1e-4).log2().ceil()) + out_scale[:, 0, tile_idx] = scale + out_nope[:, 0, start:end] = (tile.float() / scale[:, None]).to(torch.float8_e4m3fn) + return out + + +def dequantize_kv_reference(packed: torch.Tensor) -> torch.Tensor: + packed = packed.view(-1, 1, 656) + out = torch.empty((packed.shape[0], 1, 576), dtype=torch.bfloat16, device=packed.device) + packed_nope = packed[:, :, :512].view(torch.float8_e4m3fn).view(-1, 1, 512) + packed_scale = packed[:, :, 512:528].view(torch.float32).view(-1, 1, 4) + packed_rope = packed[:, :, 528:].view(torch.bfloat16).view(-1, 1, 64) + out[:, :, 512:].copy_(packed_rope) + for tile_idx in range(4): + start = tile_idx * 128 + end = start + 128 + out[:, :, start:end] = ( + packed_nope[:, :, start:end].to(torch.float32) * packed_scale[:, :, tile_idx : tile_idx + 1] + ) + return out diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d32da8097..ad1826757 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -363,13 +363,15 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--llm_kv_type", type=str, - choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"], + choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"], default="None", help="""kv type used in llm, None for dtype that llm used in config.json. fp8kv_sph: use float8_e4m3fn to store kv cache for inference, quant way is static per head kv quant. fp8kv_spt: use float8_e4m3fn to store kv cache for inference, quant way is static per tensor kv quant. + fp8kv_dsa: use DeepSeek-V3.2 DSA-specific FlashMLA FP8 sparse KV cache, + intended for the deepseek_v32 model path. fp8kv_sph and fp8kv_spt requires --kv_quant_calibration_config_path to load pre-computed FP8 scales. Note: fp8kv_spt requires flashinfer-python>=0.6.5 (default is 0.6.3, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 37c022f3a..4daf0bbe3 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -125,7 +125,9 @@ class StartArgs: vit_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) - llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]}) + llm_kv_type: str = field( + default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} + ) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field(