Skip to content

Conversation

@IzzyPutterman
Copy link
Contributor

@IzzyPutterman IzzyPutterman commented Feb 10, 2026

What does this PR do?

Type of change: ?

Overview:
Addition of SpecBench Dataset
Addition of NVIDID SPEED-Bench dataset, preproc scripts, and custom metrics aggregator
Addition of example of converting SpecBench Medusa to this FW
Addition of Initial TRTLLM AutoDeploy Specdec support

Updates to all frameworks for better perf (overlap/async scheduling etc)

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added SPEED-Bench dataset support with configurable throughput and qualitative configurations
    • Introduced SpecBench metrics with acceptance rate analysis and visualizations
    • Added progress bar during benchmark execution
    • New model implementations for auto-deployment and Medusa-style speculative decoding
    • Data preparation utility for benchmark datasets
    • Enhanced metrics with per-category analysis and performance charts
  • Documentation

    • Updated README with SPEED-Bench workflow and examples
    • New porting guide for integrating custom benchmark runners
  • Refactor

    • Streamlined model and runner interfaces for improved flexibility
    • Consolidated dataset implementations and removed deprecated base classes
  • Chores

    • Added required dependencies for data handling and visualizations

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
@IzzyPutterman IzzyPutterman requested a review from a team as a code owner February 10, 2026 19:07
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 10, 2026

📝 Walkthrough

Walkthrough

This pull request integrates SPEED-Bench and SpecBench dataset support, introduces new model implementations (AutoDeployModel, SpecBenchMedusaModel), expands the CLI with dataset selection options, updates model and runner signatures for consistency, and adds corresponding metrics and data preparation utilities.

Changes

Cohort / File(s) Summary
Documentation
examples/specdec_bench/README.md, examples/specdec_bench/SPECBENCH_PORTING.md
Updated Docker image versions and added comprehensive SPEED-Bench workflow documentation including prerequisites, quickstart guides, context-length guidance, and licensing notes; introduced new porting guide for adapting Spec-Bench inference runners to specdec_bench Model interface.
Data Management & Preparation
examples/specdec_bench/prepare_data.py, examples/specdec_bench/requirements.txt, specdec_bench/datasets/base.py, specdec_bench/datasets/__init__.py
Added new data preparation orchestration script with CLI support for dataset selection and configuration; introduced new dependencies (datasets, rich, seaborn, tiktoken); expanded Request dataclass with question_id, category, output_turn_text fields; added prepare_data classmethod to Dataset base class; updated public exports to include SpecBench and SPEEDBench.
Dataset Infrastructure Removal
specdec_bench/datasets/base_hf.py
Deleted HuggingFace-based dataset loading framework (BaseHF, OpenOrca, OpenMathInstructv2, UltraChat classes).
Dataset Implementations
specdec_bench/datasets/mtbench.py, specdec_bench/datasets/random_token.py, specdec_bench/datasets/specbench.py, specdec_bench/datasets/speed.py
Updated MTBench and RandomToken with List type annotations and refactored preprocessing signatures; introduced new SpecBench dataset class for JSON-lines loading; added comprehensive SPEEDBench integration with multi-dataset external data resolution, prompt generation for multiple benchmarks (Bamboo, ChatRag, MMLU Pro, etc.), token-length enforcement, configuration parsing, and preparation workflow.
Benchmark Execution
examples/specdec_bench/run.py
Major refactoring: added progress tracking (tqdm_gather, show_progress flag), expanded run_loop signature with end_id, show_progress, completions, chat_template_args parameters; added dataset selection via CLI (--dataset, --dataset_path) with support for mtbench, random, specbench, speed; integrated SpecBench metrics insertion; added save_dir propagation to metrics; introduced new model entries (AUTO_DEPLOY, SPECBENCH_MEDUSA); enhanced error reporting with request indices and question_ids; added ignore_eos flag behavior with warnings.
Model Base Infrastructure
specdec_bench/models/base.py
Updated abstract run method signature from (prompt_ids, max_length, end_id, request_id, turn_id) to (prompt_ids, sampling_params, request_id, turn_id), consolidating sampling parameters.
New Model Implementations
specdec_bench/models/auto_deploy.py, specdec_bench/models/specbench_medusa.py
Added AutoDeployModel wrapping tensorrt_llm LLM with async generation, beam search handling, and speculative decoding support via DraftTargetDecodingConfig; introduced SpecBenchMedusaModel with Medusa-style iterative decoding, candidate generation, tree decoding, posterior evaluation, and per-step token timing.
Existing Model Updates
specdec_bench/models/vllm.py, specdec_bench/models/sglang.py, specdec_bench/models/trtllm_torch_api.py
Enhanced vLLM with num_speculative_tokens scaling, async_scheduling/enforce_eager flags, ignore_eos support, and improved draft model handling; updated SGLang with extended configuration options (ep_size, attention_backend, torch_compile, cuda_graph support, draft settings) and removed itertools.pairwise dependency; updated TRTLLM with max_seq_len/max_num_tokens parameters, EAGLE3 extra_params handling, and chunked_prefill defaults.
Package Initialization
specdec_bench/__init__.py, specdec_bench/models/__init__.py, specdec_bench/metrics/__init__.py, specdec_bench/runners/__init__.py
Updated copyright year; reorganized model exports (added AutoDeployModel, SpecBenchMedusaModel; removed base Model); updated metrics exports (added SpecBench, AATiming; removed Metric base); removed BaseRunner from runner exports.
Metrics
specdec_bench/metrics/base.py, specdec_bench/metrics/timing.py, specdec_bench/metrics/acceptance_rate.py, specdec_bench/metrics/aa_timing.py, specdec_bench/metrics/specbench.py
Enhanced base metric file writing with explicit read mode; updated Timing with "Number of Output Tokens" metric; adjusted AcceptanceRate (fixed clear() list initialization); switched AATiming encoding from cl100k_base to o200k_base; introduced comprehensive SpecBench metrics class extending AcceptanceRate with per-category aggregation, response-length binning, result visualization (3-panel matplotlib figure), and Rich table formatting.
Runners
specdec_bench/runners/base.py, specdec_bench/runners/simple.py
Updated BaseRunner.run signature to (prompt_ids, max_length, end_id, sampling_kwargs) removing request_id/turn_id; refactored metric processing to multiline form; adjusted SimpleRunner await formatting (no logic change).
Utilities
specdec_bench/utils.py
Enhanced get_tokenizer with trust_remote_code=True; extended encode_chat with chat_template_args and completions parameters for flexible message encoding; updated read_json to explicit read mode; refactored postprocess_gptoss to extract and trim final_message using multiple token delimiters.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant CLI as run.py
    participant Dataset as Dataset Layer
    participant Model as Model Instance
    participant Metrics as Metrics Collector
    participant Output as File Output
    
    User->>CLI: run_simple(args)
    CLI->>CLI: Parse args & load tokenizer
    CLI->>Dataset: Load dataset (speed/specbench/etc)
    Dataset-->>CLI: Return prepared data
    CLI->>CLI: Initialize model & metrics
    
    loop For each request in dataset
        CLI->>CLI: Encode prompt (chat_template_args)
        CLI->>Model: run(prompt_ids, sampling_params)
        Model-->>CLI: output_ids, token_times
        CLI->>Metrics: process_step(output, request_id)
        Metrics-->>Metrics: Aggregate acceptance_rate
    end
    
    CLI->>Metrics: process_final(text_outputs)
    Metrics->>Metrics: Compute statistics & visualizations
    Metrics->>Output: Write results (jsonl, json, png)
    Output-->>Metrics: File saved
    Metrics-->>CLI: Results processed
    CLI-->>User: Benchmark complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'SpecDec Bench: February Update' is vague and generic, using non-descriptive phrasing that does not convey meaningful information about the specific changes in the pull request. Replace the generic title with a more specific summary of the main change, such as 'Add SPEED-Bench dataset integration and SpecBench metrics' or 'Introduce SPEEDBench dataset and model implementations for speculative decoding benchmarking'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch iputterman/specdec-bench-02-10

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
examples/specdec_bench/specdec_bench/metrics/aa_timing.py (1)

49-63: ⚠️ Potential issue | 🟡 Minor

Potential division by zero when all requests have ≤2 timing entries.

If gen_tp_time is empty (all requests have len(times) <= 2), the code correctly guards with if gen_tp_time:. However, if self.timing is non-empty but all timing lists share the same start/end time (Line 53), end_time - start_time could be zero, causing a ZeroDivisionError on Line 53.

Proposed guard
         end_time = max([t[-1] for t in self.timing])
-        self.out["AA Output TPS"] = sum(self.total_tokens) / (end_time - start_time)
+        elapsed = end_time - start_time
+        if elapsed > 0:
+            self.out["AA Output TPS"] = sum(self.total_tokens) / elapsed
+        else:
+            self.out["AA Output TPS"] = 0.0
examples/specdec_bench/specdec_bench/metrics/base.py (1)

36-48: ⚠️ Potential issue | 🟡 Minor

write() lacks error handling for corrupt JSON files.

If the existing JSON file is malformed (e.g., from a previous interrupted write), json.load will raise a JSONDecodeError and the current metric data in self.out will be silently lost. Consider catching decode errors and falling back to starting a fresh list.

Proposed guard
         if self.out:
             filename = os.path.join(self.directory, f"{self.name}.json")
             if os.path.exists(filename):
-                with open(filename, "r") as json_file:
-                    existing_data = json.load(json_file)
-                existing_data.append(self.out)
+                try:
+                    with open(filename, "r") as json_file:
+                        existing_data = json.load(json_file)
+                    existing_data.append(self.out)
+                except (json.JSONDecodeError, IOError):
+                    existing_data = [self.out]
             else:
                 existing_data = [self.out]
examples/specdec_bench/specdec_bench/datasets/base.py (1)

29-31: ⚠️ Potential issue | 🟠 Major

output_turn_ids is a class variable, not a dataclass field.

Line 30 assigns output_turn_ids = None without a type annotation, so the @dataclass decorator treats it as a class variable shared across all instances — not an instance field. It won't appear in __init__, __repr__, or __eq__. If this is intended as a per-instance field (like output_turn_text on Line 31), it needs a type annotation.

Additionally, Line 31 uses lowercase list[str] while Line 26 uses List[str] — inconsistent style.

Proposed fix
     # not to be set by user
-    output_turn_ids = None
-    output_turn_text: list[str] = field(default_factory=list)
+    output_turn_ids: Optional[Any] = None
+    output_turn_text: List[str] = field(default_factory=list)
examples/specdec_bench/specdec_bench/models/base.py (1)

21-28: ⚠️ Potential issue | 🟠 Major

API signature mismatch: base class run() doesn't match any concrete implementation.

The base class now declares run(self, prompt_ids, sampling_params, request_id, turn_id), but all five concrete model implementations (vllm.py, auto_deploy.py, trtllm_torch_api.py, sglang.py, specbench_medusa.py) and the runner (simple.py) still use the old signature run(self, prompt_ids, max_length, end_id, request_id, turn_id). The runner calls the model with arguments in the old order, which will cause a runtime mismatch.

Either update all concrete models and callers to use sampling_params instead of max_length and end_id, or keep the base class aligned with the current implementations until a full migration is done.

examples/specdec_bench/specdec_bench/models/vllm.py (1)

100-105: ⚠️ Potential issue | 🟠 Major

ignore_eos is sticky — once set to True it's never cleared.

self.sampling_config is a shared mutable instance attribute. When end_id == -1, ignore_eos is set to True at line 105, but there's no else branch to reset it to False. All subsequent calls with a valid end_id will still have ignore_eos=True, silently ignoring EOS tokens.

Proposed fix
     async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
         output_dict = {}
         self.sampling_config.max_tokens = max_length
         self.sampling_config.stop_token_ids = [end_id]
         if end_id == -1:
             self.sampling_config.ignore_eos = True
+        else:
+            self.sampling_config.ignore_eos = False
examples/specdec_bench/specdec_bench/models/sglang.py (1)

97-106: ⚠️ Potential issue | 🟠 Major

outputs = None leads to TypeError if async_generate yields no chunks.

If the async stream produces no chunks, outputs stays None and end_id == outputs[-1] at line 106 will raise TypeError: 'NoneType' object is not subscriptable. Initialize outputs as an empty list and add a guard:

Proposed fix
-        outputs = None
+        outputs = []
         result = await self.model.async_generate(
             sampling_params=self.sampling_config, input_ids=prompt_ids, stream=True
         )
         async for chunk in result:
             timing.append(time.perf_counter())
             outputs = chunk["output_ids"]
             beam_lens[0].append(chunk["meta_info"]["completion_tokens"])

-        if end_id == outputs[-1]:
+        if outputs and end_id == outputs[-1]:
examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py (1)

80-81: ⚠️ Potential issue | 🔴 Critical

Bug: clear() resets self.prompt_ar to a list instead of a dict.

__init__ (line 24) initializes self.prompt_ar = {}, and process_step (lines 28-34) uses it as a dict (keyed by request_id). After clear() is called, self.prompt_ar becomes [], and the next process_step call will fail when attempting dict-style keyed access on a list.

🐛 Proposed fix
     def clear(self):
-        self.prompt_ar = []
+        self.prompt_ar = {}
🤖 Fix all issues with AI agents
In `@examples/specdec_bench/README.md`:
- Line 84: The README example references a missing runtime_params file; either
add a new runtime_args_long_context.yaml with sensible runtime keys (e.g.,
max_sequence_length, batch_size, memory_pool_size, device_map, and any
TRTLLM-specific settings) and commit it, or remove the `--runtime_params
runtime_args_long_context.yaml` flag from the example command in the README (the
command shown passed to run.py) so the example works as-is; update README to
reflect the chosen approach and mention where to find or how to customize the
runtime params if you add the YAML.

In `@examples/specdec_bench/run.py`:
- Around line 138-148: When args.dataset == "random" the current uniform lookup
via datasets_available passes args.dataset_path into RandomToken but
RandomToken.__init__ requires a tokenizer; update the branch that handles
args.dataset (the datasets_available lookup) to special-case "random" and
instantiate datasets.RandomToken(tokenizer, args.dataset_path, **dataset_kwargs)
(or remove "random" from datasets_available if you prefer the latter) so
RandomToken is constructed with a tokenizer object rather than a path; ensure
you reference the RandomToken class and datasets_available map when making the
change.

In `@examples/specdec_bench/SPECBENCH_PORTING.md`:
- Around line 20-37: The Model interface's run signature (Model.run(self,
prompt_ids, sampling_params, request_id, turn_id)) conflicts with older examples
that use run(self, prompt_ids, max_length, end_id, request_id, turn_id); pick
the canonical API and make all examples consistent: either update the Step 4
implementation example and the testing example to call/run and implement
Model.run with (prompt_ids, sampling_params, request_id, turn_id) and adjust
their internals and docstrings to read sampling_params (containing
max_length/end_id), or change the interface to the older signature and update
the docstring to match; ensure Model.__init__, Model.run, the Step 4
implementation example, and the test example all use the same parameter names
and types (prompt_ids list, and a single sampling_params or explicit
max_length/end_id) and that returned dict keys and token type expectations
remain unchanged.

In `@examples/specdec_bench/specdec_bench/models/auto_deploy.py`:
- Around line 56-70: The loop assumes generate_async produced outputs and
non-empty beam_lens; add a guard before using outputs[-1] and indexing beam_len:
check if outputs is empty or beam_lens is empty and return an appropriate empty
structure (e.g., reformatted_output_ids or []) early, and inside the for loop
skip/continue any beam_len that is empty to avoid beam_len[0] / beam_len[-1]
access; update references in this block (outputs, beam_lens,
reformatted_output_ids, beam_idx, response) accordingly so no IndexError occurs
when generate_async yields nothing.

In `@examples/specdec_bench/specdec_bench/models/specbench_medusa.py`:
- Around line 38-57: Replace the brittle cwd-based sys.path insertion with a
configurable and robust lookup: read SPEC_BENCH_PATH from an environment
variable first (fallback to a path computed relative to this file using
os.path.dirname(__file__) and joining "..", "Spec-Bench"), and only modify
sys.path if that resolved path exists and is not already present (avoid blindly
inserting at index 0 to prevent module shadowing). Update the import block that
brings in MedusaModel, initialize_past_key_values, initialize_medusa,
generate_medusa_buffers, etc., to use that resolved path so tests/CI/notebook
runs can override via SPEC_BENCH_PATH and to avoid permanent, unsafe sys.path
manipulation.

In `@examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py`:
- Around line 99-107: EagleDecodingConfig.model_fields access can raise
AttributeError and trtllm.__version__ may not exist; change the logic in the
block around EagleDecodingConfig.model_fields to first check safely with
hasattr/getattr (e.g., fields = getattr(EagleDecodingConfig, "model_fields",
None)) and if fields is truthy use the existing membership check to set
extra_params["allow_advanced_sampling"] from kwargs, otherwise fall back to
checking "allow_advanced_sampling" in kwargs and emit a warning that does not
assume trtllm.__version__ (use getattr(trtllm, "__version__", "unknown version")
or omit version), referencing the EagleDecodingConfig.model_fields,
extra_params, kwargs, and trtllm.__version__ symbols so the change is made in
the same function.

In `@examples/specdec_bench/specdec_bench/runners/base.py`:
- Around line 24-25: SimpleRunner.run currently has the old signature; update
its method signature to match BaseRunner.run(prompt_ids, max_length, end_id,
sampling_kwargs), then pull request_id and turn_id from sampling_kwargs (e.g.,
request_id = sampling_kwargs.get("request_id")) before calling model.run and
before calling process_metrics_step; ensure model.run is invoked with those
extracted request_id and turn_id where it previously used the positional args
and pass sampling_kwargs (or remaining sampling params) to the model call as
appropriate so behavior is unchanged.

In `@examples/specdec_bench/specdec_bench/utils.py`:
- Around line 24-32: The function encode_chat uses a mutable default
chat_template_args={} which can lead to shared-state bugs; change the signature
to default chat_template_args=None and inside encode_chat set a local variable
(e.g., args = {} if chat_template_args is None else dict(chat_template_args)) so
callers' dicts aren’t mutated and you still pass args into
tokenizer.apply_chat_template; keep the completions branch behavior unchanged.
🟡 Minor comments (15)
examples/specdec_bench/specdec_bench/__init__.py-1-1 (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Copyright year appears to have regressed from 2025 to 2024.

Per the AI summary, this line was changed from 2025 to 2024, which moves the copyright year backwards. Given the current date (February 2026), the year should likely be 2025 or 2026, not 2024.

Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
examples/specdec_bench/README.md-44-44 (1)

44-44: ⚠️ Potential issue | 🟡 Minor

Clarify HuggingFace authentication requirement for SPEED-Bench dataset.

The SPEED-Bench dataset URL returns HTTP 401 (Unauthorized), indicating authentication is required. While the referenced models (Llama 3.3 70B and EAGLE3) are publicly accessible, the documentation should explicitly state that users need to authenticate with HuggingFace to access the SPEED-Bench dataset. Update the documentation to mention this requirement, either in the overview or in the data preparation section.

examples/specdec_bench/README.md-6-7 (1)

6-7: ⚠️ Potential issue | 🟡 Minor

Update SGLang docker image reference for specificity. The SGLang docker image in the README references lmsysorg/sglang:v0.5.7, but the available published docker image tags include cuda variant suffixes (e.g., lmsysorg/sglang:v0.5.7-cu130-runtime). Clarify which specific variant is required for this benchmark, or document that the user should select the appropriate cuda version for their environment. The vLLM v0.15.0 and TensorRT-LLM 1.3.0rc2 image tags are correctly documented.

examples/specdec_bench/specdec_bench/datasets/base.py-43-43 (1)

43-43: ⚠️ Potential issue | 🟡 Minor

str | Path union syntax requires Python 3.10+ at runtime.

Line 43 uses str | Path which is only valid at runtime in Python 3.10+. The rest of the file uses Optional[...] and List[...] from typing, suggesting older Python compatibility is intended. Either add from __future__ import annotations at the top of the file, or use Union[str, Path].

Proposed fix (option A — future annotations)
+from __future__ import annotations
 from typing import Any, List, Optional
examples/specdec_bench/requirements.txt-1-4 (1)

1-4: ⚠️ Potential issue | 🟡 Minor

Update datasets and rich to latest versions for current releases.

All specified versions exist on PyPI. However, datasets (4.4.0) and rich (14.2.0) have newer releases available: datasets 4.5.0 (Jan 2026) and rich 14.3.2 (Feb 2026). Update these for the latest patches and features. seaborn 0.13.2 and tiktoken 0.12.0 are already at their latest versions.

examples/specdec_bench/specdec_bench/models/sglang.py-88-89 (1)

88-89: ⚠️ Potential issue | 🟡 Minor

Misleading docstring.

The docstring says "Synchronous version of run for use with asyncio.to_thread" but the method is async and uses await self.model.async_generate(...). This docstring appears to be copied from another model.

examples/specdec_bench/specdec_bench/models/sglang.py-40-42 (1)

40-42: ⚠️ Potential issue | 🟡 Minor

assert False is a poor substitute for raising an error.

assert False will be silently removed when Python is run with -O. Use raise NotImplementedError(...) instead:

Proposed fix
         elif speculative_algorithm == "NGRAM":
             speculative_algorithm = "LOOKAHEAD"
-            assert False, "Needs more work"
+            raise NotImplementedError("NGRAM/LOOKAHEAD speculative decoding is not yet supported for SGLANG")
examples/specdec_bench/specdec_bench/models/specbench_medusa.py-100-106 (1)

100-106: ⚠️ Potential issue | 🟡 Minor

draft_model_path defaults to None and is passed directly to MedusaModel.from_pretrained.

If draft_model_dir is not supplied in kwargs (line 78), self.draft_model_path will be None, likely causing from_pretrained to fail with an opaque error. Validate early:

Proposed fix
         self.draft_model_path = kwargs.get("draft_model_dir", None)
+        if self.draft_model_path is None:
+            raise ValueError("draft_model_dir is required for SpecBenchMedusaModel")
examples/specdec_bench/specdec_bench/models/auto_deploy.py-86-138 (1)

86-138: ⚠️ Potential issue | 🟡 Minor

No guard against LLM being None.

If tensorrt_llm is unavailable, LLM is None and LLM(**llm_kwargs) at line 137 will raise a cryptic TypeError: 'NoneType' object is not callable. Add an early check for a clear failure message.

Proposed fix
 def create_auto_deploy_model(
     model_path: str, max_concurrent_requests: int, kwargs: Dict[str, Any]
 ):
+    if LLM is None:
+        raise ImportError(
+            "tensorrt_llm._torch.auto_deploy is not installed. Cannot create AutoDeployModel."
+        )
     world_size = kwargs.get("world_size", kwargs.get("tensor_parallel_size", 1))
examples/specdec_bench/specdec_bench/models/specbench_medusa.py-160-205 (1)

160-205: ⚠️ Potential issue | 🟡 Minor

NameError on idx if self.max_steps is 0.

If self.max_steps == 0, the for loop body never executes, leaving idx undefined. Line 205 (idx + 1) then raises a NameError.

Proposed fix
         new_token = 0
+        num_steps = 0

         for idx in range(self.max_steps):
+            num_steps = idx + 1
             candidates, tree_candidates = generate_candidates(
                 ...
-        return input_ids, new_token, idx + 1, accept_length_list, timing
+        return input_ids, new_token, num_steps, accept_length_list, timing
examples/specdec_bench/specdec_bench/models/auto_deploy.py-19-25 (1)

19-25: ⚠️ Potential issue | 🟡 Minor

Missing SamplingParams fallback when import fails.

If tensorrt_llm is not installed, LLM is set to None but SamplingParams and DraftTargetDecodingConfig are left undefined. Any code path that reaches check_sampling_config (line 144) or create_auto_deploy_model with speculative_algorithm == "DRAFT_TARGET" (line 102) will raise a NameError instead of a clear error message.

Proposed fix
 except ImportError:
     print("tensorrt_llm._torch.auto_deploy is not installed.")
     LLM = None
+    SamplingParams = None
+    DraftTargetDecodingConfig = None
examples/specdec_bench/specdec_bench/models/specbench_medusa.py-278-293 (1)

278-293: ⚠️ Potential issue | 🟡 Minor

stop() can raise AttributeError if __init__ failed partway through.

If model loading fails after entering __init__ but before self.model is set (e.g., from_pretrained raises), calling stop() will hit hasattr(self.model, ...) on a non-existent attribute. The existing hasattr(self, "model") check at line 291 covers model deletion, but the earlier hasattr(self.model, ...) calls at lines 281/287 do not first check that self.model exists.

Proposed fix
     def stop(self):
         """Cleanup resources."""
+        if not hasattr(self, "model") or self.model is None:
+            return
         # Clear cached KV states to free memory
-        if hasattr(self.model, "past_key_values"):
+        if hasattr(self.model, "past_key_values"):
             del self.model.past_key_values
             del self.model.past_key_values_data
             del self.model.current_length_data
 
         # Clear medusa buffers
         if hasattr(self.model, "medusa_buffers"):
             del self.model.medusa_buffers
 
         # Move model to CPU or delete to free GPU memory
-        if hasattr(self, "model") and self.model is not None:
-            del self.model
-            torch.cuda.empty_cache()
+        del self.model
+        torch.cuda.empty_cache()
examples/specdec_bench/specdec_bench/datasets/speed.py-175-256 (1)

175-256: ⚠️ Potential issue | 🟡 Minor

_generate_stackselect_prompt: parameter answer is shadowed by loop variable, and answers_to_add_stop may include content beyond the token budget.

  1. Line 237 for i, answer in enumerate(answers) shadows the answer parameter (line 177), which is the correct answer designation (e.g., "A3"). After the loop, answer refers to the last iterated text, not the original parameter. The correct answer index was already captured on line 228, so it works, but it's confusing.

  2. answers_to_add_stop is initialized to 0 (line 236). If the very first non-correct answer exceeds the token budget and the loop breaks immediately, answers_to_add_stop remains 0, and answers[:1] still includes the first answer—even though it doesn't fit. Consider initializing to -1 or adding a guard.

  3. Line 179: random.seed(42) mutates global state. This is a side effect that could affect other code using the random module. Consider using a local random.Random(42) instance instead.

examples/specdec_bench/specdec_bench/metrics/specbench.py-42-44 (1)

42-44: ⚠️ Potential issue | 🟡 Minor

Error message references seaborn but it's never imported or used.

The ImportError message instructs users to install seaborn, but the actual imports only use rich, matplotlib, and pandas. The "seaborn-v0_8" style on line 133 is a built-in matplotlib style and doesn't require the seaborn package. Either remove seaborn from the message or add an actual import seaborn if it's intended for future use.

Proposed fix
-            raise ImportError(
-                "Please install rich, matplotlib, seaborn, and pandas to use the SpecBench metric"
-            )
+            raise ImportError(
+                "Please install rich, matplotlib, and pandas to use the SpecBench metric"
+            )
examples/specdec_bench/run.py-59-70 (1)

59-70: ⚠️ Potential issue | 🟡 Minor

Mutable default argument chat_template_args={}.

Using a mutable default ({}) is a well-known Python pitfall. Although chat_template_args is only read here (passed to encode_chat), it's best to use None as the default to avoid any accidental mutation surprises.

Proposed fix
 async def run_loop(
     runner,
     dataset,
     tokenizer,
     output_length,
     postprocess,
     concurrency=10,
     end_id=-1,
     show_progress=False,
     completions=False,
-    chat_template_args={},
+    chat_template_args=None,
 ):

Then at the top of the function body:

     semaphore = asyncio.Semaphore(concurrency)
     max_length = output_length
+    if chat_template_args is None:
+        chat_template_args = {}
🧹 Nitpick comments (18)
examples/specdec_bench/README.md (1)

50-50: Fix markdown formatting issues for better rendering.

Multiple markdown formatting issues were detected by linters that should be corrected for consistent rendering:

  1. Line 50: Add a blank line before the code block
  2. Line 56: Add a blank line before the #### License heading
  3. Line 57: Remove trailing space at the end of the line
  4. Line 66: Add a blank line before the #### Qualitative split: heading
  5. Line 67: Add a blank line before the code block
  6. Line 71: Add a blank line before the #### Throughput split: heading
  7. Line 72: Add a blank line before the code block
  8. Line 77: Add a blank line before the code block
  9. Line 92: Add a final newline at the end of the file
📝 Proposed fix for markdown formatting
 2. Prepare the data using the provided script:
+
 ```bash
 python3 prepare_data.py --dataset speed --config all

The data will be saved to data/ directory, each config type (qualitative, throughput_1k, ...) to each own directory.

License

-GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.
+GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement.

ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1.

NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The prepare_data.py script automatically fetches data from all the source datasets.

Additional details are in HuggingFace dataset repository.

Qualitative split:

python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/qualitative --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress

Throughput split:

python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_1k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress

For longer context (>8192 tokens), please use the following configuration when using TRTLLM:
+

engine_args:
  max_seq_len: 131072   # Model max context length (for Llama 3.3 70B)
  enable_chunked_prefill: true
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml

Notes

The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics.
If larger concurrency is needed, it is recommended to fully deploy the model using vllm serve, python -m sglang.launch_server, or trtllm-serve (for vLLM, SGlang, or TRTLLM respectively) and
-use a more robust benchmarking client like NVIDIA AI Perf.
\ No newline at end of file
+use a more robust benchmarking client like NVIDIA AI Perf.


</details>


Also applies to: 56-57, 66-67, 71-72, 77-77, 92-92

</blockquote></details>
<details>
<summary>examples/specdec_bench/specdec_bench/utils.py (1)</summary><blockquote>

`20-21`: **`trust_remote_code=True` enables arbitrary code execution from model repos.**

This is a known security trade-off. It's common in ML workflows but worth noting — any model repository could execute arbitrary Python during tokenizer loading. Ensure this is intentional and documented.

</blockquote></details>
<details>
<summary>examples/specdec_bench/specdec_bench/models/auto_deploy.py (2)</summary><blockquote>

`37-75`: **`run()` and `check_sampling_config()` are near-identical copies of `TRTLLMPYTModel`.**

The entire `run()` method (lines 37–75) and `check_sampling_config()` (lines 141–154) are duplicated almost verbatim from `trtllm_torch_api.py`. Consider extracting the shared streaming-output-reformatting logic and the sampling config builder into a common utility to avoid divergence.

---

`77-83`: **`del self.model` doesn't guarantee GPU memory release.**

`del self.model` removes the reference but Python's garbage collector may not immediately free GPU resources. Consider calling `torch.cuda.empty_cache()` after deletion (as done in `SpecBenchMedusaModel.stop()`) or invoking any shutdown method the LLM object provides.

</blockquote></details>
<details>
<summary>examples/specdec_bench/specdec_bench/models/specbench_medusa.py (2)</summary><blockquote>

`76-76`: **`assert` for input validation can be silently disabled with `python -O`.**

Use an explicit `if`/`raise ValueError` instead:

<details>
<summary>Proposed fix</summary>

```diff
-        assert max_concurrent_requests == 1, "Only support batch size 1 for now!"
+        if max_concurrent_requests != 1:
+            raise ValueError("SpecBenchMedusaModel only supports batch size 1.")

200-203: EOS check converts the tensor to a Python list on every iteration.

input_ids[0, input_len:].tolist() allocates a new list each step, growing linearly with generated tokens. For long sequences this adds overhead. Consider checking only the newly appended tokens or using a tensor-based check:

-            if end_id in input_ids[0, input_len:].tolist():
+            if (input_ids[0, input_len:] == end_id).any():
                 break
examples/specdec_bench/specdec_bench/models/vllm.py (1)

33-70: Repeated kwargs.get("speculative_algorithm", None) calls are verbose but correct.

Each branch re-evaluates kwargs.get(...). The pattern works but is noisy. Consider extracting to a local variable at the top, consistent with how sglang.py does it at line 35:

speculative_algorithm = kwargs.get("speculative_algorithm", None)
examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py (1)

40-50: use_draft_logits parameter is accepted but never used.

The __init__ signature includes use_draft_logits=False (line 46) but it's not stored or referenced anywhere in the class.

examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py (1)

89-89: Yoda condition is unconventional in Python.

"system" == messages[0]["role"] reads less naturally than the standard messages[0]["role"] == "system". This is a minor style nit.

examples/specdec_bench/specdec_bench/datasets/random_token.py (1)

22-26: LGTM — tokenizer is now passed explicitly rather than stored.

Clean refactor. Note that _preprocess(self, tokenizer) diverges from the base class signature _preprocess(self) in datasets/base.py (line 39). This is the same pattern used by MTBench._preprocess(self, path) and SpecBench._preprocess(self, path), so it's consistent within the PR, but the base class definition is now misleading.

Consider updating Dataset._preprocess in base.py to accept *args, **kwargs (or remove it entirely) so the contract matches actual usage.

Also applies to: 28-28

examples/specdec_bench/specdec_bench/runners/base.py (1)

28-28: List comprehensions used purely for side effects.

process_metrics_final, process_metrics_step, and clear_metrics all use list comprehensions solely for their side effects, which builds and discards a list of None values. Conventional Python style prefers a for loop here. This is pre-existing, so feel free to defer.

♻️ Example for process_metrics_step
     def process_metrics_step(self, step_outputs, request_id, turn_id):
-        [
-            metric.process_step(step_outputs, request_id, turn_id)
-            for metric in self.metrics
-        ]
+        for metric in self.metrics:
+            metric.process_step(step_outputs, request_id, turn_id)

Also applies to: 31-34, 37-37

examples/specdec_bench/prepare_data.py (1)

66-71: --config choices are tightly coupled to SPEEDBench's config_type.

The --config choices use get_args(config_type) from speed.py, which won't be meaningful if other datasets are added to datasets_available. If you plan to support more datasets here, consider making config choices dataset-dependent (e.g., validated inside prepare_data() after the dataset is known, rather than statically in argparse).

For now this is fine since only "speed" is available.

examples/specdec_bench/specdec_bench/metrics/specbench.py (2)

65-68: Variable category_ar is shadowed from list to float.

On line 65 category_ar is the loop variable (a list of AR values), but on line 67 it's reassigned to a scalar (the mean). This works but hurts readability—a reader may expect category_ar to remain a list through the loop body.

Suggested clarity improvement
         for category_name, category_ar in per_category.items():
             if len(category_ar) > 0:
-                category_ar = mean(category_ar)
-                self.out["Category_AR"][category_name] = category_ar
+                self.out["Category_AR"][category_name] = mean(category_ar)

127-130: Remove the "Completely generated by Cursor" comment.

Tool-attribution comments are not useful for maintainers and may imply the code was not reviewed. Consider removing it.

examples/specdec_bench/run.py (2)

169-170: Verify SpecBench metrics are correctly triggered for --dataset specbench.

The condition args.specbench is not None or args.dataset == "speed" works because lines 349–350 set args.specbench = args.dataset_path when --dataset specbench. However, this relies on the cross-assignment at lines 349–352 which couples the new --dataset flag to the legacy flags. This works but is fragile—if someone adds a new dataset path without updating the cross-assignment block, metrics selection could silently fall through to the generic AcceptanceRate.

Suggested: use args.dataset directly in the metrics condition
-    if args.mtbench is not None:
-        metrics_list.insert(0, metrics.MTBench())
-    elif args.specbench is not None or args.dataset == "speed":
-        metrics_list.insert(0, metrics.SpecBench(requests=dataset.data))
+    if args.dataset == "mtbench" or args.mtbench is not None:
+        metrics_list.insert(0, metrics.MTBench())
+    elif args.dataset in ("specbench", "speed") or args.specbench is not None:
+        metrics_list.insert(0, metrics.SpecBench(requests=dataset.data))

341-352: Cross-assignment of args.specbench and args.mtbench from args.dataset_path is a leaky abstraction.

These lines (args.specbench = args.dataset_path, args.mtbench = args.dataset_path) exist solely to make the metrics selection on lines 167–170 work. This couples the new --dataset interface to the legacy flag-based interface. If the metrics condition is updated per the earlier suggestion to check args.dataset directly, this cross-assignment can be removed.

examples/specdec_bench/specdec_bench/datasets/speed.py (2)

340-344: Redundant re-encoding on line 343.

tokens is already computed on line 338 and current_num_tokens derived from it. Line 343 re-encodes the same prompt string unnecessarily.

Proposed fix
     def _pad_or_truncate_prompt(
         prompt: str, target_num_tokens: int, padding: str = "Answer now please.\n"
     ) -> str:
         encoder = tiktoken.get_encoding("o200k_base")
 
         tokens = encoder.encode(prompt, disallowed_special=())
         current_num_tokens = len(tokens)
 
         if current_num_tokens > target_num_tokens:
             # Truncate if too long
-            tokens = encoder.encode(prompt, disallowed_special=())
             return encoder.decode(tokens[:target_num_tokens])

716-753: Confusing variable name config_name_or_dataset_path_path.

The extra _path suffix on line 735 makes it awkward to read. Consider something like dataset_path or config_path.

```

```bash
python3 run.py --model_dir meta-llama/Llama-3.3-70B-Instruct --tokenizer meta-llama/Llama-3.3-70B-Instruct --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B --dataset speed --dataset_path data/speed/throughput_16k --tp_size 8 --ep_size 1 --draft_length 3 --output_length 4096 --engine TRTLLM --concurrency 32 --show_progress --runtime_params runtime_args_long_context.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if runtime_args_long_context.yaml exists
if [ -f "examples/specdec_bench/runtime_args_long_context.yaml" ]; then
  echo "✓ runtime_args_long_context.yaml exists"
  echo "Contents:"
  cat examples/specdec_bench/runtime_args_long_context.yaml
else
  echo "✗ runtime_args_long_context.yaml not found at examples/specdec_bench/"
  # Search for the file in the repository
  echo "Searching for runtime_args_long_context.yaml..."
  find . -name "runtime_args_long_context.yaml" 2>/dev/null
  # Also check what files exist in the examples/specdec_bench directory
  echo "Files in examples/specdec_bench/:"
  ls -la examples/specdec_bench/ 2>/dev/null | head -20
fi

Repository: NVIDIA/Model-Optimizer

Length of output: 724


Remove or provide the missing runtime_args_long_context.yaml file.

The README references --runtime_params runtime_args_long_context.yaml, but this file does not exist in the repository. Either create the file with appropriate configuration values or remove the parameter from the example command.

🤖 Prompt for AI Agents
In `@examples/specdec_bench/README.md` at line 84, The README example references a
missing runtime_params file; either add a new runtime_args_long_context.yaml
with sensible runtime keys (e.g., max_sequence_length, batch_size,
memory_pool_size, device_map, and any TRTLLM-specific settings) and commit it,
or remove the `--runtime_params runtime_args_long_context.yaml` flag from the
example command in the README (the command shown passed to run.py) so the
example works as-is; update README to reflect the chosen approach and mention
where to find or how to customize the runtime params if you add the YAML.

Comment on lines +138 to +148
if args.num_requests is not None:
dataset_kwargs["num_samples"] = args.num_requests
if args.dataset is not None:
dataset = datasets_available[args.dataset](args.dataset_path, **dataset_kwargs)
else:
if args.mtbench is not None:
dataset = datasets.MTBench(args.mtbench, **dataset_kwargs)
elif args.random_isl is not None:
dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs)
elif args.specbench is not None:
dataset = datasets.SpecBench(args.specbench, **dataset_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--dataset random will break: RandomToken expects a tokenizer, not a path.

When args.dataset == "random", line 141 calls datasets.RandomToken(args.dataset_path, **dataset_kwargs), but RandomToken.__init__ expects a tokenizer object as its first argument (see the legacy path on line 146 which correctly passes tokenizer). This will fail at runtime.

Consider either special-casing "random" in the --dataset path (similar to lines 145–146), or removing it from datasets_available since it can't be used through the uniform interface.

🤖 Prompt for AI Agents
In `@examples/specdec_bench/run.py` around lines 138 - 148, When args.dataset ==
"random" the current uniform lookup via datasets_available passes
args.dataset_path into RandomToken but RandomToken.__init__ requires a
tokenizer; update the branch that handles args.dataset (the datasets_available
lookup) to special-case "random" and instantiate datasets.RandomToken(tokenizer,
args.dataset_path, **dataset_kwargs) (or remove "random" from datasets_available
if you prefer the latter) so RandomToken is constructed with a tokenizer object
rather than a path; ensure you reference the RandomToken class and
datasets_available map when making the change.

Comment on lines +20 to +37
```python
class Model:
def __init__(self, model_dir, tokenizer, max_draft_length):
raise NotImplementedError

async def run(self, prompt_ids, sampling_params, request_id, turn_id):
"""
prompt_ids: list of token IDs (not a tensor!)
Returns dict with:
- output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]]
- output_logits: optional logits (usually None)
- token_times: list of timestamps per decoding step
"""
raise NotImplementedError

def stop(self):
pass
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Interface definition contradicts implementation examples within the same document.

The Model interface at Line 25 declares run(self, prompt_ids, sampling_params, request_id, turn_id), but the Step 4 implementation example (Line 170) and the testing example (Lines 305-310) both use the old signature run(self, prompt_ids, max_length, end_id, request_id, turn_id). This inconsistency will confuse anyone following the porting guide.

Align all examples with whichever signature is chosen as the canonical API.

🤖 Prompt for AI Agents
In `@examples/specdec_bench/SPECBENCH_PORTING.md` around lines 20 - 37, The Model
interface's run signature (Model.run(self, prompt_ids, sampling_params,
request_id, turn_id)) conflicts with older examples that use run(self,
prompt_ids, max_length, end_id, request_id, turn_id); pick the canonical API and
make all examples consistent: either update the Step 4 implementation example
and the testing example to call/run and implement Model.run with (prompt_ids,
sampling_params, request_id, turn_id) and adjust their internals and docstrings
to read sampling_params (containing max_length/end_id), or change the interface
to the older signature and update the docstring to match; ensure Model.__init__,
Model.run, the Step 4 implementation example, and the test example all use the
same parameter names and types (prompt_ids list, and a single sampling_params or
explicit max_length/end_id) and that returned dict keys and token type
expectations remain unchanged.

Comment on lines +56 to +70
reformatted_output_ids = [
[] for _ in range(self.sampling_kwargs.get("beam_width", 1))
]
for beam_idx, beam_len in enumerate(beam_lens):
response = outputs[-1][beam_idx]
if beam_len[0] != 0:
reformatted_output_ids[beam_idx].append(
response.token_ids[: beam_len[0]]
)
for s, e in zip(beam_len[:-1], beam_len[1:]):
reformatted_output_ids[beam_idx].append(response.token_ids[s:e])
if len(response.token_ids) > beam_len[-1]:
reformatted_output_ids[beam_idx].append(
response.token_ids[beam_len[-1] :]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

IndexError if generate_async yields no output.

If generate_async completes without yielding any items, outputs is empty and outputs[-1] at line 60 will raise an IndexError. Similarly, beam_len lists will be empty, causing beam_len[0] at line 61 to crash. Add an early return or guard for this edge case.

Proposed fix
+        if not outputs:
+            output_dict["output_ids"] = [
+                [] for _ in range(self.sampling_kwargs.get("beam_width", 1))
+            ]
+            output_dict["output_logits"] = None
+            output_dict["token_times"] = timing
+            return output_dict
+
         reformatted_output_ids = [
             [] for _ in range(self.sampling_kwargs.get("beam_width", 1))
         ]
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/auto_deploy.py` around lines 56 -
70, The loop assumes generate_async produced outputs and non-empty beam_lens;
add a guard before using outputs[-1] and indexing beam_len: check if outputs is
empty or beam_lens is empty and return an appropriate empty structure (e.g.,
reformatted_output_ids or []) early, and inside the for loop skip/continue any
beam_len that is empty to avoid beam_len[0] / beam_len[-1] access; update
references in this block (outputs, beam_lens, reformatted_output_ids, beam_idx,
response) accordingly so no IndexError occurs when generate_async yields
nothing.

Comment on lines +38 to +57
import sys
import os

spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench")
sys.path.insert(0, spec_bench_path)
from model.medusa.medusa_model import MedusaModel
from model.medusa.kv_cache import initialize_past_key_values
from model.medusa.utils import (
generate_medusa_buffers,
initialize_medusa,
reset_medusa_mode,
generate_candidates,
tree_decoding,
evaluate_posterior,
update_inference_inputs,
)
from model.medusa.medusa_choices import mc_sim_7b_63
except ImportError as e:
print(f"Medusa dependencies not found: {e}")
MedusaModel = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

sys.path manipulation with a hardcoded CWD-relative path is fragile.

os.path.join(os.getcwd(), "Spec-Bench") at line 41 assumes the process is always launched from the repo root. This breaks when the working directory differs (e.g., CI, notebooks, installed package). It also permanently prepends to sys.path, risking module-name shadowing (e.g., the generic model package on line 43).

Consider making the Spec-Bench path configurable via an environment variable or constructor argument.

🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/specbench_medusa.py` around lines
38 - 57, Replace the brittle cwd-based sys.path insertion with a configurable
and robust lookup: read SPEC_BENCH_PATH from an environment variable first
(fallback to a path computed relative to this file using
os.path.dirname(__file__) and joining "..", "Spec-Bench"), and only modify
sys.path if that resolved path exists and is not already present (avoid blindly
inserting at index 0 to prevent module shadowing). Update the import block that
brings in MedusaModel, initialize_past_key_values, initialize_medusa,
generate_medusa_buffers, etc., to use that resolved path so tests/CI/notebook
runs can override via SPEC_BENCH_PATH and to avoid permanent, unsafe sys.path
manipulation.

Comment on lines +99 to +107
extra_params = {}
if "allow_advanced_sampling" in EagleDecodingConfig.model_fields:
extra_params["allow_advanced_sampling"] = kwargs.get(
"allow_advanced_sampling", False
)
elif "allow_advanced_sampling" in kwargs:
print(
f"WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version: {trtllm.__version__}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

EagleDecodingConfig.model_fields access can raise AttributeError.

If EagleDecodingConfig doesn't have a model_fields attribute (e.g., older tensorrt_llm version without Pydantic v2), line 100 will raise AttributeError before the elif fallback at line 104 is ever reached. Additionally, line 106 references trtllm.__version__ on tensorrt_llm.bindings.executor, which may not expose a __version__ attribute.

Proposed fix
-        if "allow_advanced_sampling" in EagleDecodingConfig.model_fields:
+        if hasattr(EagleDecodingConfig, "model_fields") and "allow_advanced_sampling" in EagleDecodingConfig.model_fields:
             extra_params["allow_advanced_sampling"] = kwargs.get(
                 "allow_advanced_sampling", False
             )
         elif "allow_advanced_sampling" in kwargs:
             print(
-                f"WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version: {trtllm.__version__}"
+                "WARNING: allow_advanced_sampling was set but not supported for this tensorrt_llm version."
             )
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py` around lines
99 - 107, EagleDecodingConfig.model_fields access can raise AttributeError and
trtllm.__version__ may not exist; change the logic in the block around
EagleDecodingConfig.model_fields to first check safely with hasattr/getattr
(e.g., fields = getattr(EagleDecodingConfig, "model_fields", None)) and if
fields is truthy use the existing membership check to set
extra_params["allow_advanced_sampling"] from kwargs, otherwise fall back to
checking "allow_advanced_sampling" in kwargs and emit a warning that does not
assume trtllm.__version__ (use getattr(trtllm, "__version__", "unknown version")
or omit version), referencing the EagleDecodingConfig.model_fields,
extra_params, kwargs, and trtllm.__version__ symbols so the change is made in
the same function.

Comment on lines +24 to 25
async def run(self, prompt_ids, max_length, end_id, sampling_kwargs):
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find all files in the runners directory
find examples/specdec_bench/specdec_bench/runners -name "*.py" -type f | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 230


🏁 Script executed:

#!/bin/bash
# Search for SimpleRunner class definition and its run method
rg -A 30 "class SimpleRunner" examples/specdec_bench/specdec_bench/runners/

Repository: NVIDIA/Model-Optimizer

Length of output: 2177


🏁 Script executed:

#!/bin/bash
# Search for all async def run methods in runner files
rg -B 2 -A 5 "async def run" examples/specdec_bench/specdec_bench/runners/

Repository: NVIDIA/Model-Optimizer

Length of output: 1539


🏁 Script executed:

#!/bin/bash
# Check if there are any other BaseRunner subclasses
rg "class.*\(BaseRunner\)" examples/specdec_bench/specdec_bench/runners/

Repository: NVIDIA/Model-Optimizer

Length of output: 152


🏁 Script executed:

#!/bin/bash
# Check the __init__.py to see what's exported
cat examples/specdec_bench/specdec_bench/runners/__init__.py

Repository: NVIDIA/Model-Optimizer

Length of output: 806


Update SimpleRunner.run() signature to match base class.

The run() method signature in SimpleRunner still uses (prompt_ids, max_length, end_id, request_id, turn_id), but BaseRunner.run() now expects (prompt_ids, max_length, end_id, sampling_kwargs). Update the signature and extract request_id and turn_id from sampling_kwargs where needed for the model.run() call and process_metrics_step().

🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/runners/base.py` around lines 24 - 25,
SimpleRunner.run currently has the old signature; update its method signature to
match BaseRunner.run(prompt_ids, max_length, end_id, sampling_kwargs), then pull
request_id and turn_id from sampling_kwargs (e.g., request_id =
sampling_kwargs.get("request_id")) before calling model.run and before calling
process_metrics_step; ensure model.run is invoked with those extracted
request_id and turn_id where it previously used the positional args and pass
sampling_kwargs (or remaining sampling params) to the model call as appropriate
so behavior is unchanged.

Comment on lines +24 to 32
def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
if completions:
return tokenizer.encode(messages[-1]["content"], add_special_tokens=False)
return tokenizer.encode(
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, **chat_template_args
),
add_special_tokens=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mutable default argument chat_template_args={} — classic Python pitfall.

Using a mutable dict as a default argument means all callers share the same object. If any code path mutates it (e.g., chat_template_args["key"] = val), the mutation persists across subsequent calls.

Proposed fix
-def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
+def encode_chat(tokenizer, messages, chat_template_args=None, completions=False):
+    if chat_template_args is None:
+        chat_template_args = {}
     if completions:
         return tokenizer.encode(messages[-1]["content"], add_special_tokens=False)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def encode_chat(tokenizer, messages, chat_template_args={}, completions=False):
if completions:
return tokenizer.encode(messages[-1]["content"], add_special_tokens=False)
return tokenizer.encode(
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, **chat_template_args
),
add_special_tokens=False,
)
def encode_chat(tokenizer, messages, chat_template_args=None, completions=False):
if chat_template_args is None:
chat_template_args = {}
if completions:
return tokenizer.encode(messages[-1]["content"], add_special_tokens=False)
return tokenizer.encode(
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, **chat_template_args
),
add_special_tokens=False,
)
🤖 Prompt for AI Agents
In `@examples/specdec_bench/specdec_bench/utils.py` around lines 24 - 32, The
function encode_chat uses a mutable default chat_template_args={} which can lead
to shared-state bugs; change the signature to default chat_template_args=None
and inside encode_chat set a local variable (e.g., args = {} if
chat_template_args is None else dict(chat_template_args)) so callers' dicts
aren’t mutated and you still pass args into tokenizer.apply_chat_template; keep
the completions branch behavior unchanged.

@codecov
Copy link

codecov bot commented Feb 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.44%. Comparing base (5e43b2a) to head (e748f73).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #875   +/-   ##
=======================================
  Coverage   73.44%   73.44%           
=======================================
  Files         197      197           
  Lines       20657    20657           
=======================================
  Hits        15172    15172           
  Misses       5485     5485           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yeyu-nvidia
Copy link
Contributor

Can you ask some description on what this PR is for?

@IzzyPutterman
Copy link
Contributor Author

Can you ask some description on what this PR is for?

Added info to the top!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants