diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py index 5d294cdbd7..b47d2d9b6f 100644 --- a/onnxscript/rewriter/rules/common/_fuse_relus_clips.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py @@ -74,7 +74,9 @@ def extract_min_max(self, node: ir.Node): min_clip = min_input.const_value.numpy() if len(node.inputs) > 2: - max_clip = node.inputs[2].const_value.numpy() + max_clip = node.inputs[2] + if max_clip is not None: + max_clip = max_clip.const_value.numpy() return min_clip, max_clip, dtype diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py index df2d669930..0d36f19ce4 100644 --- a/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -9,12 +9,7 @@ import parameterized from onnx_ir.passes.common import onnx_checker, shape_inference -from onnxscript.rewriter import ( - MatchingTracer, - MatchStatus, - RewriteRule, - testing, -) +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing from onnxscript.rewriter.rules.common import _fuse_relus_clips from onnxscript.rewriter.rules.common._fuse_relus_clips import ( successive_clip_relu_rule, @@ -206,6 +201,35 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): """) self.run_test(model, expected_op_types=["Clip"]) + @parameterized.parameterized.expand( + [ + ( + "relu_then_clip", + """ + x1 = Relu(X) + Y = Clip(x1,min,"") + """, + ), + ( + "clip_then_relu", + """ + x1 = Clip(X,min,"") + Y = Relu(x1) + """, + ), + ] + ) + def test_successful_fuse_successive_relu_clip_no_max(self, _, nodes): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + {{ + {nodes} + }} + """) + self.run_test(model, expected_op_types=["Clip"]) + @parameterized.parameterized.expand( [ (