Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.fx
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -172,6 +173,30 @@ def create_node(
return node


def create_shape_node(
graph: torch.fx.Graph,
op_target: EdgeOpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
from_node: Optional[torch.fx.Node] = None,
):
"""Adds a shape node to 'graph'.

graph.inserting_before/after() should be used before the call to decide
where to insert the node.

"""
node = create_node(
graph=graph,
op_target=op_target,
args=args,
kwargs=kwargs,
from_node=from_node,
)
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
return node


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
Expand Down
165 changes: 150 additions & 15 deletions backends/arm/_passes/rewrite_upsample.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import sympy # type: ignore

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
create_shape_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.arm.tosa.utils import get_resize_parameters
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RewriteUpsamplePass(ArmPass):
"""Rewrite upsample2d nodes to TOSA.RESIZE nodes."""
"""Rewrite upsample2d nodes to TOSA.RESIZE nodes with appropriate
parameters.

For constant parameters, CONST_SHAPE nodes are inserted for the scale,
offset, and border values. For symbolic parameters, the parameters are
directly passed to the TOSA.RESIZE node, and we rely on subsequent passes to
handle them correctly once symbolic shapes are delegated by the TOSA
backend.

"""

targeted_ops = (
exir_ops.edge.aten.upsample_nearest2d.vec,
Expand All @@ -27,6 +38,89 @@ class RewriteUpsamplePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

@staticmethod
def get_resize_parameters_1d(
input_size: int | torch.SymInt,
output_size: int | torch.SymInt,
align_corners: bool,
):
"""Compute resize coefficients for a single spatial dimension.

Args:
input_size (int | torch.SymInt): Input size for the axis, possibly
symbolic.
output_size (int | torch.SymInt): Output size for the axis, possibly
symbolic.
align_corners (bool): Whether the resize should align the corner
pixels.

Returns:
tuple[int, int, int, int]: Numerator, denominator, offset, and border
terms encoded as integers.

Raises:
RuntimeError: If symbolic shapes are used with ``align_corners`` or if
the computed ratio or border is not constant.

"""
# We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
if align_corners:
if (not isinstance(input_size, int)) or (not isinstance(output_size, int)):
raise RuntimeError(
"We do not support align_corners=True for symbolic shapes."
)

# SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead
input_size = (
input_size.node._expr
if isinstance(input_size, torch.SymInt)
else input_size
)
output_size = (
output_size.node._expr
if isinstance(output_size, torch.SymInt)
else output_size
)
if align_corners and input_size > 1 and output_size > 1:
scale_n = output_size - 1
else:
scale_n = output_size
if align_corners and input_size > 1 and output_size > 1:
scale_d = input_size - 1
else:
scale_d = input_size
ratio = scale_n / scale_d
if not sympy.sympify(ratio).is_constant():
raise RuntimeError(
"Resize requires a constant ratio: " + str(ratio) + " is not constant!"
)
gcd = sympy.gcd(scale_n, scale_d)
scale_n = 2 * scale_n // gcd
scale_d = 2 * scale_d // gcd
# These should always be whole integers, based on the above calculations
scale_n = int(scale_n.evalf())
scale_d = int(scale_d.evalf())

if align_corners:
offset = 0
else:
# Half pixel centers so input and output sampling positions are offset by 1/2 pixel.
offset = scale_d // 2 - scale_n // 2

# Calculate border to maintain the correct the output size.
# Note that this should always result in a constant value, as the ratio is constant.
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset

if not sympy.sympify(border).is_constant():
raise RuntimeError(
"Resize requires a constant border: "
+ str(border)
+ " is not constant!"
)

border = int(sympy.sympify(border).evalf())
return scale_n, scale_d, offset, border

def call(self, graph_module):
modified = False
for node in graph_module.graph.nodes:
Expand All @@ -39,14 +133,65 @@ def call(self, graph_module):
resize_mode = "bilinear"
else:
x, output_size, scale_factors = node.args
# As per https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html
# align_corners is not valid for nearest mode. Default to False.
align_corners = False
resize_mode = "nearest"

input_size_yx = node.args[0].meta["val"].shape[2:]
output_size_yx = node.meta["val"].shape[2:]

scale_y_n, scale_y_d, offset_y, border_y = (
RewriteUpsamplePass.get_resize_parameters_1d(
input_size_yx[0], output_size_yx[0], align_corners
)
)
scale_x_n, scale_x_d, offset_x, border_x = (
RewriteUpsamplePass.get_resize_parameters_1d(
input_size_yx[1], output_size_yx[1], align_corners
)
)

scales = [
scale_y_n,
scale_y_d,
scale_x_n,
scale_x_d,
]
with graph_module.graph.inserting_before(node):
if all(isinstance(s, int) for s in scales):
scale = create_shape_node(
graph_module.graph,
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
args=(scales,),
kwargs={},
from_node=node,
)
else:
scale = scales
offset = [offset_y, offset_x]
if all(isinstance(o, int) for o in offset):
offset = create_shape_node(
graph_module.graph,
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
args=(offset,),
kwargs={},
from_node=node,
)
border = [border_y, border_x]
if all(isinstance(b, int) for b in border):
border = create_shape_node(
graph_module.graph,
op_target=exir_ops.backend.tosa.CONST_SHAPE.default,
args=(border,),
kwargs={},
from_node=node,
)

tosa_resize_node = create_node(
graph_module.graph,
op_target=exir_ops.backend.tosa.RESIZE.default,
args=(x, output_size, align_corners, scale_factors),
args=(x, scale, offset, border),
kwargs={"resize_mode": resize_mode},
from_node=node,
inherit_qparams=True,
Expand All @@ -57,18 +202,8 @@ def call(self, graph_module):
if (
input_dtype == torch.int8 or input_dtype == torch.int16
) and resize_mode == "bilinear":
input_size = get_first_fake_tensor(x).shape
input_size_xy = input_size[2:]
output_size = get_first_fake_tensor(node).shape
output_size_xy = output_size[2:]
scale_n_yx, _, _, _ = get_resize_parameters(
input_size_xy=input_size_xy,
output_size_xy=output_size_xy,
resize_mode=1,
align_corners=align_corners,
)
output_dtype = get_first_fake_tensor(node).dtype
output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))
output_scale = float(1 / (scale_y_n * scale_x_n))
with graph_module.graph.inserting_after(tosa_resize_node):
rescale_node = create_node(
graph_module.graph,
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None:
raise RuntimeError(
f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}"
)
if node.target == exir_ops.backend.tosa.RESIZE.default:
# RESIZE's shape input is expected to be in HW order, so we need to override the dim order to be the identity for it regardless of the user node's dim order.
dim_order = tuple(range(len(arg.meta["val"])))
arg.meta["tosa_dim_order"] = dim_order
self._propagate_dim_order_to_shape_args(arg)

Expand Down
84 changes: 6 additions & 78 deletions backends/arm/operators/op_tosa_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.utils import get_resize_parameters


@register_node_visitor
Expand All @@ -36,81 +33,12 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [3, 4])
supported_input_dtypes = [
ts.DType.INT8,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
]
if self.tosa_spec.support_extension("int16"):
supported_input_dtypes.append(ts.DType.INT16)
if self.tosa_spec.support_extension("bf16"):
supported_input_dtypes.append(ts.DType.BF16)
validate_valid_dtype(
self.target,
[inputs[0]],
supported_input_dtypes,
self.tosa_spec,
)
supported_output_dtypes = [ts.DType.FP16, ts.DType.FP32, ts.DType.BF16]
x, scales, offset, border = inputs
validate_num_inputs(self.target, inputs, [4])
if node.kwargs.get("resize_mode") == "bilinear":
resize_mode = ts.ResizeMode.BILINEAR
align_corners = bool(node.args[2])
supported_output_dtypes.append(ts.DType.INT32)
if self.tosa_spec.support_extension("int16"):
supported_output_dtypes.append(ts.DType.INT48)
else:
resize_mode = ts.ResizeMode.NEAREST
align_corners = False
validate_same_dtype(self.target, [inputs[0], output], ts)
supported_output_dtypes.append(ts.DType.INT8)
if self.tosa_spec.support_extension("int16"):
supported_output_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target, [output], supported_output_dtypes, self.tosa_spec
)
# tosa_shape output is NHWC, take HW
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
1:3
]
output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3]

# Align corners shouldn't make a difference for nearest upsampling. We set to False so
# half pixel centers are used for resize parameter logic.
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
input_size_yx, output_size_yx, resize_mode, align_corners=align_corners
)

def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

if not in_int16_range(scale_n_yx):
raise ValueError("scale_n_yx is out of the int16 range")
if not in_int16_range(scale_d_yx):
raise ValueError("scale_d_yx is out of the int16 range")
if not in_int16_range(border_yx):
raise ValueError("border_yx is out of the int16 range")

scale_n_vals = [int(v) for v in scale_n_yx.tolist()]
scale_d_vals = [int(v) for v in scale_d_yx.tolist()]
scales = [
scale_n_vals[0],
scale_d_vals[0],
scale_n_vals[1],
scale_d_vals[1],
]
scales_tensor = tosa_graph.addConst(
[len(scales)], ts.DType.SHAPE, scales, output.name + "_scales"
)
offset = [int(v) for v in offset_yx.tolist()]
offset_tensor = tosa_graph.addConst(
[len(offset)], ts.DType.SHAPE, offset, output.name + "_offset"
)
border = [int(v) for v in border_yx.tolist()]
border_tensor = tosa_graph.addConst(
[len(border)], ts.DType.SHAPE, border, output.name + "_border"
)
attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(resize_mode)

Expand All @@ -119,10 +47,10 @@ def in_int16_range(x):
tosa_graph,
ts.Op.RESIZE,
[
inputs[0].name,
scales_tensor.name,
offset_tensor.name,
border_tensor.name,
x.name,
scales.name,
offset.name,
border.name,
],
[output.name],
attr,
Expand Down
Loading
Loading