From e6f380c85196f866db94a4d8af0c86ff5b7ade96 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 4 Mar 2026 18:45:07 +0800 Subject: [PATCH 1/6] Add `celu` operator --- src/ntops/kernels/__init__.py | 2 ++ src/ntops/kernels/celu.py | 23 +++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 ++ src/ntops/torch/celu.py | 17 +++++++++++++++++ tests/test_celu.py | 20 ++++++++++++++++++++ 5 files changed, 64 insertions(+) create mode 100644 src/ntops/kernels/celu.py create mode 100644 src/ntops/torch/celu.py create mode 100644 tests/test_celu.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..9b14d62 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -7,6 +7,7 @@ bitwise_not, bitwise_or, bmm, + celu, clamp, conv2d, cos, @@ -50,6 +51,7 @@ "bitwise_not", "bitwise_or", "bmm", + "celu", "clamp", "conv2d", "cos", diff --git a/src/ntops/kernels/celu.py b/src/ntops/kernels/celu.py new file mode 100644 index 0000000..6ba511d --- /dev/null +++ b/src/ntops/kernels/celu.py @@ -0,0 +1,23 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, alpha, output): + output = max(0.0, input) + min(0.0, alpha * (ntl.exp(input / alpha) - 1)) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..128255e 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -6,6 +6,7 @@ from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm +from ntops.torch.celu import celu from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos @@ -49,6 +50,7 @@ "bitwise_not", "bitwise_or", "bmm", + "celu", "clamp", "conv2d", "cos", diff --git a/src/ntops/torch/celu.py b/src/ntops/torch/celu.py new file mode 100644 index 0000000..d3d661b --- /dev/null +++ b/src/ntops/torch/celu.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def celu(input, alpha=1.0, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.celu.premake, input.ndim) + + kernel(input, alpha, output) + + return output diff --git a/tests/test_celu.py b/tests/test_celu.py new file mode 100644 index 0000000..680cccb --- /dev/null +++ b/tests/test_celu.py @@ -0,0 +1,20 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("inplace", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_celu(shape, inplace, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + alpha = 1.0 + + ninetoothed_output = ntops.torch.celu(input, alpha, inplace) + reference_output = F.celu(input, alpha, inplace) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From a5ff100a9dfb00ee67343469310b83e5ab17f8cd Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 4 Mar 2026 19:41:21 +0800 Subject: [PATCH 2/6] Add `threshold` operator --- src/ntops/kernels/__init__.py | 2 ++ src/ntops/kernels/threshold.py | 24 ++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 ++ src/ntops/torch/threshold.py | 17 +++++++++++++++++ tests/test_threshold.py | 22 ++++++++++++++++++++++ 5 files changed, 67 insertions(+) create mode 100644 src/ntops/kernels/threshold.py create mode 100644 src/ntops/torch/threshold.py create mode 100644 tests/test_threshold.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 9b14d62..4415fc6 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -40,6 +40,7 @@ softmax, sub, tanh, + threshold, ) __all__ = [ @@ -84,4 +85,5 @@ "softmax", "sub", "tanh", + "threshold", ] diff --git a/src/ntops/kernels/threshold.py b/src/ntops/kernels/threshold.py new file mode 100644 index 0000000..2744871 --- /dev/null +++ b/src/ntops/kernels/threshold.py @@ -0,0 +1,24 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, threshold, value, output): + output = ntl.where(input > threshold, input, value) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 128255e..00c2299 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -40,6 +40,7 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.threshold import threshold __all__ = [ "abs", @@ -84,4 +85,5 @@ "softmax", "sub", "tanh", + "threshold", ] diff --git a/src/ntops/torch/threshold.py b/src/ntops/torch/threshold.py new file mode 100644 index 0000000..f671391 --- /dev/null +++ b/src/ntops/torch/threshold.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def threshold(input, threshold, value, inplace=False): + if inplace: + output = input + else: + output = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.threshold.premake, input.ndim) + + kernel(input, threshold, value, output) + + return output diff --git a/tests/test_threshold.py b/tests/test_threshold.py new file mode 100644 index 0000000..73ddb63 --- /dev/null +++ b/tests/test_threshold.py @@ -0,0 +1,22 @@ +import random + +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_threshold(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + threshold = random.uniform(-1, 1) + value = random.uniform(0, 1) + + ninetoothed_output = ntops.torch.threshold(input, threshold, value) + reference_output = F.threshold(input, threshold, value) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From ba43310f7d1e353e479cb908c9809a069a116285 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 11 Mar 2026 16:46:48 +0800 Subject: [PATCH 3/6] Add `instance_norm` operator --- src/ntops/kernels/__init__.py | 2 + src/ntops/kernels/instance_norm.py | 222 +++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 + src/ntops/torch/instance_norm.py | 82 +++++++++++ tests/test_instance_norm.py | 72 ++++++++++ 5 files changed, 380 insertions(+) create mode 100644 src/ntops/kernels/instance_norm.py create mode 100644 src/ntops/torch/instance_norm.py create mode 100644 tests/test_instance_norm.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 4415fc6..d7d0e85 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -18,6 +18,7 @@ ge, gelu, gt, + instance_norm, isinf, isnan, layer_norm, @@ -63,6 +64,7 @@ "ge", "gelu", "gt", + "instance_norm", "isinf", "isnan", "layer_norm", diff --git a/src/ntops/kernels/instance_norm.py b/src/ntops/kernels/instance_norm.py new file mode 100644 index 0000000..dce77d0 --- /dev/null +++ b/src/ntops/kernels/instance_norm.py @@ -0,0 +1,222 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement as reduction_arrangement + + +def arrangement( + input, + running_mean, + running_var, + tmp_mean, + tmp_var, + weight, + bias, + momentum, + eps, + output, + num_normalized_elements, + use_input_stats, + tracking_running_stats, + dims, + block_size=None, +): + def _arrange_per_channel_tensor(tensor): + arranged = tensor.tile((1,)) + arranged.dtype = arranged.dtype.squeeze(0) + arranged = arranged.unsqueeze(0) + arranged = arranged.expand((input.shape[0], -1)) + + return arranged + + input_arranged, output_arranged = reduction_arrangement( + input, output, dim=dims, block_size=block_size + ) + running_mean_arranged = _arrange_per_channel_tensor(running_mean) + running_var_arranged = _arrange_per_channel_tensor(running_var) + tmp_mean_arranged = _arrange_per_channel_tensor(tmp_mean) + tmp_var_arranged = _arrange_per_channel_tensor(tmp_var) + weight_arranged = _arrange_per_channel_tensor(weight) + bias_arranged = _arrange_per_channel_tensor(bias) + momentum_arranged = momentum + eps_arranged = eps + num_normalized_elements_arranged = num_normalized_elements + + if use_input_stats: + if tracking_running_stats: + return ( + input_arranged, + running_mean_arranged, + running_var_arranged, + tmp_mean_arranged, + tmp_var_arranged, + weight_arranged, + bias_arranged, + momentum_arranged, + eps_arranged, + output_arranged, + num_normalized_elements_arranged, + ) + else: + return ( + input_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + num_normalized_elements_arranged, + ) + + return ( + input_arranged, + running_mean_arranged, + running_var_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + ) + + +def application_without_tracking( + input, + weight, + bias, + eps, + output, + num_normalized_elements, +): + _mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + _mean += ntl.cast(input[i], ntl.float32) + + mean = ntl.sum(_mean, 0) / num_normalized_elements + + _var = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i], ntl.float32) - mean + diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0) + _var += diff * diff + + var = ntl.sum(_var, 0) / num_normalized_elements + + application_with_mean_var(input, mean, var, weight, bias, eps, output) + + +def application_with_tracking( + input, + running_mean, + running_var, + tmp_mean, + tmp_var, + weight, + bias, + momentum, + eps, + output, + num_normalized_elements, +): + _mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + _mean += ntl.cast(input[i], ntl.float32) + + mean = ntl.sum(_mean, 0) / num_normalized_elements + + _var = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i], ntl.float32) - mean + diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0) + _var += diff * diff + + var = ntl.sum(_var, 0) / num_normalized_elements + + ntl.atomic_add( + tmp_mean.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(mean, ntl.float32) + ) + ntl.atomic_add( + tmp_var.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(var, ntl.float32) + ) + + ntl.debug_barrier() + + if input[0].offsets(0) == 0: + tmp_mean = tmp_mean / input.source.shape[0] + tmp_var = tmp_var / input.source.shape[0] + + running_mean = running_mean * (1 - momentum) + tmp_mean * momentum + running_var = running_var * (1 - momentum) + tmp_var * momentum + + application_with_mean_var(input, mean, var, weight, bias, eps, output) + + +def application_with_mean_var( + input, + mean, + var, + weight, + bias, + eps, + output, +): + std = ntl.sqrt(var + eps) + + for i in range(input.shape[0]): + output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight + bias + + +def premake( + ndim, + use_input_stats, + tracking_running_stats, + num_normalized_elements, + dtype=None, + block_size=None, +): + dims = tuple(reversed(range(2, ndim))) + + arrangement_ = functools.partial( + arrangement, + use_input_stats=use_input_stats, + tracking_running_stats=tracking_running_stats, + dims=dims, + block_size=block_size, + ) + + input = Tensor(ndim, other=0, dtype=dtype) + running_mean, running_var, tmp_mean, tmp_var, weight, bias = ( + Tensor(1, dtype=dtype) for _ in range(6) + ) + momentum, eps = (Tensor(0, dtype=ninetoothed.float64) for _ in range(2)) + output = Tensor(ndim, dtype=dtype) + num_normalized_elements = Tensor(0, constexpr=True, value=num_normalized_elements) + + if use_input_stats: + if tracking_running_stats: + application = application_with_tracking + else: + application = application_without_tracking + else: + application = application_with_mean_var + + tensors = ( + input, + running_mean, + running_var, + tmp_mean, + tmp_var, + weight, + bias, + momentum, + eps, + output, + num_normalized_elements, + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 00c2299..a05e08a 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -17,6 +17,7 @@ from ntops.torch.ge import ge from ntops.torch.gelu import gelu from ntops.torch.gt import gt +from ntops.torch.instance_norm import instance_norm from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm @@ -62,6 +63,7 @@ "ge", "gelu", "gt", + "instance_norm", "isinf", "isnan", "layer_norm", diff --git a/src/ntops/torch/instance_norm.py b/src/ntops/torch/instance_norm.py new file mode 100644 index 0000000..92a81f5 --- /dev/null +++ b/src/ntops/torch/instance_norm.py @@ -0,0 +1,82 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def instance_norm( + input, + running_mean=None, + running_var=None, + weight=None, + bias=None, + use_input_stats=True, + momentum=0.1, + eps=1e-05, +): + if weight is None: + weight = torch.ones(input.shape[1], device=input.device, dtype=input.dtype) + + if bias is None: + bias = torch.zeros(input.shape[1], device=input.device, dtype=input.dtype) + + tracking_running_stats = False + + if not use_input_stats: + assert running_mean is not None and running_var is not None, ( + "`running_mean` and `running_var` must be provided when `use_input_stats=False`." + ) + assert running_mean.shape == (input.shape[1],) and running_var.shape == ( + input.shape[1], + ), "`running_mean` and `running_var` must have shape (C,)" + else: + if running_mean is not None and running_var is not None: + assert running_mean.shape == (input.shape[1],) and running_var.shape == ( + input.shape[1], + ), "`running_mean` and `running_var` must have shape (C,)" + tracking_running_stats = True + tmp_mean = torch.zeros_like(running_mean) + tmp_var = torch.zeros_like(running_var) + + output = torch.empty_like(input) + + num_normalized_elements = math.prod(input.shape[2:]) + kernel = _cached_make( + ntops.kernels.instance_norm.premake, + input.ndim, + use_input_stats, + tracking_running_stats, + num_normalized_elements, + block_size=32, + ) + + if use_input_stats: + if tracking_running_stats: + kernel( + input, + running_mean, + running_var, + tmp_mean, + tmp_var, + weight, + bias, + momentum, + eps, + output, + num_normalized_elements, + ) + else: + kernel( + input, + weight, + bias, + eps, + output, + num_normalized_elements, + ) + else: + kernel(input, running_mean, running_var, weight, bias, eps, output) + + return output diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py new file mode 100644 index 0000000..52a745c --- /dev/null +++ b/tests/test_instance_norm.py @@ -0,0 +1,72 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("eps", (1e-8, 1e-5, 1e-3)) +@pytest.mark.parametrize("bias_is_none", (False, True)) +@pytest.mark.parametrize("weight_is_none", (False, True)) +@pytest.mark.parametrize("use_input_stats", (False, True)) +@pytest.mark.parametrize("track_running_stats", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_instance_norm( + shape, + dtype, + device, + rtol, + atol, + weight_is_none, + bias_is_none, + use_input_stats, + track_running_stats, + eps, +): + # if len(shape) < 3: + # pytest.skip(reason="InstanceNorm requires at least 3D input.") + + while len(shape) < 3: + shape.insert(0, 1) + + input = torch.randn(shape, dtype=dtype, device=device) + + if weight_is_none: + weight = None + else: + weight = torch.randn(shape[1], dtype=dtype, device=device) + + if bias_is_none: + bias = None + else: + bias = torch.randn(shape[1], dtype=dtype, device=device) + + if use_input_stats and not track_running_stats: + running_mean = None + running_var = None + else: + running_mean = torch.randn(shape[1], dtype=dtype, device=device) + running_var = torch.randn(shape[1], dtype=dtype, device=device).abs() + 1 + + ninetoothed_output = ntops.torch.instance_norm( + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + eps=eps, + ) + reference_output = torch.nn.functional.instance_norm( + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + eps=eps, + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From d5d00fbf41b4e04a2d49168c4ce95d58bede9f42 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Wed, 11 Mar 2026 17:09:25 +0800 Subject: [PATCH 4/6] Assert `allclose` for `running_mean` in `tests/test_instance_norm.py` --- tests/test_instance_norm.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py index 52a745c..96682d0 100644 --- a/tests/test_instance_norm.py +++ b/tests/test_instance_norm.py @@ -25,9 +25,6 @@ def test_instance_norm( track_running_stats, eps, ): - # if len(shape) < 3: - # pytest.skip(reason="InstanceNorm requires at least 3D input.") - while len(shape) < 3: shape.insert(0, 1) @@ -44,16 +41,25 @@ def test_instance_norm( bias = torch.randn(shape[1], dtype=dtype, device=device) if use_input_stats and not track_running_stats: - running_mean = None - running_var = None + reference_running_mean = None + reference_running_var = None + ninetoothed_running_mean = None + ninetoothed_running_var = None else: - running_mean = torch.randn(shape[1], dtype=dtype, device=device) - running_var = torch.randn(shape[1], dtype=dtype, device=device).abs() + 1 + reference_running_mean = torch.randn(shape[1], dtype=dtype, device=device) + reference_running_var = torch.randn(shape[1], dtype=dtype, device=device).abs() + + if use_input_stats: + ninetoothed_running_mean = reference_running_mean.clone() + ninetoothed_running_var = reference_running_var.clone() + else: + ninetoothed_running_mean = reference_running_mean + ninetoothed_running_var = reference_running_var ninetoothed_output = ntops.torch.instance_norm( input, - running_mean=running_mean, - running_var=running_var, + running_mean=ninetoothed_running_mean, + running_var=ninetoothed_running_var, weight=weight, bias=bias, use_input_stats=use_input_stats, @@ -61,8 +67,8 @@ def test_instance_norm( ) reference_output = torch.nn.functional.instance_norm( input, - running_mean=running_mean, - running_var=running_var, + running_mean=reference_running_mean, + running_var=reference_running_var, weight=weight, bias=bias, use_input_stats=use_input_stats, @@ -70,3 +76,10 @@ def test_instance_norm( ) assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) + + if use_input_stats and track_running_stats: + assert torch.allclose( + ninetoothed_running_mean, reference_running_mean, rtol=rtol, atol=atol + ) + # TODO: The running var is not close. + # assert torch.allclose(ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol) From d4ee060ce81f02fa6ed95e2569f90dfc73abed00 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Mon, 16 Mar 2026 20:06:07 +0800 Subject: [PATCH 5/6] Add `msort` operator --- src/ntops/kernels/__init__.py | 2 ++ src/ntops/kernels/msort.py | 46 +++++++++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 ++ src/ntops/torch/msort.py | 15 ++++++++++++ tests/test_msort.py | 17 +++++++++++++ 5 files changed, 82 insertions(+) create mode 100644 src/ntops/kernels/msort.py create mode 100644 src/ntops/torch/msort.py create mode 100644 tests/test_msort.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index d7d0e85..e72246d 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -26,6 +26,7 @@ lt, max_pool2d, mm, + msort, mul, ne, neg, @@ -72,6 +73,7 @@ "lt", "max_pool2d", "mm", + "msort", "mul", "ne", "neg", diff --git a/src/ntops/kernels/msort.py b/src/ntops/kernels/msort.py new file mode 100644 index 0000000..e91f354 --- /dev/null +++ b/src/ntops/kernels/msort.py @@ -0,0 +1,46 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement(input, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + ndim = input.ndim + dim = 0 + + non_target_dims = tuple(i for i in range(input.ndim) if i != dim) + + def _arrangement(input): + arranged = input.permute(non_target_dims + (dim,)) + + if ndim == 1: + arranged = arranged.unsqueeze(0) + arranged = arranged.flatten(end_dim=-1) + + arranged = arranged.tile((1, -1)) + arranged.dtype = arranged.dtype.squeeze(0) + + return arranged + + return _arrangement(input), _arrangement(output) + + +def application(input, output): + output = ntl.sort(input) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor( + ndim, dtype=dtype, other=float("inf"), shape_options={"constexpr": True} + ), + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index a05e08a..9719b16 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -26,6 +26,7 @@ from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm +from ntops.torch.msort import msort from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg @@ -72,6 +73,7 @@ "matmul", "max_pool2d", "mm", + "msort", "mul", "ne", "neg", diff --git a/src/ntops/torch/msort.py b/src/ntops/torch/msort.py new file mode 100644 index 0000000..5889f19 --- /dev/null +++ b/src/ntops/torch/msort.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def msort(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.msort.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/tests/test_msort.py b/tests/test_msort.py new file mode 100644 index 0000000..a2754d2 --- /dev/null +++ b/tests/test_msort.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_msort(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.msort(input) + reference_output = torch.msort(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From 65b304ca1737d6f9706d3ea8f43a3cdadd66d890 Mon Sep 17 00:00:00 2001 From: xuzhengzhong Date: Mon, 30 Mar 2026 16:19:12 +0800 Subject: [PATCH 6/6] Refactor `instance_norm` operator to reduce in PyTorch instead of using tl.atomic_add in Triton --- src/ntops/kernels/instance_norm.py | 158 +++++++++-------------------- src/ntops/torch/instance_norm.py | 70 ++++++------- tests/test_instance_norm.py | 5 +- 3 files changed, 82 insertions(+), 151 deletions(-) diff --git a/src/ntops/kernels/instance_norm.py b/src/ntops/kernels/instance_norm.py index dce77d0..a8a486f 100644 --- a/src/ntops/kernels/instance_norm.py +++ b/src/ntops/kernels/instance_norm.py @@ -9,22 +9,23 @@ def arrangement( input, + mean, + var, running_mean, running_var, - tmp_mean, - tmp_var, weight, bias, - momentum, eps, output, num_normalized_elements, use_input_stats, - tracking_running_stats, dims, block_size=None, ): - def _arrange_per_channel_tensor(tensor): + if block_size is None: + block_size = ninetoothed.block_size() + + def _arrange_channel_tensor(tensor): arranged = tensor.tile((1,)) arranged.dtype = arranged.dtype.squeeze(0) arranged = arranged.unsqueeze(0) @@ -32,91 +33,53 @@ def _arrange_per_channel_tensor(tensor): return arranged + def _arrange_mean_or_var(tensor): + arranged = tensor.tile((1, 1)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + input_arranged, output_arranged = reduction_arrangement( input, output, dim=dims, block_size=block_size ) - running_mean_arranged = _arrange_per_channel_tensor(running_mean) - running_var_arranged = _arrange_per_channel_tensor(running_var) - tmp_mean_arranged = _arrange_per_channel_tensor(tmp_mean) - tmp_var_arranged = _arrange_per_channel_tensor(tmp_var) - weight_arranged = _arrange_per_channel_tensor(weight) - bias_arranged = _arrange_per_channel_tensor(bias) - momentum_arranged = momentum + mean_arranged = _arrange_mean_or_var(mean) + var_arranged = _arrange_mean_or_var(var) + running_mean_arranged = _arrange_channel_tensor(running_mean) + running_var_arranged = _arrange_channel_tensor(running_var) + weight_arranged = _arrange_channel_tensor(weight) + bias_arranged = _arrange_channel_tensor(bias) eps_arranged = eps num_normalized_elements_arranged = num_normalized_elements if use_input_stats: - if tracking_running_stats: - return ( - input_arranged, - running_mean_arranged, - running_var_arranged, - tmp_mean_arranged, - tmp_var_arranged, - weight_arranged, - bias_arranged, - momentum_arranged, - eps_arranged, - output_arranged, - num_normalized_elements_arranged, - ) - else: - return ( - input_arranged, - weight_arranged, - bias_arranged, - eps_arranged, - output_arranged, - num_normalized_elements_arranged, - ) - - return ( - input_arranged, - running_mean_arranged, - running_var_arranged, - weight_arranged, - bias_arranged, - eps_arranged, - output_arranged, - ) - - -def application_without_tracking( - input, - weight, - bias, - eps, - output, - num_normalized_elements, -): - _mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32) - - for i in range(input.shape[0]): - _mean += ntl.cast(input[i], ntl.float32) - - mean = ntl.sum(_mean, 0) / num_normalized_elements - - _var = ntl.zeros(input.dtype.shape, dtype=ntl.float32) - - for i in range(input.shape[0]): - diff = ntl.cast(input[i], ntl.float32) - mean - diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0) - _var += diff * diff - - var = ntl.sum(_var, 0) / num_normalized_elements - - application_with_mean_var(input, mean, var, weight, bias, eps, output) - - -def application_with_tracking( + return ( + input_arranged, + mean_arranged, + var_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + num_normalized_elements_arranged, + ) + else: + return ( + input_arranged, + running_mean_arranged, + running_var_arranged, + weight_arranged, + bias_arranged, + eps_arranged, + output_arranged, + ) + + +def application_using_input_stats( input, - running_mean, - running_var, - tmp_mean, - tmp_var, + mean, + var, weight, bias, - momentum, eps, output, num_normalized_elements, @@ -137,22 +100,6 @@ def application_with_tracking( var = ntl.sum(_var, 0) / num_normalized_elements - ntl.atomic_add( - tmp_mean.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(mean, ntl.float32) - ) - ntl.atomic_add( - tmp_var.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(var, ntl.float32) - ) - - ntl.debug_barrier() - - if input[0].offsets(0) == 0: - tmp_mean = tmp_mean / input.source.shape[0] - tmp_var = tmp_var / input.source.shape[0] - - running_mean = running_mean * (1 - momentum) + tmp_mean * momentum - running_var = running_var * (1 - momentum) + tmp_var * momentum - application_with_mean_var(input, mean, var, weight, bias, eps, output) @@ -174,7 +121,6 @@ def application_with_mean_var( def premake( ndim, use_input_stats, - tracking_running_stats, num_normalized_elements, dtype=None, block_size=None, @@ -184,36 +130,30 @@ def premake( arrangement_ = functools.partial( arrangement, use_input_stats=use_input_stats, - tracking_running_stats=tracking_running_stats, dims=dims, block_size=block_size, ) input = Tensor(ndim, other=0, dtype=dtype) - running_mean, running_var, tmp_mean, tmp_var, weight, bias = ( - Tensor(1, dtype=dtype) for _ in range(6) - ) - momentum, eps = (Tensor(0, dtype=ninetoothed.float64) for _ in range(2)) + mean, var = (Tensor(2, dtype=dtype) for _ in range(2)) + running_mean, running_var, weight, bias = (Tensor(1, dtype=dtype) for _ in range(4)) + eps = Tensor(0, dtype=ninetoothed.float64) output = Tensor(ndim, dtype=dtype) num_normalized_elements = Tensor(0, constexpr=True, value=num_normalized_elements) if use_input_stats: - if tracking_running_stats: - application = application_with_tracking - else: - application = application_without_tracking + application = application_using_input_stats else: application = application_with_mean_var tensors = ( input, + mean, + var, running_mean, running_var, - tmp_mean, - tmp_var, weight, bias, - momentum, eps, output, num_normalized_elements, diff --git a/src/ntops/torch/instance_norm.py b/src/ntops/torch/instance_norm.py index 92a81f5..7c09901 100644 --- a/src/ntops/torch/instance_norm.py +++ b/src/ntops/torch/instance_norm.py @@ -22,23 +22,11 @@ def instance_norm( if bias is None: bias = torch.zeros(input.shape[1], device=input.device, dtype=input.dtype) - tracking_running_stats = False + has_running_stats = running_mean is not None and running_var is not None - if not use_input_stats: - assert running_mean is not None and running_var is not None, ( - "`running_mean` and `running_var` must be provided when `use_input_stats=False`." - ) - assert running_mean.shape == (input.shape[1],) and running_var.shape == ( - input.shape[1], - ), "`running_mean` and `running_var` must have shape (C,)" - else: - if running_mean is not None and running_var is not None: - assert running_mean.shape == (input.shape[1],) and running_var.shape == ( - input.shape[1], - ), "`running_mean` and `running_var` must have shape (C,)" - tracking_running_stats = True - tmp_mean = torch.zeros_like(running_mean) - tmp_var = torch.zeros_like(running_var) + if use_input_stats: + mean = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype) + var = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype) output = torch.empty_like(input) @@ -47,35 +35,37 @@ def instance_norm( ntops.kernels.instance_norm.premake, input.ndim, use_input_stats, - tracking_running_stats, num_normalized_elements, - block_size=32, + dtype=input.dtype, ) if use_input_stats: - if tracking_running_stats: - kernel( - input, - running_mean, - running_var, - tmp_mean, - tmp_var, - weight, - bias, - momentum, - eps, - output, - num_normalized_elements, - ) - else: - kernel( - input, - weight, - bias, - eps, - output, - num_normalized_elements, + kernel( + input, + mean, + var, + weight, + bias, + eps, + output, + num_normalized_elements, + ) + + # We reduce in PyTorch instead of using tl.atomic_add in Triton because: + # 1. Triton blocks cannot synchronize to safely apply the momentum update after all additions finish. + # 2. N blocks atomically adding to the same C addresses creates severe memory contention. + if has_running_stats: + batch_mean = mean.mean(0) + avg_vars = var.mean(0) + + unbiased_var = ( + (avg_vars) * num_normalized_elements / (num_normalized_elements - 1) + if num_normalized_elements > 1 + else avg_vars ) + + running_mean.mul_(1 - momentum).add_(momentum * batch_mean) + running_var.mul_(1 - momentum).add_(momentum * unbiased_var) else: kernel(input, running_mean, running_var, weight, bias, eps, output) diff --git a/tests/test_instance_norm.py b/tests/test_instance_norm.py index 96682d0..b3cbbca 100644 --- a/tests/test_instance_norm.py +++ b/tests/test_instance_norm.py @@ -81,5 +81,6 @@ def test_instance_norm( assert torch.allclose( ninetoothed_running_mean, reference_running_mean, rtol=rtol, atol=atol ) - # TODO: The running var is not close. - # assert torch.allclose(ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol) + assert torch.allclose( + ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol + )