Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.16.0
ARG FLASH_MLA_REF=47c35a7
ARG TARGETPLATFORM
ARG ENABLE_DEEPEP=1
ARG ENABLE_NIXL=1
Expand Down Expand Up @@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -U pip
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
cd /root/FlashMLA && \
git checkout ${FLASH_MLA_REF} && \
git submodule update --init --recursive && \
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To reduce the final Docker image size, it's a good practice to clean up build-time dependencies and source files within the same RUN layer. After installing FlashMLA, the cloned repository at /root/FlashMLA is no longer needed and can be removed.

    FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . && rm -rf /root/FlashMLA


RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*

Expand Down
1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
Expand Down
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -56,6 +57,9 @@
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
"fp8kv_dsa": {
"flashmla_sparse": NsaFlashMlaFp8SparseAttBackend,
},
}


Expand Down
8 changes: 8 additions & 0 deletions lightllm/common/basemodel/attention/nsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
NsaFlashMlaSparsePrefillAttState,
NsaFlashMlaSparseDecodeAttState,
)
from .fp8_flashmla_sparse import (
NsaFlashMlaFp8SparseAttBackend,
NsaFlashMlaFp8SparsePrefillAttState,
NsaFlashMlaFp8SparseDecodeAttState,
)

__all__ = [
"NsaFlashMlaSparseAttBackend",
"NsaFlashMlaSparsePrefillAttState",
"NsaFlashMlaSparseDecodeAttState",
"NsaFlashMlaFp8SparseAttBackend",
"NsaFlashMlaFp8SparsePrefillAttState",
"NsaFlashMlaFp8SparseDecodeAttState",
]
187 changes: 187 additions & 0 deletions lightllm/common/basemodel/attention/nsa/fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import dataclasses
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file appears to be a duplicate of lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py. The rest of the codebase imports from fp8_flashmla_sparse.py, so this file seems to be unused and can be removed to avoid code duplication and potential confusion.

from typing import TYPE_CHECKING, Tuple

import torch

from lightllm.utils.dist_utils import get_current_device_id

from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState

if TYPE_CHECKING:
from lightllm.common.basemodel.infer_struct import InferStateInfo


class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
device = get_current_device_id()
self.ragged_mem_buffers = [
torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device)
for _ in range(2)
]

def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState":
return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState":
return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state)


@dataclasses.dataclass
class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState):
ks: torch.Tensor = None
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)
from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num,
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
return

def prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
return self._nsa_prefill_att(q=q, att_control=att_control)

def _nsa_prefill_att(
self,
q: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla

nsa_dict = att_control.nsa_prefill_dict
layer_index = nsa_dict["layer_index"]
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

kv = self.infer_state.mem_manager.get_prefill_kv_cache(layer_index)
if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)

topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32)
if topk_length.ndim == 2 and topk_length.shape[1] == 1:
topk_length = topk_length[:, 0].contiguous()

mla_out, _, _ = flash_mla.flash_mla_sparse_fwd(
q=q.contiguous(),
kv=kv.contiguous(),
indices=topk_indices.contiguous(),
sm_scale=softmax_scale,
d_v=kv_lora_rank,
topk_length=topk_length,
)
return mla_out


@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
ks: torch.Tensor = None
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
flashmla_sched_meta: object = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
model = self.backend.model
use_cuda_graph = (
self.infer_state.batch_size <= model.graph_max_batch_size
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
)

if use_cuda_graph:
self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index]
else:
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)

from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.b_seq_len.shape[0],
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
import flash_mla

self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
return

def decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
return self._nsa_decode_att(q=q, kv=k, att_control=att_control)

def _nsa_decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla

nsa_dict = att_control.nsa_decode_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)
assert topk_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1"

q_nope, q_rope = q
q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous()

o_tensor, _ = flash_mla.flash_mla_with_kvcache(
q=q_all,
k_cache=kv.contiguous(),
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
num_splits=None,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_indices.contiguous(),
)
return o_tensor[:, 0, :, :]
Loading
Loading