-
Notifications
You must be signed in to change notification settings - Fork 313
Deepseekv3.2 #1246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
blueswhen
wants to merge
4
commits into
main
Choose a base branch
from
deepseekv3.2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Deepseekv3.2 #1246
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| import dataclasses | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from typing import TYPE_CHECKING, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from lightllm.utils.dist_utils import get_current_device_id | ||
|
|
||
| from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState | ||
|
|
||
| if TYPE_CHECKING: | ||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||
|
|
||
|
|
||
| class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): | ||
| def __init__(self, model): | ||
| super().__init__(model=model) | ||
| device = get_current_device_id() | ||
| self.ragged_mem_buffers = [ | ||
| torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) | ||
| for _ in range(2) | ||
| ] | ||
|
|
||
| def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": | ||
| return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) | ||
|
|
||
| def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": | ||
| return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): | ||
| ks: torch.Tensor = None | ||
| ke: torch.Tensor = None | ||
| lengths: torch.Tensor = None | ||
| ragged_mem_index: torch.Tensor = None | ||
|
|
||
| def init_state(self): | ||
| self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend | ||
| self.ragged_mem_index = torch.empty( | ||
| self.infer_state.total_token_num, | ||
| dtype=torch.int32, | ||
| device=get_current_device_id(), | ||
| ) | ||
| from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke | ||
|
|
||
| self.ks, self.ke, self.lengths = gen_nsa_ks_ke( | ||
| b_seq_len=self.infer_state.b_seq_len, | ||
| b_q_seq_len=self.infer_state.b_q_seq_len, | ||
| b_req_idx=self.infer_state.b_req_idx, | ||
| req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, | ||
| q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, | ||
| ragged_mem_index=self.ragged_mem_index, | ||
| hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, | ||
| ) | ||
| return | ||
|
|
||
| def prefill_att( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| att_control: AttControl = AttControl(), | ||
| alloc_func=torch.empty, | ||
| ) -> torch.Tensor: | ||
| assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" | ||
| assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" | ||
| return self._nsa_prefill_att(q=q, att_control=att_control) | ||
|
|
||
| def _nsa_prefill_att( | ||
| self, | ||
| q: torch.Tensor, | ||
| att_control: AttControl, | ||
| ) -> torch.Tensor: | ||
| import flash_mla | ||
|
|
||
| nsa_dict = att_control.nsa_prefill_dict | ||
| layer_index = nsa_dict["layer_index"] | ||
| topk_indices = nsa_dict["topk_indices"] | ||
| softmax_scale = nsa_dict["softmax_scale"] | ||
| kv_lora_rank = nsa_dict["kv_lora_rank"] | ||
|
|
||
| kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index) | ||
| if topk_indices.ndim == 2: | ||
| topk_indices = topk_indices.unsqueeze(1) | ||
|
|
||
| topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32) | ||
| if topk_length.ndim == 2 and topk_length.shape[1] == 1: | ||
| topk_length = topk_length[:, 0].contiguous() | ||
|
|
||
| mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( | ||
| q=q.contiguous(), | ||
| kv=kv.contiguous(), | ||
| indices=topk_indices.contiguous(), | ||
| sm_scale=softmax_scale, | ||
| d_v=kv_lora_rank, | ||
| topk_length=topk_length, | ||
| ) | ||
| return mla_out | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): | ||
| ks: torch.Tensor = None | ||
| ke: torch.Tensor = None | ||
| lengths: torch.Tensor = None | ||
| ragged_mem_index: torch.Tensor = None | ||
| flashmla_sched_meta: object = None | ||
|
|
||
| def init_state(self): | ||
| self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend | ||
| model = self.backend.model | ||
| use_cuda_graph = ( | ||
| self.infer_state.batch_size <= model.graph_max_batch_size | ||
| and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch | ||
| ) | ||
|
|
||
| if use_cuda_graph: | ||
| self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] | ||
| else: | ||
| self.ragged_mem_index = torch.empty( | ||
| self.infer_state.total_token_num, | ||
| dtype=torch.int32, | ||
| device=get_current_device_id(), | ||
| ) | ||
|
|
||
| from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke | ||
|
|
||
| self.ks, self.ke, self.lengths = gen_nsa_ks_ke( | ||
| b_seq_len=self.infer_state.b_seq_len, | ||
| b_q_seq_len=self.infer_state.b_q_seq_len, | ||
| b_req_idx=self.infer_state.b_req_idx, | ||
| req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, | ||
| q_token_num=self.infer_state.b_seq_len.shape[0], | ||
| ragged_mem_index=self.ragged_mem_index, | ||
| hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, | ||
| ) | ||
| import flash_mla | ||
|
|
||
| self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() | ||
| return | ||
|
|
||
| def decode_att( | ||
| self, | ||
| q: Tuple[torch.Tensor, torch.Tensor], | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| att_control: AttControl = AttControl(), | ||
| alloc_func=torch.empty, | ||
| ) -> torch.Tensor: | ||
| assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" | ||
| assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" | ||
| return self._nsa_decode_att(q=q, kv=k, att_control=att_control) | ||
|
|
||
| def _nsa_decode_att( | ||
| self, | ||
| q: Tuple[torch.Tensor, torch.Tensor], | ||
| kv: torch.Tensor, | ||
| att_control: AttControl, | ||
| ) -> torch.Tensor: | ||
| import flash_mla | ||
|
|
||
| nsa_dict = att_control.nsa_decode_dict | ||
| topk_indices = nsa_dict["topk_indices"] | ||
| softmax_scale = nsa_dict["softmax_scale"] | ||
| kv_lora_rank = nsa_dict["kv_lora_rank"] | ||
|
|
||
| if topk_indices.ndim == 2: | ||
| topk_indices = topk_indices.unsqueeze(1) | ||
| assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" | ||
|
|
||
| q_nope, q_rope = q | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() | ||
|
|
||
| o_tensor, _ = flash_mla.flash_mla_with_kvcache( | ||
| q=q_all, | ||
| k_cache=kv.contiguous(), | ||
| block_table=None, | ||
| cache_seqlens=None, | ||
| head_dim_v=kv_lora_rank, | ||
| tile_scheduler_metadata=self.flashmla_sched_meta, | ||
| num_splits=None, | ||
| softmax_scale=softmax_scale, | ||
| causal=False, | ||
| is_fp8_kvcache=True, | ||
| indices=topk_indices.contiguous(), | ||
| ) | ||
| return o_tensor[:, 0, :, :] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce the final Docker image size, it's a good practice to clean up build-time dependencies and source files within the same
RUNlayer. After installingFlashMLA, the cloned repository at/root/FlashMLAis no longer needed and can be removed.