Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions onnxscript/_internal/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
return live

def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
for s in reversed(block):
live_out = visit(s, live_out)
return live_out
Expand All @@ -167,28 +167,28 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
if isinstance(stmt, ast.If):
constant_cond = self.constant_if_condition(stmt)
if constant_cond is None:
live1 = visitBlock(stmt.body, live_out)
live2 = visitBlock(stmt.orelse, live_out)
live1 = visit_block(stmt.body, live_out)
live2 = visit_block(stmt.orelse, live_out)
return live1 | live2 | _used_vars(stmt.test)
elif constant_cond:
return visitBlock(stmt.body, live_out)
return visit_block(stmt.body, live_out)
else:
return visitBlock(stmt.orelse, live_out)
return visit_block(stmt.orelse, live_out)
if isinstance(stmt, ast.For):
p_loop_var = _get_loop_var(stmt, self._formatter)
prev = None
curr = live_out
while curr != prev:
prev = curr
curr = visitBlock(stmt.body, prev).difference({p_loop_var})
curr = visit_block(stmt.body, prev).difference({p_loop_var})
return curr
if isinstance(stmt, ast.While):
cond_vars = _used_vars(stmt.test)
prev = None
curr = live_out | cond_vars
while curr != prev:
prev = curr
curr = visitBlock(stmt.body, prev) | cond_vars
curr = visit_block(stmt.body, prev) | cond_vars
return curr
if isinstance(stmt, ast.Break):
# The following is sufficient for the current restricted usage, where
Expand Down Expand Up @@ -228,7 +228,7 @@ def exposed_uses(self, stmts: Sequence[ast.stmt]) -> set[str]:
(in the first statement). Hence x is included in the exposed_uses.
"""

def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def visit_block(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
for stmt in reversed(block):
live_out = visit(stmt, live_out)
return live_out
Expand All @@ -243,13 +243,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
if isinstance(stmt, ast.If):
constant_cond = self.constant_if_condition(stmt)
if constant_cond is None:
live1 = visitBlock(stmt.body, live_out)
live2 = visitBlock(stmt.orelse, live_out)
live1 = visit_block(stmt.body, live_out)
live2 = visit_block(stmt.orelse, live_out)
return (live1 | live2) | _used_vars(stmt.test)
elif constant_cond:
return visitBlock(stmt.body, live_out)
return visit_block(stmt.body, live_out)
else:
return visitBlock(stmt.orelse, live_out)
return visit_block(stmt.orelse, live_out)
if ast_utils.is_print_call(stmt):
return live_out
if ast_utils.is_doc_string(stmt):
Expand All @@ -259,13 +259,13 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
# for loops that execute at least once.
loop_var_set = {_get_loop_var(stmt, self._formatter)}
used_after_loop = live_out.difference(loop_var_set)
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set)
used_inside_loop = visit_block(stmt.body, set()).difference(loop_var_set)
used_in_loop_header = _used_vars(stmt.iter)
return used_inside_loop | used_in_loop_header | used_after_loop
if isinstance(stmt, ast.While):
# Analysis assumes loop may execute zero times. Results can be improved
# for loops that execute at least once.
used_inside_loop = visitBlock(stmt.body, set())
used_inside_loop = visit_block(stmt.body, set())
used_in_loop_header = _used_vars(stmt.test)
return used_inside_loop | used_in_loop_header | live_out
if isinstance(stmt, ast.Break):
Expand All @@ -281,7 +281,7 @@ def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
)

return visitBlock(stmts, set())
return visit_block(stmts, set())

def outer_scope_variables(self, fun: ast.FunctionDef) -> set[str]:
"""Return the set of outer-scope variables used in a nested function.
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]:
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
castable, while X can serve as the target-type.
"""
return None if x is None or converter_.is_castable(x.name) else x
return None if x is None or converter_._is_castable(x.name) else x # pylint: disable=protected-access

def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
if x is None:
return None
if converter_.is_castable(x.name) and y is not None:
if converter_._is_castable(x.name) and y is not None: # pylint: disable=protected-access
# Polymorphic constant x is cast to the type of y:
x_cast = converter_.generate_unique_name(f"{x.name}_cast")
return converter_.emit1([x_cast], "CastLike", [x, y])
x_cast = converter_._generate_unique_name(f"{x.name}_cast") # pylint: disable=protected-access
return converter_._emit1([x_cast], "CastLike", [x, y]) # pylint: disable=protected-access
return x

return cast_inputs(get_type_info, cast_like, op_signature, args)
Loading
Loading