diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fb93bc703f..52806d621d 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +import logging from typing import Sequence, TypeVar, Union __all__ = [ @@ -48,6 +49,8 @@ _remove_optional_bias, ) +logger = logging.getLogger(__name__) + _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( *_no_op.rules, # TODO: merge this rule into constant folding? @@ -82,7 +85,8 @@ def __init__( def call(self, model: ir.Model) -> ir.passes.PassResult: count = self.rules.apply_to_model(model) if count: - print(f"Applied {count} of general pattern rewrite rules.") + logger.info("Applied %s of general pattern rewrite rules.", count) + return ir.passes.PassResult(model, bool(count))