Skip to content

Commit b59d9db

Browse files
committed
feat(ops): implement CausalSoftmax operator with Hygon backend.
1 parent abde23a commit b59d9db

3 files changed

Lines changed: 87 additions & 2 deletions

File tree

src/cuda/causal_softmax/kernel.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
22
#define INFINI_OPS_CUDA_CAUSAL_SOFTMAX_KERNEL_H_
33

4+
#include <algorithm>
45
#include <cstdint>
6+
#include <type_traits>
57

68
#include "base/causal_softmax.h"
79
#include "cuda/causal_softmax/kernel.cuh"
@@ -11,6 +13,17 @@
1113

1214
namespace infini::ops {
1315

16+
namespace causal_softmax::detail {
17+
18+
template <typename Backend, typename = void>
19+
struct MaxBlockSize : std::integral_constant<int, 2048> {};
20+
21+
template <typename Backend>
22+
struct MaxBlockSize<Backend, std::void_t<decltype(Backend::max_block_size)>>
23+
: std::integral_constant<int, Backend::max_block_size> {};
24+
25+
} // namespace causal_softmax::detail
26+
1427
template <typename Backend>
1528
class CudaCausalSoftmax : public CausalSoftmax {
1629
public:
@@ -32,10 +45,13 @@ class CudaCausalSoftmax : public CausalSoftmax {
3245
std::abort();
3346
}
3447

35-
int block_size = GetOptimalBlockSize();
48+
constexpr int kMaxBlockSize =
49+
causal_softmax::detail::MaxBlockSize<Backend>::value;
50+
int block_size = std::min(GetOptimalBlockSize(), kMaxBlockSize);
3651

3752
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
38-
AllCudaBlockSizes>(
53+
SupportedCudaBlockSizesType<
54+
causal_softmax::detail::MaxBlockSize<Backend>::value>>(
3955
// TODO: Output dtype should use the one passed in during construction.
4056
{static_cast<int64_t>(out.dtype()), block_size},
4157
[&](auto list_tag) {

src/cuda/kernel_commons.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,38 @@ namespace infini::ops {
2828

2929
using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;
3030

31+
template <int max_block_size>
32+
struct SupportedCudaBlockSizes;
33+
34+
template <>
35+
struct SupportedCudaBlockSizes<2048> {
36+
using type = AllCudaBlockSizes;
37+
};
38+
39+
template <>
40+
struct SupportedCudaBlockSizes<1024> {
41+
using type = List<128, 256, 512, 1024>;
42+
};
43+
44+
template <>
45+
struct SupportedCudaBlockSizes<512> {
46+
using type = List<128, 256, 512>;
47+
};
48+
49+
template <>
50+
struct SupportedCudaBlockSizes<256> {
51+
using type = List<128, 256>;
52+
};
53+
54+
template <>
55+
struct SupportedCudaBlockSizes<128> {
56+
using type = List<128>;
57+
};
58+
59+
template <int max_block_size>
60+
using SupportedCudaBlockSizesType =
61+
typename SupportedCudaBlockSizes<max_block_size>::type;
62+
3163
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
3264
// Cache `cudaDeviceProp` per device, initialized once at first access.
3365
class DevicePropertyCache {

src/hygon/causal_softmax/kernel.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
2+
#define INFINI_OPS_HYGON_CAUSAL_SOFTMAX_KERNEL_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include <cuda_runtime.h>
8+
// clang-format on
9+
10+
// clang-format off
11+
#include "hygon/device_.h"
12+
// clang-format on
13+
14+
#include "cuda/causal_softmax/kernel.h"
15+
16+
namespace infini::ops {
17+
18+
namespace causal_softmax {
19+
20+
struct HygonBackend {
21+
using stream_t = cudaStream_t;
22+
23+
static constexpr int max_block_size = 256;
24+
};
25+
26+
} // namespace causal_softmax
27+
28+
template <>
29+
class Operator<CausalSoftmax, Device::Type::kHygon>
30+
: public CudaCausalSoftmax<causal_softmax::HygonBackend> {
31+
public:
32+
using CudaCausalSoftmax<causal_softmax::HygonBackend>::CudaCausalSoftmax;
33+
};
34+
35+
} // namespace infini::ops
36+
37+
#endif

0 commit comments

Comments
 (0)