Skip to content

Commit a505f8c

Browse files
author
钮圣虓
committed
perf
1 parent d150428 commit a505f8c

3 files changed

Lines changed: 56 additions & 7 deletions

File tree

lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,20 @@ def _nsa_prefill_att(
7676
nsa_dict = att_control.nsa_prefill_dict
7777
layer_index = nsa_dict["layer_index"]
7878
topk_indices = nsa_dict["topk_indices"]
79+
topk_indices_local = nsa_dict["topk_indices_local"]
80+
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
7981
softmax_scale = nsa_dict["softmax_scale"]
8082
kv_lora_rank = nsa_dict["kv_lora_rank"]
8183

82-
kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index)
84+
if self.infer_state.prefix_total_token_num > 0:
85+
kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices(
86+
layer_index=layer_index,
87+
topk_indices=topk_indices,
88+
)
89+
else:
90+
kv = prefill_cache_kv
91+
topk_indices = topk_indices_local
92+
8393
if topk_indices.ndim == 2:
8494
topk_indices = topk_indices.unsqueeze(1)
8595

lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def get_att_input_params(self, layer_index: int) -> Any:
6363
def get_flashmla_kv_cache(self, layer_index: int) -> torch.Tensor:
6464
return self.kv_buffer[layer_index].view(-1, 1, 1, self.flashmla_bytes_per_token)
6565

66-
def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor:
67-
packed_kv = self.kv_buffer[layer_index]
66+
def _dequantize_packed_kv(self, packed_kv: torch.Tensor) -> torch.Tensor:
6867
kv_nope = packed_kv[:, :, : self.kv_nope_dim].view(torch.float8_e4m3fn)
6968
kv_scale = packed_kv[:, :, self.kv_nope_dim : self.kv_nope_dim + self.quant_group_num * 4].view(torch.float32)
7069
kv_rope = packed_kv[:, :, self.kv_nope_dim + self.quant_group_num * 4 :].view(torch.bfloat16)
@@ -82,6 +81,39 @@ def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor:
8281
kv[:, :, self.kv_nope_dim :] = kv_rope.to(self.prefill_dtype)
8382
return kv
8483

84+
def get_prefill_kv_cache(self, layer_index: int) -> torch.Tensor:
85+
return self._dequantize_packed_kv(self.kv_buffer[layer_index])
86+
87+
def get_prefill_kv_cache_and_remap_indices(self, layer_index: int, topk_indices: torch.Tensor):
88+
squeeze_h_kv = topk_indices.ndim == 2
89+
if squeeze_h_kv:
90+
topk_indices = topk_indices.unsqueeze(1)
91+
92+
valid_mask = topk_indices != -1
93+
valid_indices = topk_indices[valid_mask]
94+
95+
if valid_indices.numel() == 0:
96+
empty_kv = torch.empty(
97+
(0, 1, self.kv_head_dim),
98+
dtype=self.prefill_dtype,
99+
device=topk_indices.device,
100+
)
101+
remapped = topk_indices.clone()
102+
if squeeze_h_kv:
103+
remapped = remapped.squeeze(1)
104+
return empty_kv, remapped
105+
106+
unique_mem_index, inverse = torch.unique(valid_indices, sorted=False, return_inverse=True)
107+
packed_kv = self.kv_buffer[layer_index].index_select(0, unique_mem_index.to(torch.int64))
108+
compact_kv = self._dequantize_packed_kv(packed_kv)
109+
110+
remapped = torch.full_like(topk_indices, -1)
111+
remapped[valid_mask] = inverse.to(remapped.dtype)
112+
113+
if squeeze_h_kv:
114+
remapped = remapped.squeeze(1)
115+
return compact_kv, remapped
116+
85117
def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
86118
return self.indexer_k_buffer[layer_index]
87119

lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ def _context_attention_kernel(
7575

7676
# 计算 topk_indices
7777
att_state = infer_state.prefill_att_state
78-
topk_indices = self.indexer.get_indices(
78+
topk_indices_local, topk_indices = self.indexer.get_indices(
7979
hidden_states=infer_state.get_topk_indices_params["hidden_states"],
8080
q_lora=infer_state.get_topk_indices_params["q_lora"],
8181
infer_state=infer_state,
8282
att_state=att_state,
8383
layer_weight=layer_weight,
84+
return_local_index=True,
8485
)
8586
del infer_state.get_topk_indices_params
8687

@@ -90,6 +91,8 @@ def _context_attention_kernel(
9091
nsa_prefill_dict={
9192
"layer_index": self.layer_num_,
9293
"topk_indices": topk_indices,
94+
"topk_indices_local": topk_indices_local,
95+
"prefill_cache_kv": kv,
9396
"softmax_scale": self.softmax_scale,
9497
"kv_lora_rank": self.kv_lora_rank,
9598
},
@@ -171,6 +174,7 @@ def get_indices(
171174
infer_state: Deepseek2InferStateInfo,
172175
att_state: Any,
173176
layer_weight: Deepseek3_2TransformerLayerWeight,
177+
return_local_index: bool = False,
174178
) -> torch.Tensor:
175179

176180
q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight)
@@ -234,15 +238,18 @@ def get_indices(
234238
row_starts=ks,
235239
)
236240
b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1)
241+
local_topk_index = b_topk_index
237242
# 将 topk index 转化为 mem index
238243
from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index
239244

240-
b_topk_index = trans_topk_index_to_mem_index(
241-
topk_index=b_topk_index,
245+
b_topk_mem_index = trans_topk_index_to_mem_index(
246+
topk_index=local_topk_index,
242247
ragged_mem_index=att_state.ragged_mem_index,
243248
)
244249

245-
return b_topk_index
250+
if return_local_index:
251+
return local_topk_index, b_topk_mem_index
252+
return b_topk_mem_index
246253

247254
@staticmethod
248255
def _rotate_activation(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)