[FX] Model Extraction Transformation Fixes#3982
[FX] Model Extraction Transformation Fixes#3982daniil-lyakhov wants to merge 9 commits intoopenvinotoolkit:developfrom
Conversation
4189f05 to
769668a
Compare
There was a problem hiding this comment.
Pull request overview
This PR updates the Torch FX subgraph extraction logic used by Bias Correction so that extracted subgraphs include only nodes that contribute to the requested outputs, properly handle multi-input ops (e.g., aten.cat), and support specifying extraction boundaries with explicit input/output port IDs.
Changes:
- Extend
PTModelExtractionCommandto accept(node_name, port_id)pairs for subgraph inputs/outputs and update BC backends + FX extraction tests accordingly. - Rework FX extraction traversal to walk dependencies “upwards” and avoid capturing unrelated nodes; ignore original FX placeholders in extracted graphs.
- Add helper models and reference
.dotgraphs to cover concat/skip-connection cases and update Bias Correction template tests.
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/torch/fx/test_model_transformer.py | Updates extraction test cases to use (node, port) IDs and adds new concat-with-input scenarios. |
| tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_conv2d_2.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_conv2d_1_add__add__1.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_add__add.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_add__1.dot | Updates extracted-graph reference output to reflect new extraction traversal behavior. |
| tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_1_add__1.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithSeveralOutputsconv2d_output.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithSeveralOutputsconv2d_conv2d_output_conv2d.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_output.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_conv2d_output_conv2d.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_conv2d.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvolutionWithAllConstantInputsModelconv2d_conv2d.dot | Updates extracted-graph reference output formatting and placeholder naming. |
| tests/torch/data/fx/extracted/ConvConcatWithLongPathToOrigInputscat_conv2d_cat.dot | Adds new extracted-graph reference for long-path original-input concat scenario. |
| tests/torch/data/fx/extracted/ConvConcatWithInputModelcat_conv2d_cat.dot | Adds new extracted-graph reference for concat-with-original-input scenario. |
| tests/onnx/quantization/test_bias_correction.py | Skips ONNX BC for the new concat-with-input model due to known extraction limitation. |
| tests/cross_fw/test_templates/test_bias_correction.py | Adds BC template test coverage for concat/add skip-connection models. |
| tests/cross_fw/test_templates/helpers.py | Adds new helper models that exercise concat/add skip connections to original inputs. |
| src/nncf/torch/graph/transformations/commands.py | Changes PT extraction command API to use (node_name, port_id) pairs. |
| src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py | Updates backend to pass (node, port) IDs into PT extraction command. |
| src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py | Updates backend to pass (node, port) IDs into PT extraction command. |
| src/nncf/quantization/algorithms/bias_correction/torch_fx_backend.py | Updates backend to forward full (node, port) ID sets into PT extraction command. |
| src/nncf/experimental/torch/fx/node_utils.py | Adds set_node_args and extends FX node-arg handling for aten.cat. |
| src/nncf/experimental/torch/fx/model_transformer.py | Reworks FX subgraph extraction traversal and rewiring using port IDs and cat-aware arg accessors. |
| def _traverse_graph_up( | ||
| input_nodes: list[torch.fx.Node], | ||
| stop_nodes: set[torch.fx.Node], | ||
| visited: set[torch.fx.Node], | ||
| ) -> None: |
There was a problem hiding this comment.
Type hints for stop_nodes and visited are inconsistent with how they are used: the function stores in_node.name (a str) in visited and checks in_node.name in stop_nodes, but the annotations declare set[torch.fx.Node]. This breaks static type checking and makes the API unclear; annotate these as set[str] (or alternatively store actual torch.fx.Node objects consistently).
| if node.target == torch.ops.aten.cat.default: | ||
| return node.args[0] | ||
| return node.args | ||
|
|
There was a problem hiding this comment.
get_node_args() is annotated to return tuple[Any, ...], but for aten.cat it returns node.args[0], which is typically a list of inputs. This mismatch can cause type-checking issues and confusion for callers; consider returning a Sequence[Any] (or tuple[Any, ...] | list[Any]) and updating set_node_args() accordingly.
| args = list(node_with_input.args) | ||
| args[0] = graph_input | ||
| node_with_input.args = tuple(args) | ||
| args = list(get_node_args(node_with_input)) |
There was a problem hiding this comment.
When rewiring extracted-graph inputs, args[input_port_id] = graph_input can also fail with IndexError if the chosen input_port_id doesn't align with the operator's argument structure (notably for ops with non-tensor args). Please add bounds checking and an informative error (including node_with_input.target/op and len(args)) before assignment.
| args = list(get_node_args(node_with_input)) | |
| args = list(get_node_args(node_with_input)) | |
| if not (0 <= input_port_id < len(args)): | |
| raise IndexError( | |
| f"Failed to rewire extracted-graph input for node " | |
| f"'{node_with_input.target}' (op='{node_with_input.op}'): " | |
| f"input_port_id={input_port_id} is out of bounds for args " | |
| f"of length {len(args)}." | |
| ) |
There was a problem hiding this comment.
The index error is expected behavior in such case
| :param input_ids: List of the input IDs: pairs of node names and correspondent input port ids. | ||
| Each pair denotes the sub-graph beginning. | ||
| :param output_ids: List of the output IDs: pairs of node names and correspondent output port ids. |
There was a problem hiding this comment.
Docstring typo: "correspondent" should be "corresponding" (used twice in the input_ids/output_ids parameter descriptions).
| :param input_ids: List of the input IDs: pairs of node names and correspondent input port ids. | |
| Each pair denotes the sub-graph beginning. | |
| :param output_ids: List of the output IDs: pairs of node names and correspondent output port ids. | |
| :param input_ids: List of the input IDs: pairs of node names and corresponding input port ids. | |
| Each pair denotes the sub-graph beginning. | |
| :param output_ids: List of the output IDs: pairs of node names and corresponding output port ids. |
| transformation = transformations[-1] | ||
| stop_nodes = set(transformation.input_node_names + transformation.output_node_names) | ||
| input_node_names = [name for name, _ in transformation.input_ids] | ||
| output_node_names = [name for name, _ in transformation.output_ids] | ||
| stop_nodes = set(input_node_names + output_node_names) |
There was a problem hiding this comment.
PTModelExtractionCommand.output_ids now includes (node_name, output_port_id), but _apply_model_extraction() ignores output_port_id entirely (it only extracts output_node_names). This can return an incorrect output when the requested output is a specific port of a multi-output FX node (e.g., split/max/min/topk where node.meta['val'] is a tuple/list and NNCF assigns distinct output_port_ids). Please use output_port_id when forming nodes_with_output (e.g., select the right element from the FX output node’s tuple output, and/or add/select a getitem for multi-output nodes) so the extracted subgraph returns the exact tensor for the requested output port.
There was a problem hiding this comment.
All torchFX nodes have only one output by the index zero, so this is not relevant to the backend.
| input_ids: set[tuple[str, int]], output_ids: set[tuple[str, int]] | ||
| ) -> PTModelExtractionCommand: | ||
| return PTModelExtractionCommand([inp_id[0] for inp_id in input_ids], [out_id[0] for out_id in output_ids]) | ||
| return PTModelExtractionCommand(list(input_ids), list(output_ids)) |
There was a problem hiding this comment.
model_extraction_command() receives input_ids/output_ids as set[...], but list(input_ids) / list(output_ids) makes the extraction order non-deterministic (set iteration order). This can lead to unstable placeholder insertion / output tuple ordering and flaky behavior. Consider sorting (e.g., by (node_name, port_id)) before constructing PTModelExtractionCommand to keep extracted graphs reproducible.
| return PTModelExtractionCommand(list(input_ids), list(output_ids)) | |
| sorted_input_ids = sorted(input_ids, key=lambda item: (item[0], item[1])) | |
| sorted_output_ids = sorted(output_ids, key=lambda item: (item[0], item[1])) | |
| return PTModelExtractionCommand(sorted_input_ids, sorted_output_ids) |
| return self.conv_6(x) | ||
|
|
||
|
|
||
| class ConvConcatWithInputModel(nn.Module): |
There was a problem hiding this comment.
The incorrect name does not reflect the model
| outside the subgraph boundary. | ||
|
|
||
| mode="cat": | ||
| input ──→ conv1 ──→ relu ──→ cat([relu_out, input], dim=1) ──→ conv2 |
There was a problem hiding this comment.
Schema and docstring do not match this class.
Just for intersting, why does the output node point downward whereas the others point to the right?
| mode="cat": | ||
| input ──→ conv1 ──→ relu ──→ cat([relu_out, input], dim=1) ──→ conv2 | ||
| ↑ | | ||
| input ──────────────────────────┘ ↓ |
There was a problem hiding this comment.
How about define it like, without creating unnceceery hierarhy
def build_conv(in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
conv = create_conv(in_channels, out_channels, kernel_size)
with set_torch_seed():
conv.weight.data = torch.randn([out_channels, in_channels, kernel_size, kernel_size])
conv.bias.data = torch.randn([out_channels])
return conv
class ConcatWithInput(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = build_conv(2, 2, 1)
self.conv_2 = build_conv(4, 2, 1)
def forward(self, x):
x_1 = self.conv_1(x)
x_1 = torch.relu(x_1)
x_combined = torch.cat([x_1, x], dim=1)
return self.conv_2(x_combined)
class ConcatWithReluInput(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = build_conv(2, 2, 1)
self.conv_2 = build_conv(4, 2, 1)
def forward(self, x):
x_1 = self.conv_1(x)
x_1 = F.relu(x_1)
x_combined = torch.cat([x_1, F.relu(x)], dim=1)
return self.conv_2(x_combined)
class AddWithInput(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = build_conv(2, 2, 1)
self.conv_2 = build_conv(2, 2, 1)
def forward(self, x):
x_1 = self.conv_1(x)
x_1 = F.relu(x_1)
x_combined = x_1 + x
return self.conv_2(x_combined)
class AddWithReluInput(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = build_conv(2, 2, 1)
self.conv_2 = build_conv(2, 2, 1)
def forward(self, x):
x_1 = self.conv_1(x)
x_1 = F.relu(x_1)
x_combined = x_1 + torch.relu(x)
return self.conv_2(x_combined)| MultiBranchesConnectedModel, (1, 3, 3, 3), PTModelExtractionCommand([("conv2d", 0)], [("add__1", 0)]) | ||
| ), | ||
| ModelExtractionTestCase( | ||
| ConvConcatWithInputModel(mode="cat"), |
There was a problem hiding this comment.
Dont create instnace of the model as parameter
- instance createing on collection tests cases and alive whole time
- need to add extra not necessery condition
- If it nned to parametrize init of object you can use
functools.partial
And ModelExtractionTestCase got incorrect annotation
| ConvolutionWithNotTensorBiasModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["output"]) | ||
| ConvolutionWithNotTensorBiasModel, | ||
| (1, 1, 3, 3), | ||
| PTModelExtractionCommand([("conv2d", 0)], [("output", 0)]), |
There was a problem hiding this comment.
There is no tests withs port_id=1?
Changes
Before:
Reason for changes
Related tickets
180772
Tests
*NNCF/job/manual/job/post_training_quantization/835/ - green