Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,11 @@ def update_step_context(cls, step_context):
"""Update step context."""

block_num, block_size, *_ = step_context.kv_caches[0][0].shape
is_unpaged_prefill = False
is_prefill_no_cache = False
if not step_context.is_decoding:
is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
if step_context.block_offsets.dtype != torch.int32:
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
if not (step_context.is_decoding or is_unpaged_prefill):
step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0)
if step_context.kv_seqlens.dtype != torch.int32:
step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32)
if step_context.q_seqlens.dtype != torch.int32:
Expand All @@ -175,7 +173,7 @@ def get_total_slots():
cls.total_slots = cls.total_slots.view(block_num, block_size)
return cls.total_slots

def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
def get_cpu_seqlens(is_decoding, is_prefill_no_cache):
"""Get sequence lengths on CPU.

Returns:
Expand All @@ -187,37 +185,43 @@ def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
"""
if is_decoding:
q_seqlens_cpu = None
kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu()
elif is_unpaged_prefill:
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
elif is_prefill_no_cache:
q_seqlens_cpu = step_context.q_seqlens.cpu()
kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu
kv_seqlens_cpu = q_seqlens_cpu
else:
q_seqlens_cpu = step_context.q_seqlens.cpu()
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
# Expand kv_seqlens to per-token for paged prefill attention
kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0)
return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded
return q_seqlens_cpu, kv_seqlens_cpu

def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):
def get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None, kv_seqlens_cpu=None):
if is_decoding:
q_seqlens_list, kv_seqlens_list = None, None
elif is_unpaged_prefill:
elif is_prefill_no_cache:
q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist()
else:
q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist()
return q_seqlens_list, kv_seqlens_list

def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None):
def get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list=None, kv_seqlens_list=None):
if is_decoding:
max_q_seq_len, max_kv_seq_len = 1, None
elif is_unpaged_prefill:
elif is_prefill_no_cache:
max_q_seq_len = max_kv_seq_len = max(q_seqlens_list)
else:
max_q_seq_len = max(q_seqlens_list)
max_kv_seq_len = max(kv_seqlens_list)
return max_q_seq_len, max_kv_seq_len

def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list,
def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None):
if is_decoding:
batch_size = step_context.q_seqlens.size(0)
return torch.arange(1, batch_size + 1, dtype=torch.int32)
elif is_prefill_no_cache:
return q_seqlens_cpu
return q_seqlens_cpu.cumsum(dim=0)

def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list,
Comment thread
jinminxi104 marked this conversation as resolved.
max_q_seq_len, max_kv_seq_len):
kv_start_indices, attention_mask = [], []
if is_decoding:
Expand All @@ -236,25 +240,17 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)

if not is_unpaged_prefill:
single_attention_mask = torch.triu(
torch.ones(q_seq_len,
step_context.block_offsets.shape[1] * block_size,
dtype=torch.bool,
device=step_context.block_offsets.device),
diagonal=kv_seq_len - q_seq_len + 1,
)
attention_mask.append(single_attention_mask)

if is_unpaged_prefill:
if is_prefill_no_cache:
attention_mask.append(
torch.triu(torch.ones(max_q_seq_len,
max_kv_seq_len,
dtype=step_context.kv_caches[0][0].dtype,
device=step_context.block_offsets.device),
diagonal=max_kv_seq_len - max_q_seq_len + 1))
else:
attention_mask = [torch.cat(attention_mask)]
attention_mask.append(
torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device),
diagonal=1))
Comment thread
jinminxi104 marked this conversation as resolved.

kv_start_indices = torch.cat(kv_start_indices)

Expand Down Expand Up @@ -357,16 +353,16 @@ def get_moe_group_name(group):
group_name = backend.get_hccl_comm_name(local_rank)
return group_name

q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
is_unpaged_prefill)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu,
kv_seqlens_cpu)
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list)
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
is_unpaged_prefill, q_seqlens_list,
is_prefill_no_cache, q_seqlens_list,
kv_seqlens_list, max_q_seq_len,
max_kv_seq_len)
q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu)

if not cls.enable_graph and step_context.kv_quant_policy == 8:
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
Expand All @@ -387,13 +383,11 @@ def get_moe_group_name(group):
step_context.block_offsets,
q_start_loc=None,
q_seqlens=q_seqlens_cpu,
# kv_seqlens_expanded is only expanded in paged prefill,
# otherwise it equals kv_seqlens_cpu
kv_seqlens=kv_seqlens_expanded,
kv_seqlens=kv_seqlens_cpu,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=attention_mask,
is_unpaged_prefill=is_unpaged_prefill,
is_prefill_no_cache=is_prefill_no_cache,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
quant_policy=step_context.kv_quant_policy,
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
kv_start_indices: Tensor | None = None
block_size: int = 64
attention_mask: Sequence[Tensor] = tuple()
is_unpaged_prefill: bool | None = None
is_prefill_no_cache: bool | None = None
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
quant_meta: dict = None
Expand Down Expand Up @@ -79,7 +79,7 @@ def forward(
kv_start_indices = attn_metadata.kv_start_indices
block_size = attn_metadata.block_size
attn_mask = attn_metadata.attention_mask
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
is_prefill_no_cache = attn_metadata.is_prefill_no_cache
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
quant_bits = attn_metadata.quant_policy
Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(
v_head_size=self.v_head_size,
attn_mask=attn_mask,
softmax_scale=self.scale,
is_unpaged_prefill=is_unpaged_prefill,
is_prefill_no_cache=is_prefill_no_cache,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_total_slots():
kv_start_indices = []
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape

is_unpaged_prefill = False
is_prefill_no_cache = False
q_start_loc = step_context.q_start_loc
q_seqlens = step_context.q_seqlens
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
Expand All @@ -74,7 +74,7 @@ def get_total_slots():
q_seqlens_list = step_context.q_seqlens.tolist()
kv_seqlens_list = step_context.kv_seqlens.tolist()
if not step_context.is_decoding:
is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
is_prefill_no_cache = q_seqlens_list == kv_seqlens_list
# get kv_indices
for i in range(q_start_loc.size(0)):
q_seq_len = q_seqlens_list[i]
Expand All @@ -86,7 +86,7 @@ def get_total_slots():
slots = slot_tables[history_length:kv_seq_len]
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)
if not is_unpaged_prefill:
if not is_prefill_no_cache:
cu_seq_lens_kv = torch.cat((torch.tensor([0], device=kv_seqlens.device), kv_seqlens.cumsum(0))).int()
else:
# collect kv_start_indices without using a for-loop,
Expand All @@ -108,7 +108,7 @@ def get_total_slots():
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=None,
is_unpaged_prefill=is_unpaged_prefill,
is_prefill_no_cache=is_prefill_no_cache,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
)
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def get_total_slots():
kv_start_indices, attention_mask = [], []
block_num, block_size, _, _ = step_context.kv_caches[0][1].shape

is_unpaged_prefill = False
is_prefill_no_cache = False
if not step_context.is_decoding:
is_unpaged_prefill = \
is_prefill_no_cache = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
q_start_loc = step_context.q_start_loc
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_total_slots():
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=attention_mask,
is_unpaged_prefill=is_unpaged_prefill,
is_prefill_no_cache=is_prefill_no_cache,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
)
Expand Down
11 changes: 7 additions & 4 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def prefill_attention(
head_size_v: int,
attn_mask: Sequence[Tensor | None],
softmax_scale: float | None,
is_unpaged_prefill: bool | None,
is_prefill_no_cache: bool | None,
kv_scales: Tensor | None,
kv_zeros: Tensor | None,
quant_bits: int | None,
) -> Tensor:
if is_unpaged_prefill:
if is_prefill_no_cache:
return ext_ops.prefill_attention(
query_states,
key_states,
Expand Down Expand Up @@ -79,6 +79,7 @@ def paged_token_attention(
k_cache,
v_cache,
attn_output,
q_seqlens,
kv_seq_len,
max_kv_seq_len,
block_offsets,
Expand All @@ -97,6 +98,7 @@ def paged_token_attention(
v_cache,
block_offsets,
block_size,
q_seqlens,
kv_seq_len,
max_kv_seq_len,
num_q_heads,
Expand Down Expand Up @@ -131,7 +133,7 @@ def paged_attention_fwd(
v_head_size: int,
attn_mask: Sequence[Tensor | None] = (),
softmax_scale: float | None = None,
is_unpaged_prefill: bool | None = None,
is_prefill_no_cache: bool | None = None,
kv_scales: Tensor | None = None,
kv_zeros: Tensor | None = None,
quant_bits: int | None = 0,
Expand All @@ -157,7 +159,7 @@ def paged_attention_fwd(
v_head_size,
attn_mask,
softmax_scale,
is_unpaged_prefill,
is_prefill_no_cache,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
Expand All @@ -168,6 +170,7 @@ def paged_attention_fwd(
key_cache,
value_cache,
attn_output,
q_seqlens,
kv_seqlens,
max_kv_seq_len,
block_offsets,
Expand Down
Loading