Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -287,17 +287,41 @@ endif()

option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
# Check NCCL version for cuBLASMp compatibility
if(NOT DEFINED NCCL_VERSION)
message(FATAL_ERROR "NCCL_VERSION environment variable not set. NCCL 2.29.3+ is required for cuBLASMp support.")
Comment on lines +291 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 CMake variable vs environment variable mismatch

if(NOT DEFINED NCCL_VERSION) checks for a CMake cache variable (set with -DNCCL_VERSION=2.29.3), but the error message says "NCCL_VERSION environment variable not set". If a user reads this error and sets export NCCL_VERSION=2.29.3 in their shell, the error will persist because CMake does not automatically promote environment variables to CMake variables.

Furthermore, setup.py (the standard build path via NVTE_WITH_CUBLASMP=1 pip install .) never passes -DNCCL_VERSION to CMake — so this check will always fail when building via setup.py unless the user also provides NVTE_CMAKE_EXTRA_ARGS="-DNCCL_VERSION=X.Y.Z". This is not obvious and entirely undocumented.

To read an environment variable in CMake the code should use $ENV{NCCL_VERSION} and handle both sources:

  # Allow NCCL_VERSION to be supplied as either a -D CMake variable or an env var
  if(NOT DEFINED NCCL_VERSION AND DEFINED ENV{NCCL_VERSION})
    set(NCCL_VERSION "$ENV{NCCL_VERSION}")
  endif()
  if(NOT DEFINED NCCL_VERSION)
    message(FATAL_ERROR "NCCL_VERSION is not set. Pass -DNCCL_VERSION=X.Y.Z to CMake "
            "or set the NCCL_VERSION environment variable. NCCL 2.29.3+ is required for cuBLASMp support.")
  endif()

Additionally, setup.py should be updated to read NCCL_VERSION from the environment and forward it as a CMake flag alongside the other cuBLASMp flags (lines 74-79).

endif()

# Parse semantic version from NCCL_VERSION
string(REGEX MATCH "^([0-9]+)\\.([0-9]+)\\.([0-9]+)" NCCL_VERSION_MATCH "${NCCL_VERSION}")
if(NOT NCCL_VERSION_MATCH)
message(FATAL_ERROR "Invalid NCCL_VERSION format: ${NCCL_VERSION}. Expected format: X.Y.Z")
endif()

set(NCCL_VERSION_MAJOR "${CMAKE_MATCH_1}")
set(NCCL_VERSION_MINOR "${CMAKE_MATCH_2}")
set(NCCL_VERSION_PATCH "${CMAKE_MATCH_3}")

# Check if version is >= 2.29.3
if(NCCL_VERSION_MAJOR LESS 2 OR
(NCCL_VERSION_MAJOR EQUAL 2 AND NCCL_VERSION_MINOR LESS 29) OR
(NCCL_VERSION_MAJOR EQUAL 2 AND NCCL_VERSION_MINOR EQUAL 29 AND NCCL_VERSION_PATCH LESS 3))
message(FATAL_ERROR "NCCL 2.29.3+ is required for cuBLASMp tensor-parallel GEMMs, but found NCCL ${NCCL_VERSION}")
endif()

message(STATUS "NCCL version check passed: ${NCCL_VERSION}")

target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
endif()
Expand Down
Loading