Skip to content

Integrate Automated QDQ autotuner - part 3.2#838

Open
willg-nv wants to merge 5 commits intoNVIDIA:mainfrom
willg-nv:dev-willg-integrate-auto-qdq-placement-part3.2
Open

Integrate Automated QDQ autotuner - part 3.2#838
willg-nv wants to merge 5 commits intoNVIDIA:mainfrom
willg-nv:dev-willg-integrate-auto-qdq-placement-part3.2

Conversation

@willg-nv
Copy link
Contributor

@willg-nv willg-nv commented Feb 2, 2026

What does this PR do?

This PR implements QDQAutotuner class. This class is used to drive the main Autotuner workflow.

The workflow is:

  1. uses RegionSearch to build regions
  2. generate QDQ ONNX models and evaluate perf
  3. save best model

This PR is part 2/4 of #703.

PR 3.1: #837
PR 3.2 #838
PR 3.3: #839

Overview: ?

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Not in this part.
  • Did you add or update any necessary documentation?: No, document will be updated in part 4.
  • Did you update Changelog?: No, change log will be updated when all changes are ready.

Additional Information

Summary by CodeRabbit

  • New Features
    • Introduced ONNX Q/DQ autotuning framework with automatic region discovery and pattern-based optimization.
    • Added model profiling and quantization scheme generation capabilities.
    • Enabled state persistence and quantization model export functionality.
    • Introduced configuration management for quantization parameters and profiling workflows.

✏️ Tip: You can customize this high-level summary in your review settings.

@willg-nv willg-nv requested a review from a team as a code owner February 2, 2026 02:58
@willg-nv willg-nv requested a review from ajrasane February 2, 2026 02:58
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces a new ONNX quantization autotuning module that enables automatic Q/DQ (Quantize/Dequantize) node insertion and optimization using pattern-based region analysis. Provides a comprehensive framework for discovering optimal insertion points, profiling schemes, and exporting quantized models.

Changes

Cohort / File(s) Summary
Module Initialization
modelopt/onnx/quantization/autotune/__init__.py
Exposes public API surface: QDQAutotuner class, configuration/exception types (Config, InsertionScheme, PatternSchemes, RegionType), insertion point abstractions, and utility classes (PatternCache, Region, RegionPattern, CombinedRegionSearch).
Core Autotuner Implementation
modelopt/onnx/quantization/autotune/autotuner.py
Implements QDQAutotunerBase and QDQAutotuner with region discovery, pattern-based Q/DQ insertion logic, profiling workflow, state management, graph mutation, insertion point resolution, and ONNX export capabilities. Supports scheme generation, convergence tracking, FP8 conversion, and pattern cache integration.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Autotuner as QDQAutotuner
    participant RegionSearch as CombinedRegionSearch
    participant Profiler as Profiling System
    participant Inserter as Q/DQ Insertion
    participant Exporter as ONNX Exporter

    User->>Autotuner: initialize(config, pattern_cache)
    Autotuner->>Autotuner: Load model & init state

    User->>RegionSearch: discover regions
    RegionSearch-->>Autotuner: return regions

    loop For each region
        User->>Autotuner: set_profile_region(region)
        Autotuner->>Autotuner: Commit profiling outcomes
        Autotuner->>Profiler: Prepare region-pattern pairs
        
        loop Generate candidates
            User->>Autotuner: generate()
            Autotuner->>Inserter: Build insertion scheme
            Inserter->>Inserter: Insert Q/DQ nodes
            User->>Autotuner: submit(latency_ms)
            Autotuner->>Autotuner: Track performance metrics
        end
    end

    User->>Autotuner: export_onnx(best=True)
    Autotuner->>Inserter: Apply best scheme
    Inserter->>Exporter: Finalize Q/DQ graph
    Exporter-->>User: return quantized ONNX bytes
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Integrate Automated QDQ autotuner - part 3.2' accurately describes the PR's main objective: integrating the QDQAutotuner class implementation into the codebase as part 3.2 of a larger feature series.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@modelopt/onnx/quantization/autotune/autotuner.py`:
- Around line 1024-1029: The try/except around graph.cleanup().toposort()
swallows all exceptions (except Exception as e) and merely logs a warning, which
can hide serious graph corruption; update the handler in autotuner.py to either
catch only expected exception types (e.g., specific cleanup/toposort exceptions)
or log the error and re-raise it so execution stops on unexpected failures —
locate the graph.cleanup().toposort() call and replace the broad except with
either a narrowed except for known recoverable exceptions or add a raise after
logger.warning/failure log to propagate the error.
- Line 622: Remove the redundant local import "from datetime import datetime"
(the one added at line with the single import statement) in autotuner.py; the
module already imports datetime at the top of the file, so delete this local
import to avoid duplication and potential shadowing (look for the statement
"from datetime import datetime" inside the function or block and remove it).
- Around line 912-918: The zero-point arrays q_zp_values (and the corresponding
dq_zp_values) are created with a hardcoded dtype np.int8 which can mismatch the
QuantizeLinear/DequantizeLinear output type when quant_type is "uint8" or other
types; update their construction to use the same dtype as the computed
quant_dtype instead of np.int8 so q_zp_values and dq_zp_values match the
quantized output element type used when building q_inputs and dq_inputs (refer
to q_scale_values, q_zp_values, q_inputs and the corresponding dq_* variables to
locate where to change the dtype).
- Around line 1013-1021: The import of get_tensor_consumer_node_indices is wrong
and causes an import error; replace that import with get_tensor_consumer_nodes
and update any usage names accordingly (the code that uses tensor_users_map
already expects a defaultdict(list) so no KeyError handling is needed).
Specifically, change the symbol imported from
modelopt.onnx.quantization.graph_utils from get_tensor_consumer_node_indices to
get_tensor_consumer_nodes and ensure tensor_users_map is assigned from
get_tensor_consumer_nodes(...) where used in the autotuner (references:
get_tensor_consumer_node_indices, get_tensor_consumer_nodes, tensor_users_map).
🧹 Nitpick comments (4)
modelopt/onnx/quantization/autotune/autotuner.py (4)

229-229: Consider defining config attributes explicitly.

Using getattr(self.config, "maximum_generation_attempts", 100) with defaults (also seen at lines 718-719 and 744) suggests these attributes may not be formally defined on the Config class. This pattern makes it harder to discover available configuration options.

💡 Suggestion

Consider adding these attributes to the Config class with documented defaults rather than relying on getattr fallbacks:

# In Config class
maximum_generation_attempts: int = 100
top_percent_to_mutate: float = 0.1
minimum_schemes_to_mutate: int = 1
maximum_mutations: int = 3

333-335: Replace assertions with explicit checks for runtime validation.

Assertions on lines 333-335 (and similarly at line 314) are used for validating runtime conditions. Since assertions can be disabled with python -O, these should be explicit checks for production code.

🛡️ Proposed fix
-                full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph)
-                assert full_insertion_scheme is not None
-                all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme)
-                assert isinstance(all_region_ips, set)
+                full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph)
+                if full_insertion_scheme is None:
+                    logger.warning(f"Failed to get full insertion scheme for region {region.id}")
+                    continue
+                all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme)
+                if not isinstance(all_region_ips, set):
+                    raise TypeError(f"Expected set from pattern.matches, got {type(all_region_ips)}")

972-985: Assertions used for critical runtime validation.

These assertions validate critical invariants (node index bounds, input index bounds, tensor name matching) but can be disabled with python -O. Consider using explicit checks with ValueError/IndexError for production safety.

🛡️ Proposed fix
             if node_index is not None:
-                assert node_index < len(graph.nodes), "Node index out of range"
+                if node_index >= len(graph.nodes):
+                    raise IndexError(f"Node index {node_index} out of range (max: {len(graph.nodes) - 1})")
                 target_node = graph.nodes[node_index]
-                assert input_index is not None, "Input index must be set when node index is set"
-                assert input_index < len(target_node.inputs), (
-                    f"Input index out of range for node {target_node.name}"
-                )
+                if input_index is None:
+                    raise ValueError("Input index must be set when node index is set")
+                if input_index >= len(target_node.inputs):
+                    raise IndexError(f"Input index {input_index} out of range for node {target_node.name}")
                 original_tensor = target_node.inputs[input_index]
-                assert tensor_name == original_tensor.name, (
-                    f"Tensor name mismatch for node {target_node.name} input {input_index}"
-                )
+                if tensor_name != original_tensor.name:
+                    raise ValueError(f"Tensor name mismatch: expected '{tensor_name}', got '{original_tensor.name}'")
             else:
-                assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map"
-                assert input_index is None, "Input index must be None when node index is None"
+                if tensor_name not in tensor_map:
+                    raise KeyError(f"Tensor {tensor_name} not found in tensor map")
+                if input_index is not None:
+                    raise ValueError("Input index must be None when node index is None")

1042-1049: Consider iterative approach for deep region hierarchies.

_visit_region_recursively uses recursion which could hit Python's stack limit for very deep region hierarchies. While this is unlikely for typical ONNX models, an iterative approach would be more robust.

♻️ Iterative alternative
def _visit_region_recursively(self, region: Region) -> list[Region]:
    """Iteratively traverse region hierarchy and collect all regions."""
    regions = []
    stack = [region]
    while stack:
        current = stack.pop()
        regions.append(current)
        stack.extend(current.get_children())
    return regions

Comment on lines +1013 to +1021
else:
users = tensor_users_map[tensor_name]
for user_index in users:
user_node = graph.nodes[user_index]
for i, input_tensor in enumerate(user_node.inputs):
if hasattr(input_tensor, "name") and input_tensor.name == tensor_name:
user_node.inputs[i] = dq_node.outputs[0]
break
logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's read the file around line 1014 to understand the context
sed -n '990,1030p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2033


🏁 Script executed:

# Look for how tensor_users_map is created and populated
rg -n -B 5 -A 10 'tensor_users_map' modelopt/onnx/quantization/autotune/autotuner.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 2593


🏁 Script executed:

# Search for get_tensor_consumer_node_indices to understand the mapping logic
rg -n -A 20 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/autotune/autotuner.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for get_tensor_consumer_node_indices in the entire codebase
rg -n 'def get_tensor_consumer_node_indices' --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also look for the import of get_tensor_consumer_node_indices in autotuner.py
rg -n 'get_tensor_consumer_node_indices|from.*import' modelopt/onnx/quantization/autotune/autotuner.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 832


🏁 Script executed:

# Find the definition of get_tensor_consumer_node_indices
rg -n -A 20 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for the graph_utils file
fd 'graph_utils.py' --type f

Repository: NVIDIA/Model-Optimizer

Length of output: 107


🏁 Script executed:

# Now read the function definition from graph_utils.py
sed -n '1,50p' modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2180


🏁 Script executed:

# Search for get_tensor_consumer_node_indices in graph_utils.py
rg -n -A 25 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for get_tensor_consumer_node_indices without the specific line number restriction
rg 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's get the line number and then read that section
rg -n 'get_tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for get_tensor_consumer_node_indices in the entire codebase
rg -l 'def get_tensor_consumer_node_indices' --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also search the common.py file since it was imported
rg -n -A 25 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/autotune/common.py

Repository: NVIDIA/Model-Optimizer

Length of output: 151


🏁 Script executed:

# Let's find the exact line number and read that section of graph_utils.py
wc -l modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 112


🏁 Script executed:

# Search for the function with grep to get line numbers
grep -n 'def get_tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# List the imports in autotuner.py to see what is imported from common
sed -n '30,45p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 687


🏁 Script executed:

# Search for get_tensor_consumer_node_indices more broadly with a case-insensitive search
grep -i 'tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's try with rg using a different pattern
rg 'tensor_consumer_node_indices' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Maybe it's defined with a different name, let's search for "consumer" in the file
rg -n 'consumer' modelopt/onnx/quantization/graph_utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1489


🏁 Script executed:

# Search for get_tensor_consumer_node_indices across entire repo
rg 'def get_tensor_consumer_node_indices' --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's check if it's an alias or imported from somewhere else
rg 'get_tensor_consumer_node_indices' modelopt/ --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 446


🏁 Script executed:

# Let's search more carefully in graph_utils.py - maybe it's a wrapper or has different indentation
grep -n 'consumer_node_indices\|def.*consumer' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

# Let's look at the entire graph_utils.py file to find all function definitions
rg -n '^def ' modelopt/onnx/quantization/graph_utils.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1610


🏁 Script executed:

# Maybe the function is aliased or it's in a different form
grep -n 'tensor_users_map\|def.*tensor' modelopt/onnx/quantization/graph_utils.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 603


🏁 Script executed:

# Read the get_tensor_consumer_nodes function to understand what it returns
sed -n '283,310p' modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1130


🏁 Script executed:

# Let's also check if there's an __all__ export or if the function is imported from somewhere else in graph_utils
head -50 modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2180


🏁 Script executed:

# Search for any assignment or alias that could be get_tensor_consumer_node_indices
rg 'get_tensor_consumer_node_indices\s*=' modelopt/onnx/quantization/graph_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if get_tensor_consumer_node_indices is imported from gs (onnx_graphsurgeon)
rg -n 'from.*import.*get_tensor_consumer_node_indices' modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 201


🏁 Script executed:

# Let's check the end of graph_utils.py to see if there's any aliasing
tail -50 modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2434


🏁 Script executed:

# Search in the middle sections of graph_utils.py
sed -n '1450,1550p' modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4500


🏁 Script executed:

# Let's check if get_tensor_consumer_node_indices exists anywhere in the repo
find . -name "*.py" -type f -exec grep -l "def get_tensor_consumer_node_indices" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check the exact import line and what it should be importing
sed -n '40,50p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 601


🏁 Script executed:

# Let's also check if the function exists by searching more broadly
rg 'get_tensor_consumer' --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1271


🏁 Script executed:

# Check if there's any definition of get_tensor_consumer_node_indices anywhere
rg 'def get_tensor_consumer_node_indices|get_tensor_consumer_node_indices\s*=' --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's look at what get_tensor_consumer_nodes returns and understand the data structure
sed -n '283,302p' modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 842


🏁 Script executed:

# Let's check the line in autotuner.py where tensor_users_map is used after assignment
sed -n '960,975p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 967


🏁 Script executed:

# The import says get_tensor_consumer_node_indices but only get_tensor_consumer_nodes exists
# Let's search for any wrapper or check if it might be defined somewhere we haven't looked
find . -name "*.py" -type f -exec grep -l "get_tensor_consumer_node_indices" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 116


🏁 Script executed:

# Let's check if there's a conftest or __init__ that might have this function
find . -name "__init__.py" -exec grep -l "get_tensor_consumer_node_indices" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's check the graph_utils.py file around where get_tensor_consumer_nodes is defined for any related functions
sed -n '270,350p' modelopt/onnx/quantization/graph_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 3189


🏁 Script executed:

# Let's check the graph object used in autotuner.py - it might be using onnx_graphsurgeon which has different API
sed -n '50,70p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1204


🏁 Script executed:

# Let's look at what Graph object is being used
rg -n 'from.*import.*Graph|import.*Graph' modelopt/onnx/quantization/autotune/autotuner.py | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check the actual usage context - what graph object is passed to get_tensor_consumer_node_indices
sed -n '950,965p' modelopt/onnx/quantization/autotune/autotuner.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 796


Import error: get_tensor_consumer_node_indices does not exist.

The function get_tensor_consumer_node_indices imported at line 47 is not defined in modelopt.onnx.quantization.graph_utils. The available function is get_tensor_consumer_nodes, which returns a defaultdict(list).

If get_tensor_consumer_nodes is the intended function, the KeyError concern is invalid—defaultdict returns an empty list for missing keys, so line 1014 would not raise an exception. However, the import itself must be corrected to resolve the module-level import error.

🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/autotune/autotuner.py` around lines 1013 - 1021,
The import of get_tensor_consumer_node_indices is wrong and causes an import
error; replace that import with get_tensor_consumer_nodes and update any usage
names accordingly (the code that uses tensor_users_map already expects a
defaultdict(list) so no KeyError handling is needed). Specifically, change the
symbol imported from modelopt.onnx.quantization.graph_utils from
get_tensor_consumer_node_indices to get_tensor_consumer_nodes and ensure
tensor_users_map is assigned from get_tensor_consumer_nodes(...) where used in
the autotuner (references: get_tensor_consumer_node_indices,
get_tensor_consumer_nodes, tensor_users_map).

@willg-nv willg-nv force-pushed the dev-willg-integrate-auto-qdq-placement-part3.2 branch 2 times, most recently from b5032ed to 1ffcf7f Compare February 3, 2026 01:55
Signed-off-by: Will Guo <willg@nvidia.com>
Signed-off-by: Will Guo <willg@nvidia.com>
@willg-nv willg-nv force-pushed the dev-willg-integrate-auto-qdq-placement-part3.2 branch from 1ffcf7f to bd18dfa Compare February 9, 2026 08:36
Signed-off-by: Will Guo <willg@nvidia.com>
Signed-off-by: Will Guo <willg@nvidia.com>
Signed-off-by: Will Guo <willg@nvidia.com>
@ajrasane
Copy link
Contributor

/ok to test b02fef1

Comment on lines +158 to +161
if not self.initialized:
raise AutotunerNotInitializedError(
"QDQAutotunerBase not initialized. Call initialize() first."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create this:

def _requires_init(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if not self.initialized:
            raise AutotunerNotInitializedError(
                f"{type(self).__name__} not initialized. Call initialize() first."
            )
        return method(self, *args, **kwargs)
    return wrapper

And call it before the respective methods.
This can be replaced in 5 other instances.

Comment on lines +908 to +932
scheme = InsertionScheme()
base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs}
scheme.node_inputs = self._mutate_insertion_points(
base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations
)

base_region_composite_points = {
(p.region_index, p.input_index) for p in base_scheme.child_region_inputs
}
scheme.child_region_inputs = self._mutate_insertion_points(
base_region_composite_points,
full_insertion_scheme.child_region_inputs,
"region composite points",
max_mutations,
)

base_region_output_points = {
(p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs
}
scheme.region_outputs = self._mutate_insertion_points(
base_region_output_points,
full_insertion_scheme.region_outputs,
"region output points",
max_mutations,
)
Copy link
Contributor

@ajrasane ajrasane Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scheme = InsertionScheme()
base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs}
scheme.node_inputs = self._mutate_insertion_points(
base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations
)
base_region_composite_points = {
(p.region_index, p.input_index) for p in base_scheme.child_region_inputs
}
scheme.child_region_inputs = self._mutate_insertion_points(
base_region_composite_points,
full_insertion_scheme.child_region_inputs,
"region composite points",
max_mutations,
)
base_region_output_points = {
(p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs
}
scheme.region_outputs = self._mutate_insertion_points(
base_region_output_points,
full_insertion_scheme.region_outputs,
"region output points",
max_mutations,
)
_MUTATION_SPECS = [
("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)),
("child_region_inputs", "region composite points", lambda p: (p.region_index, p.input_index)),
("region_outputs", "region output points", lambda p: (p.region_index, p.node_index, p.output_index)),
]
scheme = InsertionScheme()
for attr, label, key_fn in _MUTATION_SPECS:
base_keys = {key_fn(p) for p in getattr(base_scheme, attr)}
setattr(scheme, attr, self._mutate_insertion_points(
base_keys, getattr(full_insertion_scheme, attr), label, max_mutations
))

With this, you can also remove the key_fn inside _mutate_insertion_points

Comment on lines +366 to +429
for region in self.regions:
pattern = RegionPattern.from_region(region, self.graph)
logger.debug(f"Region {region.id} (level {region.level})")
logger.debug(f" → Pattern signature: {pattern.signature}")

matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None)
current_scheme = matched.best_scheme if matched else None

if matched:
if current_scheme:
logger.debug(
f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)"
)
else:
logger.debug(" → Matched profiled pattern but no valid schemes")

if current_scheme is None:
current_scheme = self.current_profile_pattern_schemes
if current_scheme is None or pattern != current_scheme.pattern:
pass
elif best:
current_scheme = current_scheme.best_scheme
else:
scheme_index = self.current_insertion_scheme_index
if scheme_index is not None:
assert scheme_index < len(current_scheme.schemes), (
f"Invalid scheme index: {scheme_index}"
)
current_scheme = current_scheme.schemes[scheme_index]
logger.debug(f" → Using current pattern scheme #{scheme_index}")

if current_scheme is None and self.pattern_cache is not None:
pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature)
if pattern_schemes is not None:
schemes = pattern_schemes.schemes
if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled:
current_scheme = schemes[0]
logger.debug(" → Using imported pattern from cache")

if current_scheme is None:
logger.debug(" → No scheme available, skipping")
continue

full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph)
assert full_insertion_scheme is not None
all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme)
assert isinstance(all_region_ips, set)
resolved_insertion_points.difference_update(all_region_ips)
excluded_tensors = all_region_ips - resolved_insertion_points
if excluded_tensors:
logger.debug(
f" → Excluded {len(excluded_tensors)} overlapping insertion points"
)

new_ips = pattern.matches(region, self.graph, current_scheme)
if new_ips:
resolved_insertion_points.update(new_ips)
matched_regions += 1
logger.debug(f" → Added {len(new_ips)} insertion points")

logger.debug(
f"Matched {matched_regions}/{len(self.regions)} regions, "
f"total {len(resolved_insertion_points)} unique insertion points"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for region in self.regions:
pattern = RegionPattern.from_region(region, self.graph)
logger.debug(f"Region {region.id} (level {region.level})")
logger.debug(f" → Pattern signature: {pattern.signature}")
matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None)
current_scheme = matched.best_scheme if matched else None
if matched:
if current_scheme:
logger.debug(
f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)"
)
else:
logger.debug(" → Matched profiled pattern but no valid schemes")
if current_scheme is None:
current_scheme = self.current_profile_pattern_schemes
if current_scheme is None or pattern != current_scheme.pattern:
pass
elif best:
current_scheme = current_scheme.best_scheme
else:
scheme_index = self.current_insertion_scheme_index
if scheme_index is not None:
assert scheme_index < len(current_scheme.schemes), (
f"Invalid scheme index: {scheme_index}"
)
current_scheme = current_scheme.schemes[scheme_index]
logger.debug(f" → Using current pattern scheme #{scheme_index}")
if current_scheme is None and self.pattern_cache is not None:
pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature)
if pattern_schemes is not None:
schemes = pattern_schemes.schemes
if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled:
current_scheme = schemes[0]
logger.debug(" → Using imported pattern from cache")
if current_scheme is None:
logger.debug(" → No scheme available, skipping")
continue
full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph)
assert full_insertion_scheme is not None
all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme)
assert isinstance(all_region_ips, set)
resolved_insertion_points.difference_update(all_region_ips)
excluded_tensors = all_region_ips - resolved_insertion_points
if excluded_tensors:
logger.debug(
f" → Excluded {len(excluded_tensors)} overlapping insertion points"
)
new_ips = pattern.matches(region, self.graph, current_scheme)
if new_ips:
resolved_insertion_points.update(new_ips)
matched_regions += 1
logger.debug(f" → Added {len(new_ips)} insertion points")
logger.debug(
f"Matched {matched_regions}/{len(self.regions)} regions, "
f"total {len(resolved_insertion_points)} unique insertion points"
)
def _resolve_scheme_for_region(self, region, pattern, best: bool) -> InsertionScheme | None:
"""Find the best applicable scheme for a region."""
# 1. Try profiled patterns
matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None)
if matched and matched.best_scheme:
return matched.best_scheme
# 2. Try current profile pattern
if self.current_profile_pattern_schemes and pattern == self.current_profile_pattern_schemes.pattern:
if best:
return self.current_profile_pattern_schemes.best_scheme
idx = self.current_insertion_scheme_index
if idx is not None:
return self.current_profile_pattern_schemes.schemes[idx]
# 3. Try pattern cache
if self.pattern_cache:
ps = self.pattern_cache.get_pattern_schemes(pattern.signature)
if ps and len(ps.schemes) == 1 and not ps.schemes[0].is_profiled:
return ps.schemes[0]
return None

new_graph.toposort()
return new_graph

def _get_quant_dtype(self, quant_type: str) -> np.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you unify this and _get_dq_output_dtype as:

_DTYPE_MAP = {
    "int8": np.int8, "uint8": np.uint8,
    "float16": np.float16, "float32": np.float32,
}

def _resolve_dtype(self, dtype_str: str, default=np.int8) -> np.dtype:
    if dtype_str == "fp8":
        return getattr(np, "float8_e4m3fn", np.uint8)
    if hasattr(np, "bfloat16") and dtype_str == "bfloat16":
        return np.bfloat16
    return self._DTYPE_MAP.get(dtype_str, default)

dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_")
# Determine scale dtype from output_dtype (fp16/tf32/fp32)
# Scale should match the precision of the original I/O tensor
dtype_map = {"float16": np.float16, "float32": np.float32}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to module/class level?

Comment on lines +982 to +1004
def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]:
"""Build mapping from tensor names to tensor objects."""
tensor_map = {}

for node in graph.nodes:
for output in node.outputs:
if hasattr(output, "name") and output.name:
tensor_map[output.name] = output

for input_tensor in graph.inputs:
if hasattr(input_tensor, "name") and input_tensor.name:
tensor_map[input_tensor.name] = input_tensor

for node in graph.nodes:
for input_tensor in node.inputs:
if (
isinstance(input_tensor, gs.Constant)
and hasattr(input_tensor, "name")
and input_tensor.name
):
tensor_map[input_tensor.name] = input_tensor

return tensor_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]:
    tensor_map = {t.name: t for t in graph.inputs if hasattr(t, "name") and t.name}
    for node in graph.nodes:
        for t in node.outputs:
            if hasattr(t, "name") and t.name:
                tensor_map[t.name] = t
        for t in node.inputs:
            if isinstance(t, gs.Constant) and hasattr(t, "name") and t.name:
                tensor_map[t.name] = t
    return tensor_map

With this we iterate the tensor nodes only once.

Comment on lines +783 to +792
def _is_region_profiled(self, region: Region) -> bool:
"""Check if a region's pattern has already been fully profiled."""

def match_pattern(pattern: PatternSchemes, region: Region) -> bool:
"""Check if a pattern matches a region."""
if pattern.pattern is None or not pattern.pattern.matches(region, self.graph):
return False
return not any(not scheme.is_profiled for scheme in pattern.schemes)

return any(match_pattern(pattern, region) for pattern in self.profiled_patterns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _is_region_profiled(self, region: Region) -> bool:
"""Check if a region's pattern has already been fully profiled."""
def match_pattern(pattern: PatternSchemes, region: Region) -> bool:
"""Check if a pattern matches a region."""
if pattern.pattern is None or not pattern.pattern.matches(region, self.graph):
return False
return not any(not scheme.is_profiled for scheme in pattern.schemes)
return any(match_pattern(pattern, region) for pattern in self.profiled_patterns)
def _is_region_profiled(self, region: Region) -> bool:
return any(
p.pattern is not None
and p.pattern.matches(region, self.graph)
and all(s.is_profiled for s in p.schemes)
for p in self.profiled_patterns
)


self.initialized = True

def set_profile_region(self, region: Region | None, commit: bool = True) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split into _commit_current_pattern(), _seed_from_cache(pattern), and keep set_profile_region as the orchestrator.

super().initialize(config, pattern_cache)
self._search_regions()

def _visit_region_recursively(self, region: Region) -> list[Region]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this a static method?

Comment on lines +1285 to +1315
search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = search.search_regions()

self._reassign_region_ids(self.regions)
logger.debug(f"Found {len(self.regions)} top-level regions")

all_regions = []
for region in self.regions:
all_regions.extend(self._visit_region_recursively(region))

logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions")

leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF]
other_regions = [region for region in all_regions if region.type != RegionType.LEAF]

all_regions = leaf_regions + other_regions
self.regions = all_regions

num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF)
num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE)
num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT)

logger.info(
f"Discovery complete: {len(self.regions)} regions "
f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)"
)
logger.debug("Regions prioritized: LEAF regions first for profiling")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this to:

        self.regions = search.search_regions()
        self._reassign_region_ids(self.regions)
        logger.debug(f"Found {len(self.regions)} top-level regions")

        # Flatten hierarchy: collect all regions (root + descendants)
        all_regions = []
        for region in self.regions:
            all_regions.extend(self._visit_region_recursively(region))

        # Stable sort: LEAF first, everything else after
        all_regions.sort(key=lambda r: r.type != RegionType.LEAF)
        self.regions = all_regions

        # Count by type using Counter for a single pass
        from collections import Counter
        type_counts = Counter(r.type for r in self.regions)

        logger.info(
            f"Discovery complete: {len(self.regions)} regions "
            f"({type_counts[RegionType.LEAF]} LEAF, "
            f"{type_counts[RegionType.COMPOSITE]} COMPOSITE, "
            f"{type_counts[RegionType.ROOT]} ROOT)"
        )
        logger.debug("Regions prioritized: LEAF regions first for profiling")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants