From 7aaed12f61a00f0196943f7503e327a6a3c05337 Mon Sep 17 00:00:00 2001 From: Duyi-Wang Date: Wed, 15 Apr 2026 17:22:27 +0800 Subject: [PATCH] Multiprocess Parallel Random Data Generation for Benchmark Serving. --- utils/bench_serving/benchmark_serving.py | 180 +++++++++++++++++++---- 1 file changed, 152 insertions(+), 28 deletions(-) diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 38365dbfc..877d1b8c8 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -35,6 +35,7 @@ import warnings from dataclasses import dataclass from datetime import datetime +from multiprocessing import Pool, cpu_count from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np @@ -88,6 +89,64 @@ class BenchmarkMetrics: percentiles_e2el_ms: List[Tuple[float, float]] +# --- Multiprocessing helpers for sample_random_requests --- +_worker_tokenizer = None + + +def _init_tokenizer_worker(tokenizer_id, tokenizer_mode, trust_remote_code): + """Initialize tokenizer once per worker process.""" + global _worker_tokenizer + _worker_tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + ) + + +def _process_prompt_chunk(chunk_args): + """Generate a chunk of random prompts in a worker process.""" + (indices, prefix_token_ids, input_lens, output_lens, offsets, + prefix_len, vocab_size, use_chat_template, seed) = chunk_args + + rng = np.random.RandomState(seed) + tokenizer = _worker_tokenizer + + results = [] + for local_idx, global_idx in enumerate(indices): + tgt_prompt_len = prefix_len + input_lens[local_idx] + prompt_token_ids = prefix_token_ids + [ + (offsets[local_idx] + global_idx + j) % vocab_size + for j in range(input_lens[local_idx]) + ] + prompt = tokenizer.decode(prompt_token_ids) + + max_retries = 10 + for _ in range(max_retries): + prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) + if len(prompt_token_ids) < tgt_prompt_len: + num_extras = tgt_prompt_len - len(prompt_token_ids) + prompt_token_ids.extend( + rng.randint(0, vocab_size, size=num_extras).tolist()) + elif len(prompt_token_ids) > tgt_prompt_len: + prompt_token_ids = prompt_token_ids[:tgt_prompt_len] + else: + break + prompt = tokenizer.decode(prompt_token_ids) + + if use_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False)) + mismatch = prompt_len - tgt_prompt_len + results.append((prompt, prompt_len, output_lens[local_idx], None, mismatch)) + + return results + + def sample_random_requests( prefix_len: int, input_len: int, @@ -96,8 +155,13 @@ def sample_random_requests( range_ratio: float, tokenizer: PreTrainedTokenizerBase, use_chat_template: bool = False, + tokenizer_id: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + num_workers: int = 0, ) -> List[Tuple[str, int, int]]: - prefix_token_ids = np.random.randint(0, tokenizer.vocab_size, size=prefix_len).tolist() + vocab_size = tokenizer.vocab_size + prefix_token_ids = np.random.randint(0, vocab_size, size=prefix_len).tolist() if use_chat_template: chat_template_dummy = tokenizer.apply_chat_template( @@ -117,37 +181,85 @@ def sample_uniform(seq_len): input_lens = sample_uniform(input_len) output_lens = sample_uniform(output_len) - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - - input_requests = [] - mismatches = [] - for i in range(num_prompts): - tgt_prompt_len = prefix_len + input_lens[i] - prompt_token_ids = prefix_token_ids + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])] - prompt = tokenizer.decode(prompt_token_ids) - - max_retries = 10 - for _ in range(max_retries): - prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) - if len(prompt_token_ids) < tgt_prompt_len: - num_extras = tgt_prompt_len - len(prompt_token_ids) - prompt_token_ids.extend(np.random.randint(0, tokenizer.vocab_size, size=num_extras).tolist()) - elif len(prompt_token_ids) > tgt_prompt_len: - prompt_token_ids = prompt_token_ids[:tgt_prompt_len] - else: + offsets = np.random.randint(0, vocab_size, size=num_prompts) + + # Decide whether to use multiprocessing + if num_workers <= 0: + num_workers = min(cpu_count() or 1, 8) + use_parallel = num_workers > 1 and tokenizer_id is not None + + if use_parallel: + # Split work into chunks, one per worker + chunk_size = (num_prompts + num_workers - 1) // num_workers + chunk_args_list = [] + for w in range(num_workers): + start = w * chunk_size + end = min(start + chunk_size, num_prompts) + if start >= num_prompts: break + chunk_args_list.append(( + list(range(start, end)), + prefix_token_ids, + input_lens[start:end], + output_lens[start:end], + offsets[start:end].tolist(), + prefix_len, + vocab_size, + use_chat_template, + int(np.random.randint(0, 2**31)), + )) + + actual_workers = len(chunk_args_list) + print(f"Generating {num_prompts} prompts using {actual_workers} worker processes...") + t0 = time.perf_counter() + with Pool( + processes=actual_workers, + initializer=_init_tokenizer_worker, + initargs=(tokenizer_id, tokenizer_mode, trust_remote_code), + ) as pool: + chunk_results = pool.map(_process_prompt_chunk, chunk_args_list) + + input_requests = [] + mismatches = [] + for chunk in chunk_results: + for prompt, prompt_len, out_len, mm_content, mismatch in chunk: + input_requests.append((prompt, prompt_len, out_len, mm_content)) + mismatches.append(mismatch) + elapsed = time.perf_counter() - t0 + print(f"Prompt generation completed in {elapsed:.1f}s") + else: + # Original serial path + if tokenizer_id is None and num_workers > 1: + print("Warning: tokenizer_id not provided, falling back to serial prompt generation.") + input_requests = [] + mismatches = [] + for i in range(num_prompts): + tgt_prompt_len = prefix_len + input_lens[i] + prompt_token_ids = prefix_token_ids + [(offsets[i] + i + j) % vocab_size for j in range(input_lens[i])] prompt = tokenizer.decode(prompt_token_ids) - if use_chat_template: - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) + max_retries = 10 + for _ in range(max_retries): + prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) + if len(prompt_token_ids) < tgt_prompt_len: + num_extras = tgt_prompt_len - len(prompt_token_ids) + prompt_token_ids.extend(np.random.randint(0, vocab_size, size=num_extras).tolist()) + elif len(prompt_token_ids) > tgt_prompt_len: + prompt_token_ids = prompt_token_ids[:tgt_prompt_len] + else: + break + prompt = tokenizer.decode(prompt_token_ids) - prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False)) - mismatches.append(prompt_len - tgt_prompt_len) - input_requests.append((prompt, prompt_len, output_lens[i], None)) + if use_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer.encode(prompt, add_special_tokens=False)) + mismatches.append(prompt_len - tgt_prompt_len) + input_requests.append((prompt, prompt_len, output_lens[i], None)) header_str = f'{"-"*19} Input/Output Length Statistics {"-"*19}' print(header_str) @@ -663,6 +775,10 @@ def main(args: argparse.Namespace): range_ratio=args.random_range_ratio, tokenizer=tokenizer, use_chat_template=args.use_chat_template, + tokenizer_id=tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + num_workers=args.random_num_workers, ) else: @@ -1023,6 +1139,14 @@ def main(args: argparse.Namespace): action="store_true", help="Use chat template to format the prompt.", ) + random_group.add_argument( + '--random-num-workers', + type=int, + default=0, + help="Number of worker processes for parallel random prompt generation. " + "Only used with --dataset-name random. " + "0 (default) = auto (min(cpu_count, 8)). 1 = serial (no multiprocessing).", + ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset",