diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 80de190fedf..b0c26e3101f 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -329,6 +329,11 @@ - arg_meta: null kernel_name: impl::generic::quantized_add_asym8uxasym8u_asym8u_per_tensor_out +- func: cadence::quantized_mul.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_mul_per_tensor_out + - func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 601d54fe49b..8127b180699 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -376,6 +376,10 @@ def register_fake( "quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" ) +lib.define( + "quantized_mul.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" +) lib.define( "quantized_add_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point) -> (Tensor Z)" @@ -582,6 +586,10 @@ def register_fake( "quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_mul.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_add_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" @@ -892,6 +900,22 @@ def quantized_add_meta( return X.new_empty(out_size, dtype=X.dtype) +@register_fake("cadence::quantized_mul.per_tensor") +def quantized_mul_per_tensor_meta( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + + out_size = torch.broadcast_shapes(X.size(), Y.size()) + return X.new_empty(out_size, dtype=X.dtype) + + @register_fake("cadence::quantized_add.per_tensor") def quantized_add_per_tensor_meta( X: torch.Tensor, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ed8b3ca60ae..2a6839bd5c3 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -403,6 +403,67 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor( ) +@impl_tracked(m, "quantized_mul.per_tensor") +def quantized_mul_per_tensor( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + """ + Multiplies two quantized tensors and returns another quantized tensor. The intuition + is that we want dequant(out) ~= dequant(X) * dequant(Y) + + If we do that math, we get + out_scale(out - out_zero_point) = X_scale(X - X_zero_point) * Y_scale(Y - Y_zero_point) + + Rearranging, we get + out = (X_scale(X - X_zero_point) * Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point + + Args: + - X: The first operand + - X_scale: The ratio between the sizes of X's floating point and quantized + ranges + - X_zero_point: The quantized mapping of zero for X + - Y: The second operand + - Y_scale: The ratio between the sizes of Y's floating point and quantized + ranges + - Y_zero_point: The quantized mapping of zero for Y + - out_scale: The ratio between the sizes of the output's floating point and + quantized ranges + - out_zero_point: The quantized mapping of zero for the output + """ + supported_dtypes = [torch.int8, torch.uint8] + if X.dtype != Y.dtype: + raise ValueError("X and Y dtypes need to match") + + dtype = X.dtype + if dtype not in supported_dtypes: + raise ValueError( + f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" + ) + + if dtype == torch.uint8: + X = X.to(torch.int8) + Y = Y.to(torch.int8) + + dequant_X = X_scale * (X - X_zero_point) + dequant_Y = Y_scale * (Y - Y_zero_point) + + return quantize_per_tensor( + dequant_X * dequant_Y, + out_scale, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, + ) + + def quantized_linear_common( src: torch.Tensor, weight: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 936aad4e585..222fb27bfcd 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2715,6 +2715,48 @@ def test_quantized_add(self) -> None: out_zero_point, ) + @expand( + [ + # X=5, zp=4 → dequant=0.8*(5-4)=0.8; Y=5, zp=4 → dequant=0.8 + # mul=0.64; quantize: round(0.64/0.8)+4=1+4=5 + ("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 5, torch.int8), + ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 5, torch.uint8), + ] + ) + def test_quantized_mul( + self, + name: str, + X: int, + X_scale: float, + X_zero_point: int, + Y: int, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, + expected_value: int, + dtype: torch.dtype, + ) -> None: + X_tensor = torch.tensor([X], dtype=dtype) + Y_tensor = torch.tensor([Y], dtype=dtype) + expected_output = torch.tensor([expected_value], dtype=dtype) + + output = torch.ops.cadence.quantized_mul.per_tensor( + X_tensor, + X_scale, + X_zero_point, + Y_tensor, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + ) + + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match in {name}: got {output}, expected {expected_output}", + ) + def test_requantize(self) -> None: # Test requantize (default variant), just to make sure it runs since wrapper around per_tensor variant input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8) diff --git a/backends/cadence/generic/operators/op_quantized_mul.cpp b/backends/cadence/generic/operators/op_quantized_mul.cpp index 89fb2a5250d..fe3e44e6687 100644 --- a/backends/cadence/generic/operators/op_quantized_mul.cpp +++ b/backends/cadence/generic/operators/op_quantized_mul.cpp @@ -72,6 +72,43 @@ Tensor& quantized_mul_out( return out; } +Tensor& quantized_mul_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { +#define typed_quantized_mul_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_mul_( \ + X, \ + static_cast(X_scale), \ + static_cast(X_zero_point), \ + Y, \ + static_cast(Y_scale), \ + static_cast(Y_zero_point), \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul_per_tensor) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_mul_per_tensor + return out; +} + // Generate kernels that perform elementwise arithmetic on a quantized tensor, // and a scalar. #define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ diff --git a/backends/cadence/generic/operators/op_quantized_mul.h b/backends/cadence/generic/operators/op_quantized_mul.h index 7ca8b2f1db0..2bae0012cf8 100644 --- a/backends/cadence/generic/operators/op_quantized_mul.h +++ b/backends/cadence/generic/operators/op_quantized_mul.h @@ -27,6 +27,18 @@ ::executorch::aten::Tensor& quantized_mul_out( int64_t out_zero_point, ::executorch::aten::Tensor& out); +::executorch::aten::Tensor& quantized_mul_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + ::executorch::aten::Tensor& quantized_mul_Scalar_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& X,