diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 9f55c0ca568..52ff3c604ac 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -14,6 +14,7 @@ import torch.fx from executorch.backends.arm.common.debug import get_node_debug_info from executorch.backends.arm.common.type import ensure_type +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -172,6 +173,30 @@ def create_node( return node +def create_shape_node( + graph: torch.fx.Graph, + op_target: EdgeOpOverload, + args: tuple = (), + kwargs: Optional[dict] = None, + from_node: Optional[torch.fx.Node] = None, +): + """Adds a shape node to 'graph'. + + graph.inserting_before/after() should be used before the call to decide + where to insert the node. + + """ + node = create_node( + graph=graph, + op_target=op_target, + args=args, + kwargs=kwargs, + from_node=from_node, + ) + node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + return node + + def insert_q_dq_pair( graph: torch.fx.Graph, anchor: torch.fx.Node, diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index cff241d33cf..4a864125faf 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -1,24 +1,35 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Set, Type +import sympy # type: ignore + import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, + create_shape_node, get_first_fake_tensor, ) from executorch.backends.arm.tosa.mapping import TosaSpecialDtype -from executorch.backends.arm.tosa.utils import get_resize_parameters from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult class RewriteUpsamplePass(ArmPass): - """Rewrite upsample2d nodes to TOSA.RESIZE nodes.""" + """Rewrite upsample2d nodes to TOSA.RESIZE nodes with appropriate + parameters. + + For constant parameters, CONST_SHAPE nodes are inserted for the scale, + offset, and border values. For symbolic parameters, the parameters are + directly passed to the TOSA.RESIZE node, and we rely on subsequent passes to + handle them correctly once symbolic shapes are delegated by the TOSA + backend. + + """ targeted_ops = ( exir_ops.edge.aten.upsample_nearest2d.vec, @@ -27,6 +38,89 @@ class RewriteUpsamplePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + @staticmethod + def get_resize_parameters_1d( + input_size: int | torch.SymInt, + output_size: int | torch.SymInt, + align_corners: bool, + ): + """Compute resize coefficients for a single spatial dimension. + + Args: + input_size (int | torch.SymInt): Input size for the axis, possibly + symbolic. + output_size (int | torch.SymInt): Output size for the axis, possibly + symbolic. + align_corners (bool): Whether the resize should align the corner + pixels. + + Returns: + tuple[int, int, int, int]: Numerator, denominator, offset, and border + terms encoded as integers. + + Raises: + RuntimeError: If symbolic shapes are used with ``align_corners`` or if + the computed ratio or border is not constant. + + """ + # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky. + if align_corners: + if (not isinstance(input_size, int)) or (not isinstance(output_size, int)): + raise RuntimeError( + "We do not support align_corners=True for symbolic shapes." + ) + + # SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead + input_size = ( + input_size.node._expr + if isinstance(input_size, torch.SymInt) + else input_size + ) + output_size = ( + output_size.node._expr + if isinstance(output_size, torch.SymInt) + else output_size + ) + if align_corners and input_size > 1 and output_size > 1: + scale_n = output_size - 1 + else: + scale_n = output_size + if align_corners and input_size > 1 and output_size > 1: + scale_d = input_size - 1 + else: + scale_d = input_size + ratio = scale_n / scale_d + if not sympy.sympify(ratio).is_constant(): + raise RuntimeError( + "Resize requires a constant ratio: " + str(ratio) + " is not constant!" + ) + gcd = sympy.gcd(scale_n, scale_d) + scale_n = 2 * scale_n // gcd + scale_d = 2 * scale_d // gcd + # These should always be whole integers, based on the above calculations + scale_n = int(scale_n.evalf()) + scale_d = int(scale_d.evalf()) + + if align_corners: + offset = 0 + else: + # Half pixel centers so input and output sampling positions are offset by 1/2 pixel. + offset = scale_d // 2 - scale_n // 2 + + # Calculate border to maintain the correct the output size. + # Note that this should always result in a constant value, as the ratio is constant. + border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset + + if not sympy.sympify(border).is_constant(): + raise RuntimeError( + "Resize requires a constant border: " + + str(border) + + " is not constant!" + ) + + border = int(sympy.sympify(border).evalf()) + return scale_n, scale_d, offset, border + def call(self, graph_module): modified = False for node in graph_module.graph.nodes: @@ -39,14 +133,65 @@ def call(self, graph_module): resize_mode = "bilinear" else: x, output_size, scale_factors = node.args + # As per https://docs.pytorch.org/docs/stable/generated/torch.nn.Upsample.html + # align_corners is not valid for nearest mode. Default to False. align_corners = False resize_mode = "nearest" + input_size_yx = node.args[0].meta["val"].shape[2:] + output_size_yx = node.meta["val"].shape[2:] + + scale_y_n, scale_y_d, offset_y, border_y = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[0], output_size_yx[0], align_corners + ) + ) + scale_x_n, scale_x_d, offset_x, border_x = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[1], output_size_yx[1], align_corners + ) + ) + + scales = [ + scale_y_n, + scale_y_d, + scale_x_n, + scale_x_d, + ] with graph_module.graph.inserting_before(node): + if all(isinstance(s, int) for s in scales): + scale = create_shape_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.CONST_SHAPE.default, + args=(scales,), + kwargs={}, + from_node=node, + ) + else: + scale = scales + offset = [offset_y, offset_x] + if all(isinstance(o, int) for o in offset): + offset = create_shape_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.CONST_SHAPE.default, + args=(offset,), + kwargs={}, + from_node=node, + ) + border = [border_y, border_x] + if all(isinstance(b, int) for b in border): + border = create_shape_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.CONST_SHAPE.default, + args=(border,), + kwargs={}, + from_node=node, + ) + tosa_resize_node = create_node( graph_module.graph, op_target=exir_ops.backend.tosa.RESIZE.default, - args=(x, output_size, align_corners, scale_factors), + args=(x, scale, offset, border), kwargs={"resize_mode": resize_mode}, from_node=node, inherit_qparams=True, @@ -57,18 +202,8 @@ def call(self, graph_module): if ( input_dtype == torch.int8 or input_dtype == torch.int16 ) and resize_mode == "bilinear": - input_size = get_first_fake_tensor(x).shape - input_size_xy = input_size[2:] - output_size = get_first_fake_tensor(node).shape - output_size_xy = output_size[2:] - scale_n_yx, _, _, _ = get_resize_parameters( - input_size_xy=input_size_xy, - output_size_xy=output_size_xy, - resize_mode=1, - align_corners=align_corners, - ) output_dtype = get_first_fake_tensor(node).dtype - output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) + output_scale = float(1 / (scale_y_n * scale_x_n)) with graph_module.graph.inserting_after(tosa_resize_node): rescale_node = create_node( graph_module.graph, diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 52cbf924fc4..0f32fbb52df 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -436,6 +436,9 @@ def _propagate_dim_order_to_shape_args(self, node: torch.fx.Node) -> None: raise RuntimeError( f"Conflicting dim orders {arg.meta['tosa_dim_order']} and {dim_order} for shape node {arg.name}" ) + if node.target == exir_ops.backend.tosa.RESIZE.default: + # RESIZE's shape input is expected to be in HW order, so we need to override the dim order to be the identity for it regardless of the user node's dim order. + dim_order = tuple(range(len(arg.meta["val"]))) arg.meta["tosa_dim_order"] = dim_order self._propagate_dim_order_to_shape_args(arg) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index c875ebe01e2..036c32a01da 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -15,11 +15,8 @@ ) from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import get_resize_parameters @register_node_visitor @@ -36,81 +33,12 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - validate_num_inputs(self.target, inputs, [3, 4]) - supported_input_dtypes = [ - ts.DType.INT8, - ts.DType.FP16, - ts.DType.FP32, - ts.DType.BF16, - ] - if self.tosa_spec.support_extension("int16"): - supported_input_dtypes.append(ts.DType.INT16) - if self.tosa_spec.support_extension("bf16"): - supported_input_dtypes.append(ts.DType.BF16) - validate_valid_dtype( - self.target, - [inputs[0]], - supported_input_dtypes, - self.tosa_spec, - ) - supported_output_dtypes = [ts.DType.FP16, ts.DType.FP32, ts.DType.BF16] + x, scales, offset, border = inputs + validate_num_inputs(self.target, inputs, [4]) if node.kwargs.get("resize_mode") == "bilinear": resize_mode = ts.ResizeMode.BILINEAR - align_corners = bool(node.args[2]) - supported_output_dtypes.append(ts.DType.INT32) - if self.tosa_spec.support_extension("int16"): - supported_output_dtypes.append(ts.DType.INT48) else: resize_mode = ts.ResizeMode.NEAREST - align_corners = False - validate_same_dtype(self.target, [inputs[0], output], ts) - supported_output_dtypes.append(ts.DType.INT8) - if self.tosa_spec.support_extension("int16"): - supported_output_dtypes.append(ts.DType.INT16) - validate_valid_dtype( - self.target, [output], supported_output_dtypes, self.tosa_spec - ) - # tosa_shape output is NHWC, take HW - input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ - 1:3 - ] - output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3] - - # Align corners shouldn't make a difference for nearest upsampling. We set to False so - # half pixel centers are used for resize parameter logic. - scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, output_size_yx, resize_mode, align_corners=align_corners - ) - - def in_int16_range(x): - return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - - if not in_int16_range(scale_n_yx): - raise ValueError("scale_n_yx is out of the int16 range") - if not in_int16_range(scale_d_yx): - raise ValueError("scale_d_yx is out of the int16 range") - if not in_int16_range(border_yx): - raise ValueError("border_yx is out of the int16 range") - - scale_n_vals = [int(v) for v in scale_n_yx.tolist()] - scale_d_vals = [int(v) for v in scale_d_yx.tolist()] - scales = [ - scale_n_vals[0], - scale_d_vals[0], - scale_n_vals[1], - scale_d_vals[1], - ] - scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, output.name + "_scales" - ) - offset = [int(v) for v in offset_yx.tolist()] - offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, output.name + "_offset" - ) - border = [int(v) for v in border_yx.tolist()] - border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, output.name + "_border" - ) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(resize_mode) @@ -119,10 +47,10 @@ def in_int16_range(x): tosa_graph, ts.Op.RESIZE, [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, + x.name, + scales.name, + offset.name, + border.name, ], [output.name], attr, diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 606440eff05..c4e720cd849 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Literal, Optional +from typing import Literal import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError @@ -13,55 +13,90 @@ get_context_spec, TosaSpecification, ) -from executorch.exir.dialects._ops import ops as exir_ops -# Add kwarg instead? -@register_fake_tosa_op( - "RESIZE(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, str resize_mode) -> Tensor", # schema - TosaSpecification.all_versions_and_profiles(), # target TOSA specifications -) -def RESIZE( - x: torch.Tensor, - output_size: list[int] | None = None, - align_corners: Optional[bool] = False, - scale_factors: list[float] | None = None, - *, - resize_mode: Literal["nearest", "bilinear"], -) -> torch.Tensor: - tosa_spec = get_context_spec() - +def _validate_resize_mode(resize_mode: str) -> None: if resize_mode not in ("nearest", "bilinear"): raise TosaValueError(f"Unsupported resize mode {resize_mode} for TOSA RESIZE") - if x.dtype == torch.int8: + + +def _get_output_dtype( + dtype: torch.dtype, tosa_spec: TosaSpecification, resize_mode: str +) -> torch.dtype: + if dtype == torch.int8: if not tosa_spec.support_integer(): raise TosaValueError( f"TOSA spec {tosa_spec} doesn't support integers", op="RESIZE" ) - bilinear = resize_mode == "bilinear" - output_dtype = torch.int32 if bilinear else torch.int8 - elif x.dtype == torch.int16: + output_dtype = torch.int8 if resize_mode == "nearest" else torch.int32 + elif dtype == torch.int16: if not tosa_spec.support_integer(): raise TosaValueError( f"Context TOSA spec {tosa_spec} doesn't support int16", op="RESIZE" ) - output_dtype = x.dtype - elif x.dtype in (torch.float16, torch.float32, torch.bfloat16): + output_dtype = dtype + elif dtype in (torch.float16, torch.float32, torch.bfloat16): if not tosa_spec.support_float(): raise TosaValueError( f"TOSA spec {tosa_spec} doesn't support float", op="RESIZE" ) - if x.dtype == torch.bfloat16 and not tosa_spec.support_extension("bf16"): + if dtype == torch.bfloat16 and not tosa_spec.support_extension("bf16"): raise TosaValueError( f"TOSA spec {tosa_spec} doesn't support bf16", op="RESIZE" ) - output_dtype = x.dtype + output_dtype = dtype else: - raise TosaValueError(f"Unsupported input dtype {x.dtype} for TOSA RESIZE") + raise TosaValueError(f"Unsupported input dtype {dtype}", op="RESIZE") + return output_dtype + + +def _validate_resize_parameters(scale, border): + def in_int16_range(values): + return all((x >= -(2**15)) and (x <= 2**15 - 1) for x in values) + + if not in_int16_range(scale): + raise TosaValueError("scale is out of the int16 range", op="RESIZE") + if not in_int16_range(border): + raise TosaValueError("border is out of the int16 range", op="RESIZE") + + +@register_fake_tosa_op( + "RESIZE(Tensor input, SymInt[4] scale_factors, SymInt[2] offset, SymInt[2] border, *, str resize_mode) -> Tensor", # schema + TosaSpecification.all_versions_and_profiles(), # target TOSA specifications +) +def RESIZE( + x: torch.Tensor, + scale: list[torch.SymInt], + offset: list[torch.SymInt], + border: list[torch.SymInt], + *, + resize_mode: Literal["nearest", "bilinear"], +) -> torch.Tensor: + tosa_spec = get_context_spec() + + if x.dim() != 4: + raise TosaValueError( + f"Input tensor must be 4D, but got {x.dim()}D", op="RESIZE" + ) + _validate_resize_mode(resize_mode) + _validate_resize_parameters(scale, border) + output_dtype = _get_output_dtype(x.dtype, tosa_spec, resize_mode) - # Does it matter which one to use for fake tracing? - fake_aten_tensor = exir_ops.edge.aten.upsample_nearest2d.vec( - x, output_size, scale_factors - ) + input_shape = x.shape + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale + offset_y, offset_x = offset + border_y, border_x = border + H, W = input_shape[2], input_shape[3] + # RESIZE first upscales the input by an integer value, to "upscale space". + H_upscaled = (H - 1) * scale_y_n + # offset and border are provided in this scale, therefore adjust for these while in this space. + H_shifted = H_upscaled - offset_y + border_y + # Then, complete the RESIZE by downscaling with another integer value, approximating multplication with a fraction. + OH = (H_shifted // scale_y_d) + 1 + # Mirror the same computation horizontally for the output width. + W_upscaled = (W - 1) * scale_x_n + W_shifted = W_upscaled - offset_x + border_x + OW = (W_shifted // scale_x_d) + 1 + fake_aten_tensor = torch.empty(size=(*input_shape[:2], OH, OW), dtype=output_dtype) - return fake_aten_tensor.to(output_dtype) + return fake_aten_tensor diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 6666e422039..602a9548791 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -9,8 +9,6 @@ import numpy as np -import sympy # type: ignore - import torch import tosa_serializer as ts @@ -186,116 +184,3 @@ def tosa_shape(shape, dim_order): [-1 if isinstance(d, torch.SymInt) else d for d in reordered] ) return list(removed_symints) - - -def get_resize_parameters_1d( - input_size: int | torch.SymInt, - output_size: int | torch.SymInt, - resize_mode: int, - align_corners: bool, -): - """Compute resize coefficients for a single spatial dimension. - - Args: - input_size (int | torch.SymInt): Input size for the axis, possibly - symbolic. - output_size (int | torch.SymInt): Output size for the axis, possibly - symbolic. - resize_mode (int): Target resize mode defined by TOSA. - align_corners (bool): Whether the resize should align the corner - pixels. - - Returns: - tuple[int, int, int, int]: Numerator, denominator, offset, and border - terms encoded as integers. - - Raises: - RuntimeError: If symbolic shapes are used with ``align_corners`` or if - the computed ratio or border is not constant. - - """ - # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky. - if align_corners: - if (not isinstance(input_size, int)) or (not isinstance(output_size, int)): - raise RuntimeError( - "We do not support align_corners=True for symbolic shapes." - ) - - # SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead - input_size = ( - input_size.node._expr if isinstance(input_size, torch.SymInt) else input_size - ) - output_size = ( - output_size.node._expr if isinstance(output_size, torch.SymInt) else output_size - ) - if align_corners and input_size > 1 and output_size > 1: - scale_n = output_size - 1 - else: - scale_n = output_size - if align_corners and input_size > 1 and output_size > 1: - scale_d = input_size - 1 - else: - scale_d = input_size - ratio = scale_n / scale_d - if not sympy.sympify(ratio).is_constant(): - raise RuntimeError( - "Resize requires a constant ratio: " + str(ratio) + " is not constant!" - ) - gcd = sympy.gcd(scale_n, scale_d) - scale_n = 2 * scale_n // gcd - scale_d = 2 * scale_d // gcd - # These should always be whole integers, based on the above calculations - scale_n = int(scale_n.evalf()) - scale_d = int(scale_d.evalf()) - - if align_corners: - offset = 0 - else: - # Half pixel centers so input and output sampling positions are offset by 1/2 pixel. - offset = scale_d // 2 - scale_n // 2 - - # Calculate border to maintain the correct the output size. - # Note that this should always result in a constant value, as the ratio is constant. - border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset - - if not sympy.sympify(border).is_constant(): - raise RuntimeError( - "Resize requires a constant border: " + str(border) + " is not constant!" - ) - - border = int(sympy.sympify(border).evalf()) - return scale_n, scale_d, offset, border - - -def get_resize_parameters( - input_size_xy: tuple[int | torch.SymInt, int | torch.SymInt], - output_size_xy: tuple[int | torch.SymInt, int | torch.SymInt], - resize_mode: int, - align_corners: bool, -) -> tuple[torch.IntTensor, ...]: - """Calculate 2D resize parameters for TOSA emission. - - Args: - input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height - and width of the input tensor. - output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height - and width of the output tensor. - resize_mode (int): TOSA resize mode used for coefficient generation. - align_corners (bool): Whether to align corner pixels between input and - output. - - Returns: - tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing - the scale numerator, scale denominator, offset, and border for Y - and X dimensions. - - """ - # Get the parameters for each dimension independently - y_params = get_resize_parameters_1d( - input_size_xy[0], output_size_xy[0], resize_mode, align_corners - ) - x_params = get_resize_parameters_1d( - input_size_xy[1], output_size_xy[1], resize_mode, align_corners - ) - # Combine them together, so we return four 2-element tensors (scale_n, scale_d, offset, border) - return tuple(map(torch.IntTensor, zip(y_params, x_params)))