From 764ea971b74431ae62f6eef377ad1e2054370e73 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:55:49 -0700 Subject: [PATCH] Update CMakeLists.txt Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 46 ++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..f5625da22f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -287,17 +287,41 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) - target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) - find_library(CUBLASMP_LIB - NAMES cublasmp libcublasmp - PATHS ${CUBLASMP_DIR} - PATH_SUFFIXES lib - REQUIRED) - find_library(NCCL_LIB - NAMES nccl libnccl - PATH_SUFFIXES lib - REQUIRED) + # Check NCCL version for cuBLASMp compatibility + if(NOT DEFINED NCCL_VERSION) + message(FATAL_ERROR "NCCL_VERSION environment variable not set. NCCL 2.29.3+ is required for cuBLASMp support.") + endif() + + # Parse semantic version from NCCL_VERSION + string(REGEX MATCH "^([0-9]+)\\.([0-9]+)\\.([0-9]+)" NCCL_VERSION_MATCH "${NCCL_VERSION}") + if(NOT NCCL_VERSION_MATCH) + message(FATAL_ERROR "Invalid NCCL_VERSION format: ${NCCL_VERSION}. Expected format: X.Y.Z") + endif() + + set(NCCL_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(NCCL_VERSION_MINOR "${CMAKE_MATCH_2}") + set(NCCL_VERSION_PATCH "${CMAKE_MATCH_3}") + + # Check if version is >= 2.29.3 + if(NCCL_VERSION_MAJOR LESS 2 OR + (NCCL_VERSION_MAJOR EQUAL 2 AND NCCL_VERSION_MINOR LESS 29) OR + (NCCL_VERSION_MAJOR EQUAL 2 AND NCCL_VERSION_MINOR EQUAL 29 AND NCCL_VERSION_PATCH LESS 3)) + message(FATAL_ERROR "NCCL 2.29.3+ is required for cuBLASMp tensor-parallel GEMMs, but found NCCL ${NCCL_VERSION}") + endif() + + message(STATUS "NCCL version check passed: ${NCCL_VERSION}") + + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") endif()