diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index f75e46baaa..32cc5cbdd0 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -153,7 +153,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: return live def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: for s in reversed(block): live_out = visit(s, live_out) return live_out @@ -167,20 +167,20 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: if isinstance(stmt, ast.If): constant_cond = self.constant_if_condition(stmt) if constant_cond is None: - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) + live1 = visit_block(stmt.body, live_out) + live2 = visit_block(stmt.orelse, live_out) return live1 | live2 | _used_vars(stmt.test) elif constant_cond: - return visitBlock(stmt.body, live_out) + return visit_block(stmt.body, live_out) else: - return visitBlock(stmt.orelse, live_out) + return visit_block(stmt.orelse, live_out) if isinstance(stmt, ast.For): p_loop_var = _get_loop_var(stmt, self._formatter) prev = None curr = live_out while curr != prev: prev = curr - curr = visitBlock(stmt.body, prev).difference({p_loop_var}) + curr = visit_block(stmt.body, prev).difference({p_loop_var}) return curr if isinstance(stmt, ast.While): cond_vars = _used_vars(stmt.test) @@ -188,7 +188,7 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: curr = live_out | cond_vars while curr != prev: prev = curr - curr = visitBlock(stmt.body, prev) | cond_vars + curr = visit_block(stmt.body, prev) | cond_vars return curr if isinstance(stmt, ast.Break): # The following is sufficient for the current restricted usage, where @@ -228,7 +228,7 @@ def exposed_uses(self, stmts: Sequence[ast.stmt]) -> set[str]: (in the first statement). Hence x is included in the exposed_uses. """ - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: for stmt in reversed(block): live_out = visit(stmt, live_out) return live_out @@ -243,13 +243,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: if isinstance(stmt, ast.If): constant_cond = self.constant_if_condition(stmt) if constant_cond is None: - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) + live1 = visit_block(stmt.body, live_out) + live2 = visit_block(stmt.orelse, live_out) return (live1 | live2) | _used_vars(stmt.test) elif constant_cond: - return visitBlock(stmt.body, live_out) + return visit_block(stmt.body, live_out) else: - return visitBlock(stmt.orelse, live_out) + return visit_block(stmt.orelse, live_out) if ast_utils.is_print_call(stmt): return live_out if ast_utils.is_doc_string(stmt): @@ -259,13 +259,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: # for loops that execute at least once. loop_var_set = {_get_loop_var(stmt, self._formatter)} used_after_loop = live_out.difference(loop_var_set) - used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) + used_inside_loop = visit_block(stmt.body, set()).difference(loop_var_set) used_in_loop_header = _used_vars(stmt.iter) return used_inside_loop | used_in_loop_header | used_after_loop if isinstance(stmt, ast.While): # Analysis assumes loop may execute zero times. Results can be improved # for loops that execute at least once. - used_inside_loop = visitBlock(stmt.body, set()) + used_inside_loop = visit_block(stmt.body, set()) used_in_loop_header = _used_vars(stmt.test) return used_inside_loop | used_in_loop_header | live_out if isinstance(stmt, ast.Break): @@ -281,7 +281,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") ) - return visitBlock(stmts, set()) + return visit_block(stmts, set()) def outer_scope_variables(self, fun: ast.FunctionDef) -> set[str]: """Return the set of outer-scope variables used in a nested function. diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index a97a88c14d..1177882abc 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -187,15 +187,15 @@ def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]: argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. """ - return None if x is None or converter_.is_castable(x.name) else x + return None if x is None or converter_._is_castable(x.name) else x # pylint: disable=protected-access def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: if x is None: return None - if converter_.is_castable(x.name) and y is not None: + if converter_._is_castable(x.name) and y is not None: # pylint: disable=protected-access # Polymorphic constant x is cast to the type of y: - x_cast = converter_.generate_unique_name(f"{x.name}_cast") - return converter_.emit1([x_cast], "CastLike", [x, y]) + x_cast = converter_._generate_unique_name(f"{x.name}_cast") # pylint: disable=protected-access + return converter_._emit1([x_cast], "CastLike", [x, y]) # pylint: disable=protected-access return x return cast_inputs(get_type_info, cast_like, op_signature, args) diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 468aa41675..490017bab0 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -37,10 +37,6 @@ # Python-to-IR converter: -def not_allowed(construct): - return f"{construct}not supported." - - class TranslationError(Exception): def __init__(self, *args: object) -> None: super().__init__(*args) @@ -194,7 +190,7 @@ def __init__( self._analyzer: analysis.AstAnalyzer | None = None self._castable: set[str] = set() - def is_castable(self, var_name: str) -> bool: + def _is_castable(self, var_name: str) -> bool: """Returns True if the variable with the given name represents a polymorphic constant.""" return var_name in self._castable @@ -221,7 +217,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: opset.domain != self.default_opset_.domain or opset.version != self.default_opset_.version ): - self.fail( + self._fail( node, f"Two distincts opset were used ({opset} != {self.default_opset_})." ) else: @@ -264,10 +260,10 @@ def _message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error _message containing source information about an ast node.""" return self._source_of(node).msg(error_msg) - def warn(self, node: ast.AST, error_msg: str) -> None: + def _warn(self, node: ast.AST, error_msg: str) -> None: warn(self._message(node, error_msg)) - def fail(self, node: ast.AST, error_msg: str) -> NoReturn: + def _fail(self, node: ast.AST, error_msg: str) -> NoReturn: fail(self._message(node, error_msg)) # Name resolution and namescopes: This component handles the following aspects: @@ -319,7 +315,7 @@ def _lookup( raise ValueError(info.msg(f"Unbound name: {name}.")) return None - def generate_unique_name(self, candidate: str = "tmp") -> str: + def _generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate while r in self._used_vars: @@ -354,19 +350,19 @@ def _to_onnx_var( ) -> ir.Value: if isinstance(val, values.AttrRef): # promote attribute to value - result_name = self.generate_unique_name(target or "tmp") + result_name = self._generate_unique_name(target or "tmp") attr = self._to_onnx_attr_ref(val, info) - result = self.emit( + result = self._emit( [result_name], values.Op(self.default_opset, "Constant"), [], [attr] ) if val.as_bool: # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self.generate_unique_name(result_name + "_as_bool") + result_as_bool = self._generate_unique_name(result_name + "_as_bool") cast_attr = ir.AttrInt64("to", onnx_types.BOOL.dtype) self._castable.add(result_as_bool) - return self.emit1( + return self._emit1( [result_as_bool], values.Op(self.default_opset, "Cast"), [result], @@ -388,7 +384,7 @@ def _to_onnx_var( def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) - def emit( + def _emit( self, outputs: Sequence[str], callee: values.Op | str, @@ -416,8 +412,8 @@ def emit( return output_values if len(output_values) > 1 else output_values[0] - def emit1(self, *args, **kwargs) -> ir.Value: - r = self.emit(*args, **kwargs) + def _emit1(self, *args, **kwargs) -> ir.Value: + r = self._emit(*args, **kwargs) if not isinstance(r, ir.Value): raise TypeError(f"Expected single ONNX IR Value, got {type(r)!r}.") return r @@ -443,23 +439,23 @@ def _emit_const( suggested_name = f"int64_m{abs(pyvalue[0])}_1d" else: suggested_name = "const" - ovar = self.generate_unique_name(suggested_name) + ovar = self._generate_unique_name(suggested_name) try: tensor = ir.tensor(pyvalue, name=ovar) except Exception as exc: # pylint: disable=broad-exception-caught - self.fail( + self._fail( info.ast_node, f"Failed to convert constant value {pyvalue!r} to ONNX tensor: {exc}", ) attr = ir.AttrTensor("value", tensor) self._castable.add(ovar) - return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) + return self._emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) def _emit_copy(self, original_var: ir.Value, suggested_name: str) -> ir.Value: """Emits a copy statement, using the ONNX Identity operator.""" - new_var = self.generate_unique_name(suggested_name) - return self.emit([new_var], "Identity", [original_var]) + new_var = self._generate_unique_name(suggested_name) + return self._emit([new_var], "Identity", [original_var]) def _is_constant_expr(self, node: ast.AST) -> None: if isinstance(node, ast.UnaryOp): @@ -507,7 +503,7 @@ def _eval_constant_expr(self, expr: ast.expr) -> PyValue: def _get_type_annotation(self, annotation: ast.expr) -> ta.TypeAnnotationValue | None: typeinfo = self._eval_constant_expr(annotation) if not ta.is_valid_type(typeinfo): - self.warn( + self._warn( annotation, "Unsupported type annotation.", ) @@ -540,7 +536,7 @@ def _translate_attr( attr_name, attr.type, value=None, ref_attr_name=attr.name ) if attr_meta is not None and (attr.type != attr_meta.type): - self.fail( + self._fail( expr, f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) @@ -553,7 +549,7 @@ def _translate_attr( for pyvar, previous in irfunction.outer_scope_variables: current = self._lookup(pyvar, self._source_of(expr)) if current.value != previous.value: - self.fail( + self._fail( expr, f"Outer scope variable '{pyvar}' referenced by function " f"'{expr.id!r}' modified.", @@ -561,7 +557,7 @@ def _translate_attr( val = irfunction.graph if isinstance(val, ir.Value): - self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.") + self._fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.") else: # Treat as a constant python-value, to be converted below. pass @@ -579,7 +575,7 @@ def _translate_attr( # in a NodeProto. if val is None: if attr_meta and attr_meta.required: - self.fail(expr, f"Attribute '{attr_name}' is required.") + self._fail(expr, f"Attribute '{attr_name}' is required.") return None attr_type = attr_meta.type if attr_meta else None if attr_type == ir.AttributeType.TENSOR: @@ -589,7 +585,7 @@ def _translate_attr( def _translate_docstring(self, node: ast.Expr) -> None: if not isinstance(node.value, ast.Constant): - self.fail(node, "Docstring expression must be a constant.") + self._fail(node, "Docstring expression must be a constant.") self._current_fn.doc_string = node.value.value def _translate_expr(self, node: ast.AST, target: PreferredName | None = None) -> ir.Value: @@ -620,8 +616,8 @@ def _translate_expr(self, node: ast.AST, target: PreferredName | None = None) -> callee, args, attrs = r target = "tmp" if target is None else target assert isinstance(target, str) - result = self.generate_unique_name(target) - return self.emit1([result], callee, args, attrs) + result = self._generate_unique_name(target) + return self._emit1([result], callee, args, attrs) def _translate_opt_expr(self, node: ast.expr) -> ir.Value | None: """Translation of an expression where "None" is permitted (eg., for an optional argument). @@ -674,7 +670,7 @@ def _translate_subscript_expr( var_name = var.name if target is None: target = f"{var_name}_subscripted" - target = self.generate_unique_name(target) + target = self._generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) info = self._source_of(node.slice) @@ -715,8 +711,8 @@ def translate_slice_component( raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: value = self._translate_expr(node_arg) - reshaped = self.generate_unique_name(f"{value.name}_reshaped") - reshaped_value = self.emit1( + reshaped = self._generate_unique_name(f"{value.name}_reshaped") + reshaped_value = self._emit1( [reshaped], values.Op(self.default_opset, "Reshape"), [value, one_1d()], @@ -767,7 +763,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value non_scalar_indices.append((axis, elt)) if not (sliced_indices or scalar_indices or non_scalar_indices): # Edge case: no index specified. Eg. A[:, :] - return self.emit1([target], "Identity", [var_name]) + return self._emit1([target], "Identity", [var_name]) if sliced_indices or len(scalar_indices) > 1: # We emit a Slice operation if we have any indices like 1:5:2 or if the number of @@ -803,17 +799,17 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value if len(starts) > 1: axis_0_attr = ir.AttrInt64("axis", 0) - start_name = self.generate_unique_name(f"{var_name}_start") - start_value = self.emit([start_name], "Concat", starts, [axis_0_attr]) + start_name = self._generate_unique_name(f"{var_name}_start") + start_value = self._emit([start_name], "Concat", starts, [axis_0_attr]) - end_name = self.generate_unique_name(f"{var_name}_end") - end_value = self.emit([end_name], "Concat", ends, [axis_0_attr]) + end_name = self._generate_unique_name(f"{var_name}_end") + end_value = self._emit([end_name], "Concat", ends, [axis_0_attr]) - axes_name = self.generate_unique_name(f"{var_name}_axis") - axes_value = self.emit([axes_name], "Concat", axes, [axis_0_attr]) + axes_name = self._generate_unique_name(f"{var_name}_axis") + axes_value = self._emit([axes_name], "Concat", axes, [axis_0_attr]) - steps_name = self.generate_unique_name(f"{var_name}_step") - steps_value = self.emit([steps_name], "Concat", steps, [axis_0_attr]) + steps_name = self._generate_unique_name(f"{var_name}_step") + steps_value = self._emit([steps_name], "Concat", steps, [axis_0_attr]) else: start_value = starts[0] end_value = ends[0] @@ -821,8 +817,8 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value steps_value = steps[0] if squeezed_axes: - sliced_name = self.generate_unique_name(f"{var_name}_sliced") - sliced_value = self.emit( + sliced_name = self._generate_unique_name(f"{var_name}_sliced") + sliced_value = self._emit( [sliced_name], "Slice", [var, start_value, end_value, axes_value, steps_value], @@ -830,18 +826,18 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze - result_name = self.generate_unique_name(f"{var_name}_squeezed") + result_name = self._generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target result_name = target - result = self.emit([result_name], "Squeeze", [sliced_value, squeezed_axes]) + result = self._emit([result_name], "Squeeze", [sliced_value, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice - result_name = self.generate_unique_name(f"{var_name}_sliced") + result_name = self._generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target result_name = target slice_inputs = [var, start_value, end_value, axes_value, steps_value] - result = self.emit1([result_name], "Slice", slice_inputs) + result = self._emit1([result_name], "Slice", slice_inputs) else: result = var non_scalar_indices.extend(scalar_indices) @@ -856,10 +852,10 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather - gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") + gathered = self._generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - result = self.emit1([gathered], "Gather", [result, index_value], [axis_attr]) + result = self._emit1([gathered], "Gather", [result, index_value], [axis_attr]) return result @@ -948,8 +944,8 @@ def _translate_compare_expr(self, node): op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal") left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": - tmp = self.generate_unique_name() - tmp_value = self.emit1([tmp], op, [left, right]) + tmp = self._generate_unique_name() + tmp_value = self._emit1([tmp], op, [left, right]) not_op = values.Op(self.default_opset, "Not") return not_op, [tmp_value], [] @@ -958,7 +954,6 @@ def _translate_compare_expr(self, node): def _translate_name_expr(self, node: ast.Name) -> ir.Value: return self._py_var_to_onnx_var(node.id, self._source_of(node)) - # pylint: disable=inconsistent-return-statements def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: """Return an Opset""" if isinstance(node, ast.Name): @@ -968,21 +963,17 @@ def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: elif isinstance(val, builder.OpBuilder): # Convert OpBuilder to Opset for compatibility return values.Opset(val.domain, val.version) - self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") + self._fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") elif isinstance(node, ast.Attribute): - self.fail(node, "Nested module unimplemented.") # TODO + self._fail(node, "Nested module unimplemented.") # TODO else: - self.fail(node, "Invalid opset expression.") + self._fail(node, "Invalid opset expression.") - # pylint: enable=inconsistent-return-statements def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710 """Return an Op""" if isinstance(node, ast.Attribute): module = self._translate_opset_expr(node.value) self._set_default_opset(module, node) - opname = node.attr - if opname in module: - return values.Op(module, node.attr) return values.Op(module, node.attr) if isinstance(node, ast.Name): function_name = node.id @@ -996,7 +987,7 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable f"The ONNX graph may not work." ) return values.Op(self.default_opset, function_name) - self.fail(node, "Invalid callee") + self._fail(node, "Invalid callee") def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: """Statement translation: A single Python statement is mapped into a @@ -1046,7 +1037,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: elif isinstance(lhs, ast.Tuple): # Assignments of the form "x, y, z = op.SomeOp(...)" if not isinstance(rhs, ast.Call): - self.fail( + self._fail( rhs, f"RHS must be a Call expression for unpacking, found: '{type(rhs)!r}'", ) @@ -1054,11 +1045,13 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): - self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") - return self.generate_unique_name(x.id) + self._fail( + x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'" + ) + return self._generate_unique_name(x.id) output_names = [generate_onnx_name(x) for x in lhs.elts] - outputs = self.emit(output_names, callee, inputs, attrs) + outputs = self._emit(output_names, callee, inputs, attrs) if isinstance(outputs, ir.Value): outputs = [outputs] for x, output in zip(lhs.elts, outputs): @@ -1067,7 +1060,7 @@ def generate_onnx_name(x: ast.AST): values.SymbolValue(output, self._source_of(x)), ) else: - self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") + self._fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") if isinstance(stmt, ast.Assign): targets = stmt.targets @@ -1075,7 +1068,7 @@ def generate_onnx_name(x: ast.AST): targets = [stmt.target] if len(targets) != 1: # Assignments of the form "x = y = SomeExpression" - self.fail(stmt, "Multi-assignment not supported.") + self._fail(stmt, "Multi-assignment not supported.") lhs = targets[0] rhs = stmt.value if isinstance(rhs, ast.Tuple): @@ -1083,10 +1076,10 @@ def generate_onnx_name(x: ast.AST): if not isinstance(lhs, ast.Tuple): # Assignments of the form "single_var = Expression1, Expression2". # We do not support tuple-typed variables. - self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.") + self._fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.") # Parallel assignments of the form "x, y = Expression1, Expression2" if len(lhs.elts) != len(rhs.elts): - self.fail( + self._fail( stmt, "Expected same number of elements on lhs and rhs of assignments." ) for p, r in zip(lhs.elts, rhs.elts): @@ -1155,26 +1148,26 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: live_defs = list(live_def_set) test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno - thenGraph = self._translate_block(stmt.body, f"thenGraph_{lineno}", live_defs) - thenAttr = ir.AttrGraph("then_branch", thenGraph) - elseGraph = self._translate_block(stmt.orelse, f"elseGraph_{lineno}", live_defs) - elseAttr = ir.AttrGraph("else_branch", elseGraph) + then_graph = self._translate_block(stmt.body, f"thenGraph_{lineno}", live_defs) + then_attr = ir.AttrGraph("then_branch", then_graph) + else_graph = self._translate_block(stmt.orelse, f"elseGraph_{lineno}", live_defs) + else_attr = ir.AttrGraph("else_branch", else_graph) def rename(x): - return self.generate_unique_name(x) + return self._generate_unique_name(x) # no break condition renamed = [rename(x) for x in live_defs] if not renamed: - self.fail(stmt, "A subgraph for a test do not have any output variable.") + self._fail(stmt, "A subgraph for a test do not have any output variable.") if renamed == [test.name]: - self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") - if_outputs = self.emit( + self._fail(stmt, f"Input and output cannot be the same {renamed!r}.") + if_outputs = self._emit( renamed, values.Op(self.default_opset, "If"), [test], - [thenAttr, elseAttr], + [then_attr, else_attr], ) if isinstance(if_outputs, ir.Value): if_outputs = [if_outputs] @@ -1188,23 +1181,23 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # loop-variable if isinstance(loop_stmt, ast.For): if not isinstance(loop_stmt.target, ast.Name): - self.fail(loop_stmt, "For loop target must be a single variable.") + self._fail(loop_stmt, "For loop target must be a single variable.") python_loop_var_name = loop_stmt.target.id # iter iter = loop_stmt.iter assert isinstance(iter, ast.Call), "Loop bound not a call." if not isinstance(iter.func, ast.Name): - self.fail(loop_stmt, f"Unsupported loop bound {iter.func!r}.") + self._fail(loop_stmt, f"Unsupported loop bound {iter.func!r}.") if iter.func.id != "range": - self.fail( + self._fail( loop_stmt, "Unsupported loop bound, only function 'range' is allowed." ) if not iter.args or len(iter.args) != 1: - self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") + self._fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound") onnx_cond_var = make_value( - self.generate_unique_name("cond_in"), + self._generate_unique_name("cond_in"), onnx_types.BOOL, self._source_of(loop_stmt), ) @@ -1214,7 +1207,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: elif isinstance(loop_stmt, ast.While): test = loop_stmt.test if not isinstance(test, ast.Name): - self.fail( + self._fail( loop_stmt, "Unexpected condition type {type(loop_stmt)!r} for a while loop, " "it should be 'while :'.", @@ -1222,7 +1215,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: python_loop_var_name = "infinite_loop" o_loop_bound = None i_cond_var = make_value( - self.generate_unique_name(test.id), + self._generate_unique_name(test.id), onnx_types.BOOL, self._source_of(loop_stmt), ) @@ -1232,7 +1225,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # we need to go through all the instructions to see # which instruction defines the condition test.id else: - self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") + self._fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body exposed_uses = self.analyzer.exposed_uses(loop_stmt.body) vars_def_in_loop = self.analyzer.assigned_vars(loop_stmt.body) @@ -1247,7 +1240,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # build loop_body self._enter_scope("loop_body", loop_stmt) - onnx_loop_var_name = self.generate_unique_name(python_loop_var_name) + onnx_loop_var_name = self._generate_unique_name(python_loop_var_name) onnx_loop_var = make_value( onnx_loop_var_name, onnx_types.INT64, @@ -1262,7 +1255,7 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: self._current_fn.append_parameter(i_cond_var) for pv in loop_state_vars: - onnx_var_name = self.generate_unique_name(pv) + onnx_var_name = self._generate_unique_name(pv) parameter = make_value(onnx_var_name, None, self._source_of(loop_stmt)) self._current_fn.append_parameter(parameter) self._bind( @@ -1281,18 +1274,18 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # This instruction must be the last of the loop body. if isinstance(s, ast.If) and len(s.body) == 1 and isinstance(s.body[0], ast.Break): if not isinstance(s.test, ast.Name): - self.fail( + self._fail( s, f"Instruction break can be introduced with test but it must be " f"if : break. However condition is of type " f"{type(s.test)!r}.", ) if i != len(loop_stmt.body) - 1: - self.fail(s, "Instruction break must be the last one of the loop.") + self._fail(s, "Instruction break must be the last one of the loop.") current_scope = self._current_scope() if s.test.id not in current_scope: - self.fail( + self._fail( loop_stmt, f"Unable to find condition variable {s.test.id!r} in known " f"variables {list(current_scope)!r}.", @@ -1306,15 +1299,15 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: # Loop while current_scope = self._current_scope() if cond_while not in current_scope: - self.fail( + self._fail( loop_stmt, f"Unable to find condition variable {cond_while} in known " f"variables {list(current_scope)!r}.", ) onnx_cond_var = current_scope[cond_while].value - cond_out = self.emit1( - [self.generate_unique_name("cond_out")], + cond_out = self._emit1( + [self._generate_unique_name("cond_out")], values.Op(self.default_opset, operator_name), [condition_name or onnx_cond_var], [], @@ -1339,11 +1332,11 @@ def _translate_loop_stmt(self, loop_stmt: ast.For | ast.While) -> None: info = self._source_of(loop_stmt) def rename(x): - r = self.generate_unique_name(x) + r = self._generate_unique_name(x) return r onnx_output_names = [rename(x) for x in outputs] - loop_outputs = self.emit( + loop_outputs = self._emit( onnx_output_names, "Loop", inputs, @@ -1383,7 +1376,7 @@ def _translate_block( python_var_value = scope[python_var] break if python_var_value is None: - self.fail( + self._fail( stmts[0], f"ir.Value {python_var} is not assigned a value along a conditional " f"branch, known variables: {list(self._locals)}.", @@ -1451,7 +1444,7 @@ def _translate_function_signature_common( invalid = False for t in self.returntype: if not ta.is_valid_type(t): - self.warn( + self._warn( fn.returns, f"Unsupported type annotation for return value {t}.", )