From cc9a518429b72373c4629e2f473d2f917918f07c Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 16 Mar 2026 16:41:56 +0100 Subject: [PATCH] Add support for template kernels in the NVCUDA backend --- kernel_tuner/backends/nvcuda.py | 24 +++++++++++++++++++----- kernel_tuner/core.py | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/kernel_tuner/backends/nvcuda.py b/kernel_tuner/backends/nvcuda.py index 6729e683..23bd235a 100644 --- a/kernel_tuner/backends/nvcuda.py +++ b/kernel_tuner/backends/nvcuda.py @@ -154,10 +154,7 @@ def compile(self, kernel_instance): """ kernel_string = kernel_instance.kernel_string kernel_name = kernel_instance.name - - # mimic pycuda behavior to wrap kernel_string in extern "C" if not in kernel_string already - if 'extern "C"' not in kernel_string: - kernel_string = 'extern "C" {\n' + kernel_string + "\n}" + expression_name = str.encode(kernel_name) compiler_options = self.compiler_options_bytes if not any([b"--std=" in opt for opt in compiler_options]): @@ -171,20 +168,37 @@ def compile(self, kernel_instance): err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], []) try: + # Add the kernel as an expression. This is necessary for templated kernels to ensure that the + # compiler actually instantiates the kernel that we want to compile. + cuda_error_check(err) + err = nvrtc.nvrtcAddNameExpression(program, expression_name) + + # Compile the program cuda_error_check(err) err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options) + + # Get the PTX cuda_error_check(err) err, size = nvrtc.nvrtcGetPTXSize(program) cuda_error_check(err) buff = b" " * size err = nvrtc.nvrtcGetPTX(program, buff) cuda_error_check(err) + + # Load the module err, self.current_module = driver.cuModuleLoadData(np.char.array(buff)) if err == driver.CUresult.CUDA_ERROR_INVALID_PTX: raise SkippableFailure("uses too much shared data") else: cuda_error_check(err) - err, self.func = driver.cuModuleGetFunction(self.current_module, str.encode(kernel_name)) + + # First, get the "lowered" name of the kernel (i.e., the name inside the PTX). + # After, we can use the lowered name to lookup the kernel in the module. + err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name) + cuda_error_check(err) + err, self.func = driver.cuModuleGetFunction( + self.current_module, lowered_name + ) cuda_error_check(err) # get the number of registers per thread used in this kernel diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 5352ced7..2a71af40 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -707,7 +707,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) ) # check for templated kernel - if kernel_source.lang in ["CUDA", "NVCUDA", "HIP"] and "<" in name and ">" in name: + if kernel_source.lang in ["CUDA", "HIP"] and "<" in name and ">" in name: kernel_string, name = wrap_templated_kernel(kernel_string, name) # Preprocess GPU arguments. Require for handling `Tunable` arguments