Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
abs,
add,
addmm,
alpha_dropout,
avg_pool2d,
bitwise_and,
bitwise_not,
Expand All @@ -10,6 +11,8 @@
clamp,
conv2d,
cos,
cosh,
diag,
div,
dropout,
eq,
Expand All @@ -31,12 +34,14 @@
relu,
rms_norm,
rotary_position_embedding,
round,
rsqrt,
scaled_dot_product_attention,
sigmoid,
silu,
sin,
softmax,
sort,
sub,
tanh,
)
Expand All @@ -45,6 +50,7 @@
"abs",
"add",
"addmm",
"alpha_dropout",
"avg_pool2d",
"bitwise_and",
"bitwise_not",
Expand All @@ -53,6 +59,8 @@
"clamp",
"conv2d",
"cos",
"cosh",
"diag",
"div",
"dropout",
"eq",
Expand All @@ -74,12 +82,14 @@
"relu",
"rms_norm",
"rotary_position_embedding",
"round",
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
"silu",
"sin",
"softmax",
"sort",
"sub",
"tanh",
]
28 changes: 28 additions & 0 deletions src/ntops/kernels/alpha_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, a, b, sat, p, seed, output):
keep = ntl.rand(seed, input.offsets()) > p
output = ntl.where(keep, a * input + b, sat) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.int64),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
19 changes: 19 additions & 0 deletions src/ntops/kernels/cosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = libdevice.cosh(ntl.cast(input, ntl.float32)) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

return arrangement_, application, tensors
58 changes: 58 additions & 0 deletions src/ntops/kernels/diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import functools

import ninetoothed
from ninetoothed import Symbol, Tensor


def arrangement_embed(input, output, stride=None, block_size=None):
if stride is None:
stride = Symbol("stride", constexpr=True)

if block_size is None:
block_size = ninetoothed.block_size()

input_arranged = input.tile((block_size,))
output_arranged = output.tile(
(block_size,), strides=(block_size * stride,), dilation=(stride,)
)

return input_arranged, output_arranged


def arrangement_extract(input, output, stride=None, block_size=None):
if stride is None:
stride = Symbol("stride", constexpr=True)

if block_size is None:
block_size = ninetoothed.block_size()

input_arranged = input.tile(
(block_size,), strides=(block_size * stride,), dilation=(stride,)
)
output_arranged = output.tile((block_size,))

return input_arranged, output_arranged


def application(input, output):
output = input # noqa: F841


def premake_embed(stride=None, dtype=None, block_size=None):
arrangement_ = functools.partial(
arrangement_embed, stride=stride, block_size=block_size
)

tensors = (Tensor(1, dtype=dtype, other=0), Tensor(1, dtype=dtype))

return arrangement_, application, tensors


def premake_extract(stride=None, dtype=None, block_size=None):
arrangement_ = functools.partial(
arrangement_extract, stride=stride, block_size=block_size
)

tensors = (Tensor(1, dtype=dtype, other=0), Tensor(1, dtype=dtype))

return arrangement_, application, tensors
35 changes: 35 additions & 0 deletions src/ntops/kernels/round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


def application(input, output):
output = libdevice.nearbyint(ntl.cast(input, ntl.float32)) # noqa: F841


def application_with_decimals(input, factor, inv_factor, output):
scaled = input * ntl.cast(
factor, input.dtype
) # 在 input 的原始精度下乘,匹配 torch 行为
output = libdevice.nearbyint(ntl.cast(scaled, ntl.float32)) * inv_factor # noqa: F841


def premake(ndim, decimals=0, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

if decimals == 0:
tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))
return arrangement_, application, tensors
else:
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=ninetoothed.float64),
Tensor(0, dtype=ninetoothed.float64),
Tensor(ndim, dtype=dtype),
)
return arrangement_, application_with_decimals, tensors
74 changes: 74 additions & 0 deletions src/ntops/kernels/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def _next_power_of_two(value):
if value <= 1:
return 1

return 1 << (value - 1).bit_length()


def application(input, values, indices, sort_size, descending):
input_0 = input[0]
offsets = ntl.arange(0, input_0.shape[0])
valid = offsets < sort_size

sign_mask = ntl.cast(0x7FFFFFFF, ntl.int32)
input_fp32 = ntl.cast(input_0, ntl.float32)
encoded = ntl.cast(input_fp32, ntl.int32, bitcast=True)
encoded = encoded ^ ((encoded >> 31) & sign_mask)

if descending:
encoded = ~encoded

encoded = ntl.where(valid, encoded, ntl.cast(0x7FFFFFFF, ntl.int32))

offsets = ntl.cast(offsets, ntl.int64)
key = ((ntl.cast(encoded, ntl.int64) & ntl.cast(0xFFFFFFFF, ntl.int64)) << 32) | offsets
sorted_key = ntl.sort(key)

sorted_encoded = ntl.cast(sorted_key >> 32, ntl.int32)

if descending:
sorted_encoded = ~sorted_encoded

sorted_encoded = sorted_encoded ^ ((sorted_encoded >> 31) & sign_mask)
sorted_values = ntl.cast(sorted_encoded, ntl.float32, bitcast=True)
sorted_indices = sorted_key & ntl.cast(0xFFFFFFFF, ntl.int64)

values[0] = ntl.cast(sorted_values, values[0].dtype)
indices[0] = ntl.cast(sorted_indices, indices[0].dtype)


def premake(
ndim,
dim,
sort_size,
descending=False,
stable=False,
dtype=None,
block_size=None,
):
if block_size is None:
block_size = _next_power_of_two(sort_size)

arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

# `stable` is kept for `torch.sort` interface parity. Current key design is stable.
_ = stable

tensors = (
Tensor(ndim, dtype=dtype, other=0),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=ninetoothed.int64),
Tensor(0, constexpr=True, value=sort_size),
Tensor(0, constexpr=True, value=descending),
)

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ntops.torch.abs import abs
from ntops.torch.add import add
from ntops.torch.addmm import addmm
from ntops.torch.alpha_dropout import alpha_dropout
from ntops.torch.avg_pool2d import avg_pool2d
from ntops.torch.bitwise_and import bitwise_and
from ntops.torch.bitwise_not import bitwise_not
Expand All @@ -9,6 +10,8 @@
from ntops.torch.clamp import clamp
from ntops.torch.conv2d import conv2d
from ntops.torch.cos import cos
from ntops.torch.cosh import cosh
from ntops.torch.diag import diag
from ntops.torch.div import div
from ntops.torch.dropout import dropout
from ntops.torch.eq import eq
Expand All @@ -31,19 +34,22 @@
from ntops.torch.relu import relu
from ntops.torch.rms_norm import rms_norm
from ntops.torch.rotary_position_embedding import rotary_position_embedding
from ntops.torch.round import round
from ntops.torch.rsqrt import rsqrt
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
from ntops.torch.sigmoid import sigmoid
from ntops.torch.silu import silu
from ntops.torch.sin import sin
from ntops.torch.softmax import softmax
from ntops.torch.sort import sort
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh

__all__ = [
"abs",
"add",
"addmm",
"alpha_dropout",
"avg_pool2d",
"bitwise_and",
"bitwise_not",
Expand All @@ -52,6 +58,8 @@
"clamp",
"conv2d",
"cos",
"cosh",
"diag",
"div",
"dropout",
"eq",
Expand All @@ -74,12 +82,14 @@
"relu",
"rms_norm",
"rotary_position_embedding",
"round",
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
"silu",
"sin",
"softmax",
"sort",
"sub",
"tanh",
]
36 changes: 36 additions & 0 deletions src/ntops/torch/alpha_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import math
import random

import torch

import ntops
from ntops.torch.utils import _cached_make

# SELU saturation value: -lambda * alpha
_ALPHA_P = -1.7580993408473766


def alpha_dropout(input, p=0.5, training=False, inplace=False):
if not training or p == 0:
if inplace:
return input
else:
return input.clone()

q = 1.0 - p
a = 1.0 / math.sqrt(q * (1.0 + p * _ALPHA_P * _ALPHA_P))
b = -a * p * _ALPHA_P
sat = a * _ALPHA_P + b

seed = random.randrange(0, 2**31)

if inplace:
output = input
else:
output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.alpha_dropout.premake, input.ndim)

kernel(input, a, b, sat, p, seed, output)

return output
15 changes: 15 additions & 0 deletions src/ntops/torch/cosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def cosh(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.cosh.premake, input.ndim)

kernel(input, out)

return out
Loading