-
Notifications
You must be signed in to change notification settings - Fork 270
SpecDec Bench: February Update #875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
📝 WalkthroughWalkthroughThis pull request integrates SPEED-Bench and SpecBench dataset support, introduces new model implementations (AutoDeployModel, SpecBenchMedusaModel), expands the CLI with dataset selection options, updates model and runner signatures for consistency, and adds corresponding metrics and data preparation utilities. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant CLI as run.py
participant Dataset as Dataset Layer
participant Model as Model Instance
participant Metrics as Metrics Collector
participant Output as File Output
User->>CLI: run_simple(args)
CLI->>CLI: Parse args & load tokenizer
CLI->>Dataset: Load dataset (speed/specbench/etc)
Dataset-->>CLI: Return prepared data
CLI->>CLI: Initialize model & metrics
loop For each request in dataset
CLI->>CLI: Encode prompt (chat_template_args)
CLI->>Model: run(prompt_ids, sampling_params)
Model-->>CLI: output_ids, token_times
CLI->>Metrics: process_step(output, request_id)
Metrics-->>Metrics: Aggregate acceptance_rate
end
CLI->>Metrics: process_final(text_outputs)
Metrics->>Metrics: Compute statistics & visualizations
Metrics->>Output: Write results (jsonl, json, png)
Output-->>Metrics: File saved
Metrics-->>CLI: Results processed
CLI-->>User: Benchmark complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
examples/specdec_bench/specdec_bench/metrics/aa_timing.py (1)
49-63:⚠️ Potential issue | 🟡 MinorPotential division by zero when all requests have ≤2 timing entries.
If
gen_tp_timeis empty (all requests havelen(times) <= 2), the code correctly guards withif gen_tp_time:. However, ifself.timingis non-empty but all timing lists share the same start/end time (Line 53),end_time - start_timecould be zero, causing aZeroDivisionErroron Line 53.Proposed guard
end_time = max([t[-1] for t in self.timing]) - self.out["AA Output TPS"] = sum(self.total_tokens) / (end_time - start_time) + elapsed = end_time - start_time + if elapsed > 0: + self.out["AA Output TPS"] = sum(self.total_tokens) / elapsed + else: + self.out["AA Output TPS"] = 0.0examples/specdec_bench/specdec_bench/metrics/base.py (1)
36-48:⚠️ Potential issue | 🟡 Minor
write()lacks error handling for corrupt JSON files.If the existing JSON file is malformed (e.g., from a previous interrupted write),
json.loadwill raise aJSONDecodeErrorand the current metric data inself.outwill be silently lost. Consider catching decode errors and falling back to starting a fresh list.Proposed guard
if self.out: filename = os.path.join(self.directory, f"{self.name}.json") if os.path.exists(filename): - with open(filename, "r") as json_file: - existing_data = json.load(json_file) - existing_data.append(self.out) + try: + with open(filename, "r") as json_file: + existing_data = json.load(json_file) + existing_data.append(self.out) + except (json.JSONDecodeError, IOError): + existing_data = [self.out] else: existing_data = [self.out]examples/specdec_bench/specdec_bench/datasets/base.py (1)
29-31:⚠️ Potential issue | 🟠 Major
output_turn_idsis a class variable, not a dataclass field.Line 30 assigns
output_turn_ids = Nonewithout a type annotation, so the@dataclassdecorator treats it as a class variable shared across all instances — not an instance field. It won't appear in__init__,__repr__, or__eq__. If this is intended as a per-instance field (likeoutput_turn_texton Line 31), it needs a type annotation.Additionally, Line 31 uses lowercase
list[str]while Line 26 usesList[str]— inconsistent style.Proposed fix
# not to be set by user - output_turn_ids = None - output_turn_text: list[str] = field(default_factory=list) + output_turn_ids: Optional[Any] = None + output_turn_text: List[str] = field(default_factory=list)examples/specdec_bench/specdec_bench/models/base.py (1)
21-28:⚠️ Potential issue | 🟠 MajorAPI signature mismatch: base class
run()doesn't match any concrete implementation.The base class now declares
run(self, prompt_ids, sampling_params, request_id, turn_id), but all five concrete model implementations (vllm.py,auto_deploy.py,trtllm_torch_api.py,sglang.py,specbench_medusa.py) and the runner (simple.py) still use the old signaturerun(self, prompt_ids, max_length, end_id, request_id, turn_id). The runner calls the model with arguments in the old order, which will cause a runtime mismatch.Either update all concrete models and callers to use
sampling_paramsinstead ofmax_lengthandend_id, or keep the base class aligned with the current implementations until a full migration is done.examples/specdec_bench/specdec_bench/models/vllm.py (1)
100-105:⚠️ Potential issue | 🟠 Major
ignore_eosis sticky — once set toTrueit's never cleared.
self.sampling_configis a shared mutable instance attribute. Whenend_id == -1,ignore_eosis set toTrueat line 105, but there's noelsebranch to reset it toFalse. All subsequent calls with a validend_idwill still haveignore_eos=True, silently ignoring EOS tokens.Proposed fix
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): output_dict = {} self.sampling_config.max_tokens = max_length self.sampling_config.stop_token_ids = [end_id] if end_id == -1: self.sampling_config.ignore_eos = True + else: + self.sampling_config.ignore_eos = Falseexamples/specdec_bench/specdec_bench/models/sglang.py (1)
97-106:⚠️ Potential issue | 🟠 Major
outputs = Noneleads toTypeErrorifasync_generateyields no chunks.If the async stream produces no chunks,
outputsstaysNoneandend_id == outputs[-1]at line 106 will raiseTypeError: 'NoneType' object is not subscriptable. Initializeoutputsas an empty list and add a guard:Proposed fix
- outputs = None + outputs = [] result = await self.model.async_generate( sampling_params=self.sampling_config, input_ids=prompt_ids, stream=True ) async for chunk in result: timing.append(time.perf_counter()) outputs = chunk["output_ids"] beam_lens[0].append(chunk["meta_info"]["completion_tokens"]) - if end_id == outputs[-1]: + if outputs and end_id == outputs[-1]:examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py (1)
80-81:⚠️ Potential issue | 🔴 CriticalBug:
clear()resetsself.prompt_arto alistinstead of adict.
__init__(line 24) initializesself.prompt_ar = {}, andprocess_step(lines 28-34) uses it as adict(keyed byrequest_id). Afterclear()is called,self.prompt_arbecomes[], and the nextprocess_stepcall will fail when attempting dict-style keyed access on a list.🐛 Proposed fix
def clear(self): - self.prompt_ar = [] + self.prompt_ar = {}
🤖 Fix all issues with AI agents
In `@examples/specdec_bench/README.md`:
- Line 84: The README example references a missing runtime_params file; either
add a new runtime_args_long_context.yaml with sensible runtime keys (e.g.,
max_sequence_length, batch_size, memory_pool_size, device_map, and any
TRTLLM-specific settings) and commit it, or remove the `--runtime_params
runtime_args_long_context.yaml` flag from the example command in the README (the
command shown passed to run.py) so the example works as-is; update README to
reflect the chosen approach and mention where to find or how to customize the
runtime params if you add the YAML.
In `@examples/specdec_bench/run.py`:
- Around line 138-148: When args.dataset == "random" the current uniform lookup
via datasets_available passes args.dataset_path into RandomToken but
RandomToken.__init__ requires a tokenizer; update the branch that handles
args.dataset (the datasets_available lookup) to special-case "random" and
instantiate datasets.RandomToken(tokenizer, args.dataset_path, **dataset_kwargs)
(or remove "random" from datasets_available if you prefer the latter) so
RandomToken is constructed with a tokenizer object rather than a path; ensure
you reference the RandomToken class and datasets_available map when making the
change.
In `@examples/specdec_bench/SPECBENCH_PORTING.md`:
- Around line 20-37: The Model interface's run signature (Model.run(self,
prompt_ids, sampling_params, request_id, turn_id)) conflicts with older examples
that use run(self, prompt_ids, max_length, end_id, request_id, turn_id); pick
the canonical API and make all examples consistent: either update the Step 4
implementation example and the testing example to call/run and implement
Model.run with (prompt_ids, sampling_params, request_id, turn_id) and adjust
their internals and docstrings to read sampling_params (containing
max_length/end_id), or change the interface to the older signature and update
the docstring to match; ensure Model.__init__, Model.run, the Step 4
implementation example, and the test example all use the same parameter names
and types (prompt_ids list, and a single sampling_params or explicit
max_length/end_id) and that returned dict keys and token type expectations
remain unchanged.
In `@examples/specdec_bench/specdec_bench/models/auto_deploy.py`:
- Around line 56-70: The loop assumes generate_async produced outputs and
non-empty beam_lens; add a guard before using outputs[-1] and indexing beam_len:
check if outputs is empty or beam_lens is empty and return an appropriate empty
structure (e.g., reformatted_output_ids or []) early, and inside the for loop
skip/continue any beam_len that is empty to avoid beam_len[0] / beam_len[-1]
access; update references in this block (outputs, beam_lens,
reformatted_output_ids, beam_idx, response) accordingly so no IndexError occurs
when generate_async yields nothing.
In `@examples/specdec_bench/specdec_bench/models/specbench_medusa.py`:
- Around line 38-57: Replace the brittle cwd-based sys.path insertion with a
configurable and robust lookup: read SPEC_BENCH_PATH from an environment
variable first (fallback to a path computed relative to this file using
os.path.dirname(__file__) and joining "..", "Spec-Bench"), and only modify
sys.path if that resolved path exists and is not already present (avoid blindly
inserting at index 0 to prevent module shadowing). Update the import block that
brings in MedusaModel, initialize_past_key_values, initialize_medusa,
generate_medusa_buffers, etc., to use that resolved path so tests/CI/notebook
runs can override via SPEC_BENCH_PATH and to avoid permanent, unsafe sys.path
manipulation.
In `@examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py`:
- Around line 99-107: EagleDecodingConfig.model_fields access can raise
AttributeError and trtllm.__version__ may not exist; change the logic in the
block around EagleDecodingConfig.model_fields to first check safely with
hasattr/getattr (e.g., fields = getattr(EagleDecodingConfig, "model_fields",
None)) and if fields is truthy use the existing membership check to set
extra_params["allow_advanced_sampling"] from kwargs, otherwise fall back to
checking "allow_advanced_sampling" in kwargs and emit a warning that does not
assume trtllm.__version__ (use getattr(trtllm, "__version__", "unknown version")
or omit version), referencing the EagleDecodingConfig.model_fields,
extra_params, kwargs, and trtllm.__version__ symbols so the change is made in
the same function.
In `@examples/specdec_bench/specdec_bench/runners/base.py`:
- Around line 24-25: SimpleRunner.run currently has the old signature; update
its method signature to match BaseRunner.run(prompt_ids, max_length, end_id,
sampling_kwargs), then pull request_id and turn_id from sampling_kwargs (e.g.,
request_id = sampling_kwargs.get("request_id")) before calling model.run and
before calling process_metrics_step; ensure model.run is invoked with those
extracted request_id and turn_id where it previously used the positional args
and pass sampling_kwargs (or remaining sampling params) to the model call as
appropriate so behavior is unchanged.
In `@examples/specdec_bench/specdec_bench/utils.py`:
- Around line 24-32: The function encode_chat uses a mutable default
chat_template_args={} which can lead to shared-state bugs; change the signature
to default chat_template_args=None and inside encode_chat set a local variable
(e.g., args = {} if chat_template_args is None else dict(chat_template_args)) so
callers' dicts aren’t mutated and you still pass args into
tokenizer.apply_chat_template; keep the completions branch behavior unchanged.
🟡 Minor comments (15)
examples/specdec_bench/specdec_bench/__init__.py-1-1 (1)
1-1:⚠️ Potential issue | 🟡 MinorCopyright year appears to have regressed from 2025 to 2024.
Per the AI summary, this line was changed from 2025 to 2024, which moves the copyright year backwards. Given the current date (February 2026), the year should likely be 2025 or 2026, not 2024.
Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.examples/specdec_bench/README.md-44-44 (1)
44-44:⚠️ Potential issue | 🟡 MinorClarify HuggingFace authentication requirement for SPEED-Bench dataset.
The SPEED-Bench dataset URL returns HTTP 401 (Unauthorized), indicating authentication is required. While the referenced models (Llama 3.3 70B and EAGLE3) are publicly accessible, the documentation should explicitly state that users need to authenticate with HuggingFace to access the SPEED-Bench dataset. Update the documentation to mention this requirement, either in the overview or in the data preparation section.
examples/specdec_bench/README.md-6-7 (1)
6-7:⚠️ Potential issue | 🟡 MinorUpdate SGLang docker image reference for specificity. The SGLang docker image in the README references
lmsysorg/sglang:v0.5.7, but the available published docker image tags include cuda variant suffixes (e.g.,lmsysorg/sglang:v0.5.7-cu130-runtime). Clarify which specific variant is required for this benchmark, or document that the user should select the appropriate cuda version for their environment. The vLLM v0.15.0 and TensorRT-LLM 1.3.0rc2 image tags are correctly documented.examples/specdec_bench/specdec_bench/datasets/base.py-43-43 (1)
43-43:⚠️ Potential issue | 🟡 Minor
str | Pathunion syntax requires Python 3.10+ at runtime.Line 43 uses
str | Pathwhich is only valid at runtime in Python 3.10+. The rest of the file usesOptional[...]andList[...]fromtyping, suggesting older Python compatibility is intended. Either addfrom __future__ import annotationsat the top of the file, or useUnion[str, Path].Proposed fix (option A — future annotations)
+from __future__ import annotations from typing import Any, List, Optionalexamples/specdec_bench/requirements.txt-1-4 (1)
1-4:⚠️ Potential issue | 🟡 MinorUpdate
datasetsandrichto latest versions for current releases.All specified versions exist on PyPI. However,
datasets(4.4.0) andrich(14.2.0) have newer releases available:datasets4.5.0 (Jan 2026) andrich14.3.2 (Feb 2026). Update these for the latest patches and features.seaborn0.13.2 andtiktoken0.12.0 are already at their latest versions.examples/specdec_bench/specdec_bench/models/sglang.py-88-89 (1)
88-89:⚠️ Potential issue | 🟡 MinorMisleading docstring.
The docstring says "Synchronous version of run for use with asyncio.to_thread" but the method is
asyncand usesawait self.model.async_generate(...). This docstring appears to be copied from another model.examples/specdec_bench/specdec_bench/models/sglang.py-40-42 (1)
40-42:⚠️ Potential issue | 🟡 Minor
assert Falseis a poor substitute for raising an error.
assert Falsewill be silently removed when Python is run with-O. Useraise NotImplementedError(...)instead:Proposed fix
elif speculative_algorithm == "NGRAM": speculative_algorithm = "LOOKAHEAD" - assert False, "Needs more work" + raise NotImplementedError("NGRAM/LOOKAHEAD speculative decoding is not yet supported for SGLANG")examples/specdec_bench/specdec_bench/models/specbench_medusa.py-100-106 (1)
100-106:⚠️ Potential issue | 🟡 Minor
draft_model_pathdefaults toNoneand is passed directly toMedusaModel.from_pretrained.If
draft_model_diris not supplied in kwargs (line 78),self.draft_model_pathwill beNone, likely causingfrom_pretrainedto fail with an opaque error. Validate early:Proposed fix
self.draft_model_path = kwargs.get("draft_model_dir", None) + if self.draft_model_path is None: + raise ValueError("draft_model_dir is required for SpecBenchMedusaModel")examples/specdec_bench/specdec_bench/models/auto_deploy.py-86-138 (1)
86-138:⚠️ Potential issue | 🟡 MinorNo guard against
LLMbeingNone.If
tensorrt_llmis unavailable,LLMisNoneandLLM(**llm_kwargs)at line 137 will raise a crypticTypeError: 'NoneType' object is not callable. Add an early check for a clear failure message.Proposed fix
def create_auto_deploy_model( model_path: str, max_concurrent_requests: int, kwargs: Dict[str, Any] ): + if LLM is None: + raise ImportError( + "tensorrt_llm._torch.auto_deploy is not installed. Cannot create AutoDeployModel." + ) world_size = kwargs.get("world_size", kwargs.get("tensor_parallel_size", 1))examples/specdec_bench/specdec_bench/models/specbench_medusa.py-160-205 (1)
160-205:⚠️ Potential issue | 🟡 Minor
NameErroronidxifself.max_stepsis 0.If
self.max_steps == 0, theforloop body never executes, leavingidxundefined. Line 205 (idx + 1) then raises aNameError.Proposed fix
new_token = 0 + num_steps = 0 for idx in range(self.max_steps): + num_steps = idx + 1 candidates, tree_candidates = generate_candidates( ... - return input_ids, new_token, idx + 1, accept_length_list, timing + return input_ids, new_token, num_steps, accept_length_list, timingexamples/specdec_bench/specdec_bench/models/auto_deploy.py-19-25 (1)
19-25:⚠️ Potential issue | 🟡 MinorMissing
SamplingParamsfallback when import fails.If
tensorrt_llmis not installed,LLMis set toNonebutSamplingParamsandDraftTargetDecodingConfigare left undefined. Any code path that reachescheck_sampling_config(line 144) orcreate_auto_deploy_modelwithspeculative_algorithm == "DRAFT_TARGET"(line 102) will raise aNameErrorinstead of a clear error message.Proposed fix
except ImportError: print("tensorrt_llm._torch.auto_deploy is not installed.") LLM = None + SamplingParams = None + DraftTargetDecodingConfig = Noneexamples/specdec_bench/specdec_bench/models/specbench_medusa.py-278-293 (1)
278-293:⚠️ Potential issue | 🟡 Minor
stop()can raiseAttributeErrorif__init__failed partway through.If model loading fails after entering
__init__but beforeself.modelis set (e.g.,from_pretrainedraises), callingstop()will hithasattr(self.model, ...)on a non-existent attribute. The existinghasattr(self, "model")check at line 291 covers model deletion, but the earlierhasattr(self.model, ...)calls at lines 281/287 do not first check thatself.modelexists.Proposed fix
def stop(self): """Cleanup resources.""" + if not hasattr(self, "model") or self.model is None: + return # Clear cached KV states to free memory - if hasattr(self.model, "past_key_values"): + if hasattr(self.model, "past_key_values"): del self.model.past_key_values del self.model.past_key_values_data del self.model.current_length_data # Clear medusa buffers if hasattr(self.model, "medusa_buffers"): del self.model.medusa_buffers # Move model to CPU or delete to free GPU memory - if hasattr(self, "model") and self.model is not None: - del self.model - torch.cuda.empty_cache() + del self.model + torch.cuda.empty_cache()examples/specdec_bench/specdec_bench/datasets/speed.py-175-256 (1)
175-256:⚠️ Potential issue | 🟡 Minor
_generate_stackselect_prompt: parameteransweris shadowed by loop variable, andanswers_to_add_stopmay include content beyond the token budget.
Line 237
for i, answer in enumerate(answers)shadows theanswerparameter (line 177), which is the correct answer designation (e.g.,"A3"). After the loop,answerrefers to the last iterated text, not the original parameter. The correct answer index was already captured on line 228, so it works, but it's confusing.
answers_to_add_stopis initialized to0(line 236). If the very first non-correct answer exceeds the token budget and the loop breaks immediately,answers_to_add_stopremains0, andanswers[:1]still includes the first answer—even though it doesn't fit. Consider initializing to-1or adding a guard.Line 179:
random.seed(42)mutates global state. This is a side effect that could affect other code using therandommodule. Consider using a localrandom.Random(42)instance instead.examples/specdec_bench/specdec_bench/metrics/specbench.py-42-44 (1)
42-44:⚠️ Potential issue | 🟡 MinorError message references
seabornbut it's never imported or used.The
ImportErrormessage instructs users to installseaborn, but the actual imports only userich,matplotlib, andpandas. The"seaborn-v0_8"style on line 133 is a built-in matplotlib style and doesn't require theseabornpackage. Either removeseabornfrom the message or add an actualimport seabornif it's intended for future use.Proposed fix
- raise ImportError( - "Please install rich, matplotlib, seaborn, and pandas to use the SpecBench metric" - ) + raise ImportError( + "Please install rich, matplotlib, and pandas to use the SpecBench metric" + )examples/specdec_bench/run.py-59-70 (1)
59-70:⚠️ Potential issue | 🟡 MinorMutable default argument
chat_template_args={}.Using a mutable default (
{}) is a well-known Python pitfall. Althoughchat_template_argsis only read here (passed toencode_chat), it's best to useNoneas the default to avoid any accidental mutation surprises.Proposed fix
async def run_loop( runner, dataset, tokenizer, output_length, postprocess, concurrency=10, end_id=-1, show_progress=False, completions=False, - chat_template_args={}, + chat_template_args=None, ):Then at the top of the function body:
semaphore = asyncio.Semaphore(concurrency) max_length = output_length + if chat_template_args is None: + chat_template_args = {}
🧹 Nitpick comments (18)
examples/specdec_bench/README.md (1)
50-50: Fix markdown formatting issues for better rendering.Multiple markdown formatting issues were detected by linters that should be corrected for consistent rendering:
- Line 50: Add a blank line before the code block
- Line 56: Add a blank line before the
#### Licenseheading- Line 57: Remove trailing space at the end of the line
- Line 66: Add a blank line before the
#### Qualitative split:heading- Line 67: Add a blank line before the code block
- Line 71: Add a blank line before the
#### Throughput split:heading- Line 72: Add a blank line before the code block
- Line 77: Add a blank line before the code block
- Line 92: Add a final newline at the end of the file
📝 Proposed fix for markdown formatting
2. Prepare the data using the provided script: + ```bash python3 prepare_data.py --dataset speed --config allThe data will be saved to
data/directory, each config type (qualitative, throughput_1k, ...) to each own directory.
License
-GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.
+GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1.
NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The
prepare_data.pyscript automatically fetches data from all the source datasets.Additional details are in HuggingFace dataset repository.
Qualitative split:
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/qualitative --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress
Throughput split:
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_1k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progressFor longer context (>8192 tokens), please use the following configuration when using TRTLLM:
+engine_args: max_seq_len: 131072 # Model max context length (for Llama 3.3 70B) enable_chunked_prefill: truepython3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yamlNotes
The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics.
If larger concurrency is needed, it is recommended to fully deploy the model usingvllm serve,python -m sglang.launch_server, ortrtllm-serve(for vLLM, SGlang, or TRTLLM respectively) and
-use a more robust benchmarking client like NVIDIA AI Perf.
\ No newline at end of file
+use a more robust benchmarking client like NVIDIA AI Perf.</details> Also applies to: 56-57, 66-67, 71-72, 77-77, 92-92 </blockquote></details> <details> <summary>examples/specdec_bench/specdec_bench/utils.py (1)</summary><blockquote> `20-21`: **`trust_remote_code=True` enables arbitrary code execution from model repos.** This is a known security trade-off. It's common in ML workflows but worth noting — any model repository could execute arbitrary Python during tokenizer loading. Ensure this is intentional and documented. </blockquote></details> <details> <summary>examples/specdec_bench/specdec_bench/models/auto_deploy.py (2)</summary><blockquote> `37-75`: **`run()` and `check_sampling_config()` are near-identical copies of `TRTLLMPYTModel`.** The entire `run()` method (lines 37–75) and `check_sampling_config()` (lines 141–154) are duplicated almost verbatim from `trtllm_torch_api.py`. Consider extracting the shared streaming-output-reformatting logic and the sampling config builder into a common utility to avoid divergence. --- `77-83`: **`del self.model` doesn't guarantee GPU memory release.** `del self.model` removes the reference but Python's garbage collector may not immediately free GPU resources. Consider calling `torch.cuda.empty_cache()` after deletion (as done in `SpecBenchMedusaModel.stop()`) or invoking any shutdown method the LLM object provides. </blockquote></details> <details> <summary>examples/specdec_bench/specdec_bench/models/specbench_medusa.py (2)</summary><blockquote> `76-76`: **`assert` for input validation can be silently disabled with `python -O`.** Use an explicit `if`/`raise ValueError` instead: <details> <summary>Proposed fix</summary> ```diff - assert max_concurrent_requests == 1, "Only support batch size 1 for now!" + if max_concurrent_requests != 1: + raise ValueError("SpecBenchMedusaModel only supports batch size 1.")
200-203: EOS check converts the tensor to a Python list on every iteration.
input_ids[0, input_len:].tolist()allocates a new list each step, growing linearly with generated tokens. For long sequences this adds overhead. Consider checking only the newly appended tokens or using a tensor-based check:- if end_id in input_ids[0, input_len:].tolist(): + if (input_ids[0, input_len:] == end_id).any(): breakexamples/specdec_bench/specdec_bench/models/vllm.py (1)
33-70: Repeatedkwargs.get("speculative_algorithm", None)calls are verbose but correct.Each branch re-evaluates
kwargs.get(...). The pattern works but is noisy. Consider extracting to a local variable at the top, consistent with howsglang.pydoes it at line 35:speculative_algorithm = kwargs.get("speculative_algorithm", None)examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py (1)
40-50:use_draft_logitsparameter is accepted but never used.The
__init__signature includesuse_draft_logits=False(line 46) but it's not stored or referenced anywhere in the class.examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py (1)
89-89: Yoda condition is unconventional in Python.
"system" == messages[0]["role"]reads less naturally than the standardmessages[0]["role"] == "system". This is a minor style nit.examples/specdec_bench/specdec_bench/datasets/random_token.py (1)
22-26: LGTM — tokenizer is now passed explicitly rather than stored.Clean refactor. Note that
_preprocess(self, tokenizer)diverges from the base class signature_preprocess(self)indatasets/base.py(line 39). This is the same pattern used byMTBench._preprocess(self, path)andSpecBench._preprocess(self, path), so it's consistent within the PR, but the base class definition is now misleading.Consider updating
Dataset._preprocessinbase.pyto accept*args, **kwargs(or remove it entirely) so the contract matches actual usage.Also applies to: 28-28
examples/specdec_bench/specdec_bench/runners/base.py (1)
28-28: List comprehensions used purely for side effects.
process_metrics_final,process_metrics_step, andclear_metricsall use list comprehensions solely for their side effects, which builds and discards a list ofNonevalues. Conventional Python style prefers aforloop here. This is pre-existing, so feel free to defer.♻️ Example for process_metrics_step
def process_metrics_step(self, step_outputs, request_id, turn_id): - [ - metric.process_step(step_outputs, request_id, turn_id) - for metric in self.metrics - ] + for metric in self.metrics: + metric.process_step(step_outputs, request_id, turn_id)Also applies to: 31-34, 37-37
examples/specdec_bench/prepare_data.py (1)
66-71:--configchoices are tightly coupled toSPEEDBench'sconfig_type.The
--configchoices useget_args(config_type)fromspeed.py, which won't be meaningful if other datasets are added todatasets_available. If you plan to support more datasets here, consider making config choices dataset-dependent (e.g., validated insideprepare_data()after the dataset is known, rather than statically in argparse).For now this is fine since only
"speed"is available.examples/specdec_bench/specdec_bench/metrics/specbench.py (2)
65-68: Variablecategory_aris shadowed fromlisttofloat.On line 65
category_aris the loop variable (a list of AR values), but on line 67 it's reassigned to a scalar (the mean). This works but hurts readability—a reader may expectcategory_arto remain a list through the loop body.Suggested clarity improvement
for category_name, category_ar in per_category.items(): if len(category_ar) > 0: - category_ar = mean(category_ar) - self.out["Category_AR"][category_name] = category_ar + self.out["Category_AR"][category_name] = mean(category_ar)
127-130: Remove the "Completely generated by Cursor" comment.Tool-attribution comments are not useful for maintainers and may imply the code was not reviewed. Consider removing it.
examples/specdec_bench/run.py (2)
169-170: VerifySpecBenchmetrics are correctly triggered for--dataset specbench.The condition
args.specbench is not None or args.dataset == "speed"works because lines 349–350 setargs.specbench = args.dataset_pathwhen--dataset specbench. However, this relies on the cross-assignment at lines 349–352 which couples the new--datasetflag to the legacy flags. This works but is fragile—if someone adds a new dataset path without updating the cross-assignment block, metrics selection could silently fall through to the genericAcceptanceRate.Suggested: use args.dataset directly in the metrics condition
- if args.mtbench is not None: - metrics_list.insert(0, metrics.MTBench()) - elif args.specbench is not None or args.dataset == "speed": - metrics_list.insert(0, metrics.SpecBench(requests=dataset.data)) + if args.dataset == "mtbench" or args.mtbench is not None: + metrics_list.insert(0, metrics.MTBench()) + elif args.dataset in ("specbench", "speed") or args.specbench is not None: + metrics_list.insert(0, metrics.SpecBench(requests=dataset.data))
341-352: Cross-assignment ofargs.specbenchandargs.mtbenchfromargs.dataset_pathis a leaky abstraction.These lines (
args.specbench = args.dataset_path,args.mtbench = args.dataset_path) exist solely to make the metrics selection on lines 167–170 work. This couples the new--datasetinterface to the legacy flag-based interface. If the metrics condition is updated per the earlier suggestion to checkargs.datasetdirectly, this cross-assignment can be removed.examples/specdec_bench/specdec_bench/datasets/speed.py (2)
340-344: Redundant re-encoding on line 343.
tokensis already computed on line 338 andcurrent_num_tokensderived from it. Line 343 re-encodes the samepromptstring unnecessarily.Proposed fix
def _pad_or_truncate_prompt( prompt: str, target_num_tokens: int, padding: str = "Answer now please.\n" ) -> str: encoder = tiktoken.get_encoding("o200k_base") tokens = encoder.encode(prompt, disallowed_special=()) current_num_tokens = len(tokens) if current_num_tokens > target_num_tokens: # Truncate if too long - tokens = encoder.encode(prompt, disallowed_special=()) return encoder.decode(tokens[:target_num_tokens])
716-753: Confusing variable nameconfig_name_or_dataset_path_path.The extra
_pathsuffix on line 735 makes it awkward to read. Consider something likedataset_pathorconfig_path.
| ``` | ||
|
|
||
| ```bash | ||
| python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if runtime_args_long_context.yaml exists
if [ -f "examples/specdec_bench/runtime_args_long_context.yaml" ]; then
echo "✓ runtime_args_long_context.yaml exists"
echo "Contents:"
cat examples/specdec_bench/runtime_args_long_context.yaml
else
echo "✗ runtime_args_long_context.yaml not found at examples/specdec_bench/"
# Search for the file in the repository
echo "Searching for runtime_args_long_context.yaml..."
find . -name "runtime_args_long_context.yaml" 2>/dev/null
# Also check what files exist in the examples/specdec_bench directory
echo "Files in examples/specdec_bench/:"
ls -la examples/specdec_bench/ 2>/dev/null | head -20
fiRepository: NVIDIA/Model-Optimizer
Length of output: 724
Remove or provide the missing runtime_args_long_context.yaml file.
The README references --runtime_params runtime_args_long_context.yaml, but this file does not exist in the repository. Either create the file with appropriate configuration values or remove the parameter from the example command.
🤖 Prompt for AI Agents
In `@examples/specdec_bench/README.md` at line 84, The README example references a
missing runtime_params file; either add a new runtime_args_long_context.yaml
with sensible runtime keys (e.g., max_sequence_length, batch_size,
memory_pool_size, device_map, and any TRTLLM-specific settings) and commit it,
or remove the `--runtime_params runtime_args_long_context.yaml` flag from the
example command in the README (the command shown passed to run.py) so the
example works as-is; update README to reflect the chosen approach and mention
where to find or how to customize the runtime params if you add the YAML.
| if args.num_requests is not None: | ||
| dataset_kwargs["num_samples"] = args.num_requests | ||
| if args.dataset is not None: | ||
| dataset = datasets_available[args.dataset](args.dataset_path, **dataset_kwargs) | ||
| else: | ||
| if args.mtbench is not None: | ||
| dataset = datasets.MTBench(args.mtbench, **dataset_kwargs) | ||
| elif args.random_isl is not None: | ||
| dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs) | ||
| elif args.specbench is not None: | ||
| dataset = datasets.SpecBench(args.specbench, **dataset_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--dataset random will break: RandomToken expects a tokenizer, not a path.
When args.dataset == "random", line 141 calls datasets.RandomToken(args.dataset_path, **dataset_kwargs), but RandomToken.__init__ expects a tokenizer object as its first argument (see the legacy path on line 146 which correctly passes tokenizer). This will fail at runtime.
Consider either special-casing "random" in the --dataset path (similar to lines 145–146), or removing it from datasets_available since it can't be used through the uniform interface.
🤖 Prompt for AI Agents
In `@examples/specdec_bench/run.py` around lines 138 - 148, When args.dataset ==
"random" the current uniform lookup via datasets_available passes
args.dataset_path into RandomToken but RandomToken.__init__ requires a
tokenizer; update the branch that handles args.dataset (the datasets_available
lookup) to special-case "random" and instantiate datasets.RandomToken(tokenizer,
args.dataset_path, **dataset_kwargs) (or remove "random" from datasets_available
if you prefer the latter) so RandomToken is constructed with a tokenizer object
rather than a path; ensure you reference the RandomToken class and
datasets_available map when making the change.
| ```python | ||
| class Model: | ||
| def __init__(self, model_dir, tokenizer, max_draft_length): | ||
| raise NotImplementedError | ||
|
|
||
| async def run(self, prompt_ids, sampling_params, request_id, turn_id): | ||
| """ | ||
| prompt_ids: list of token IDs (not a tensor!) | ||
| Returns dict with: | ||
| - output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]] | ||
| - output_logits: optional logits (usually None) | ||
| - token_times: list of timestamps per decoding step | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def stop(self): | ||
| pass | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interface definition contradicts implementation examples within the same document.
The Model interface at Line 25 declares run(self, prompt_ids, sampling_params, request_id, turn_id), but the Step 4 implementation example (Line 170) and the testing example (Lines 305-310) both use the old signature run(self, prompt_ids, max_length, end_id, request_id, turn_id). This inconsistency will confuse anyone following the porting guide.
Align all examples with whichever signature is chosen as the canonical API.
🤖 Prompt for AI Agents
In `@examples/specdec_bench/SPECBENCH_PORTING.md` around lines 20 - 37, The Model
interface's run signature (Model.run(self, prompt_ids, sampling_params,
request_id, turn_id)) conflicts with older examples that use run(self,
prompt_ids, max_length, end_id, request_id, turn_id); pick the canonical API and
make all examples consistent: either update the Step 4 implementation example
and the testing example to call/run and implement Model.run with (prompt_ids,
sampling_params, request_id, turn_id) and adjust their internals and docstrings
to read sampling_params (containing max_length/end_id), or change the interface
to the older signature and update the docstring to match; ensure Model.__init__,
Model.run, the Step 4 implementation example, and the test example all use the
same parameter names and types (prompt_ids list, and a single sampling_params or
explicit max_length/end_id) and that returned dict keys and token type
expectations remain unchanged.
| reformatted_output_ids = [ | ||
| [] for _ in range(self.sampling_kwargs.get("beam_width", 1)) | ||
| ] | ||
| for beam_idx, beam_len in enumerate(beam_lens): | ||
| response = outputs[-1][beam_idx] | ||
| if beam_len[0] != 0: | ||
| reformatted_output_ids[beam_idx].append( | ||
| response.token_ids[: beam_len[0]] | ||
| ) | ||
| for s, e in zip(beam_len[:-1], beam_len[1:]): | ||
| reformatted_output_ids[beam_idx].append(response.token_ids[s:e]) | ||
| if len(response.token_ids) > beam_len[-1]: | ||
| reformatted_output_ids[beam_idx].append( | ||
| response.token_ids[beam_len[-1] :] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IndexError if generate_async yields no output.
If generate_async completes without yielding any items, outputs is empty and outputs[-1] at line 60 will raise an IndexError. Similarly, beam_len lists will be empty, causing beam_len[0] at line 61 to crash. Add an early return or guard for this edge case.
Proposed fix
+ if not outputs:
+ output_dict["output_ids"] = [
+ [] for _ in range(self.sampling_kwargs.get("beam_width", 1))
+ ]
+ output_dict["output_logits"] = None
+ output_dict["token_times"] = timing
+ return output_dict
+
reformatted_output_ids = [
[] for _ in range(self.sampling_kwargs.get("beam_width", 1))
]🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/auto_deploy.py` around lines 56 -
70, The loop assumes generate_async produced outputs and non-empty beam_lens;
add a guard before using outputs[-1] and indexing beam_len: check if outputs is
empty or beam_lens is empty and return an appropriate empty structure (e.g.,
reformatted_output_ids or []) early, and inside the for loop skip/continue any
beam_len that is empty to avoid beam_len[0] / beam_len[-1] access; update
references in this block (outputs, beam_lens, reformatted_output_ids, beam_idx,
response) accordingly so no IndexError occurs when generate_async yields
nothing.
| import sys | ||
| import os | ||
|
|
||
| spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench") | ||
| sys.path.insert(0, spec_bench_path) | ||
| from model.medusa.medusa_model import MedusaModel | ||
| from model.medusa.kv_cache import initialize_past_key_values | ||
| from model.medusa.utils import ( | ||
| generate_medusa_buffers, | ||
| initialize_medusa, | ||
| reset_medusa_mode, | ||
| generate_candidates, | ||
| tree_decoding, | ||
| evaluate_posterior, | ||
| update_inference_inputs, | ||
| ) | ||
| from model.medusa.medusa_choices import mc_sim_7b_63 | ||
| except ImportError as e: | ||
| print(f"Medusa dependencies not found: {e}") | ||
| MedusaModel = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sys.path manipulation with a hardcoded CWD-relative path is fragile.
os.path.join(os.getcwd(), "Spec-Bench") at line 41 assumes the process is always launched from the repo root. This breaks when the working directory differs (e.g., CI, notebooks, installed package). It also permanently prepends to sys.path, risking module-name shadowing (e.g., the generic model package on line 43).
Consider making the Spec-Bench path configurable via an environment variable or constructor argument.
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/specbench_medusa.py` around lines
38 - 57, Replace the brittle cwd-based sys.path insertion with a configurable
and robust lookup: read SPEC_BENCH_PATH from an environment variable first
(fallback to a path computed relative to this file using
os.path.dirname(__file__) and joining "..", "Spec-Bench"), and only modify
sys.path if that resolved path exists and is not already present (avoid blindly
inserting at index 0 to prevent module shadowing). Update the import block that
brings in MedusaModel, initialize_past_key_values, initialize_medusa,
generate_medusa_buffers, etc., to use that resolved path so tests/CI/notebook
runs can override via SPEC_BENCH_PATH and to avoid permanent, unsafe sys.path
manipulation.
| extra_params = {} | ||
| if "allow_advanced_sampling" in EagleDecodingConfig.model_fields: | ||
| extra_params["allow_advanced_sampling"] = kwargs.get( | ||
| "allow_advanced_sampling", False | ||
| ) | ||
| elif "allow_advanced_sampling" in kwargs: | ||
| print( | ||
| f"WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version: {trtllm.__version__}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EagleDecodingConfig.model_fields access can raise AttributeError.
If EagleDecodingConfig doesn't have a model_fields attribute (e.g., older tensorrt_llm version without Pydantic v2), line 100 will raise AttributeError before the elif fallback at line 104 is ever reached. Additionally, line 106 references trtllm.__version__ on tensorrt_llm.bindings.executor, which may not expose a __version__ attribute.
Proposed fix
- if "allow_advanced_sampling" in EagleDecodingConfig.model_fields:
+ if hasattr(EagleDecodingConfig, "model_fields") and "allow_advanced_sampling" in EagleDecodingConfig.model_fields:
extra_params["allow_advanced_sampling"] = kwargs.get(
"allow_advanced_sampling", False
)
elif "allow_advanced_sampling" in kwargs:
print(
- f"WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version: {trtllm.__version__}"
+ "WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version."
)🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py` around lines
99 - 107, EagleDecodingConfig.model_fields access can raise AttributeError and
trtllm.__version__ may not exist; change the logic in the block around
EagleDecodingConfig.model_fields to first check safely with hasattr/getattr
(e.g., fields = getattr(EagleDecodingConfig, "model_fields", None)) and if
fields is truthy use the existing membership check to set
extra_params["allow_advanced_sampling"] from kwargs, otherwise fall back to
checking "allow_advanced_sampling" in kwargs and emit a warning that does not
assume trtllm.__version__ (use getattr(trtllm, "__version__", "unknown version")
or omit version), referencing the EagleDecodingConfig.model_fields,
extra_params, kwargs, and trtllm.__version__ symbols so the change is made in
the same function.
| async def run(self, prompt_ids, max_length, end_id, sampling_kwargs): | ||
| raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find all files in the runners directory
find examples/specdec_bench/specdec_bench/runners -name "*.py" -type f | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 230
🏁 Script executed:
#!/bin/bash
# Search for SimpleRunner class definition and its run method
rg -A 30 "class SimpleRunner" examples/specdec_bench/specdec_bench/runners/Repository: NVIDIA/Model-Optimizer
Length of output: 2177
🏁 Script executed:
#!/bin/bash
# Search for all async def run methods in runner files
rg -B 2 -A 5 "async def run" examples/specdec_bench/specdec_bench/runners/Repository: NVIDIA/Model-Optimizer
Length of output: 1539
🏁 Script executed:
#!/bin/bash
# Check if there are any other BaseRunner subclasses
rg "class.*\(BaseRunner\)" examples/specdec_bench/specdec_bench/runners/Repository: NVIDIA/Model-Optimizer
Length of output: 152
🏁 Script executed:
#!/bin/bash
# Check the __init__.py to see what's exported
cat examples/specdec_bench/specdec_bench/runners/__init__.pyRepository: NVIDIA/Model-Optimizer
Length of output: 806
Update SimpleRunner.run() signature to match base class.
The run() method signature in SimpleRunner still uses (prompt_ids, max_length, end_id, request_id, turn_id), but BaseRunner.run() now expects (prompt_ids, max_length, end_id, sampling_kwargs). Update the signature and extract request_id and turn_id from sampling_kwargs where needed for the model.run() call and process_metrics_step().
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/runners/base.py` around lines 24 - 25,
SimpleRunner.run currently has the old signature; update its method signature to
match BaseRunner.run(prompt_ids, max_length, end_id, sampling_kwargs), then pull
request_id and turn_id from sampling_kwargs (e.g., request_id =
sampling_kwargs.get("request_id")) before calling model.run and before calling
process_metrics_step; ensure model.run is invoked with those extracted
request_id and turn_id where it previously used the positional args and pass
sampling_kwargs (or remaining sampling params) to the model call as appropriate
so behavior is unchanged.
| def encode_chat(tokenizer, messages, chat_template_args={}, completions=False): | ||
| if completions: | ||
| return tokenizer.encode(messages[-1]["content"], add_special_tokens=False) | ||
| return tokenizer.encode( | ||
| tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), | ||
| tokenizer.apply_chat_template( | ||
| messages, tokenize=False, add_generation_prompt=True, **chat_template_args | ||
| ), | ||
| add_special_tokens=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mutable default argument chat_template_args={} — classic Python pitfall.
Using a mutable dict as a default argument means all callers share the same object. If any code path mutates it (e.g., chat_template_args["key"] = val), the mutation persists across subsequent calls.
Proposed fix
-def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
+def encode_chat(tokenizer, messages, chat_template_args=None, completions=False):
+ if chat_template_args is None:
+ chat_template_args = {}
if completions:
return tokenizer.encode(messages[-1]["content"], add_special_tokens=False)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def encode_chat(tokenizer, messages, chat_template_args={}, completions=False): | |
| if completions: | |
| return tokenizer.encode(messages[-1]["content"], add_special_tokens=False) | |
| return tokenizer.encode( | |
| tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), | |
| tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, **chat_template_args | |
| ), | |
| add_special_tokens=False, | |
| ) | |
| def encode_chat(tokenizer, messages, chat_template_args=None, completions=False): | |
| if chat_template_args is None: | |
| chat_template_args = {} | |
| if completions: | |
| return tokenizer.encode(messages[-1]["content"], add_special_tokens=False) | |
| return tokenizer.encode( | |
| tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, **chat_template_args | |
| ), | |
| add_special_tokens=False, | |
| ) |
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/utils.py` around lines 24 - 32, The
function encode_chat uses a mutable default chat_template_args={} which can lead
to shared-state bugs; change the signature to default chat_template_args=None
and inside encode_chat set a local variable (e.g., args = {} if
chat_template_args is None else dict(chat_template_args)) so callers' dicts
aren’t mutated and you still pass args into tokenizer.apply_chat_template; keep
the completions branch behavior unchanged.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #875 +/- ##
=======================================
Coverage 73.44% 73.44%
=======================================
Files 197 197
Lines 20657 20657
=======================================
Hits 15172 15172
Misses 5485 5485 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Can you ask some description on what this PR is for? |
Added info to the top! |
What does this PR do?
Type of change: ?
Overview:
Addition of SpecBench Dataset
Addition of NVIDID SPEED-Bench dataset, preproc scripts, and custom metrics aggregator
Addition of example of converting SpecBench Medusa to this FW
Addition of Initial TRTLLM AutoDeploy Specdec support
Updates to all frameworks for better perf (overlap/async scheduling etc)
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Refactor
Chores