Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@
- arg_meta: null
kernel_name: impl::generic::quantized_add_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::quantized_mul.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_mul_per_tensor_out

- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
24 changes: 24 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def register_fake(
"quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
)
lib.define(
"quantized_mul.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
)
lib.define(
"quantized_add_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
"float out_scale, int out_zero_point) -> (Tensor Z)"
Expand Down Expand Up @@ -582,6 +586,10 @@ def register_fake(
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_mul.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_add_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, "
"float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
Expand Down Expand Up @@ -892,6 +900,22 @@ def quantized_add_meta(
return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_mul.per_tensor")
def quantized_mul_per_tensor_meta(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:

out_size = torch.broadcast_shapes(X.size(), Y.size())
return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_add.per_tensor")
def quantized_add_per_tensor_meta(
X: torch.Tensor,
Expand Down
61 changes: 61 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,67 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor(
)


@impl_tracked(m, "quantized_mul.per_tensor")
def quantized_mul_per_tensor(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
"""
Multiplies two quantized tensors and returns another quantized tensor. The intuition
is that we want dequant(out) ~= dequant(X) * dequant(Y)

If we do that math, we get
out_scale(out - out_zero_point) = X_scale(X - X_zero_point) * Y_scale(Y - Y_zero_point)

Rearranging, we get
out = (X_scale(X - X_zero_point) * Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point

Args:
- X: The first operand
- X_scale: The ratio between the sizes of X's floating point and quantized
ranges
- X_zero_point: The quantized mapping of zero for X
- Y: The second operand
- Y_scale: The ratio between the sizes of Y's floating point and quantized
ranges
- Y_zero_point: The quantized mapping of zero for Y
- out_scale: The ratio between the sizes of the output's floating point and
quantized ranges
- out_zero_point: The quantized mapping of zero for the output
"""
supported_dtypes = [torch.int8, torch.uint8]
if X.dtype != Y.dtype:
raise ValueError("X and Y dtypes need to match")

dtype = X.dtype
if dtype not in supported_dtypes:
raise ValueError(
f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}"
)

if dtype == torch.uint8:
X = X.to(torch.int8)
Y = Y.to(torch.int8)

dequant_X = X_scale * (X - X_zero_point)
dequant_Y = Y_scale * (Y - Y_zero_point)

return quantize_per_tensor(
dequant_X * dequant_Y,
out_scale,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
)


def quantized_linear_common(
src: torch.Tensor,
weight: torch.Tensor,
Expand Down
42 changes: 42 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,48 @@ def test_quantized_add(self) -> None:
out_zero_point,
)

@expand(
[
# X=5, zp=4 → dequant=0.8*(5-4)=0.8; Y=5, zp=4 → dequant=0.8
# mul=0.64; quantize: round(0.64/0.8)+4=1+4=5
("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 5, torch.int8),
("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 5, torch.uint8),
]
)
def test_quantized_mul(
self,
name: str,
X: int,
X_scale: float,
X_zero_point: int,
Y: int,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
expected_value: int,
dtype: torch.dtype,
) -> None:
X_tensor = torch.tensor([X], dtype=dtype)
Y_tensor = torch.tensor([Y], dtype=dtype)
expected_output = torch.tensor([expected_value], dtype=dtype)

output = torch.ops.cadence.quantized_mul.per_tensor(
X_tensor,
X_scale,
X_zero_point,
Y_tensor,
Y_scale,
Y_zero_point,
out_scale,
out_zero_point,
)

self.assertTrue(
torch.equal(output, expected_output),
f"Values don't match in {name}: got {output}, expected {expected_output}",
)

def test_requantize(self) -> None:
# Test requantize (default variant), just to make sure it runs since wrapper around per_tensor variant
input_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.int8)
Expand Down
37 changes: 37 additions & 0 deletions backends/cadence/generic/operators/op_quantized_mul.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -72,6 +72,43 @@
return out;
}

Tensor& quantized_mul_per_tensor_out(
ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& X,
double X_scale,
int64_t X_zero_point,
const Tensor& Y,
double Y_scale,
int64_t Y_zero_point,
double out_scale,
int64_t out_zero_point,
Tensor& out) {
#define typed_quantized_mul_per_tensor(ctype, dtype) \
case ScalarType::dtype: { \
quantized_mul_<ctype>( \
X, \
static_cast<float>(X_scale), \
static_cast<int32_t>(X_zero_point), \
Y, \
static_cast<float>(Y_scale), \
static_cast<int32_t>(Y_zero_point), \
static_cast<float>(out_scale), \
static_cast<int32_t>(out_zero_point), \
out); \
break; \
}

ScalarType dtype = out.scalar_type();
switch (dtype) {
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul_per_tensor)
default:
ET_DCHECK_MSG(
false, "Unhandled dtype %s", torch::executor::toString(dtype));
}
#undef typed_quantized_mul_per_tensor
return out;
}

// Generate kernels that perform elementwise arithmetic on a quantized tensor,
// and a scalar.
#define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \
Expand Down
12 changes: 12 additions & 0 deletions backends/cadence/generic/operators/op_quantized_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ ::executorch::aten::Tensor& quantized_mul_out(
int64_t out_zero_point,
::executorch::aten::Tensor& out);

::executorch::aten::Tensor& quantized_mul_per_tensor_out(
::executorch::runtime::KernelRuntimeContext& ctx,
const ::executorch::aten::Tensor& X,
double X_scale,
int64_t X_zero_point,
const ::executorch::aten::Tensor& Y,
double Y_scale,
int64_t Y_zero_point,
double out_scale,
int64_t out_zero_point,
::executorch::aten::Tensor& out);

::executorch::aten::Tensor& quantized_mul_Scalar_out(
::executorch::runtime::KernelRuntimeContext& ctx,
const ::executorch::aten::Tensor& X,
Expand Down
Loading