Skip to content

[FX] Model Extraction Transformation Fixes#3982

Open
daniil-lyakhov wants to merge 9 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/fx/yolo26_bc_fix
Open

[FX] Model Extraction Transformation Fixes#3982
daniil-lyakhov wants to merge 9 commits intoopenvinotoolkit:developfrom
daniil-lyakhov:dl/fx/yolo26_bc_fix

Conversation

@daniil-lyakhov
Copy link
Copy Markdown
Collaborator

@daniil-lyakhov daniil-lyakhov commented Mar 13, 2026

Changes

  • TorchFX model extractor logic is updated to prevent capture of nodes that does not affect the target outputs (Extractor traverses graph only upwards from the start nodes and do not go to the node users). For example:
    Before:
image After: image
  • Input port id / output port id was introduced to the PT extraction command
  • In case the subgraph contains the original input node it would fail before, now the original inputs are ignored
  • get_attribute/ set_attribute is used to work with the concat node correctly

Reason for changes

  • To enable BC algorithm for the Yolo26 model in TorchFX

Related tickets

180772

Tests

  • tests/torch/fx/test_model_transformer.py is updated to cover all the fixes in the model transformer
  • tests/cross_fw/test_templates/test_bias_correction.py is updated with models that have skip connections through cocnat/elementwise node to check that BC algo is working properly

*NNCF/job/manual/job/post_training_quantization/835/ - green

@github-actions github-actions Bot added NNCF PT Pull requests that updates NNCF PyTorch NNCF PTQ Pull requests that updates NNCF PTQ labels Mar 13, 2026
@github-actions github-actions Bot added NNCF ONNX Pull requests that updates NNCF ONNX and removed NNCF PTQ Pull requests that updates NNCF PTQ labels Mar 24, 2026
@daniil-lyakhov daniil-lyakhov marked this pull request as ready for review March 25, 2026 15:48
@daniil-lyakhov daniil-lyakhov requested a review from a team as a code owner March 25, 2026 15:48
Copilot AI review requested due to automatic review settings March 25, 2026 15:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Torch FX subgraph extraction logic used by Bias Correction so that extracted subgraphs include only nodes that contribute to the requested outputs, properly handle multi-input ops (e.g., aten.cat), and support specifying extraction boundaries with explicit input/output port IDs.

Changes:

  • Extend PTModelExtractionCommand to accept (node_name, port_id) pairs for subgraph inputs/outputs and update BC backends + FX extraction tests accordingly.
  • Rework FX extraction traversal to walk dependencies “upwards” and avoid capturing unrelated nodes; ignore original FX placeholders in extracted graphs.
  • Add helper models and reference .dot graphs to cover concat/skip-connection cases and update Bias Correction template tests.

Reviewed changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/torch/fx/test_model_transformer.py Updates extraction test cases to use (node, port) IDs and adds new concat-with-input scenarios.
tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_conv2d_2.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_conv2d_1_add__add__1.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_add__add.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_add__1.dot Updates extracted-graph reference output to reflect new extraction traversal behavior.
tests/torch/data/fx/extracted/MultiBranchesConnectedModelconv2d_1_add__1.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithSeveralOutputsconv2d_output.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithSeveralOutputsconv2d_conv2d_output_conv2d.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_output.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_conv2d_output_conv2d.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithNotTensorBiasModelconv2d_conv2d.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvolutionWithAllConstantInputsModelconv2d_conv2d.dot Updates extracted-graph reference output formatting and placeholder naming.
tests/torch/data/fx/extracted/ConvConcatWithLongPathToOrigInputscat_conv2d_cat.dot Adds new extracted-graph reference for long-path original-input concat scenario.
tests/torch/data/fx/extracted/ConvConcatWithInputModelcat_conv2d_cat.dot Adds new extracted-graph reference for concat-with-original-input scenario.
tests/onnx/quantization/test_bias_correction.py Skips ONNX BC for the new concat-with-input model due to known extraction limitation.
tests/cross_fw/test_templates/test_bias_correction.py Adds BC template test coverage for concat/add skip-connection models.
tests/cross_fw/test_templates/helpers.py Adds new helper models that exercise concat/add skip connections to original inputs.
src/nncf/torch/graph/transformations/commands.py Changes PT extraction command API to use (node_name, port_id) pairs.
src/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py Updates backend to pass (node, port) IDs into PT extraction command.
src/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py Updates backend to pass (node, port) IDs into PT extraction command.
src/nncf/quantization/algorithms/bias_correction/torch_fx_backend.py Updates backend to forward full (node, port) ID sets into PT extraction command.
src/nncf/experimental/torch/fx/node_utils.py Adds set_node_args and extends FX node-arg handling for aten.cat.
src/nncf/experimental/torch/fx/model_transformer.py Reworks FX subgraph extraction traversal and rewiring using port IDs and cat-aware arg accessors.

Comment on lines 66 to 70
def _traverse_graph_up(
input_nodes: list[torch.fx.Node],
stop_nodes: set[torch.fx.Node],
visited: set[torch.fx.Node],
) -> None:
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints for stop_nodes and visited are inconsistent with how they are used: the function stores in_node.name (a str) in visited and checks in_node.name in stop_nodes, but the annotations declare set[torch.fx.Node]. This breaks static type checking and makes the API unclear; annotate these as set[str] (or alternatively store actual torch.fx.Node objects consistently).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +95 to +105
if node.target == torch.ops.aten.cat.default:
return node.args[0]
return node.args

Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_node_args() is annotated to return tuple[Any, ...], but for aten.cat it returns node.args[0], which is typically a list of inputs. This mismatch can cause type-checking issues and confusion for callers; consider returning a Sequence[Any] (or tuple[Any, ...] | list[Any]) and updating set_node_args() accordingly.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread src/nncf/experimental/torch/fx/model_transformer.py Outdated
args = list(node_with_input.args)
args[0] = graph_input
node_with_input.args = tuple(args)
args = list(get_node_args(node_with_input))
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When rewiring extracted-graph inputs, args[input_port_id] = graph_input can also fail with IndexError if the chosen input_port_id doesn't align with the operator's argument structure (notably for ops with non-tensor args). Please add bounds checking and an informative error (including node_with_input.target/op and len(args)) before assignment.

Suggested change
args = list(get_node_args(node_with_input))
args = list(get_node_args(node_with_input))
if not (0 <= input_port_id < len(args)):
raise IndexError(
f"Failed to rewire extracted-graph input for node "
f"'{node_with_input.target}' (op='{node_with_input.op}'): "
f"input_port_id={input_port_id} is out of bounds for args "
f"of length {len(args)}."
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index error is expected behavior in such case

Comment on lines +190 to +192
:param input_ids: List of the input IDs: pairs of node names and correspondent input port ids.
Each pair denotes the sub-graph beginning.
:param output_ids: List of the output IDs: pairs of node names and correspondent output port ids.
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring typo: "correspondent" should be "corresponding" (used twice in the input_ids/output_ids parameter descriptions).

Suggested change
:param input_ids: List of the input IDs: pairs of node names and correspondent input port ids.
Each pair denotes the sub-graph beginning.
:param output_ids: List of the output IDs: pairs of node names and correspondent output port ids.
:param input_ids: List of the input IDs: pairs of node names and corresponding input port ids.
Each pair denotes the sub-graph beginning.
:param output_ids: List of the output IDs: pairs of node names and corresponding output port ids.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Comment on lines 105 to +108
transformation = transformations[-1]
stop_nodes = set(transformation.input_node_names + transformation.output_node_names)
input_node_names = [name for name, _ in transformation.input_ids]
output_node_names = [name for name, _ in transformation.output_ids]
stop_nodes = set(input_node_names + output_node_names)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTModelExtractionCommand.output_ids now includes (node_name, output_port_id), but _apply_model_extraction() ignores output_port_id entirely (it only extracts output_node_names). This can return an incorrect output when the requested output is a specific port of a multi-output FX node (e.g., split/max/min/topk where node.meta['val'] is a tuple/list and NNCF assigns distinct output_port_ids). Please use output_port_id when forming nodes_with_output (e.g., select the right element from the FX output node’s tuple output, and/or add/select a getitem for multi-output nodes) so the extracted subgraph returns the exact tensor for the requested output port.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All torchFX nodes have only one output by the index zero, so this is not relevant to the backend.

input_ids: set[tuple[str, int]], output_ids: set[tuple[str, int]]
) -> PTModelExtractionCommand:
return PTModelExtractionCommand([inp_id[0] for inp_id in input_ids], [out_id[0] for out_id in output_ids])
return PTModelExtractionCommand(list(input_ids), list(output_ids))
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_extraction_command() receives input_ids/output_ids as set[...], but list(input_ids) / list(output_ids) makes the extraction order non-deterministic (set iteration order). This can lead to unstable placeholder insertion / output tuple ordering and flaky behavior. Consider sorting (e.g., by (node_name, port_id)) before constructing PTModelExtractionCommand to keep extracted graphs reproducible.

Suggested change
return PTModelExtractionCommand(list(input_ids), list(output_ids))
sorted_input_ids = sorted(input_ids, key=lambda item: (item[0], item[1]))
sorted_output_ids = sorted(output_ids, key=lambda item: (item[0], item[1]))
return PTModelExtractionCommand(sorted_input_ids, sorted_output_ids)

Copilot uses AI. Check for mistakes.
Comment thread tests/torch/fx/test_model_transformer.py
Comment thread src/nncf/torch/graph/transformations/commands.py
Copy link
Copy Markdown
Collaborator

@anzr299 anzr299 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor

Comment thread src/nncf/experimental/torch/fx/model_transformer.py Outdated
Comment thread src/nncf/experimental/torch/fx/model_transformer.py
return self.conv_6(x)


class ConvConcatWithInputModel(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The incorrect name does not reflect the model

outside the subgraph boundary.

mode="cat":
input ──→ conv1 ──→ relu ──→ cat([relu_out, input], dim=1) ──→ conv2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Schema and docstring do not match this class.
Just for intersting, why does the output node point downward whereas the others point to the right?

mode="cat":
input ──→ conv1 ──→ relu ──→ cat([relu_out, input], dim=1) ──→ conv2
↑ |
input ──────────────────────────┘ ↓
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about define it like, without creating unnceceery hierarhy

def build_conv(in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
    conv = create_conv(in_channels, out_channels, kernel_size)
    with set_torch_seed():
        conv.weight.data = torch.randn([out_channels, in_channels, kernel_size, kernel_size])
        conv.bias.data = torch.randn([out_channels])
    return conv

class ConcatWithInput(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = build_conv(2, 2, 1)
        self.conv_2 = build_conv(4, 2, 1)

    def forward(self, x):
        x_1 = self.conv_1(x)
        x_1 = torch.relu(x_1)
        x_combined = torch.cat([x_1, x], dim=1)
        return self.conv_2(x_combined)

class ConcatWithReluInput(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = build_conv(2, 2, 1)
        self.conv_2 = build_conv(4, 2, 1)

    def forward(self, x):
        x_1 = self.conv_1(x)
        x_1 = F.relu(x_1)
        x_combined = torch.cat([x_1, F.relu(x)], dim=1)
        return self.conv_2(x_combined)


class AddWithInput(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = build_conv(2, 2, 1)
        self.conv_2 = build_conv(2, 2, 1)

    def forward(self, x):
        x_1 = self.conv_1(x)
        x_1 = F.relu(x_1)
        x_combined = x_1 + x
        return self.conv_2(x_combined)

class AddWithReluInput(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = build_conv(2, 2, 1)
        self.conv_2 = build_conv(2, 2, 1)

    def forward(self, x):
        x_1 = self.conv_1(x)
        x_1 = F.relu(x_1)
        x_combined = x_1 + torch.relu(x)
        return self.conv_2(x_combined)

MultiBranchesConnectedModel, (1, 3, 3, 3), PTModelExtractionCommand([("conv2d", 0)], [("add__1", 0)])
),
ModelExtractionTestCase(
ConvConcatWithInputModel(mode="cat"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont create instnace of the model as parameter

  • instance createing on collection tests cases and alive whole time
  • need to add extra not necessery condition
  • If it nned to parametrize init of object you can use functools.partial

And ModelExtractionTestCase got incorrect annotation

ConvolutionWithNotTensorBiasModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["output"])
ConvolutionWithNotTensorBiasModel,
(1, 1, 3, 3),
PTModelExtractionCommand([("conv2d", 0)], [("output", 0)]),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no tests withs port_id=1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

NNCF ONNX Pull requests that updates NNCF ONNX NNCF PT Pull requests that updates NNCF PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants