diff --git a/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp b/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp index cbf18e0..47d4b98 100644 --- a/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp +++ b/include/sofieBLAS/backends/cuda/sofieBLAS_cublas.hpp @@ -49,7 +49,6 @@ struct PairEq { class BlasCuda { cublasLtHandle_t ltHandle = nullptr; - cublasHandle_t handle = nullptr; cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatmulPreference_t preference = nullptr; void *d_workspace = nullptr; @@ -72,7 +71,6 @@ class BlasCuda { BlasCuda(alpaka::QueueCudaRtNonBlocking &queue) : m_queue{queue} { stream = static_cast(m_queue.getNativeHandle()); CHECK_CUBLAS(cublasLtCreate(<Handle)); - CHECK_CUBLAS(cublasCreate(&handle)); heuristic = {}; CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); @@ -118,10 +116,10 @@ class BlasCuda { } } - void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k) { - CheckAndAddLayout(k, m); - CheckAndAddLayout(k, n); - CheckAndAddLayout(m, n); + void AddLayoutConfig(std::size_t m, std::size_t n, std::size_t k, std::size_t lda, std::size_t ldb, std::size_t ldc) { + CheckAndAddLayout(k, m, lda); + CheckAndAddLayout(k, n, ldb); + CheckAndAddLayout(m, n, ldc); } template @@ -171,7 +169,6 @@ gemm(char transa, char transb, const unsigned int m, 1, &localHeuristic, &returnedResults)); - if (returnedResults == 0) { cublasLtMatmulDescDestroy(localDesc); std::cerr << "No suitable cuBLASLt algorithm found!\n"; @@ -238,7 +235,8 @@ gemmrelu(char transa, char transb, const unsigned int m, 1, &localHeuristic, &error_flag)); - + std::cout << "Requested workspace: " + << localHeuristic.workspaceSize << std::endl; if (error_flag == 0) { cublasLtMatmulDescDestroy(localDesc); std::cerr << "No suitable cuBLASLt algorithm found!\n"; @@ -313,11 +311,10 @@ gemmrelu(char transa, char transb, const unsigned int m, private: alpaka::QueueCudaRtNonBlocking m_queue; - void CheckAndAddLayout(size_t rows, size_t cols) { + void CheckAndAddLayout(size_t rows, size_t cols, size_t ld) { auto key = std::make_pair(rows, cols); if (LayoutStore.find(key) == LayoutStore.end()) { cublasLtMatrixLayout_t temp = nullptr; - size_t ld = rows; CHECK_CUBLAS( cublasLtMatrixLayoutCreate(&temp, CUDA_R_32F, rows, cols, ld)); LayoutStore.emplace(key, temp);