From bd06f4c18d97dfc4b317dc958d7d07169473a548 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Thu, 19 Mar 2026 13:17:58 +0800 Subject: [PATCH 1/2] Update backend_cuda.hpp to support sm_89 machine --- .../extension/backend/backend_cuda.hpp | 359 +++++++++++++----- 1 file changed, 272 insertions(+), 87 deletions(-) diff --git a/openequivariance/openequivariance/extension/backend/backend_cuda.hpp b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp index 84a5e6f..e52e2fe 100644 --- a/openequivariance/openequivariance/extension/backend/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp @@ -6,37 +6,44 @@ #include #include #include -#include +#include +// ===== NEW: caching includes ===== +#include +#include +#include +#include +#include +// ===== END NEW ===== using namespace std; using Stream = cudaStream_t; -#define NVRTC_SAFE_CALL(x) \ -do { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - std::cerr << "\nerror: " #x " failed with error " \ - << nvrtcGetErrorString(result) << '\n'; \ - exit(1); \ - } \ -} while(0) - -#define CUDA_SAFE_CALL(x) \ -do { \ - CUresult result = x; \ - if (result != CUDA_SUCCESS) { \ - const char *msg; \ - cuGetErrorName(result, &msg); \ - std::cerr << "\nerror: " #x " failed with error " \ - << msg << '\n'; \ - exit(1); \ - } \ -} while(0) +#define NVRTC_SAFE_CALL(x) \ + do { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + std::cerr << "\nerror: " #x " failed with error " \ + << nvrtcGetErrorString(result) << '\n'; \ + exit(1); \ + } \ + } while(0) + +#define CUDA_SAFE_CALL(x) \ + do { \ + CUresult result = x; \ + if (result != CUDA_SUCCESS) { \ + const char *msg; \ + cuGetErrorName(result, &msg); \ + std::cerr << "\nerror: " #x " failed with error " \ + << msg << '\n'; \ + exit(1); \ + } \ + } while(0) #define CUDA_ERRCHK(ans) { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) { - if (code != cudaSuccess) + if (code != cudaSuccess) { fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); if (abort) exit(code); @@ -45,22 +52,22 @@ inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=t class CUDA_Allocator { public: - static void* gpu_alloc (size_t size) { + static void* gpu_alloc(size_t size) { void* ptr; - CUDA_ERRCHK( cudaMalloc((void**) &ptr, size )) + CUDA_ERRCHK( cudaMalloc((void**) &ptr, size) ) return ptr; } - static void gpu_free (void* ptr) { - CUDA_ERRCHK( cudaFree(ptr)) + static void gpu_free(void* ptr) { + CUDA_ERRCHK( cudaFree(ptr) ) } - static void copy_host_to_device (void* host, void* device, size_t size) { - CUDA_ERRCHK( cudaMemcpy(device, host, size, cudaMemcpyHostToDevice)); + static void copy_host_to_device(void* host, void* device, size_t size) { + CUDA_ERRCHK( cudaMemcpy(device, host, size, cudaMemcpyHostToDevice) ); } - static void copy_device_to_host (void* host, void* device, size_t size) { - CUDA_ERRCHK( cudaMemcpy(host, device, size, cudaMemcpyDeviceToHost)); + static void copy_device_to_host(void* host, void* device, size_t size) { + CUDA_ERRCHK( cudaMemcpy(host, device, size, cudaMemcpyDeviceToHost) ); } }; @@ -68,7 +75,7 @@ class GPUTimer { cudaEvent_t start_evt, stop_evt; public: - GPUTimer() { + GPUTimer() { cudaEventCreate(&start_evt); cudaEventCreate(&stop_evt); } @@ -82,18 +89,17 @@ class GPUTimer { cudaEventRecord(stop_evt); cudaEventSynchronize(stop_evt); cudaEventElapsedTime(&time_millis, start_evt, stop_evt); - return time_millis; + return time_millis; } void clear_L2_cache() { size_t element_count = 25000000; - - int* ptr = (int*) (CUDA_Allocator::gpu_alloc(element_count * sizeof(int))); + int* ptr = (int*)(CUDA_Allocator::gpu_alloc(element_count * sizeof(int))); CUDA_ERRCHK(cudaMemset(ptr, 42, element_count * sizeof(int))) CUDA_Allocator::gpu_free(ptr); cudaDeviceSynchronize(); } - + ~GPUTimer() { cudaEventDestroy(start_evt); cudaEventDestroy(stop_evt); @@ -102,15 +108,15 @@ class GPUTimer { class __attribute__((visibility("default"))) DeviceProp { public: - std::string name; + std::string name; int warpsize; int major, minor; int multiprocessorCount; int maxSharedMemPerBlock; - int maxSharedMemoryPerMultiprocessor; + int maxSharedMemoryPerMultiprocessor; DeviceProp(int device_id) { - cudaDeviceProp prop; + cudaDeviceProp prop; cudaGetDeviceProperties(&prop, device_id); name = std::string(prop.name); CUDA_ERRCHK(cudaDeviceGetAttribute(&maxSharedMemoryPerMultiprocessor, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)); @@ -136,20 +142,20 @@ class __attribute__((visibility("default"))) KernelLaunchConfig { KernelLaunchConfig(uint32_t num_blocks, uint32_t num_threads_per_block, uint32_t smem) : num_blocks(num_blocks), num_threads(num_threads_per_block), - smem(smem) + smem(smem) { } KernelLaunchConfig(int64_t num_blocks_i, int64_t num_threads_i, int64_t smem_i) : KernelLaunchConfig( static_cast(num_blocks_i), static_cast(num_threads_i), - static_cast(smem_i)) + static_cast(smem_i)) { } }; /* -* This page is a useful resource on NVRTC: -* https://docs.nvidia.com/cuda/nvrtc/index.html#example-using-nvrtcgettypename -*/ + * This page is a useful resource on NVRTC: + * https://docs.nvidia.com/cuda/nvrtc/index.html#example-using-nvrtcgettypename + */ class __attribute__((visibility("default"))) CUJITKernel { private: @@ -157,28 +163,180 @@ class __attribute__((visibility("default"))) CUJITKernel { bool compiled = false; char* code = nullptr; + size_t codeSize = 0; // ===== NEW: stored as member for caching ===== int cu_major, cu_minor; CUlibrary library; - vector supported_archs; + vector supported_archs; vector kernel_names; vector kernels; + vector lowered_names; // ===== NEW: stored for caching ===== + + // ===== NEW: Cache helper methods ===== + + // Get cache directory from CUDA_CACHE_PATH env var (or default) + static std::string get_cache_dir() { + const char* env = std::getenv("CUDA_CACHE_PATH"); + if (env && std::strlen(env) > 0) { + return std::string(env); + } + return "./cuda_jit_cache"; + } + + // FNV-1a 64-bit hash + static std::string fnv1a_hash(const std::string& input) { + uint64_t hash = 14695981039346656037ULL; + for (unsigned char c : input) { + hash ^= static_cast(c); + hash *= 1099511628211ULL; + } + char buf[17]; + std::snprintf(buf, sizeof(buf), "%016llx", (unsigned long long)hash); + return std::string(buf); + } + + // Compute a unique cache key from source + arch + kernel names + NVRTC version + std::string compute_cache_key(const std::string& sm_flag) { + int nvrtc_major, nvrtc_minor; + nvrtcVersion(&nvrtc_major, &nvrtc_minor); + + std::string combined; + combined += "cache_v1\n"; + combined += "nvrtc=" + std::to_string(nvrtc_major) + "." + std::to_string(nvrtc_minor) + "\n"; + combined += "arch=" + sm_flag + "\n"; + combined += "src_len=" + std::to_string(kernel_plaintext.size()) + "\n"; + combined += kernel_plaintext; + combined += "\n"; + for (const auto& name : kernel_names) { + combined += "kern=" + name + "\n"; + } + + return "sm_" + std::to_string(cu_major) + std::to_string(cu_minor) + + "_" + fnv1a_hash(combined); + } + + // Create directories recursively (POSIX) + static void mkdir_recursive(const std::string& path) { + std::string current; + for (size_t i = 0; i < path.size(); i++) { + current += path[i]; + if (path[i] == '/') { + ::mkdir(current.c_str(), 0755); + } + } + if (!current.empty()) { + ::mkdir(current.c_str(), 0755); + } + } + + // Try to load cached CUBIN + lowered names from disk. + // Returns true on success (kernels are ready), false on cache miss or any error. + bool try_load_from_cache(const std::string& cubin_path, const std::string& names_path) { + // 1. Open both files + std::ifstream cubin_file(cubin_path, std::ios::binary | std::ios::ate); + if (!cubin_file.is_open()) return false; + + std::ifstream names_file(names_path); + if (!names_file.is_open()) return false; + + // 2. Read lowered names + std::vector cached_names; + std::string line; + while (std::getline(names_file, line)) { + if (!line.empty()) { + cached_names.push_back(line); + } + } + + if (cached_names.size() != kernel_names.size()) { + return false; // Mismatch: cache is stale or corrupted + } + + // 3. Read CUBIN binary + size_t size = static_cast(cubin_file.tellg()); + if (size == 0) return false; + cubin_file.seekg(0, std::ios::beg); + + char* cached_code = new char[size]; + cubin_file.read(cached_code, size); + if (!cubin_file.good()) { + delete[] cached_code; + return false; + } + + // 4. Load CUBIN into CUDA + CUDA_SAFE_CALL(cuInit(0)); + + CUresult load_result = cuLibraryLoadData(&library, cached_code, 0, 0, 0, 0, 0, 0); + if (load_result != CUDA_SUCCESS) { + delete[] cached_code; + return false; // CUBIN incompatible (e.g., driver update), will recompile + } + + // 5. Resolve kernel handles + for (size_t i = 0; i < cached_names.size(); i++) { + CUkernel k; + CUresult r = cuLibraryGetKernel(&k, library, cached_names[i].c_str()); + if (r != CUDA_SUCCESS) { + kernels.clear(); + cuLibraryUnload(library); + delete[] cached_code; + return false; + } + kernels.push_back(k); + } + + // Success — store state + code = cached_code; + codeSize = size; + lowered_names = cached_names; + + std::cerr << "[CUJITKernel] Cache HIT: loaded " << kernel_names.size() + << " kernel(s) from " << cubin_path << std::endl; + return true; + } + + // Save CUBIN + lowered names to disk for future runs. + void save_to_cache(const std::string& cubin_path, const std::string& names_path) { + std::string dir = cubin_path.substr(0, cubin_path.find_last_of('/')); + if (!dir.empty()) { + mkdir_recursive(dir); + } + + std::ofstream cubin_file(cubin_path, std::ios::binary); + if (cubin_file.is_open()) { + cubin_file.write(code, static_cast(codeSize)); + cubin_file.close(); + } + + std::ofstream names_file(names_path); + if (names_file.is_open()) { + for (const auto& name : lowered_names) { + names_file << name << "\n"; + } + names_file.close(); + } + + std::cerr << "[CUJITKernel] Cache SAVE: wrote " << kernel_names.size() + << " kernel(s) to " << cubin_path << std::endl; + } + + // ===== END NEW ===== public: string kernel_plaintext; CUJITKernel(string plaintext) : kernel_plaintext(plaintext) { - - int num_supported_archs; + + int num_supported_archs; NVRTC_SAFE_CALL( nvrtcGetNumSupportedArchs(&num_supported_archs)); - - supported_archs.resize(num_supported_archs); + + supported_archs.resize(num_supported_archs); NVRTC_SAFE_CALL( - nvrtcGetSupportedArchs(supported_archs.data())); - + nvrtcGetSupportedArchs(supported_archs.data())); NVRTC_SAFE_CALL( nvrtcCreateProgram( &prog, // prog @@ -196,7 +354,7 @@ class __attribute__((visibility("default"))) CUJITKernel { } void compile(vector kernel_names_i, vector> template_param_list, int opt_level=3) { - DeviceProp dp(0); // We only query the first device on the system at the moment + DeviceProp dp(0); cu_major = dp.major; cu_minor = dp.minor; @@ -210,31 +368,30 @@ class __attribute__((visibility("default"))) CUJITKernel { int device_arch = cu_major * 10 + cu_minor; if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()){ - int nvrtc_version_major, nvrtc_version_minor; + int nvrtc_version_major, nvrtc_version_minor; NVRTC_SAFE_CALL( - nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor)); + nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor)); - throw std::runtime_error("NVRTC version " - + std::to_string(nvrtc_version_major) - + "." - + std::to_string(nvrtc_version_minor) - + " does not support device architecture " + throw std::runtime_error("NVRTC version " + + std::to_string(nvrtc_version_major) + + "." + + std::to_string(nvrtc_version_minor) + + " does not support device architecture " + std::to_string(device_arch) - ); + ); } for(unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) { string kernel_name = kernel_names_i[kernel]; vector &template_params = template_param_list[kernel]; - // Step 1: Generate kernel names from the template parameters if(template_params.size() == 0) { kernel_names.push_back(kernel_name); } else { std::string result = kernel_name + "<"; for(unsigned int i = 0; i < template_params.size(); i++) { - result += std::to_string(template_params[i]); + result += std::to_string(template_params[i]); if(i != template_params.size() - 1) { result += ","; } @@ -242,26 +399,43 @@ class __attribute__((visibility("default"))) CUJITKernel { result += ">"; kernel_names.push_back(result); } - } - + std::string sm = "-arch=sm_" + std::to_string(cu_major) + std::to_string(cu_minor); + // ===== NEW: Try loading from cache before NVRTC compilation ===== + { + std::string cache_key = compute_cache_key(sm); + std::string cache_dir = get_cache_dir(); + std::string cubin_cache = cache_dir + "/" + cache_key + ".cubin"; + std::string names_cache = cache_dir + "/" + cache_key + ".names"; + + if (try_load_from_cache(cubin_cache, names_cache)) { + compiled = true; + return; // Skip NVRTC compilation entirely! + } + + std::cerr << "[CUJITKernel] Cache MISS for sm_" + << cu_major << cu_minor + << ", compiling with NVRTC..." << std::endl; + } + // ===== END NEW ===== + std::vector opts = { "--std=c++17", sm.c_str(), "--split-compile=0", "--use_fast_math" - }; + }; // ========================================================= - // Step 2: Add name expressions, compile + // Step 2: Add name expressions, compile for(size_t i = 0; i < kernel_names.size(); ++i) NVRTC_SAFE_CALL(nvrtcAddNameExpression(prog, kernel_names[i].c_str())); - nvrtcResult compileResult = nvrtcCompileProgram(prog, // prog - static_cast(opts.size()), // numOptions - opts.data()); // options + nvrtcResult compileResult = nvrtcCompileProgram(prog, + static_cast(opts.size()), + opts.data()); size_t logSize; NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); @@ -270,15 +444,14 @@ class __attribute__((visibility("default"))) CUJITKernel { if (compileResult != NVRTC_SUCCESS) { throw std::logic_error("NVRTC Fail, log: " + std::string(log)); - } + } delete[] log; compiled = true; // ========================================================= - // Step 3: Get PTX, initialize device, context, and module + // Step 3: Get CUBIN, initialize device, context, and module - size_t codeSize; - NVRTC_SAFE_CALL(nvrtcGetCUBINSize(prog, &codeSize)); + NVRTC_SAFE_CALL(nvrtcGetCUBINSize(prog, &codeSize)); // ===== CHANGED: use member ===== code = new char[codeSize]; NVRTC_SAFE_CALL(nvrtcGetCUBIN(prog, code)); @@ -289,20 +462,32 @@ class __attribute__((visibility("default"))) CUJITKernel { const char *name; NVRTC_SAFE_CALL(nvrtcGetLoweredName( - prog, - kernel_names[i].c_str(), // name expression - &name // lowered name - )); + prog, + kernel_names[i].c_str(), + &name + )); + + lowered_names.push_back(std::string(name)); // ===== NEW: store lowered name ===== kernels.emplace_back(); CUDA_SAFE_CALL(cuLibraryGetKernel(&(kernels[i]), library, name)); } + + // ===== NEW: Save to cache for future runs ===== + { + std::string cache_key = compute_cache_key(sm); + std::string cache_dir = get_cache_dir(); + std::string cubin_cache = cache_dir + "/" + cache_key + ".cubin"; + std::string names_cache = cache_dir + "/" + cache_key + ".names"; + save_to_cache(cubin_cache, names_cache); + } + // ===== END NEW ===== } void set_max_smem(int kernel_id, uint32_t max_smem_bytes) { if(!compiled) throw std::logic_error("JIT object has not been compiled!"); - if(kernel_id >= kernels.size()) + if(kernel_id >= (int)kernels.size()) throw std::logic_error("Kernel index out of range!"); int device_count; @@ -323,10 +508,10 @@ class __attribute__((visibility("default"))) CUJITKernel { } void execute(int kernel_id, void* args[], KernelLaunchConfig config) { - if(kernel_id >= kernels.size()) + if(kernel_id >= (int)kernels.size()) throw std::logic_error("Kernel index out of range!"); - CUcontext pctx = NULL; + CUcontext pctx = NULL; CUDA_SAFE_CALL(cuCtxGetCurrent(&pctx)); if(pctx == NULL) { @@ -339,19 +524,19 @@ class __attribute__((visibility("default"))) CUJITKernel { } CUDA_SAFE_CALL( - cuLaunchKernel( (CUfunction) (kernels[kernel_id]), - config.num_blocks, 1, 1, // grid dim - config.num_threads, 1, 1, // block dim - config.smem, config.hStream, // shared mem and stream - args, NULL) // arguments - ); + cuLaunchKernel( (CUfunction)(kernels[kernel_id]), + config.num_blocks, 1, 1, + config.num_threads, 1, 1, + config.smem, config.hStream, + args, NULL) + ); } ~CUJITKernel() { if(compiled) { auto result = cuLibraryUnload(library); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { - std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl; + std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl; } delete[] code; From ba51a5b1aafc5ad6a35c78849bff128454b9bbf4 Mon Sep 17 00:00:00 2001 From: Franklalalala <42742342+Franklalalala@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:26:26 +0800 Subject: [PATCH 2/2] Update backend_cuda.hpp --- .../extension/backend/backend_cuda.hpp | 362 +++++++++++------- 1 file changed, 223 insertions(+), 139 deletions(-) diff --git a/openequivariance/openequivariance/extension/backend/backend_cuda.hpp b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp index e52e2fe..d4906a3 100644 --- a/openequivariance/openequivariance/extension/backend/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/backend/backend_cuda.hpp @@ -7,13 +7,30 @@ #include #include #include -// ===== NEW: caching includes ===== + +// ===== Caching & threading includes ===== #include #include +#include // ← 修复: std::snprintf, std::rename, std::remove +#include // ← 修复: PRIx64 跨平台格式化 #include #include -#include -// ===== END NEW ===== +#include +#include + +#ifdef _WIN32 + #include + #include + #define GET_PID _getpid + #define MKDIR(path) _mkdir(path) +#else + #include + #include + #include + #define GET_PID getpid + #define MKDIR(path) ::mkdir(path, 0755) +#endif +// ======================================== using namespace std; using Stream = cudaStream_t; @@ -153,8 +170,14 @@ class __attribute__((visibility("default"))) KernelLaunchConfig { }; /* - * This page is a useful resource on NVRTC: - * https://docs.nvidia.com/cuda/nvrtc/index.html#example-using-nvrtcgettypename + * CUBIN Caching for NVRTC JIT Compilation + * ======================================== + * - Cache location: $OEQ_CACHE_PATH > $HOME/.cache/openequivariance > /tmp/oeq_jit_cache + * - Disable: OEQ_DISABLE_CACHE=1 + * - Clear: rm -rf ~/.cache/openequivariance/ + * - Cache key: FNV-1a hash of (NVRTC version + compile opts + source + kernel names + arch) + * - Format: .names file stores "original_name\tlowered_name" per line for collision detection + * - Writes use temp file + rename() for crash safety */ class __attribute__((visibility("default"))) CUJITKernel { @@ -163,7 +186,7 @@ class __attribute__((visibility("default"))) CUJITKernel { bool compiled = false; char* code = nullptr; - size_t codeSize = 0; // ===== NEW: stored as member for caching ===== + size_t codeSize = 0; int cu_major, cu_minor; CUlibrary library; @@ -172,40 +195,76 @@ class __attribute__((visibility("default"))) CUJITKernel { vector kernel_names; vector kernels; - vector lowered_names; // ===== NEW: stored for caching ===== + vector lowered_names; + + // =================== Cache & Context Helpers =================== - // ===== NEW: Cache helper methods ===== + // Ensure a CUDA context is bound on the current thread. + // Critical for multi-threaded usage where worker threads + // don't automatically have a CUDA context. + void ensure_cuda_context() { + CUcontext pctx = NULL; + CUDA_SAFE_CALL(cuCtxGetCurrent(&pctx)); + if (pctx == NULL) { + int device_id; + CUDA_ERRCHK(cudaGetDevice(&device_id)); + CUdevice dev; + CUDA_SAFE_CALL(cuDeviceGet(&dev, device_id)); + CUDA_SAFE_CALL(cuDevicePrimaryCtxRetain(&pctx, dev)); + CUDA_SAFE_CALL(cuCtxSetCurrent(pctx)); + } + } + + // Check if caching is enabled (default: yes). + static bool cache_enabled() { + static int result = -1; + if (result < 0) { + const char* env = std::getenv("OEQ_DISABLE_CACHE"); + result = (env && (std::strcmp(env, "1") == 0 || + std::strcmp(env, "true") == 0)) ? 0 : 1; + } + return result != 0; + } - // Get cache directory from CUDA_CACHE_PATH env var (or default) + // Get cache directory. Priority: + // $OEQ_CACHE_PATH > $HOME/.cache/openequivariance > /tmp/oeq_jit_cache static std::string get_cache_dir() { - const char* env = std::getenv("CUDA_CACHE_PATH"); - if (env && std::strlen(env) > 0) { + const char* env = std::getenv("OEQ_CACHE_PATH"); + if (env && std::strlen(env) > 0) return std::string(env); - } - return "./cuda_jit_cache"; + const char* home = std::getenv("HOME"); + if (home && std::strlen(home) > 0) + return std::string(home) + "/.cache/openequivariance"; + return "/tmp/oeq_jit_cache"; } - // FNV-1a 64-bit hash - static std::string fnv1a_hash(const std::string& input) { + // FNV-1a 64-bit hash → 16-char hex string. + static std::string fnv1a_hash_hex(const std::string& input) { uint64_t hash = 14695981039346656037ULL; for (unsigned char c : input) { hash ^= static_cast(c); hash *= 1099511628211ULL; } char buf[17]; - std::snprintf(buf, sizeof(buf), "%016llx", (unsigned long long)hash); + std::snprintf(buf, sizeof(buf), "%016" PRIx64, hash); return std::string(buf); } - // Compute a unique cache key from source + arch + kernel names + NVRTC version - std::string compute_cache_key(const std::string& sm_flag) { + // Build a deterministic cache key from ALL compilation inputs. + std::string compute_cache_key(const std::vector& compile_opts) { int nvrtc_major, nvrtc_minor; nvrtcVersion(&nvrtc_major, &nvrtc_minor); std::string combined; - combined += "cache_v1\n"; - combined += "nvrtc=" + std::to_string(nvrtc_major) + "." + std::to_string(nvrtc_minor) + "\n"; - combined += "arch=" + sm_flag + "\n"; + combined.reserve(kernel_plaintext.size() + 512); + combined += "cache_v2\n"; + combined += "nvrtc=" + std::to_string(nvrtc_major) + "." + + std::to_string(nvrtc_minor) + "\n"; + for (const char* opt : compile_opts) { + combined += "opt="; + combined += opt; + combined += "\n"; + } combined += "src_len=" + std::to_string(kernel_plaintext.size()) + "\n"; combined += kernel_plaintext; combined += "\n"; @@ -213,120 +272,151 @@ class __attribute__((visibility("default"))) CUJITKernel { combined += "kern=" + name + "\n"; } - return "sm_" + std::to_string(cu_major) + std::to_string(cu_minor) - + "_" + fnv1a_hash(combined); + return "sm" + std::to_string(cu_major) + std::to_string(cu_minor) + + "_" + fnv1a_hash_hex(combined); } - // Create directories recursively (POSIX) + // mkdir -p equivalent. static void mkdir_recursive(const std::string& path) { std::string current; for (size_t i = 0; i < path.size(); i++) { current += path[i]; - if (path[i] == '/') { - ::mkdir(current.c_str(), 0755); - } - } - if (!current.empty()) { - ::mkdir(current.c_str(), 0755); + if (path[i] == '/' && current.size() > 1) + MKDIR(current.c_str()); } + if (!current.empty() && current != "/") + MKDIR(current.c_str()); } // Try to load cached CUBIN + lowered names from disk. - // Returns true on success (kernels are ready), false on cache miss or any error. - bool try_load_from_cache(const std::string& cubin_path, const std::string& names_path) { - // 1. Open both files + // Returns true on success (kernels vector populated), false on miss/error. + bool try_load_from_cache(const std::string& cubin_path, + const std::string& names_path) { std::ifstream cubin_file(cubin_path, std::ios::binary | std::ios::ate); if (!cubin_file.is_open()) return false; std::ifstream names_file(names_path); if (!names_file.is_open()) return false; - // 2. Read lowered names - std::vector cached_names; - std::string line; - while (std::getline(names_file, line)) { - if (!line.empty()) { - cached_names.push_back(line); + // Parse names file: each line is "original_name\tlowered_name" + std::vector cached_orig; + std::vector cached_lowered; + { + std::string line; + while (std::getline(names_file, line)) { + if (line.empty()) continue; + auto tab = line.find('\t'); + if (tab == std::string::npos) return false; // corrupted + cached_orig.push_back(line.substr(0, tab)); + cached_lowered.push_back(line.substr(tab + 1)); } } + names_file.close(); + + if (cached_orig.size() != kernel_names.size()) return false; - if (cached_names.size() != kernel_names.size()) { - return false; // Mismatch: cache is stale or corrupted + // Verify original kernel names match (guards against hash collisions) + for (size_t i = 0; i < kernel_names.size(); i++) { + if (cached_orig[i] != kernel_names[i]) return false; } - // 3. Read CUBIN binary - size_t size = static_cast(cubin_file.tellg()); + // Read CUBIN binary + auto size = static_cast(cubin_file.tellg()); if (size == 0) return false; cubin_file.seekg(0, std::ios::beg); - char* cached_code = new char[size]; - cubin_file.read(cached_code, size); - if (!cubin_file.good()) { - delete[] cached_code; - return false; - } + char* buf = new char[size]; + cubin_file.read(buf, static_cast(size)); + bool read_ok = cubin_file.good(); + cubin_file.close(); + if (!read_ok) { delete[] buf; return false; } - // 4. Load CUBIN into CUDA + // Load CUBIN into CUDA CUDA_SAFE_CALL(cuInit(0)); + ensure_cuda_context(); - CUresult load_result = cuLibraryLoadData(&library, cached_code, 0, 0, 0, 0, 0, 0); + CUresult load_result = cuLibraryLoadData(&library, buf, 0, 0, 0, 0, 0, 0); if (load_result != CUDA_SUCCESS) { - delete[] cached_code; - return false; // CUBIN incompatible (e.g., driver update), will recompile + // CUBIN incompatible (driver update etc.) — will recompile + delete[] buf; + return false; } - // 5. Resolve kernel handles - for (size_t i = 0; i < cached_names.size(); i++) { + // Resolve kernel handles + for (size_t i = 0; i < cached_lowered.size(); i++) { CUkernel k; - CUresult r = cuLibraryGetKernel(&k, library, cached_names[i].c_str()); + CUresult r = cuLibraryGetKernel(&k, library, cached_lowered[i].c_str()); if (r != CUDA_SUCCESS) { kernels.clear(); cuLibraryUnload(library); - delete[] cached_code; + delete[] buf; return false; } kernels.push_back(k); } - // Success — store state - code = cached_code; + // Success + code = buf; codeSize = size; - lowered_names = cached_names; - - std::cerr << "[CUJITKernel] Cache HIT: loaded " << kernel_names.size() - << " kernel(s) from " << cubin_path << std::endl; + lowered_names = std::move(cached_lowered); return true; } - // Save CUBIN + lowered names to disk for future runs. - void save_to_cache(const std::string& cubin_path, const std::string& names_path) { + // Save CUBIN + lowered names to disk via atomic temp+rename. + void save_to_cache(const std::string& cubin_path, + const std::string& names_path) { std::string dir = cubin_path.substr(0, cubin_path.find_last_of('/')); - if (!dir.empty()) { - mkdir_recursive(dir); - } + if (!dir.empty()) mkdir_recursive(dir); + + // Unique temp suffix: PID + thread ID + std::stringstream ss; + ss << GET_PID() << "_" << std::this_thread::get_id(); + std::string tmp_suffix = ".tmp." + ss.str(); + + std::string tmp_cubin = cubin_path + tmp_suffix; + std::string tmp_names = names_path + tmp_suffix; - std::ofstream cubin_file(cubin_path, std::ios::binary); - if (cubin_file.is_open()) { - cubin_file.write(code, static_cast(codeSize)); - cubin_file.close(); + // Write CUBIN + { + std::ofstream f(tmp_cubin, std::ios::binary); + if (!f.is_open()) return; + f.write(code, static_cast(codeSize)); + f.close(); + if (f.fail()) { std::remove(tmp_cubin.c_str()); return; } } - std::ofstream names_file(names_path); - if (names_file.is_open()) { - for (const auto& name : lowered_names) { - names_file << name << "\n"; + // Write names: "original_name\tlowered_name" per line + { + std::ofstream f(tmp_names); + if (!f.is_open()) { std::remove(tmp_cubin.c_str()); return; } + for (size_t i = 0; i < kernel_names.size(); i++) { + f << kernel_names[i] << "\t" << lowered_names[i] << "\n"; + } + f.close(); + if (f.fail()) { + std::remove(tmp_cubin.c_str()); + std::remove(tmp_names.c_str()); + return; } - names_file.close(); } - std::cerr << "[CUJITKernel] Cache SAVE: wrote " << kernel_names.size() - << " kernel(s) to " << cubin_path << std::endl; + // Atomic rename + if (std::rename(tmp_cubin.c_str(), cubin_path.c_str()) != 0) { + std::remove(tmp_cubin.c_str()); + std::remove(tmp_names.c_str()); + return; + } + if (std::rename(tmp_names.c_str(), names_path.c_str()) != 0) { + std::remove(tmp_names.c_str()); + // cubin already renamed; next load will fail on names → recompile. Safe. + } } - // ===== END NEW ===== + // =================== End Cache Helpers =================== public: string kernel_plaintext; + CUJITKernel(string plaintext) : kernel_plaintext(plaintext) { @@ -339,18 +429,18 @@ class __attribute__((visibility("default"))) CUJITKernel { nvrtcGetSupportedArchs(supported_archs.data())); NVRTC_SAFE_CALL( - nvrtcCreateProgram( &prog, // prog - kernel_plaintext.c_str(), // buffer - "kernel.cu", // name - 0, // numHeaders - NULL, // headers - NULL)); // includeNames + nvrtcCreateProgram( &prog, + kernel_plaintext.c_str(), + "kernel.cu", + 0, + NULL, + NULL)); } void compile(string kernel_name, const vector template_params, int opt_level=3) { - vector kernel_names = {kernel_name}; + vector kernel_names_local = {kernel_name}; vector> template_param_list = {template_params}; - compile(kernel_names, template_param_list); + compile(kernel_names_local, template_param_list, opt_level); } void compile(vector kernel_names_i, vector> template_param_list, int opt_level=3) { @@ -358,16 +448,16 @@ class __attribute__((visibility("default"))) CUJITKernel { cu_major = dp.major; cu_minor = dp.minor; - if(compiled) { + if (compiled) { throw std::logic_error("JIT object has already been compiled!"); } - if(kernel_names_i.size() != template_param_list.size()) { + if (kernel_names_i.size() != template_param_list.size()) { throw std::logic_error("Kernel names and template parameters must have the same size!"); } int device_arch = cu_major * 10 + cu_minor; - if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()){ + if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()) { int nvrtc_version_major, nvrtc_version_minor; NVRTC_SAFE_CALL( nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor)); @@ -381,20 +471,19 @@ class __attribute__((visibility("default"))) CUJITKernel { ); } - for(unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) { + // Step 1: Generate kernel names from template parameters + for (unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) { string kernel_name = kernel_names_i[kernel]; vector &template_params = template_param_list[kernel]; - if(template_params.size() == 0) { + if (template_params.size() == 0) { kernel_names.push_back(kernel_name); - } - else { + } else { std::string result = kernel_name + "<"; - for(unsigned int i = 0; i < template_params.size(); i++) { + for (unsigned int i = 0; i < template_params.size(); i++) { result += std::to_string(template_params[i]); - if(i != template_params.size() - 1) { + if (i != template_params.size() - 1) result += ","; - } } result += ">"; kernel_names.push_back(result); @@ -403,34 +492,30 @@ class __attribute__((visibility("default"))) CUJITKernel { std::string sm = "-arch=sm_" + std::to_string(cu_major) + std::to_string(cu_minor); - // ===== NEW: Try loading from cache before NVRTC compilation ===== - { - std::string cache_key = compute_cache_key(sm); + std::vector opts = { + "--std=c++17", + sm.c_str(), + "--split-compile=0", + "--use_fast_math" + }; + + // ===== Try loading from cache ===== + if (cache_enabled()) { + std::string cache_key = compute_cache_key(opts); std::string cache_dir = get_cache_dir(); std::string cubin_cache = cache_dir + "/" + cache_key + ".cubin"; std::string names_cache = cache_dir + "/" + cache_key + ".names"; if (try_load_from_cache(cubin_cache, names_cache)) { compiled = true; - return; // Skip NVRTC compilation entirely! + return; } - - std::cerr << "[CUJITKernel] Cache MISS for sm_" - << cu_major << cu_minor - << ", compiling with NVRTC..." << std::endl; } - // ===== END NEW ===== - - std::vector opts = { - "--std=c++17", - sm.c_str(), - "--split-compile=0", - "--use_fast_math" - }; + // ===== End cache check ===== // ========================================================= // Step 2: Add name expressions, compile - for(size_t i = 0; i < kernel_names.size(); ++i) + for (size_t i = 0; i < kernel_names.size(); ++i) NVRTC_SAFE_CALL(nvrtcAddNameExpression(prog, kernel_names[i].c_str())); nvrtcResult compileResult = nvrtcCompileProgram(prog, @@ -443,59 +528,58 @@ class __attribute__((visibility("default"))) CUJITKernel { NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log)); if (compileResult != NVRTC_SUCCESS) { - throw std::logic_error("NVRTC Fail, log: " + std::string(log)); + std::string log_str(log); + delete[] log; + throw std::logic_error("NVRTC Fail, log: " + log_str); } delete[] log; compiled = true; // ========================================================= - // Step 3: Get CUBIN, initialize device, context, and module + // Step 3: Get CUBIN, load module - NVRTC_SAFE_CALL(nvrtcGetCUBINSize(prog, &codeSize)); // ===== CHANGED: use member ===== + NVRTC_SAFE_CALL(nvrtcGetCUBINSize(prog, &codeSize)); code = new char[codeSize]; NVRTC_SAFE_CALL(nvrtcGetCUBIN(prog, code)); CUDA_SAFE_CALL(cuInit(0)); + ensure_cuda_context(); CUDA_SAFE_CALL(cuLibraryLoadData(&library, code, 0, 0, 0, 0, 0, 0)); for (size_t i = 0; i < kernel_names.size(); i++) { const char *name; - NVRTC_SAFE_CALL(nvrtcGetLoweredName( prog, kernel_names[i].c_str(), - &name - )); - - lowered_names.push_back(std::string(name)); // ===== NEW: store lowered name ===== + &name)); + lowered_names.push_back(std::string(name)); kernels.emplace_back(); CUDA_SAFE_CALL(cuLibraryGetKernel(&(kernels[i]), library, name)); } - // ===== NEW: Save to cache for future runs ===== - { - std::string cache_key = compute_cache_key(sm); + // ===== Save to cache ===== + if (cache_enabled()) { + std::string cache_key = compute_cache_key(opts); std::string cache_dir = get_cache_dir(); - std::string cubin_cache = cache_dir + "/" + cache_key + ".cubin"; - std::string names_cache = cache_dir + "/" + cache_key + ".names"; - save_to_cache(cubin_cache, names_cache); + save_to_cache(cache_dir + "/" + cache_key + ".cubin", + cache_dir + "/" + cache_key + ".names"); } - // ===== END NEW ===== + // ===== End save ===== } void set_max_smem(int kernel_id, uint32_t max_smem_bytes) { - if(!compiled) + if (!compiled) throw std::logic_error("JIT object has not been compiled!"); - if(kernel_id >= (int)kernels.size()) + if (kernel_id >= (int)kernels.size()) throw std::logic_error("Kernel index out of range!"); int device_count; CUDA_SAFE_CALL(cuDeviceGetCount(&device_count)); - for(int i = 0; i < device_count; i++) { + for (int i = 0; i < device_count; i++) { DeviceProp dp(i); - if(dp.major == cu_major && dp.minor == cu_minor) { + if (dp.major == cu_major && dp.minor == cu_minor) { CUdevice dev; CUDA_SAFE_CALL(cuDeviceGet(&dev, i)); CUDA_SAFE_CALL(cuKernelSetAttribute( @@ -508,13 +592,13 @@ class __attribute__((visibility("default"))) CUJITKernel { } void execute(int kernel_id, void* args[], KernelLaunchConfig config) { - if(kernel_id >= (int)kernels.size()) + if (kernel_id >= (int)kernels.size()) throw std::logic_error("Kernel index out of range!"); CUcontext pctx = NULL; CUDA_SAFE_CALL(cuCtxGetCurrent(&pctx)); - if(pctx == NULL) { + if (pctx == NULL) { int device_id; CUdevice dev; CUDA_ERRCHK(cudaGetDevice(&device_id)); @@ -533,12 +617,12 @@ class __attribute__((visibility("default"))) CUJITKernel { } ~CUJITKernel() { - if(compiled) { + if (compiled) { auto result = cuLibraryUnload(library); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { - std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl; + std::cout << "Failed to unload CUDA library, error code: " + << ((int) result) << std::endl; } - delete[] code; } NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));