Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 117 additions & 14 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
*/
// clang-format on
#include "multidevice/cuda_p2p.h"

#include <array>

#include "nvfuser_resources/alltoallv.h"
#include "nvfuser_resources/multicast.h"
#include "nvfuser_resources/multicast_reduce.h"
#include "nvfuser_resources/tma_copy.h"

#include "cuda_utils.h"
#include "multidevice/communication.h"
Expand Down Expand Up @@ -42,6 +46,22 @@ P2pProtocol getP2pProtocol() {
: P2pProtocol::Get;
}

std::ostream& operator<<(std::ostream& os, P2pTransport transport) {
switch (transport) {
case P2pTransport::CopyEngine:
return os << "CopyEngine";
case P2pTransport::Tma:
return os << "Tma";
}
std::unreachable();
}

P2pTransport getP2pTransport() {
return hasEnableOptionArgument(EnableOption::P2pTransport, "Tma")
? P2pTransport::Tma
: P2pTransport::CopyEngine;
}

void launchMulticastReduceKernel(
const void* mc_src,
void* dst,
Expand Down Expand Up @@ -103,7 +123,6 @@ void launchAlltoallvKernel(
std::string arch_arg = "--gpu-architecture=compute_" +
std::to_string(major) + std::to_string(minor);
std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"};
// NVRTC needs CUDA headers to compile alltoallv.cu.
opts.push_back("-I/usr/local/cuda/include");
opts.push_back("-I/usr/local/cuda/include/cccl");

Expand Down Expand Up @@ -291,7 +310,6 @@ void launchMulticastKernel(
CUresult load_result = cuModuleLoadData(&module, ptx.data());

if (load_result != CUDA_SUCCESS) {
// Fallback to extensive logging only on failure
constexpr size_t kLogSize = 8192;
std::array<char, kLogSize> error_log{};
std::array<char, kLogSize> info_log{};
Expand Down Expand Up @@ -397,6 +415,81 @@ void launchMulticastKernel(
nullptr));
}

} // anonymous namespace

void launchTmaCopy(void* dst, const void* src, size_t size, CUstream stream) {
static CUmodule module = nullptr;
static CUfunction kernel = nullptr;
Comment on lines +421 to +422
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static initialization lacks thread-safety protection. Multiple threads calling launchTmaCopy concurrently could race on the module == nullptr check (line 365), causing duplicate compilations or accessing partially-initialized state.

Other kernels in this file (launchAlltoallvKernel, launchMulticastKernel) have the same pattern. Consider adding mutex protection or using std::call_once for thread-safe lazy initialization if concurrent calls are possible.


if (module == nullptr) {
nvrtcProgram prog = nullptr;
NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
&prog,
nvfuser_resources::tma_copy_cu,
"tma_copy.cu",
0,
nullptr,
nullptr));

int device = 0;
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
cudaDeviceProp prop{};
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device));

NVF_CHECK(
prop.major >= 9,
"TMA transport requires Compute Capability >= 9.0 (Hopper+). "
"Current device ",
device,
" is Compute Capability ",
prop.major,
".",
prop.minor);

std::string arch_arg = "--gpu-architecture=compute_" +
std::to_string(prop.major) + std::to_string(prop.minor);
std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"};

nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
if (res != NVRTC_SUCCESS) {
size_t logSize = 0;
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize));
std::vector<char> log(logSize);
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data()));
NVF_ERROR(false, "TMA kernel compilation failed:\n", log.data());
}

size_t ptxSize = 0;
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
std::vector<char> ptx(ptxSize);
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));

NVFUSER_CUDA_SAFE_CALL(cuModuleLoadData(&module, ptx.data()));
NVFUSER_CUDA_SAFE_CALL(cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
}

NVF_CHECK(
size % 16 == 0, "TMA requires size (", size, ") to be a multiple of 16");

constexpr int kDefaultSmem = 48 * 1024;
constexpr int kMbarrierBytes = 8;
constexpr int kMaxChunk = ((kDefaultSmem - kMbarrierBytes) / 16) * 16;

int total_bytes = static_cast<int>(size);
int max_chunk = kMaxChunk;
unsigned int num_blocks =
static_cast<unsigned int>((size + kMaxChunk - 1) / kMaxChunk);
int smem_size = kMaxChunk + static_cast<int>(sizeof(uint64_t));

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,bugprone-multi-level-implicit-pointer-conversion)
void* args[] = {&dst, &src, &total_bytes, &max_chunk};
NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel(
kernel, num_blocks, 1, 1, 32, 1, 1, smem_size, stream, args, nullptr));
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better if it launched one kernel that managed these chunks versus multiple? I'm not sure if there's a clear performance gain by launching 1 vs N kernels.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! You are right that having a loop inside the kernel is a much better solution. What is even better is to launch a single kernel with many blocks, so that SMs and TCs can work in parallel. I'm implementing the latter solution.
A follow-up optimization would be to use SW pipelining and double-buffering to overlap the Tma loads and stores


namespace {

// We choose duplicate the state of the semaphore on both the local and peer
// devices to avoid cuStreamWaitValue32 to poll on a remote buffer and pollutes
// the network. This is a theoretical consideration that we have not proved or
Expand Down Expand Up @@ -1132,12 +1225,17 @@ void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
(cuuint32_t)(IpcSemaphore::kInProgress),
CU_STREAM_WAIT_VALUE_EQ));
// Get the data from the sender
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
if (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.local().ptr(), ipc_handles.peer().ptr(), count, stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.local().ptr(),
ipc_handles.peer().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
// Signals completion
WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
break;
Expand Down Expand Up @@ -1185,12 +1283,17 @@ void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
(cuuint32_t)(IpcSemaphore::kInProgress),
CU_STREAM_WAIT_VALUE_EQ));
// Put the data to the receiver
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
if (getP2pTransport() == P2pTransport::Tma) {
launchTmaCopy(
ipc_handles.peer().ptr(), ipc_handles.local().ptr(), count, stream);
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
ipc_handles.peer().ptr(),
ipc_handles.local().ptr(),
count,
cudaMemcpyDeviceToDevice,
stream));
}
WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
break;
}
Expand Down
11 changes: 11 additions & 0 deletions csrc/multidevice/cuda_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ P2pProtocol getP2pProtocol();

std::ostream& operator<<(std::ostream& os, P2pProtocol protocol);

enum class P2pTransport : std::uint8_t { CopyEngine, Tma };

P2pTransport getP2pTransport();

std::ostream& operator<<(std::ostream& os, P2pTransport transport);

//! TMA 1D bulk copy: GMEM(src) -> SMEM -> GMEM(dst).
//! Compiled at runtime via NVRTC from runtime/tma_copy.cu.
//! Handles arbitrarily large sizes by chunking to fit shared memory.
void launchTmaCopy(void* dst, const void* src, size_t size, CUstream stream);

void recvPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream);

void recvWait(const P2pIpcHandle& ipc_handles, CUstream stream);
Expand Down
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"insert_resharding_after", EnableOption::InsertReshardingAfter},
{"fast_math", EnableOption::FastMath},
{"p2p_protocol", EnableOption::P2pProtocol},
{"p2p_transport", EnableOption::P2pTransport},
{"multicast_protocol", EnableOption::MulticastProtocol},
};
return available_options;
Expand Down
2 changes: 2 additions & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ enum class EnableOption : std::uint8_t {
InsertReshardingAfter, //! Insert resharding set after the expression
FastMath, //! Enable fast math optimizations (--use_fast_math)
P2pProtocol, //! Prescribe P2P protocol: put|get
P2pTransport, //! Prescribe P2P data transport: CopyEngine|Tma (default:
//! CopyEngine)
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
EndOfOption //! Placeholder for counting the number of elements
Expand Down
28 changes: 20 additions & 8 deletions runtime/tma_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
//
// GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)
//
// A single elected thread (thread 0) drives both phases:
// The host launches ceil(total_bytes / max_chunk) blocks. Each block
// copies one chunk of the data, so all chunks can execute concurrently
// across SMs. Thread 0 in each block drives both TMA phases:
// 1. mbarrier.init (arrival count = 1)
// 2. mbarrier.arrive.expect_tx (announce expected bytes)
// 3. cp.async.bulk.shared::cluster.global (TMA load, async)
Expand All @@ -35,11 +37,25 @@
// Dynamic shared memory layout (128-byte aligned):
// [0, num_bytes) : staging buffer
// [num_bytes, num_bytes+8) : mbarrier (uint64_t)
//
// TODO: Proposition C — multi-stage TMA pipelining with
// double-buffered shared memory could further improve throughput
// by overlapping TMA loads and stores within each block. Explore
// if profiling shows TMA engine utilization is a bottleneck.

extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
void* __restrict__ dst,
const void* __restrict__ src,
int num_bytes) {
int total_bytes,
int max_chunk) {
long long offset = static_cast<long long>(blockIdx.x) * max_chunk;
int num_bytes = static_cast<int>(
min(static_cast<long long>(max_chunk),
static_cast<long long>(total_bytes) - offset));

const char* block_src = static_cast<const char*>(src) + offset;
char* block_dst = static_cast<char*>(dst) + offset;

extern __shared__ __align__(128) unsigned char smem[];

unsigned long long* mbar =
Expand All @@ -57,23 +73,20 @@ extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
__syncwarp();

if (threadIdx.x == 0) {
// Announce expected transaction bytes on the mbarrier
asm volatile(
"mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"(
mbar_addr),
"r"(num_bytes));

// TMA Load: GMEM -> SMEM (async, completed via mbarrier)
asm volatile(
"cp.async.bulk.shared::cluster.global"
".mbarrier::complete_tx::bytes"
" [%0], [%1], %2, [%3];\n" ::"r"(smem_addr),
"l"(src),
"l"(block_src),
"r"(num_bytes),
"r"(mbar_addr)
: "memory");

// Block until the mbarrier phase flips (TMA load completed)
asm volatile(
"{\n"
".reg .pred P1;\n"
Expand All @@ -86,10 +99,9 @@ extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
"}" ::"r"(mbar_addr),
"r"(0));

// TMA Store: SMEM -> GMEM
asm volatile(
"cp.async.bulk.global.shared::cta.bulk_group"
" [%0], [%1], %2;\n" ::"l"(dst),
" [%0], [%1], %2;\n" ::"l"(block_dst),
"r"(smem_addr),
"r"(num_bytes)
: "memory");
Expand Down
Loading
Loading