Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ jobs:
os:
- ubuntu-22.04
# - windows-2019
python: ['3.10', '3.11', '3.12']
torch_version: ['2.10.0']
python: ['3.10', '3.11', '3.12', '3.13', '3.14']
# torch_version: ['2.7.0', '2.8.0', '2.9.0', '2.10.0', '2.11.0']
torch_version: ['2.11.0']
cuda_short_version: ['126']
exclude:
- torch_version: '2.7.0'
python: '3.14'
- torch_version: '2.8.0'
python: '3.14'

uses: ./.github/workflows/wheels_build.yml
with:
Expand All @@ -22,21 +28,21 @@ jobs:
torch_version: ${{ matrix.torch_version }}
cuda_short_version: ${{ matrix.cuda_short_version }}

build-pypi:
# Single canonical build intended for PyPI: no local CUDA/torch suffix
strategy:
fail-fast: false
matrix:
os: ['ubuntu-22.04']
python: ['3.10', '3.11', '3.12']
# build-pypi:
# # Single canonical build intended for PyPI: no local CUDA/torch suffix
# strategy:
# fail-fast: false
# matrix:
# os: ['ubuntu-22.04']
# python: ['3.10', '3.11', '3.12', '3.13', '3.14']

uses: ./.github/workflows/wheels_build.yml
with:
os: ${{ matrix.os }}
python: ${{ matrix.python }}
torch_version: '2.10.0'
cuda_short_version: '128'
append_local_version: '0' # 0 to disable local version suffix
# uses: ./.github/workflows/wheels_build.yml
# with:
# os: ${{ matrix.os }}
# python: ${{ matrix.python }}
# torch_version: '2.9.0'
# cuda_short_version: '128'
# append_local_version: '0' # 0 to disable local version suffix

# publish to GitHub Release
# gh_release:
Expand Down Expand Up @@ -79,11 +85,12 @@ jobs:


consolidate-wheels:
needs: [build-local, build-pypi]
# needs: [build-local, build-pypi]
needs: [build-local]
runs-on: ubuntu-latest
steps:
- name: Download all wheel artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
path: dist

Expand All @@ -94,7 +101,7 @@ jobs:
ls -l consolidated_wheels

- name: Upload consolidated wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: built-wheels
path: consolidated_wheels
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ jobs:
sudo apt autoremove -y

- name: Recursive checkout
uses: actions/checkout@v3
uses: actions/checkout@v5
with:
submodules: recursive
path: "."
Expand Down Expand Up @@ -236,14 +236,14 @@ jobs:

- name: Upload artifact (local build)
if: ${{ inputs.append_local_version != '0' }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: ${{ inputs.os }}-py${{ inputs.python }}-torch${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}
path: dist/*.whl

- name: Upload artifact (pypi build)
if: ${{ inputs.append_local_version == '0' }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: ${{ inputs.os }}-py${{ inputs.python }}
path: dist/*.whl
Expand Down
49 changes: 49 additions & 0 deletions src/sfast/csrc/operators/cublas/CUDABlas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <c10/cuda/CUDAFunctions.h>
#include <c10/macros/Export.h>
#include <c10/util/irange.h>
#include <torch/version.h>

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
Expand Down Expand Up @@ -226,7 +227,9 @@ cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -239,7 +242,9 @@ void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -252,7 +257,9 @@ void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -267,7 +274,9 @@ void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -282,7 +291,9 @@ void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand Down Expand Up @@ -311,7 +322,11 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {

cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (at::globalContext().allowFP16ReductionCuBLAS()) {
#endif
at::Half falpha = alpha;
at::Half fbeta = beta;
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
Expand Down Expand Up @@ -350,7 +365,9 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
Expand Down Expand Up @@ -383,7 +400,9 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -396,7 +415,9 @@ void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
template <>
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -410,7 +431,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
template <>
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -427,7 +450,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
template <>
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -443,7 +468,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
template <>
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand Down Expand Up @@ -490,12 +517,20 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowFP16ReductionCuBLAS() != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
#endif
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (at::globalContext().allowFP16ReductionCuBLAS()) {
#endif
at::Half falpha = alpha;
at::Half fbeta = beta;
TORCH_CUDABLAS_CHECK(cublasGemmEx_(
Expand Down Expand Up @@ -606,7 +641,9 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
Expand All @@ -617,7 +654,11 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowBF16ReductionCuBLAS() != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
#endif
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
Expand Down Expand Up @@ -1126,7 +1167,9 @@ void trsmBatched<c10::complex<double>>(
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
Expand All @@ -1145,7 +1188,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
Expand All @@ -1160,7 +1205,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
template <>
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
Expand All @@ -1175,7 +1222,9 @@ void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
// loss still happens on TF32. So we disable it here.
NoTF32Guard disable_tf32;
// See Note [Writing Nondeterministic Operations]
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
globalContext().alertCuBLASConfigNotDeterministic();
#endif
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t op = _cublasOpFromChar(trans);
_cublasAdjustLdLevel2(m, n, &lda);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/extension.h>
#include <torch/version.h>

#include <c10/cuda/CUDAMathCompat.h>
#include <c10/cuda/CUDAStream.h>
Expand Down Expand Up @@ -486,7 +487,11 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
auto dispatch_bf16 = [&] {
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowBF16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (at::globalContext().allowBF16ReductionCuBLAS()) {
#endif
output =
CutlassDualGemmLauncher<at::BFloat16, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU_taylor_fast,
Expand All @@ -506,7 +511,11 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
AT_DISPATCH_CASE(
at::kHalf,
[&] {
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
#else
if (at::globalContext().allowFP16ReductionCuBLAS()) {
#endif
output = CutlassDualGemmLauncher<
at::Half, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU_taylor_fast,
Expand Down