-
Notifications
You must be signed in to change notification settings - Fork 80
[Multidevice] Add TMA bulk copy kernel and P2P transport option #6012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c2941a0
64c6162
a852aad
c32c130
ae0c760
13f72a5
52c2258
9b95f48
9e280a6
80b005e
6ffe84c
284211c
b4b0628
8485178
097beac
ff0ad0a
4b484a8
bcfe2bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,9 +6,13 @@ | |
| */ | ||
| // clang-format on | ||
| #include "multidevice/cuda_p2p.h" | ||
|
|
||
| #include <array> | ||
|
|
||
| #include "nvfuser_resources/alltoallv.h" | ||
| #include "nvfuser_resources/multicast.h" | ||
| #include "nvfuser_resources/multicast_reduce.h" | ||
| #include "nvfuser_resources/tma_copy.h" | ||
|
|
||
| #include "cuda_utils.h" | ||
| #include "multidevice/communication.h" | ||
|
|
@@ -42,6 +46,22 @@ P2pProtocol getP2pProtocol() { | |
| : P2pProtocol::Get; | ||
| } | ||
|
|
||
| std::ostream& operator<<(std::ostream& os, P2pTransport transport) { | ||
| switch (transport) { | ||
| case P2pTransport::CopyEngine: | ||
| return os << "CopyEngine"; | ||
| case P2pTransport::Tma: | ||
| return os << "Tma"; | ||
| } | ||
| std::unreachable(); | ||
| } | ||
|
|
||
| P2pTransport getP2pTransport() { | ||
| return hasEnableOptionArgument(EnableOption::P2pTransport, "Tma") | ||
| ? P2pTransport::Tma | ||
| : P2pTransport::CopyEngine; | ||
| } | ||
|
|
||
| void launchMulticastReduceKernel( | ||
| const void* mc_src, | ||
| void* dst, | ||
|
|
@@ -103,7 +123,6 @@ void launchAlltoallvKernel( | |
| std::string arch_arg = "--gpu-architecture=compute_" + | ||
| std::to_string(major) + std::to_string(minor); | ||
| std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"}; | ||
| // NVRTC needs CUDA headers to compile alltoallv.cu. | ||
| opts.push_back("-I/usr/local/cuda/include"); | ||
| opts.push_back("-I/usr/local/cuda/include/cccl"); | ||
|
|
||
|
|
@@ -291,7 +310,6 @@ void launchMulticastKernel( | |
| CUresult load_result = cuModuleLoadData(&module, ptx.data()); | ||
|
|
||
| if (load_result != CUDA_SUCCESS) { | ||
| // Fallback to extensive logging only on failure | ||
| constexpr size_t kLogSize = 8192; | ||
| std::array<char, kLogSize> error_log{}; | ||
| std::array<char, kLogSize> info_log{}; | ||
|
|
@@ -397,6 +415,81 @@ void launchMulticastKernel( | |
| nullptr)); | ||
| } | ||
|
|
||
| } // anonymous namespace | ||
|
|
||
| void launchTmaCopy(void* dst, const void* src, size_t size, CUstream stream) { | ||
| static CUmodule module = nullptr; | ||
| static CUfunction kernel = nullptr; | ||
|
|
||
| if (module == nullptr) { | ||
| nvrtcProgram prog = nullptr; | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( | ||
| &prog, | ||
| nvfuser_resources::tma_copy_cu, | ||
| "tma_copy.cu", | ||
| 0, | ||
| nullptr, | ||
| nullptr)); | ||
|
|
||
| int device = 0; | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); | ||
| cudaDeviceProp prop{}; | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); | ||
|
|
||
| NVF_CHECK( | ||
| prop.major >= 9, | ||
| "TMA transport requires Compute Capability >= 9.0 (Hopper+). " | ||
| "Current device ", | ||
| device, | ||
| " is Compute Capability ", | ||
| prop.major, | ||
| ".", | ||
| prop.minor); | ||
|
|
||
| std::string arch_arg = "--gpu-architecture=compute_" + | ||
| std::to_string(prop.major) + std::to_string(prop.minor); | ||
| std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"}; | ||
|
|
||
| nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); | ||
| if (res != NVRTC_SUCCESS) { | ||
| size_t logSize = 0; | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); | ||
| std::vector<char> log(logSize); | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); | ||
| NVF_ERROR(false, "TMA kernel compilation failed:\n", log.data()); | ||
| } | ||
|
|
||
| size_t ptxSize = 0; | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); | ||
| std::vector<char> ptx(ptxSize); | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); | ||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); | ||
|
|
||
| NVFUSER_CUDA_SAFE_CALL(cuModuleLoadData(&module, ptx.data())); | ||
| NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(&kernel, module, "tma_copy_1d")); | ||
| } | ||
|
|
||
| NVF_CHECK( | ||
| size % 16 == 0, "TMA requires size (", size, ") to be a multiple of 16"); | ||
|
|
||
| constexpr int kDefaultSmem = 48 * 1024; | ||
| constexpr int kMbarrierBytes = 8; | ||
| constexpr int kMaxChunk = ((kDefaultSmem - kMbarrierBytes) / 16) * 16; | ||
|
|
||
| int total_bytes = static_cast<int>(size); | ||
| int max_chunk = kMaxChunk; | ||
| unsigned int num_blocks = | ||
| static_cast<unsigned int>((size + kMaxChunk - 1) / kMaxChunk); | ||
| int smem_size = kMaxChunk + static_cast<int>(sizeof(uint64_t)); | ||
|
|
||
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,bugprone-multi-level-implicit-pointer-conversion) | ||
| void* args[] = {&dst, &src, &total_bytes, &max_chunk}; | ||
| NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( | ||
| kernel, num_blocks, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr)); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better if it launched one kernel that managed these chunks versus multiple? I'm not sure if there's a clear performance gain by launching 1 vs N kernels.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! You are right that having a loop inside the kernel is a much better solution. What is even better is to launch a single kernel with many blocks, so that SMs and TCs can work in parallel. I'm implementing the latter solution. |
||
|
|
||
| namespace { | ||
|
|
||
| // We choose duplicate the state of the semaphore on both the local and peer | ||
| // devices to avoid cuStreamWaitValue32 to poll on a remote buffer and pollutes | ||
| // the network. This is a theoretical consideration that we have not proved or | ||
|
|
@@ -1132,12 +1225,17 @@ void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) { | |
| (cuuint32_t)(IpcSemaphore::kInProgress), | ||
| CU_STREAM_WAIT_VALUE_EQ)); | ||
| // Get the data from the sender | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( | ||
| ipc_handles.local().ptr(), | ||
| ipc_handles.peer().ptr(), | ||
| count, | ||
| cudaMemcpyDeviceToDevice, | ||
| stream)); | ||
| if (getP2pTransport() == P2pTransport::Tma) { | ||
| launchTmaCopy( | ||
| ipc_handles.local().ptr(), ipc_handles.peer().ptr(), count, stream); | ||
| } else { | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( | ||
| ipc_handles.local().ptr(), | ||
| ipc_handles.peer().ptr(), | ||
| count, | ||
| cudaMemcpyDeviceToDevice, | ||
| stream)); | ||
| } | ||
| // Signals completion | ||
| WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle); | ||
| break; | ||
|
|
@@ -1185,12 +1283,17 @@ void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) { | |
| (cuuint32_t)(IpcSemaphore::kInProgress), | ||
| CU_STREAM_WAIT_VALUE_EQ)); | ||
| // Put the data to the receiver | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( | ||
| ipc_handles.peer().ptr(), | ||
| ipc_handles.local().ptr(), | ||
| count, | ||
| cudaMemcpyDeviceToDevice, | ||
| stream)); | ||
| if (getP2pTransport() == P2pTransport::Tma) { | ||
| launchTmaCopy( | ||
| ipc_handles.peer().ptr(), ipc_handles.local().ptr(), count, stream); | ||
| } else { | ||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( | ||
| ipc_handles.peer().ptr(), | ||
| ipc_handles.local().ptr(), | ||
| count, | ||
| cudaMemcpyDeviceToDevice, | ||
| stream)); | ||
| } | ||
| WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle); | ||
| break; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static initialization lacks thread-safety protection. Multiple threads calling
launchTmaCopyconcurrently could race on themodule == nullptrcheck (line 365), causing duplicate compilations or accessing partially-initialized state.Other kernels in this file (
launchAlltoallvKernel,launchMulticastKernel) have the same pattern. Consider adding mutex protection or usingstd::call_oncefor thread-safe lazy initialization if concurrent calls are possible.