Skip to content

[Multidevice] Add TMA bulk copy kernel and P2P transport option#6012

Merged
samnordmann merged 18 commits intomainfrom
tma_integration
Apr 1, 2026
Merged

[Multidevice] Add TMA bulk copy kernel and P2P transport option#6012
samnordmann merged 18 commits intomainfrom
tma_integration

Conversation

@samnordmann
Copy link
Copy Markdown
Collaborator

@samnordmann samnordmann commented Feb 25, 2026

  • Add a Hopper TMA (cp.async.bulk) copy kernel (csrc/multidevice/tma_copy.cu) compiled at runtime via NVRTC, and wire it as an alternative P2P data transport alongside the existing copy-engine (cudaMemcpyAsync) path.
  • Add P2pTransport option (NVFUSER_ENABLE=p2p_transport(tma)) that switches sendPost/recvPost in cuda_p2p.cpp between copy-engine (default) and TMA.
  • Note that using TMA transport from the hostis a bit artificial since the data needs to be artificially staged to shared memory. But it is still good to have it as a reference and still might actually be faster than other transports

@github-actions
Copy link
Copy Markdown

Description

  • Add Hopper TMA (Tensor Memory Accelerator) bulk copy kernel compiled at runtime via NVRTC

  • Add P2pTransport option (NVFUSER_ENABLE=p2p_transport(tma)) to switch between copy-engine and TMA

  • Integrate TMA transport into existing sendPost/recvPost functions in cuda_p2p.cpp

  • Simplify and improve test coverage for TMA copy across local, P2P, and multicast scenarios

Changes walkthrough

Relevant files
Enhancement
cuda_p2p.cpp
Add TMA kernel compilation and P2P transport integration 

csrc/multidevice/cuda_p2p.cpp

  • Added TMA copy kernel compilation and launch logic with NVRTC
  • Added getP2pTransport() function to read NVFUSER_ENABLE option
  • Modified recvPost() and sendPost() to conditionally use TMA vs
    copy-engine
  • Added operator<< for P2pTransport enum
  • Implemented chunked TMA copy handling for large transfers via shared
    memory
  • +142/-12
    cuda_p2p.h
    Add TMA transport declarations and enum                                   

    csrc/multidevice/cuda_p2p.h

  • Added P2pTransport enum with CopyEngine and Tma options
  • Added launchTmaCopy() function declaration
  • Added getP2pTransport() and operator<< declarations
  • +14/-2   
    Configuration changes
    options.cpp
    Register p2p_transport enable option                                         

    csrc/options.cpp

    • Added "p2p_transport" option mapping to EnableOption::P2pTransport
    +1/-0     
    options.h
    Add P2pTransport to enable options enum                                   

    csrc/options.h

    • Added P2pTransport to EnableOption enum
    +1/-0     
    Tests
    test_multidevice_tma.cpp
    Simplify TMA tests using production kernel launcher           

    tests/cpp/test_multidevice_tma.cpp

  • Removed NVRTC compilation helpers, now uses production launchTmaCopy()
  • Simplified tests for local device, P2P, and multicast TMA copy
    scenarios
  • Maintained test coverage for TMA bulk copy functionality
  • +10/-122

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Thread Safety

    The static module and kernel variables in launchTmaCopy() are not thread-safe. Multiple threads could simultaneously enter the initialization block, potentially causing race conditions during NVRTC compilation and CUDA module loading. Consider adding mutex protection or std::call_once for thread-safe initialization.

    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    
    if (module == nullptr) {
      nvrtcProgram prog;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
          &prog,
          nvfuser_resources::tma_copy_cu,
          "tma_copy.cu",
          0,
          nullptr,
          nullptr));
    
      int device = 0;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
      cudaDeviceProp prop;
      NVFUSER_CUDA_RT_SAFE_CALL(
          cudaGetDeviceProperties(&prop, device));
    
      NVF_CHECK(
          prop.major >= 9,
          "TMA transport requires Compute Capability >= 9.0 (Hopper+). "
          "Current device ",
          device,
          " is Compute Capability ",
          prop.major,
          ".",
          prop.minor);
    
      std::string arch_arg = "--gpu-architecture=compute_" +
          std::to_string(prop.major) + std::to_string(prop.minor);
      std::vector<const char*> opts = {
          arch_arg.c_str(), "--std=c++17"};
    
      nvrtcResult res =
          nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
      if (res != NVRTC_SUCCESS) {
        size_t logSize;
        NVFUSER_NVRTC_SAFE_CALL(
            nvrtcGetProgramLogSize(prog, &logSize));
        std::vector<char> log(logSize);
        NVFUSER_NVRTC_SAFE_CALL(
            nvrtcGetProgramLog(prog, log.data()));
        NVF_ERROR(
            false, "TMA kernel compilation failed:\n", log.data());
      }
    
      size_t ptxSize;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
      std::vector<char> ptx(ptxSize);
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
      NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleLoadData(&module, ptx.data()));
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
    }
    Resource Cleanup

    If NVRTC compilation fails (lines 398-407), the nvrtcProgram is destroyed but there's no cleanup path if cuModuleLoadData or cuModuleGetFunction fail. Consider adding proper error handling with nvrtcDestroyProgram in all error paths to prevent resource leaks.

    nvrtcResult res =
        nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
    if (res != NVRTC_SUCCESS) {
      size_t logSize;
      NVFUSER_NVRTC_SAFE_CALL(
          nvrtcGetProgramLogSize(prog, &logSize));
      std::vector<char> log(logSize);
      NVFUSER_NVRTC_SAFE_CALL(
          nvrtcGetProgramLog(prog, log.data()));
      NVF_ERROR(
          false, "TMA kernel compilation failed:\n", log.data());
    }
    
    size_t ptxSize;
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
    std::vector<char> ptx(ptxSize);
    NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
    NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
    NVFUSER_CUDA_SAFE_CALL(
        cuModuleLoadData(&module, ptx.data()));
    NVFUSER_CUDA_SAFE_CALL(
        cuModuleGetFunction(&kernel, module, "tma_copy_1d"));
    Performance Validation

    The PR lacks performance comparison data between TMA and copy engine transports. Consider adding benchmarking results or performance metrics to validate that TMA provides expected benefits over the default copy engine, especially for different transfer sizes and patterns.

          if (getP2pTransport() == P2pTransport::Tma) {
            launchTmaCopy(
                ipc_handles.local().ptr(),
                ipc_handles.peer().ptr(),
                count,
                stream);
          } else {
            NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
                ipc_handles.local().ptr(),
                ipc_handles.peer().ptr(),
                count,
                cudaMemcpyDeviceToDevice,
                stream));
          }
          // Signals completion
          WriteValue32ToLocalAndPeer(stream, ipc_handles, IpcSemaphore::kIdle);
          break;
        }
        case P2pProtocol::Put: {
          WriteValue32ToLocalAndPeer(
              stream, ipc_handles, IpcSemaphore::kInProgress);
          break;
        }
        default:
          NVF_ERROR("Invalid P2P protocol: ", protocol);
      }
    }
    
    void recvWait(const P2pIpcHandle& ipc_handles, CUstream stream) {
      P2pProtocol protocol = getP2pProtocol();
      switch (protocol) {
        case P2pProtocol::Put:
          NVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
              stream,
              reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
              (cuuint32_t)(IpcSemaphore::kIdle),
              CU_STREAM_WAIT_VALUE_EQ));
          break;
        case P2pProtocol::Get:
          break;
        default:
          NVF_ERROR("Invalid P2P protocol: ", protocol);
      }
    }
    
    void sendPost(const P2pIpcHandle& ipc_handles, int64_t count, CUstream stream) {
      P2pProtocol protocol = getP2pProtocol();
      switch (protocol) {
        case P2pProtocol::Get:
          // signal to self and peer that transfer is in progress
          WriteValue32ToLocalAndPeer(
              stream, ipc_handles, IpcSemaphore::kInProgress);
          break;
        case P2pProtocol::Put: {
          // wait for receiver to be ready
          NVFUSER_CUDA_SAFE_CALL(cuStreamWaitValue32(
              stream,
              reinterpret_cast<CUdeviceptr>(ipc_handles.local().semaphore()),
              (cuuint32_t)(IpcSemaphore::kInProgress),
              CU_STREAM_WAIT_VALUE_EQ));
          // Put the data to the receiver
          if (getP2pTransport() == P2pTransport::Tma) {
            launchTmaCopy(
                ipc_handles.peer().ptr(),
                ipc_handles.local().ptr(),
                count,
                stream);
          } else {
            NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
                ipc_handles.peer().ptr(),
                ipc_handles.local().ptr(),
                count,
                cudaMemcpyDeviceToDevice,
                stream));
          }

    @greptile-apps
    Copy link
    Copy Markdown
    Contributor

    greptile-apps bot commented Feb 25, 2026

    Greptile Summary

    This PR adds a Hopper TMA (cp.async.bulk) bulk-copy kernel compiled at runtime via NVRTC, and wires it as an optional P2P data transport (NVFUSER_ENABLE=p2p_transport(tma)) alongside the existing cudaMemcpyAsync copy-engine path. The kernel is refactored from a single-block design to a multi-block design, launching one block per ~48 KB chunk so that all chunks can execute concurrently across SMs. sendPost/recvPost in cuda_p2p.cpp now branch on getP2pTransport() to choose between the two transports.

    Key changes:

    • runtime/tma_copy.cu — kernel gains total_bytes/max_chunk params; per-block offset computed from blockIdx.x
    • csrc/multidevice/cuda_p2p.cpp — new launchTmaCopy() (lazily NVRTC-compiled, Hopper-gated) and getP2pTransport()
    • csrc/multidevice/cuda_p2p.hP2pTransport enum + declarations
    • csrc/options.{h,cpp}EnableOption::P2pTransport + \"p2p_transport\" env-var key
    • tests/cpp/test_multidevice_tma.cpp — duplicated local helper removed; tests now call launchTmaCopy directly

    Issues found:

    • P1launchTmaCopy casts size_t size to int total_bytes (line 479 of cuda_p2p.cpp). For transfers larger than INT_MAX (~2 GB) this is undefined behaviour; the kernel receives a wrong (potentially negative) byte count for the last block, silently corrupting data. The header comment promises "arbitrarily large sizes", so the parameter should be widened to int64_t in both the host launcher and the kernel.
    • P2 — The file-level comment in runtime/tma_copy.cu says integration into cuda_p2p is a "future PR" — this PR is that future PR; the comment is stale and should be updated.

    Confidence Score: 4/5

    Safe to merge for transfers ≤ 2 GB; the int truncation of total_bytes is a real data-corruption risk for larger payloads that should be fixed first.

    One confirmed P1 logic defect: static_cast<int>(size) overflows for sizes > INT_MAX, causing the kernel to copy a wrong number of bytes in the last block without any diagnostic. The function signature (size_t) and the header-level documentation both invite callers who can legitimately exceed 2 GB. Everything else — the multi-block refactor, the TMA PTX, the option wiring, and the test cleanup — looks correct.

    csrc/multidevice/cuda_p2p.cpp (line 479) and the matching int total_bytes parameter in runtime/tma_copy.cu need the intint64_t fix.

    Important Files Changed

    Filename Overview
    csrc/multidevice/cuda_p2p.cpp Adds launchTmaCopy (NVRTC-compiled, lazily initialized) and getP2pTransport(), and wires them into sendPost/recvPost; contains a P1 int truncation bug for transfers > 2 GB.
    runtime/tma_copy.cu Refactored to multi-block execution (new total_bytes/max_chunk params); kernel logic is correct but the header comment is stale and should be updated to reflect current integration status.
    csrc/multidevice/cuda_p2p.h Adds P2pTransport enum, getP2pTransport(), and launchTmaCopy() declarations; clean and consistent with existing style.
    csrc/options.h Adds P2pTransport to EnableOption enum with documentation comment; no issues.
    csrc/options.cpp Registers "p2p_transport" key in the option map; no issues.
    tests/cpp/test_multidevice_tma.cpp Removes local launchTmaCopy1D helper (now using the shared launchTmaCopy from cuda_p2p.h); tests are unchanged in coverage and correctness.

    Sequence Diagram

    sequenceDiagram
        participant Host
        participant GPU_Sender as GPU (Sender)
        participant GPU_Receiver as GPU (Receiver)
    
        Note over Host: getP2pTransport() == Tma?
    
        alt P2pProtocol::Put + TMA transport
            Host->>GPU_Sender: sendPost → launchTmaCopy(peer.ptr, local.ptr, count)
            Note over GPU_Sender: Block 0…N each handle one chunk
            GPU_Sender->>GPU_Sender: cp.async.bulk GMEM(local)→SMEM
            GPU_Sender->>GPU_Receiver: cp.async.bulk SMEM→GMEM(peer.ptr)
            Host->>GPU_Receiver: recvWait (cuStreamWaitValue32 kIdle)
        else P2pProtocol::Get + TMA transport
            Host->>GPU_Receiver: recvPost → launchTmaCopy(local.ptr, peer.ptr, count)
            Note over GPU_Receiver: Block 0…N each handle one chunk
            GPU_Receiver->>GPU_Sender: cp.async.bulk GMEM(peer.ptr)→SMEM
            GPU_Receiver->>GPU_Receiver: cp.async.bulk SMEM→GMEM(local.ptr)
            Host->>GPU_Sender: sendWait (cuStreamWaitValue32 kIdle)
        else CopyEngine transport (unchanged)
            Host->>GPU_Sender: cudaMemcpyAsync(peer.ptr, local.ptr)
        end
    
    Loading

    Comments Outside Diff (1)

    1. csrc/multidevice/cuda_p2p.cpp, line 479-482 (link)

      size silently truncated to int for transfers > 2 GB

      total_bytes = static_cast<int>(size) overflows (undefined behaviour in C++) when size > INT_MAX (~2 GB). The num_blocks calculation on line 482 correctly uses size_t, so the right number of blocks is launched, but the kernel receives a wrong (potentially negative) total_bytes. The last block then computes num_bytes = total_bytes - offset with bad arithmetic, producing a garbage copy-length and corrupting data silently.

      The header comment on launchTmaCopy says it "handles arbitrarily large sizes by chunking", which makes larger-than-2 GB callers a reasonable expectation. The kernel parameter should be widened to int64_t (or the host should add an explicit NVF_CHECK(size <= INT_MAX, ...) to fail fast until the kernel is updated).

      Note: the kernel signature (int total_bytes) must be updated to int64_t in runtime/tma_copy.cu as well, and the total_bytes - offset arithmetic there should stay as long long throughout.

    Reviews (12): Last reviewed commit: "remove int64_t" | Re-trigger Greptile

    Copy link
    Copy Markdown
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +362 to +363
    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    Copy link
    Copy Markdown
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Static initialization lacks thread-safety protection. Multiple threads calling launchTmaCopy concurrently could race on the module == nullptr check (line 365), causing duplicate compilations or accessing partially-initialized state.

    Other kernels in this file (launchAlltoallvKernel, launchMulticastKernel) have the same pattern. Consider adding mutex protection or using std::call_once for thread-safe lazy initialization if concurrent calls are possible.

    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    Base automatically changed from tma_p2p to main March 23, 2026 17:02
    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    src_bytes += chunk;
    remaining -= chunk;
    }
    }
    Copy link
    Copy Markdown
    Member

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Would it be better if it launched one kernel that managed these chunks versus multiple? I'm not sure if there's a clear performance gain by launching 1 vs N kernels.

    Copy link
    Copy Markdown
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Good point! You are right that having a loop inside the kernel is a much better solution. What is even better is to launch a single kernel with many blocks, so that SMs and TCs can work in parallel. I'm implementing the latter solution.
    A follow-up optimization would be to use SW pipelining and double-buffering to overlap the Tma loads and stores

    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @samnordmann samnordmann requested a review from wujingyue March 24, 2026 10:46
    Copy link
    Copy Markdown
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Some nits before I review the kernel

    @samnordmann samnordmann requested a review from wujingyue March 30, 2026 13:40
    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Copy Markdown
    Collaborator Author

    !test

    @samnordmann samnordmann merged commit 771eddd into main Apr 1, 2026
    57 of 58 checks passed
    @samnordmann samnordmann deleted the tma_integration branch April 1, 2026 09:56
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants