Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 152 additions & 28 deletions utils/bench_serving/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Pool, cpu_count
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -88,6 +89,64 @@ class BenchmarkMetrics:
percentiles_e2el_ms: List[Tuple[float, float]]


# --- Multiprocessing helpers for sample_random_requests ---
_worker_tokenizer = None


def _init_tokenizer_worker(tokenizer_id, tokenizer_mode, trust_remote_code):
"""Initialize tokenizer once per worker process."""
global _worker_tokenizer
_worker_tokenizer = get_tokenizer(
tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
)


def _process_prompt_chunk(chunk_args):
"""Generate a chunk of random prompts in a worker process."""
(indices, prefix_token_ids, input_lens, output_lens, offsets,
prefix_len, vocab_size, use_chat_template, seed) = chunk_args

rng = np.random.RandomState(seed)
tokenizer = _worker_tokenizer

results = []
for local_idx, global_idx in enumerate(indices):
tgt_prompt_len = prefix_len + input_lens[local_idx]
prompt_token_ids = prefix_token_ids + [
(offsets[local_idx] + global_idx + j) % vocab_size
for j in range(input_lens[local_idx])
]
prompt = tokenizer.decode(prompt_token_ids)

max_retries = 10
for _ in range(max_retries):
prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(prompt_token_ids) < tgt_prompt_len:
num_extras = tgt_prompt_len - len(prompt_token_ids)
prompt_token_ids.extend(
rng.randint(0, vocab_size, size=num_extras).tolist())
elif len(prompt_token_ids) > tgt_prompt_len:
prompt_token_ids = prompt_token_ids[:tgt_prompt_len]
else:
break
prompt = tokenizer.decode(prompt_token_ids)

if use_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
mismatch = prompt_len - tgt_prompt_len
results.append((prompt, prompt_len, output_lens[local_idx], None, mismatch))

return results


def sample_random_requests(
prefix_len: int,
input_len: int,
Expand All @@ -96,8 +155,13 @@ def sample_random_requests(
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
use_chat_template: bool = False,
tokenizer_id: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
num_workers: int = 0,
) -> List[Tuple[str, int, int]]:
prefix_token_ids = np.random.randint(0, tokenizer.vocab_size, size=prefix_len).tolist()
vocab_size = tokenizer.vocab_size
prefix_token_ids = np.random.randint(0, vocab_size, size=prefix_len).tolist()

if use_chat_template:
chat_template_dummy = tokenizer.apply_chat_template(
Expand All @@ -117,37 +181,85 @@ def sample_uniform(seq_len):

input_lens = sample_uniform(input_len)
output_lens = sample_uniform(output_len)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)

input_requests = []
mismatches = []
for i in range(num_prompts):
tgt_prompt_len = prefix_len + input_lens[i]
prompt_token_ids = prefix_token_ids + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]
prompt = tokenizer.decode(prompt_token_ids)

max_retries = 10
for _ in range(max_retries):
prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(prompt_token_ids) < tgt_prompt_len:
num_extras = tgt_prompt_len - len(prompt_token_ids)
prompt_token_ids.extend(np.random.randint(0, tokenizer.vocab_size, size=num_extras).tolist())
elif len(prompt_token_ids) > tgt_prompt_len:
prompt_token_ids = prompt_token_ids[:tgt_prompt_len]
else:
offsets = np.random.randint(0, vocab_size, size=num_prompts)

# Decide whether to use multiprocessing
if num_workers <= 0:
num_workers = min(cpu_count() or 1, 8)
use_parallel = num_workers > 1 and tokenizer_id is not None

if use_parallel:
# Split work into chunks, one per worker
chunk_size = (num_prompts + num_workers - 1) // num_workers
chunk_args_list = []
for w in range(num_workers):
start = w * chunk_size
end = min(start + chunk_size, num_prompts)
if start >= num_prompts:
break
chunk_args_list.append((
list(range(start, end)),
prefix_token_ids,
input_lens[start:end],
output_lens[start:end],
offsets[start:end].tolist(),
prefix_len,
vocab_size,
use_chat_template,
int(np.random.randint(0, 2**31)),
))

actual_workers = len(chunk_args_list)
print(f"Generating {num_prompts} prompts using {actual_workers} worker processes...")
t0 = time.perf_counter()
with Pool(
processes=actual_workers,
initializer=_init_tokenizer_worker,
initargs=(tokenizer_id, tokenizer_mode, trust_remote_code),
) as pool:
chunk_results = pool.map(_process_prompt_chunk, chunk_args_list)

input_requests = []
mismatches = []
for chunk in chunk_results:
for prompt, prompt_len, out_len, mm_content, mismatch in chunk:
input_requests.append((prompt, prompt_len, out_len, mm_content))
mismatches.append(mismatch)
elapsed = time.perf_counter() - t0
print(f"Prompt generation completed in {elapsed:.1f}s")
else:
# Original serial path
if tokenizer_id is None and num_workers > 1:
print("Warning: tokenizer_id not provided, falling back to serial prompt generation.")
input_requests = []
mismatches = []
for i in range(num_prompts):
tgt_prompt_len = prefix_len + input_lens[i]
prompt_token_ids = prefix_token_ids + [(offsets[i] + i + j) % vocab_size for j in range(input_lens[i])]
prompt = tokenizer.decode(prompt_token_ids)

if use_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
max_retries = 10
for _ in range(max_retries):
prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(prompt_token_ids) < tgt_prompt_len:
num_extras = tgt_prompt_len - len(prompt_token_ids)
prompt_token_ids.extend(np.random.randint(0, vocab_size, size=num_extras).tolist())
elif len(prompt_token_ids) > tgt_prompt_len:
prompt_token_ids = prompt_token_ids[:tgt_prompt_len]
else:
break
prompt = tokenizer.decode(prompt_token_ids)

prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
mismatches.append(prompt_len - tgt_prompt_len)
input_requests.append((prompt, prompt_len, output_lens[i], None))
if use_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)

prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False))
mismatches.append(prompt_len - tgt_prompt_len)
input_requests.append((prompt, prompt_len, output_lens[i], None))

header_str = f'{"-"*19} Input/Output Length Statistics {"-"*19}'
print(header_str)
Expand Down Expand Up @@ -663,6 +775,10 @@ def main(args: argparse.Namespace):
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
use_chat_template=args.use_chat_template,
tokenizer_id=tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code,
num_workers=args.random_num_workers,
)

else:
Expand Down Expand Up @@ -1023,6 +1139,14 @@ def main(args: argparse.Namespace):
action="store_true",
help="Use chat template to format the prompt.",
)
random_group.add_argument(
'--random-num-workers',
type=int,
default=0,
help="Number of worker processes for parallel random prompt generation. "
"Only used with --dataset-name random. "
"0 (default) = auto (min(cpu_count, 8)). 1 = serial (no multiprocessing).",
)

hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
Expand Down