Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxoptimizer/passes/fuse_consecutive_squeezes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
!GetValueFromAttrOrInput(n, kaxes, 1, axes_2)) {
return false;
}
if (std::any_of(axes_1.begin(), axes_1.end(), [](int64_t v) { return v < 0; }) ||

Check warning on line 47 in onnxoptimizer/passes/fuse_consecutive_squeezes.h

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace with the version of "std::ranges::any_of" that takes a range.

See more on https://sonarcloud.io/project/issues?id=onnx_optimizer&issues=AZyZ23rtEFps_QDQtVPx&open=AZyZ23rtEFps_QDQtVPx&pullRequest=293
std::any_of(axes_2.begin(), axes_2.end(), [](int64_t v) { return v < 0; })) {

Check warning on line 48 in onnxoptimizer/passes/fuse_consecutive_squeezes.h

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace with the version of "std::ranges::any_of" that takes a range.

See more on https://sonarcloud.io/project/issues?id=onnx_optimizer&issues=AZyZ23rtEFps_QDQtVPy&open=AZyZ23rtEFps_QDQtVPy&pullRequest=293
return false;
}

std::vector<int64_t> &ret = composed_axes;
ret.clear();
Expand Down
12 changes: 12 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2872,6 +2872,18 @@ def test_fuse_consecutive_squeezes_multi_uses(self): # type: () -> None
if init.name == optimized_model.graph.node[2].input[1]:
assert list(to_array(init)) == [0, 1, 4, 5, 6]

def test_fuse_consecutive_squeezes_negative_axes(self): # type: () -> None
graph = parser.parse_graph("""
agraph (float[5, 7, 1, 1] X) => (float[5, 7] Z)
{
Axes = Constant<value=int64[1]{-1}> ()
Y = Squeeze (X, Axes)
Z = Squeeze (Y, Axes)
}
""")
optimized_model = self._optimized(graph, ["fuse_consecutive_squeezes"])
assert len(optimized_model.graph.node) == 3

@pytest.mark.xfail
def test_fuse_consecutive_softmax_log_axis(self): # type: () -> None
for axis in range(3):
Expand Down
Loading