Skip to content
Merged
3 changes: 3 additions & 0 deletions dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def paged_decode_attention(
value_cache: Tensor,
block_table: Tensor,
block_size: int,
q_seq_len: Tensor,
kv_seq_len: Tensor,
max_kv_seq_len: int,
num_q_heads: int,
Expand All @@ -292,6 +293,7 @@ def paged_decode_attention(
block_table (Tensor): A tensor that maps each position in the query sequence to the corresponding
block in the key/value cache.
block_size (int): The size of each block in the input sequence.
q_seq_len (Tensor): The length of query sequence.
kv_seq_len (Tensor): The length of each key/value sequence.
max_kv_seq_len (int): The maximum length of any key/value sequence.
num_q_heads (int): The number of query heads.
Expand All @@ -313,6 +315,7 @@ def paged_decode_attention(
value_cache,
block_table,
block_size,
q_seq_len,
kv_seq_len,
max_kv_seq_len,
num_q_heads,
Expand Down
20 changes: 4 additions & 16 deletions dlinfer/vendor/ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def decode_attention(
scale_value: float,
block_table: Tensor,
block_size: int,
q_seq_len: Tensor,
kv_seq_len: Tensor,
softmax_scale: float,
attn_output: Tensor,
Expand Down Expand Up @@ -57,12 +58,11 @@ def decode_attention(
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
elif AscendGraphRunner.capturing:
else:
bs, _, dim = query.shape
block_num = key_cache.size(0)
query = query.contiguous()
attn_output = attn_output.contiguous()
query = query.view(bs, 1, num_q_heads * dim)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)
scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(dim)
Expand All @@ -73,27 +73,15 @@ def decode_attention(
value=value_cache,
atten_mask=None,
block_table=block_table,
input_layout="BSH",
input_layout="TND",
block_size=block_size,
actual_seq_lengths=None,
actual_seq_lengths=q_seq_len,
actual_seq_lengths_kv=kv_seq_len,
num_key_value_heads=num_kv_heads,
num_heads=num_q_heads,
scale=scale_value,
sparse_mode=0,
)
else:
torch.ops.atb._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_q_heads,
scale_value=scale_value,
block_table=block_table,
context_lens=kv_seq_len,
out=attn_output,
)
return attn_output


Expand Down
11 changes: 6 additions & 5 deletions dlinfer/vendor/ascend/torch_npu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def paged_decode_attention(
value_cache: Tensor,
block_table: Optional[Tensor],
block_size: int,
q_seq_len: Tensor,
kv_seq_len: Tensor,
max_kv_seq_len: int,
num_q_heads: int,
Expand Down Expand Up @@ -383,6 +384,7 @@ def paged_decode_attention(
scale_value=scale_value,
block_table=block_table,
block_size=block_size,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
softmax_scale=softmax_scale,
attn_output=attn_output,
Expand Down Expand Up @@ -432,26 +434,25 @@ def paged_prefill_attention(
)

scale_value = softmax_scale if softmax_scale else 1.0 / math.sqrt(query.shape[-1])
query = query.contiguous().view(query.shape[0], 1, -1)
query = query.contiguous()
block_num = key_cache.size(0)
key_cache = key_cache.view(block_num, block_size, -1)
value_cache = value_cache.view(block_num, block_size, -1)

# Note: actual_seq_lengths is not set here because the default query sequence
# length per batch is 1, which matches our paged prefill phase assumption.
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
query=query,
key=key_cache,
value=value_cache,
atten_mask=attn_mask[0],
block_table=block_table,
input_layout="BSH",
input_layout="TND",
block_size=block_size,
actual_seq_lengths=q_seq_len,
actual_seq_lengths_kv=kv_seq_len,
num_key_value_heads=num_kv_heads,
num_heads=num_q_heads,
scale=scale_value,
sparse_mode=0,
sparse_mode=3,
)

return attn_output
Expand Down
1 change: 1 addition & 0 deletions dlinfer/vendor/camb/camb_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def paged_decode_attention(
value_cache: Tensor,
block_table: Optional[Tensor],
block_size: int,
q_seq_len: Tensor,
kv_seq_len: Tensor,
max_kv_seq_len: int,
num_q_heads: int,
Expand Down
1 change: 1 addition & 0 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def paged_decode_attention(
value_cache: Tensor,
block_table: Optional[Tensor],
block_size: int,
q_seq_len: Tensor,
kv_seq_len: Tensor,
max_kv_seq_len: int,
num_q_heads: int,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_lmdeploy/e2e/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,19 @@ pytorch_vl_model:
- internvl_model/InternVL2-26B
- internvl_model/InternVL3-8B
- qwen_model/Qwen2-VL-7B-Instruct

mixed_prefill_precision:
model_case: qwen_model/Qwen3-30B-A3B-Instruct-2507
tp: 2
session_len: 8192
max_batch_size: 8
max_prefill_token_num: 2048
prompt_token_lengths:
- 3000
- 4000
- 5000
- 200
- 200
- 200
- 200
- 200
238 changes: 238 additions & 0 deletions tests/test_lmdeploy/e2e/test_mixed_prefill_precision_tp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import gc
import os
from pathlib import Path

import pytest
import torch

import dlinfer
from lmdeploy import GenerationConfig, PytorchEngineConfig, Tokenizer, pipeline

ANSWER_TAG_SEEDS = (
"OK_314159",
"OK_271828",
"OK_161803",
"OK_141421",
"OK_173205",
"OK_223606",
"OK_244949",
"OK_264575",
)


def _resolve_case_config(config):
case_config = config.get("mixed_prefill_precision")
if case_config is None:
raise ValueError(
"missing mixed_prefill_precision config in test_lmdeploy/e2e/config.yaml"
)
prompt_token_lengths = case_config.get("prompt_token_lengths")
if not isinstance(prompt_token_lengths, list) or len(prompt_token_lengths) == 0:
raise ValueError(
"mixed_prefill_precision.prompt_token_lengths must be a non-empty list"
)
if len(prompt_token_lengths) > len(ANSWER_TAG_SEEDS):
raise ValueError(
f"prompt_token_lengths length {len(prompt_token_lengths)} exceeds supported answer tag count {len(ANSWER_TAG_SEEDS)}"
)
if case_config["max_batch_size"] < len(prompt_token_lengths):
raise ValueError(
"max_batch_size must be greater than or equal to the number of prompts"
)
return case_config


def _resolve_model_path(model_root, model_case):
model_path = os.path.join(model_root, model_case)
if not Path(model_path).exists():
raise FileNotFoundError(f"failed to locate model path: {model_path}")
return model_path


def _token_len(tokenizer, text):
return len(tokenizer.encode(text, add_bos=False))


def _build_prompt(tokenizer, target_len, answer_tag):
prefix = "下面是一些背景片段,你只需要阅读,不要总结。\n" "背景开始:\n"
suffix = (
"\n背景结束。\n"
"请完成下面任务。\n"
"你必须遵守以下规则:\n"
"1. 不要思考过程。\n"
"2. 不要解释。\n"
"3. 不要重复题目。\n"
f"4. 最终答案必须且只能是:{answer_tag}\n"
"答案:"
)
filler_unit = "这是一段用于混合prefill精度回归测试的背景信息。"
fine_grained_units = [
"补充",
"说明",
"数据",
"样例",
"文本",
"A",
"B",
"C",
"。",
"\n",
]

base_len = _token_len(tokenizer, prefix + suffix)
if base_len > target_len:
raise ValueError(f"base prompt length {base_len} exceeds target {target_len}")

low = 0
high = 1
while _token_len(tokenizer, prefix + filler_unit * high + suffix) <= target_len:
low = high
high *= 2

while low + 1 < high:
mid = (low + high) // 2
prompt = prefix + filler_unit * mid + suffix
if _token_len(tokenizer, prompt) <= target_len:
low = mid
else:
high = mid

filler = filler_unit * low
prompt = prefix + filler + suffix

changed = True
while changed:
changed = False
for unit in fine_grained_units:
candidate = prefix + filler + unit + suffix
if _token_len(tokenizer, candidate) <= target_len:
filler += unit
prompt = candidate
changed = True

return prompt, _token_len(tokenizer, prompt)


def _build_prompts(model_path, case_config):
tokenizer = Tokenizer(model_path)
prompts = []
actual_lengths = []
answer_tags = _build_answer_tags(case_config["prompt_token_lengths"])
for target_len, answer_tag in zip(case_config["prompt_token_lengths"], answer_tags):
prompt, actual_len = _build_prompt(tokenizer, target_len, answer_tag)
prompts.append(prompt)
actual_lengths.append(actual_len)
return prompts, actual_lengths, answer_tags


def _build_answer_tags(prompt_token_lengths):
return tuple(
f"P{idx + 1}_{ANSWER_TAG_SEEDS[idx]}"
for idx in range(len(prompt_token_lengths))
)


def _strip_thinking(text):
final_text = text.strip()
if "</think>" in final_text:
final_text = final_text.split("</think>", 1)[1].strip()
return final_text


@pytest.mark.lmdeploy
@pytest.mark.parametrize(
"eager_mode",
[True, False],
ids=["eager", "graph"],
)
def test_mixed_prefill_precision_tp2(config, eager_mode):
case_config = _resolve_case_config(config)
model_path = _resolve_model_path(config["model_path"], case_config["model_case"])
prompts, actual_lengths, answer_tags = _build_prompts(model_path, case_config)
model_name = Path(case_config["model_case"]).name

backend_config = PytorchEngineConfig(
tp=case_config["tp"],
session_len=case_config["session_len"],
max_batch_size=case_config["max_batch_size"],
max_prefill_token_num=case_config["max_prefill_token_num"],
device_type="ascend",
eager_mode=eager_mode,
)
gen_config = GenerationConfig(
do_sample=False,
top_k=1,
temperature=0.0,
max_new_tokens=96,
random_seed=0,
)

mode_name = "eager" if eager_mode else "graph"
log_dir = config["log_path"]
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(
log_dir,
f"ascend_pipeline_chat_pytorch_{model_name}_mixed_prefill_tp2_{mode_name}.log",
)

pipe = None
try:
pipe = pipeline(model_path, backend_config=backend_config)
responses = pipe(
prompts,
gen_config=gen_config,
chat_template_kwargs={"enable_thinking": False},
)
outputs = [response.text.strip() for response in responses]
final_outputs = [_strip_thinking(output) for output in outputs]

print(
f"[mixed-prefill] model_case={case_config['model_case']} mode={mode_name}"
)
print(f"[mixed-prefill] prompt_token_lengths={actual_lengths}")
for idx, (target, output, final_output) in enumerate(
zip(answer_tags, outputs, final_outputs),
start=1,
):
print(f"[mixed-prefill] prompt_{idx}_target={target}")
print(f"[mixed-prefill] output_{idx}={output}")
print(f"[mixed-prefill] final_output_{idx}={final_output}")
print(f"[mixed-prefill] match_{idx}={target in final_output}")

with open(log_path, "w") as file:
file.writelines(
[
f"model_case: {case_config['model_case']}\n",
f"model_path: {model_path}\n",
f"backend_config: {backend_config}\n",
f"gen_config: {gen_config}\n",
f"prompt_token_lengths: {actual_lengths}\n",
]
)
for idx, (target, actual_len, prompt, output, final_output) in enumerate(
zip(answer_tags, actual_lengths, prompts, outputs, final_outputs),
start=1,
):
file.writelines(
[
f"prompt_{idx}_target: {target}\n",
f"prompt_{idx}_token_len: {actual_len}\n",
f"prompt_{idx}_tail: {prompt[-256:]}\n",
f"output_{idx}: {output}\n",
f"final_output_{idx}: {final_output}\n",
f"match_{idx}: {target in final_output}\n",
]
)

for target, output, final_output in zip(answer_tags, outputs, final_outputs):
assert target in final_output, (
f"mixed prefill precision regression failed in {mode_name} mode: "
f"expected {target} in final output {final_output!r}, raw output {output!r}"
)
finally:
if pipe is not None:
pipe.close()
del pipe
gc.collect()
if hasattr(torch, "npu"):
torch.npu.empty_cache()
Loading
Loading