diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..d677e11 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -2,6 +2,7 @@ abs, add, addmm, + alpha_dropout, avg_pool2d, bitwise_and, bitwise_not, @@ -10,6 +11,8 @@ clamp, conv2d, cos, + cosh, + diag, div, dropout, eq, @@ -31,12 +34,14 @@ relu, rms_norm, rotary_position_embedding, + round, rsqrt, scaled_dot_product_attention, sigmoid, silu, sin, softmax, + sort, sub, tanh, ) @@ -45,6 +50,7 @@ "abs", "add", "addmm", + "alpha_dropout", "avg_pool2d", "bitwise_and", "bitwise_not", @@ -53,6 +59,8 @@ "clamp", "conv2d", "cos", + "cosh", + "diag", "div", "dropout", "eq", @@ -74,12 +82,14 @@ "relu", "rms_norm", "rotary_position_embedding", + "round", "rsqrt", "scaled_dot_product_attention", "sigmoid", "silu", "sin", "softmax", + "sort", "sub", "tanh", ] diff --git a/src/ntops/kernels/alpha_dropout.py b/src/ntops/kernels/alpha_dropout.py new file mode 100644 index 0000000..4fba567 --- /dev/null +++ b/src/ntops/kernels/alpha_dropout.py @@ -0,0 +1,28 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, a, b, sat, p, seed, output): + keep = ntl.rand(seed, input.offsets()) > p + output = ntl.where(keep, a * input + b, sat) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.int64), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/cosh.py b/src/ntops/kernels/cosh.py new file mode 100644 index 0000000..252d7b0 --- /dev/null +++ b/src/ntops/kernels/cosh.py @@ -0,0 +1,19 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = libdevice.cosh(ntl.cast(input, ntl.float32)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/diag.py b/src/ntops/kernels/diag.py new file mode 100644 index 0000000..0469d8e --- /dev/null +++ b/src/ntops/kernels/diag.py @@ -0,0 +1,58 @@ +import functools + +import ninetoothed +from ninetoothed import Symbol, Tensor + + +def arrangement_embed(input, output, stride=None, block_size=None): + if stride is None: + stride = Symbol("stride", constexpr=True) + + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.tile((block_size,)) + output_arranged = output.tile( + (block_size,), strides=(block_size * stride,), dilation=(stride,) + ) + + return input_arranged, output_arranged + + +def arrangement_extract(input, output, stride=None, block_size=None): + if stride is None: + stride = Symbol("stride", constexpr=True) + + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.tile( + (block_size,), strides=(block_size * stride,), dilation=(stride,) + ) + output_arranged = output.tile((block_size,)) + + return input_arranged, output_arranged + + +def application(input, output): + output = input # noqa: F841 + + +def premake_embed(stride=None, dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement_embed, stride=stride, block_size=block_size + ) + + tensors = (Tensor(1, dtype=dtype, other=0), Tensor(1, dtype=dtype)) + + return arrangement_, application, tensors + + +def premake_extract(stride=None, dtype=None, block_size=None): + arrangement_ = functools.partial( + arrangement_extract, stride=stride, block_size=block_size + ) + + tensors = (Tensor(1, dtype=dtype, other=0), Tensor(1, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/round.py b/src/ntops/kernels/round.py new file mode 100644 index 0000000..3d95f9a --- /dev/null +++ b/src/ntops/kernels/round.py @@ -0,0 +1,35 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = libdevice.nearbyint(ntl.cast(input, ntl.float32)) # noqa: F841 + + +def application_with_decimals(input, factor, inv_factor, output): + scaled = input * ntl.cast( + factor, input.dtype + ) # 在 input 的原始精度下乘,匹配 torch 行为 + output = libdevice.nearbyint(ntl.cast(scaled, ntl.float32)) * inv_factor # noqa: F841 + + +def premake(ndim, decimals=0, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + if decimals == 0: + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + return arrangement_, application, tensors + else: + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + return arrangement_, application_with_decimals, tensors diff --git a/src/ntops/kernels/sort.py b/src/ntops/kernels/sort.py new file mode 100644 index 0000000..53d7e18 --- /dev/null +++ b/src/ntops/kernels/sort.py @@ -0,0 +1,74 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def _next_power_of_two(value): + if value <= 1: + return 1 + + return 1 << (value - 1).bit_length() + + +def application(input, values, indices, sort_size, descending): + input_0 = input[0] + offsets = ntl.arange(0, input_0.shape[0]) + valid = offsets < sort_size + + sign_mask = ntl.cast(0x7FFFFFFF, ntl.int32) + input_fp32 = ntl.cast(input_0, ntl.float32) + encoded = ntl.cast(input_fp32, ntl.int32, bitcast=True) + encoded = encoded ^ ((encoded >> 31) & sign_mask) + + if descending: + encoded = ~encoded + + encoded = ntl.where(valid, encoded, ntl.cast(0x7FFFFFFF, ntl.int32)) + + offsets = ntl.cast(offsets, ntl.int64) + key = ((ntl.cast(encoded, ntl.int64) & ntl.cast(0xFFFFFFFF, ntl.int64)) << 32) | offsets + sorted_key = ntl.sort(key) + + sorted_encoded = ntl.cast(sorted_key >> 32, ntl.int32) + + if descending: + sorted_encoded = ~sorted_encoded + + sorted_encoded = sorted_encoded ^ ((sorted_encoded >> 31) & sign_mask) + sorted_values = ntl.cast(sorted_encoded, ntl.float32, bitcast=True) + sorted_indices = sorted_key & ntl.cast(0xFFFFFFFF, ntl.int64) + + values[0] = ntl.cast(sorted_values, values[0].dtype) + indices[0] = ntl.cast(sorted_indices, indices[0].dtype) + + +def premake( + ndim, + dim, + sort_size, + descending=False, + stable=False, + dtype=None, + block_size=None, +): + if block_size is None: + block_size = _next_power_of_two(sort_size) + + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + # `stable` is kept for `torch.sort` interface parity. Current key design is stable. + _ = stable + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=ninetoothed.int64), + Tensor(0, constexpr=True, value=sort_size), + Tensor(0, constexpr=True, value=descending), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..83d08a2 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -1,6 +1,7 @@ from ntops.torch.abs import abs from ntops.torch.add import add from ntops.torch.addmm import addmm +from ntops.torch.alpha_dropout import alpha_dropout from ntops.torch.avg_pool2d import avg_pool2d from ntops.torch.bitwise_and import bitwise_and from ntops.torch.bitwise_not import bitwise_not @@ -9,6 +10,8 @@ from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos +from ntops.torch.cosh import cosh +from ntops.torch.diag import diag from ntops.torch.div import div from ntops.torch.dropout import dropout from ntops.torch.eq import eq @@ -31,12 +34,14 @@ from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding +from ntops.torch.round import round from ntops.torch.rsqrt import rsqrt from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention from ntops.torch.sigmoid import sigmoid from ntops.torch.silu import silu from ntops.torch.sin import sin from ntops.torch.softmax import softmax +from ntops.torch.sort import sort from ntops.torch.sub import sub from ntops.torch.tanh import tanh @@ -44,6 +49,7 @@ "abs", "add", "addmm", + "alpha_dropout", "avg_pool2d", "bitwise_and", "bitwise_not", @@ -52,6 +58,8 @@ "clamp", "conv2d", "cos", + "cosh", + "diag", "div", "dropout", "eq", @@ -74,12 +82,14 @@ "relu", "rms_norm", "rotary_position_embedding", + "round", "rsqrt", "scaled_dot_product_attention", "sigmoid", "silu", "sin", "softmax", + "sort", "sub", "tanh", ] diff --git a/src/ntops/torch/alpha_dropout.py b/src/ntops/torch/alpha_dropout.py new file mode 100644 index 0000000..860f47b --- /dev/null +++ b/src/ntops/torch/alpha_dropout.py @@ -0,0 +1,36 @@ +import math +import random + +import torch + +import ntops +from ntops.torch.utils import _cached_make + +# SELU saturation value: -lambda * alpha +_ALPHA_P = -1.7580993408473766 + + +def alpha_dropout(input, p=0.5, training=False, inplace=False): + if not training or p == 0: + if inplace: + return input + else: + return input.clone() + + q = 1.0 - p + a = 1.0 / math.sqrt(q * (1.0 + p * _ALPHA_P * _ALPHA_P)) + b = -a * p * _ALPHA_P + sat = a * _ALPHA_P + b + + seed = random.randrange(0, 2**31) + + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.alpha_dropout.premake, input.ndim) + + kernel(input, a, b, sat, p, seed, output) + + return output diff --git a/src/ntops/torch/cosh.py b/src/ntops/torch/cosh.py new file mode 100644 index 0000000..546de0f --- /dev/null +++ b/src/ntops/torch/cosh.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def cosh(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.cosh.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/diag.py b/src/ntops/torch/diag.py new file mode 100644 index 0000000..a1f9add --- /dev/null +++ b/src/ntops/torch/diag.py @@ -0,0 +1,62 @@ +import builtins + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def diag(input, diagonal=0): + if input.ndim == 1: + return _diag_embed(input, diagonal) + elif input.ndim == 2: + return _diag_extract(input, diagonal) + else: + raise ValueError(f"Input must be 1-D or 2-D, but got {input.ndim}-D.") + + +def _diag_embed(input, diagonal): + n = input.shape[0] + size = n + builtins.abs(diagonal) + output = torch.zeros((size, size), dtype=input.dtype, device=input.device) + + if n == 0: + return output + + output_flat = output.view(-1) + + if diagonal >= 0: + start = diagonal + else: + start = (-diagonal) * size + + stride = size + 1 + + kernel = _cached_make(ntops.kernels.diag.premake_embed, stride=stride) + kernel(input, output_flat[start:]) + + return output + + +def _diag_extract(input, diagonal): + m, n = input.shape + + if diagonal >= 0: + diag_len = max(min(m, n - diagonal), 0) + start = diagonal + else: + diag_len = max(min(m + diagonal, n), 0) + start = (-diagonal) * n + + output = torch.empty(diag_len, dtype=input.dtype, device=input.device) + + if diag_len == 0: + return output + + input_flat = input.contiguous().view(-1) + stride = n + 1 + + kernel = _cached_make(ntops.kernels.diag.premake_extract, stride=stride) + kernel(input_flat[start:], output) + + return output diff --git a/src/ntops/torch/round.py b/src/ntops/torch/round.py new file mode 100644 index 0000000..2496767 --- /dev/null +++ b/src/ntops/torch/round.py @@ -0,0 +1,20 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def round(input, decimals=0, *, out=None): + if out is None: + out = torch.empty_like(input) + + if decimals == 0: + kernel = _cached_make(ntops.kernels.round.premake, input.ndim) + kernel(input, out) + else: + factor = 10.0**decimals + inv_factor = 1.0 / factor + kernel = _cached_make(ntops.kernels.round.premake, input.ndim, decimals=True) + kernel(input, factor, inv_factor, out) + + return out diff --git a/src/ntops/torch/sort.py b/src/ntops/torch/sort.py new file mode 100644 index 0000000..306059f --- /dev/null +++ b/src/ntops/torch/sort.py @@ -0,0 +1,42 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def sort(input, dim=-1, descending=False, stable=False, *, out=None): + assert input.device.type == "cuda", "`input` must be on CUDA." + assert input.ndim > 0, "`input` must have at least one dimension." + assert input.dtype in (torch.float16, torch.bfloat16, torch.float32), ( + "`input.dtype` must be one of float16, bfloat16, or float32." + ) + + if dim < 0: + dim += input.ndim + + if dim < 0 or dim >= input.ndim: + raise IndexError( + f"Dimension out of range (expected to be in range of [{-input.ndim}, {input.ndim - 1}], but got {dim})" + ) + + sort_size = input.shape[dim] + + assert sort_size > 0, "`input.shape[dim]` must be greater than 0." + + if out is None: + values = torch.empty_like(input) + indices = torch.empty_like(input, dtype=torch.int64) + else: + values, indices = out + + kernel = _cached_make( + ntops.kernels.sort.premake, + input.ndim, + dim, + sort_size=sort_size, + descending=descending, + ) + + kernel(input, values, indices, sort_size, descending) + + return torch.return_types.sort((values, indices)) diff --git a/tests/test_alpha_dropout.py b/tests/test_alpha_dropout.py new file mode 100644 index 0000000..187e84a --- /dev/null +++ b/tests/test_alpha_dropout.py @@ -0,0 +1,64 @@ +import math +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +_ALPHA_P = -1.7580993408473766 + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_alpha_dropout(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + p = random.uniform(0.1, 0.5) + + ninetoothed_output = ntops.torch.alpha_dropout(input, p=p, training=True) + reference_output = F.alpha_dropout(input, p=p, training=True) + + # 1. Shape must match. + assert ninetoothed_output.shape == reference_output.shape + + # 2. Compute expected affine parameters. + q = 1.0 - p + a = 1.0 / math.sqrt(q * (1.0 + p * _ALPHA_P * _ALPHA_P)) + b = -a * p * _ALPHA_P + sat = a * _ALPHA_P + b + + # 3. Drop ratios should be close to each other. + ninetoothed_drop_ratio = ( + torch.isclose( + ninetoothed_output, torch.full_like(ninetoothed_output, sat), atol=atol + ) + .float() + .mean() + .item() + ) + reference_drop_ratio = ( + torch.isclose( + reference_output, torch.full_like(reference_output, sat), atol=atol + ) + .float() + .mean() + .item() + ) + + assert abs(ninetoothed_drop_ratio - reference_drop_ratio) < 0.1 + + # 4. Kept elements should satisfy the same affine transform. + kept_mask = ~torch.isclose( + ninetoothed_output, torch.full_like(ninetoothed_output, sat), atol=atol + ) + expected_kept = a * input[kept_mask].float() + b + actual_kept = ninetoothed_output[kept_mask].float() + + assert torch.allclose(actual_kept, expected_kept, rtol=rtol, atol=atol) + + # 5. training=False should return input unchanged. + output_eval = ntops.torch.alpha_dropout(input, p=p, training=False) + assert torch.equal(output_eval, input) diff --git a/tests/test_cosh.py b/tests/test_cosh.py new file mode 100644 index 0000000..c25e959 --- /dev/null +++ b/tests/test_cosh.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_cosh(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.cosh(input) + reference_output = torch.cosh(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_diag.py b/tests/test_diag.py new file mode 100644 index 0000000..3fddbc0 --- /dev/null +++ b/tests/test_diag.py @@ -0,0 +1,98 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize( + "n, diagonal", + [ + (1, 0), + (5, 0), + (5, 1), + (5, -1), + (5, 3), + (5, -3), + (10, 0), + (10, 5), + (10, -5), + ], +) +def test_diag_1d(n, diagonal, dtype): + device = "cuda" + input = torch.randn(n, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.diag(input, diagonal) + reference_output = torch.diag(input, diagonal) + + assert torch.allclose(ninetoothed_output, reference_output) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("diagonal", [0, 3, -3]) +def test_diag_1d_empty_input(diagonal, dtype): + device = "cuda" + input = torch.empty((0,), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.diag(input, diagonal) + reference_output = torch.diag(input, diagonal) + + assert ninetoothed_output.shape == reference_output.shape + assert torch.allclose(ninetoothed_output, reference_output) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize( + "shape, diagonal", + [ + ((5, 5), 0), + ((5, 5), 1), + ((5, 5), -1), + ((5, 5), 4), + ((5, 5), -4), + ((3, 5), 0), + ((3, 5), 1), + ((3, 5), -1), + ((5, 3), 0), + ((5, 3), 1), + ((5, 3), -1), + ((10, 10), 0), + ((10, 10), 3), + ((10, 10), -3), + ], +) +def test_diag_2d(shape, diagonal, dtype): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.diag(input, diagonal) + reference_output = torch.diag(input, diagonal) + + assert torch.allclose(ninetoothed_output, reference_output) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize( + "shape, diagonal", + [ + ((3, 5), 5), + ((3, 5), -3), + ((5, 3), 3), + ((5, 3), -5), + ], +) +def test_diag_2d_out_of_range_diagonal(shape, diagonal, dtype): + device = "cuda" + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.diag(input, diagonal) + reference_output = torch.diag(input, diagonal) + + assert ninetoothed_output.shape == reference_output.shape + assert torch.allclose(ninetoothed_output, reference_output) diff --git a/tests/test_round.py b/tests/test_round.py new file mode 100644 index 0000000..0179f98 --- /dev/null +++ b/tests/test_round.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_round(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.round(input) + reference_output = torch.round(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 0000000..0aa9828 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,61 @@ +import random + +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("stable", (False, True)) +@pytest.mark.parametrize("descending", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_sort(shape, dtype, device, rtol, atol, descending, stable): + input = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(-input.ndim, input.ndim - 1) + + ninetoothed_output = ntops.torch.sort( + input, dim=dim, descending=descending, stable=stable + ) + reference_output = torch.sort(input, dim=dim, descending=descending, stable=stable) + + assert torch.allclose( + ninetoothed_output.values, reference_output.values, rtol=rtol, atol=atol + ) + assert torch.equal(ninetoothed_output.indices, reference_output.indices) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("descending", (False, True)) +def test_sort_stable_with_duplicate_values(descending): + input = torch.randint(-4, 5, (16, 33), dtype=torch.int32, device="cuda").to( + torch.float32 + ) + + ninetoothed_output = ntops.torch.sort( + input, dim=-1, descending=descending, stable=True + ) + reference_output = torch.sort(input, dim=-1, descending=descending, stable=True) + + assert torch.equal(ninetoothed_output.indices, reference_output.indices) + assert torch.equal(ninetoothed_output.values, reference_output.values) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("descending", (False, True)) +def test_sort_with_out(descending): + input = torch.randn((19, 23), dtype=torch.float16, device="cuda") + values = torch.empty_like(input) + indices = torch.empty_like(input, dtype=torch.int64) + + ninetoothed_output = ntops.torch.sort( + input, dim=-1, descending=descending, out=(values, indices) + ) + reference_output = torch.sort(input, dim=-1, descending=descending) + + assert ninetoothed_output.values.data_ptr() == values.data_ptr() + assert ninetoothed_output.indices.data_ptr() == indices.data_ptr() + assert torch.equal(values, reference_output.values) + assert torch.equal(indices, reference_output.indices)