Skip to content

Cuda-graph capturable Dispatch and combine#6031

Open
samnordmann wants to merge 17 commits intomainfrom
dispatch_combine/remove_tcp_sync
Open

Cuda-graph capturable Dispatch and combine#6031
samnordmann wants to merge 17 commits intomainfrom
dispatch_combine/remove_tcp_sync

Conversation

@samnordmann
Copy link
Copy Markdown
Collaborator

@samnordmann samnordmann commented Mar 9, 2026

Replace TCPStore-based synchronization and CPU barriers in the kCuda backend of Dispatch / Combine, with a fully graph-capturable implementation:

  • Binary semaphore protocol and counts exchange via GPU rdma reads (replaces TCPStore)
  • Over-allocated recv buffers [C=T*R] to avoid data-dependent shapes. MoeCombine IR node carries num_tokens as an attribute to allocate the output (this could be removed when we support pre-allocated output buffers)
  • Cached symmetric memory handles (SymMemForAlltoallv) with static buffer allocation -- buffers are allocated and "rendezvous-ed" once and reused; re-allocation is NVF_CHECK-guarded because captured CUDA graphs hold baked pointers to these buffers
  • at::bincount replaced with scatter_add because bincount has hidden CPU-GPU syncs
  • New method SymmetricTensor::remotePointersTensor to pack all the remote pointers into a gpu buffer for device-initiated comms. Change signature ofalltoallvWithCudaBackend to account for that.
  • New test: DispatchCombineCudaGraphTest captures dispatch+combine
    into a CUDAGraph and exercises replay

The NCCL backend path is unchanged and not graph-capturable.

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR replaces TCPStore-based synchronization and CPU barriers in the kCuda backend of MoeDispatch / MoeCombine with a fully CUDA-graph-capturable implementation. The key changes are:

  • Binary semaphore protocol via SymMemForAlltoallv: a single-slot per-pair EQ+reset design (batchSignal(kInProgress)batchWait(EQ)batchReset(kIdle)) eliminates epoch counters and is correct for unlimited graph replays.
  • Over-allocated recv buffers [C = T×R] to avoid data-dependent shapes, with MoeCombine carrying num_tokens as a new IR attribute for output sizing.
  • Cached symmetric memory handles (SymMemForAlltoallv) with a static allocation model — buffers are allocated and rendezvous-ed once and reused; reallocation after first use is explicitly NVF_CHECK-guarded.
  • bincountscatter_add to remove a hidden CPU-GPU sync.
  • SymmetricTensor::remotePointersTensor() to pack remote pointers into a GPU buffer, enabling graph capture of the alltoallvWithCudaBackend call.
  • New DispatchCombineCudaGraphTest that captures dispatch+combine and replays 5 times.

The NCCL path is unchanged. Several previously-discussed concerns (semaphore FLUSH flags for NVLink visibility, global communicator singleton in SymMemForAlltoallv, max_send_bytes over-provisioning, missing recv_ptrs_gpu validation) remain open and are worth revisiting before merging.

Confidence Score: 3/5

Multiple open P1 concerns from prior rounds remain unaddressed; not safe to merge without resolving them.

The single-slot EQ+reset semaphore design is a correct fix for the epoch-baking problem from the prior round. However, (1) CU_STREAM_WRITE_VALUE_DEFAULT lacks FLUSH semantics for NVLink memory-model correctness, (2) SymMemForAlltoallv uses the global Communicator singleton ignoring the caller's communicator, (3) max_send_bytes over-provisions the alltoallv kernel grid, and (4) recv_ptrs_gpu has no size/device validation — all flagged as P1 in prior threads and still present.

csrc/multidevice/ipc_handle.cpp (flush flags, singleton communicator), csrc/multidevice/dispatch_combine.cpp (max_send_bytes, device cache key), csrc/multidevice/cuda_p2p.cpp (recv_ptrs_gpu validation)

Important Files Changed

Filename Overview
csrc/multidevice/ipc_handle.cpp New SymMemForAlltoallv class: alltoallv context with counts-exchange semaphores, completion barrier, and named recv-buffer cache. Multiple correctness concerns documented in previous threads (singleton communicator, 32-bit ops on 64-bit allocation, cache not validating dtype/extra_sizes).
csrc/multidevice/dispatch_combine.cpp Core rewrite to graph-capturable CUDA backend: replaces TCPStore sync + CPU bincount with GPU semaphores + scatter_add, caches symmetric memory via SymMemForAlltoallv, over-allocates recv buffers [C=T*R]. Previously flagged: missing max_send_bytes per-rank max, device key not part of cache key.
csrc/multidevice/ipc_handle.h New SymMemForAlltoallv class declaration: combines sync-buffer, semaphores, and named recv handles. Single-slot EQ+reset semaphore protocol is correct for unlimited replays.
csrc/multidevice/cuda_p2p.cpp alltoallvWithCudaBackend now accepts at::Tensor recv_ptrs_gpu instead of std::vector; divisibility NVF_CHECKs and elem_stride computation are preserved.
tests/cpp/test_multidevice_dispatch_combine.cpp New DispatchCombineCudaGraphTest captures dispatch+combine into a CUDA graph and replays 5 times; warms up symmetric memory caches before capture. Comment describing the semaphore protocol is inaccurate relative to the implementation.

Sequence Diagram

sequenceDiagram
    participant RankA as Rank A
    participant RankB as Rank B

    Note over RankA,RankB: Phase 1 – counts exchange (prepareAlltoallvMetadataGpu)
    RankA->>RankA: cudaMemcpyAsync(sync_buf ← send_counts)
    RankB->>RankB: cudaMemcpyAsync(sync_buf ← send_counts)
    RankA->>RankB: batchSignal(countsSem, kInProgress)
    RankB->>RankA: batchSignal(countsSem, kInProgress)
    RankA->>RankA: batchWait(countsSem, EQ kInProgress)
    RankB->>RankB: batchWait(countsSem, EQ kInProgress)
    RankA->>RankB: cudaMemcpyAsync P2P read syncRemotePtr(B)
    RankB->>RankA: cudaMemcpyAsync P2P read syncRemotePtr(A)
    RankA->>RankA: batchReset(countsSem, kIdle)
    RankB->>RankB: batchReset(countsSem, kIdle)

    Note over RankA,RankB: Phase 2 – alltoallv payload transfer
    RankA->>RankB: alltoallvWithCudaBackend (x, topk_idx, topk_weights, src_idx)
    RankB->>RankA: alltoallvWithCudaBackend (x, topk_idx, topk_weights, src_idx)

    Note over RankA,RankB: Phase 3 – completion barrier (doneBarrier)
    RankA->>RankB: batchSignal(doneSem, kInProgress)
    RankB->>RankA: batchSignal(doneSem, kInProgress)
    RankA->>RankA: batchWait(doneSem, EQ kInProgress)
    RankB->>RankB: batchWait(doneSem, EQ kInProgress)
    RankA->>RankA: batchReset(doneSem, kIdle)
    RankB->>RankB: batchReset(doneSem, kIdle)
Loading

Reviews (10): Last reviewed commit: "Merge branch 'main' of github.com:NVIDIA..." | Re-trigger Greptile

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

Comment on lines 978 to 979
const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Divisibility guard removed — silent wrong elem_stride if mismatched

The PR removes the checks:

NVF_CHECK(
    metadata.max_send_total == 0 ||
        send.numel() % metadata.max_send_total == 0, ...);
NVF_CHECK(
    metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, ...);

elem_stride is computed as send.numel() / metadata.max_send_total. If send.numel() is not divisible by max_send_total (e.g. because a caller passes mismatched metadata), integer truncation silently gives a wrong stride. Every send_offsets, send_counts, and recv_offsets is then scaled by this wrong value before being passed to the kernel, producing corrupted data without any error. The checks were cheap and provided essential diagnostic value; removing them for graph-capturability does not improve performance because they are CPU-side and never execute inside a captured region.

Comment on lines 415 to 430
const at::Tensor& src_idx,
const at::Tensor& n_tokens_to_rank,
const at::Tensor& n_tokens_from_rank,
int64_t num_tokens,
Communicator* communicator,
CommunicatorBackend backend) {
NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator.");
NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA.");
const bool has_topk_weights = topk_weights.numel() > 0;
if (has_topk_weights) {
NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA.");
NVF_CHECK(
topk_weights.is_floating_point(),
"Combine topk_weights must be floating point.");
NVF_CHECK(
topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) &&
topk_weights.size(1) == 1,
"topk_weights must be shape [T, 1], got: ",
topk_weights.sizes());
}
NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA.");
NVF_CHECK(
n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA.");
NVF_CHECK(
n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA.");
NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden].");
NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D.");
n_tokens_to_rank.is_cuda() && n_tokens_from_rank.is_cuda(),
"Combine count tensors must be on CUDA.");
NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D.");
NVF_CHECK_EQ(
src_idx.size(0), x.size(0), "src_idx size must match x first dimension.");
NVF_CHECK_EQ(
n_tokens_to_rank.numel(),
communicator->size(),
"n_tokens_to_rank must match world size.");
NVF_CHECK_EQ(
n_tokens_from_rank.numel(),
communicator->size(),
"n_tokens_from_rank must match world size.");

// Reconstruct source ranks from per-rank counts. alltoall_base concatenates
// received chunks in rank order, so this matches the receive layout.
auto src_rank = at::arange(
n_tokens_from_rank.numel(),
at::TensorOptions().dtype(at::kLong).device(x.device()))
.repeat_interleave(n_tokens_from_rank.to(at::kLong));
NVF_CHECK_EQ(
src_rank.size(0),
x.size(0),
"Reconstructed src_rank must match x first dimension.");
// Sort by source rank so alltoall can send contiguous chunks per rank.
auto sorted_indices = at::argsort(src_rank);
auto send_x = x.index_select(0, sorted_indices);
auto send_src_idx = src_idx.index_select(0, sorted_indices);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several input validations removed from doMoeCombine

The following checks that existed in the old code were removed and are not replaced:

  • NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D.") — a 2D src_idx would still pass the size(0) == x.size(0) check and cause a silent runtime error inside index_copy_.
  • NVF_CHECK_EQ(n_tokens_to_rank.numel(), communicator->size(), ...) and the equivalent for n_tokens_from_rank — without these, toSplitSizes (NCCL path) silently operates on a wrong-sized tensor, and prepareAlltoallvMetadataGpu (CUDA path) reads/writes W entries from a tensor that may have fewer.

These checks are CPU-side and do not execute during graph capture, so removing them provides no graph-capturability benefit.

Comment on lines 490 to 496

// Scatter by original token index to restore local order.
auto combined_x = at::empty({total_recv, hidden}, x.options());
combined_x.index_copy_(0, recv_src_idx, recv_x);
auto combined_x = at::zeros({num_tokens, hidden}, x.options());
combined_x.index_copy_(
0,
rs.buffer.narrow(0, 0, num_tokens),
rx.buffer.narrow(0, 0, num_tokens));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::zeros initialization of combined_x silently masks index errors

combined_x is initialized to all-zeros, then index_copy_ is expected to fill all num_tokens rows. If rs.buffer.narrow(0, 0, num_tokens) ever contains duplicate indices or misses some positions (due to a bug in the alltoallv or send_counts mismatch), the affected rows will silently be zero rather than triggering any error. Using at::empty plus an assertion that each index in [0, num_tokens) appears exactly once would make the contract explicit.

If zeros are intentional as a safety net, a brief comment explaining the guarantee (that all num_tokens positions are covered exactly once) would clarify the intent.

Comment on lines 356 to 370
ctx,
n_tokens_to_rank,
/*max_send_total=*/num_tokens,
/*max_send_bytes=*/num_tokens,
/*max_recv=*/capacity,
stream);
auto n_tokens_from_rank = metadata.recv_counts;
const int64_t total_recv = metadata.total_recv;
const int64_t max_recv = metadata.max_recv;

// Allocate symmetric buffers for send/recv payloads.
auto send_x_sym = SymmetricTensor::allocate(
{metadata.max_send_total, hidden}, x.scalar_type(), x.device());
send_x_sym.narrow(0, 0, num_tokens).copy_(send_x);
auto send_topk_idx_sym = SymmetricTensor::allocate(
{metadata.max_send_total, topk_idx.size(1)},

auto& rx = ctx.recv("x", capacity, {hidden}, x.scalar_type(), x.device());
auto& ri = ctx.recv(
"topk_idx",
capacity,
{topk_idx.size(1)},
topk_idx.scalar_type(),
x.device());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DispatchResult from CUDA backend returns full capacity buffers — callers must not narrow by recv_counts

recv_x, recv_topk_idx, recv_topk_weights, and recv_src_idx all have shape [C, ...] = [T*R, ...] with only the first V = sum(n_tokens_from_rank) rows valid. Any downstream operation that passes these tensors to kernel code treating the first dimension as the actual token count (e.g., a linear layer, softmax, or gradient accumulation) will silently process T*R - V garbage rows.

The header now documents this contract (DispatchResult comments), but the HostIrEvaluator that schedules the fusion kernel between dispatch and combine receives these over-allocated tensors with no enforcement that the kernel only operates on valid rows. Consider adding a runtime check in the evaluator path or document that any kernel scheduled between dispatch and combine must accept and correctly ignore padding rows.

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

Comment on lines +365 to +368
auto& entry = recv_entries_[name];
if (entry.sym && entry.cached_first_dim >= first_dim) {
return entry.handle;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache hit doesn't validate dtype or extra_sizes

The cache returns the existing RecvHandle whenever entry.cached_first_dim >= first_dim, but does not check that dtype or extra_sizes match the originally-allocated buffer. If the same logical name (e.g., "x") were ever called with a different scalar type or a different extra dimension (e.g., hidden size changes), the kernel would silently receive a buffer with the wrong element type or stride, leading to data corruption.

While the current callers always pass consistent types for a given name, this implicit contract is not enforced. A defensive check would make it explicit:

if (entry.sym && entry.cached_first_dim >= first_dim) {
  NVF_CHECK(
      entry.cached_dtype == dtype,
      "SymMemForAlltoallv::recv: buffer '", name,
      "' dtype mismatch (cached ", entry.cached_dtype, " vs requested ", dtype, ").");
  // similarly for extra_sizes
  return entry.handle;
}

At minimum, storing and asserting on dtype on every cache hit would catch mismatches early.

Comment on lines +338 to +356
SymMemForAlltoallv::SymMemForAlltoallv(
at::Device device,
const std::string& tag)
: tag_(tag) {
Communicator& comm = Communicator::getInstance();
world_size_ = comm.size();
my_rank_ = comm.deviceId();

sync_buf_ = SymmetricTensor::allocate({world_size_ + 2}, at::kLong, device);
sync_buf_.zero_();

sync_sym_ = std::make_unique<SymmetricTensor>(sync_buf_);
sync_sym_->setupRemoteHandles(tag + "_sync");

sync_ptrs_.resize(world_size_);
for (int64_t r = 0; r < world_size_; r++) {
sync_ptrs_[r] =
reinterpret_cast<CUdeviceptr>(sync_sym_->remoteTensor(r).data_ptr());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SymMemForAlltoallv uses global communicator singleton, ignoring the caller's communicator

The constructor captures world size and rank from Communicator::getInstance(), but both doMoeDispatch and doMoeCombine accept an arbitrary Communicator* communicator parameter. If a caller passes a non-singleton communicator (e.g., a sub-communicator for a subset of ranks), the SymMemForAlltoallv context will be initialised with the wrong world size and rank. This leads to incorrect sync-buffer layout (wrong number of semaphore/count slots), wrong send_count reads in prepareAlltoallvMetadataGpu (which iterates ctx.worldSize()), and potential out-of-bounds access into sync_ptrs_.

The getOrCreateAlltoallv helper does not accept a communicator, so there is currently no way for the caller's communicator to propagate down to the context. The communicator (or at least its size() and deviceId()) should be threaded through to the constructor.

Comment on lines +291 to +294
auto n_tokens_to_rank =
at::zeros({world_size}, gpu_long_opts)
.scatter_add(
0, rank_for_token_long, at::ones({num_tokens}, gpu_long_opts));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scatter_add silently accepts out-of-range indices — regression from bincount

rank_for_token_long = floor_divide(topk_idx_long, experts_per_rank). If any token's expert ID is invalid (e.g., >= num_experts), the resulting rank index will be >= world_size. PyTorch's scatter_add with an output of size world_size will silently write to a position outside the allocated tensor, corrupting adjacent memory and producing wrong n_tokens_to_rank counts without any diagnostic.

The old at::bincount(rank_for_token_cpu, {}, world_size) would have produced a tensor longer than world_size if values exceeded world_size - 1, which at least makes the anomaly detectable downstream (e.g., the earlier NVF_CHECK_EQ for tensor sizes would have fired). Adding an explicit bounds check before scatter_add preserves the error-detection property without a CPU-GPU sync:

NVF_CHECK(
    rank_for_token_long.min().item<int64_t>() >= 0 &&
    rank_for_token_long.max().item<int64_t>() < world_size,
    "topk_idx maps to an out-of-range rank.");

Note: this check would be a CPU-GPU sync and would need to be performed outside the captured graph region (e.g., as a one-time warmup assertion).

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

Comment on lines +346 to +356
sync_buf_ = SymmetricTensor::allocate({3 * world_size_}, at::kLong, device);
sync_buf_.zero_();

sync_sym_ = std::make_unique<SymmetricTensor>(sync_buf_);
sync_sym_->setupRemoteHandles(tag + "_sync");

sync_ptrs_.resize(world_size_);
for (int64_t r = 0; r < world_size_; r++) {
sync_ptrs_[r] =
reinterpret_cast<CUdeviceptr>(sync_sym_->remoteTensor(r).data_ptr());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync_buf_ allocated as int64 but semaphore slots use 32-bit stream ops

sync_buf_ is allocated with at::kLong (8 bytes/slot). The semaphore addresses are computed as sync_ptrs_[rank] + N * sizeof(int64_t), placing each semaphore 8 bytes apart. However, CU_STREAM_MEM_OP_WRITE_VALUE_32 and CU_STREAM_MEM_OP_WAIT_VALUE_32 operate on 4-byte quantities.

On little-endian NVIDIA hardware this happens to work — the 32-bit write goes to the lower 4 bytes of the 8-byte slot and the upper 4 bytes remain zero — but it is a type mismatch that relies on:

  1. The GPU being little-endian.
  2. The upper 4 bytes of each semaphore slot never being touched by any 64-bit op.

A more explicit design would allocate the semaphore region as at::kInt (or in a separate tensor), so that WRITE/WAIT_VALUE_32 addresses map directly to element boundaries. Alternatively, computing semaphore addresses with sizeof(int32_t) strides within a dedicated 32-bit allocation would make the intent and layout unambiguous.

Comment on lines +489 to +502
void SymMemForAlltoallv::doneBarrier(CUstream stream) {
batchSignal(
stream,
static_cast<cuuint32_t>(IpcSemaphore::kInProgress),
&SymMemForAlltoallv::doneSemAddr);
batchWait(
stream,
static_cast<cuuint32_t>(IpcSemaphore::kInProgress),
&SymMemForAlltoallv::doneSemAddr);
batchReset(
stream,
static_cast<cuuint32_t>(IpcSemaphore::kIdle),
&SymMemForAlltoallv::doneSemAddr);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doneBarrier resets own slots but write ordering across ranks is unspecified

The barrier sequence is:

  1. batchSignal(kInProgress) — write to peers' done_sem slots
  2. batchWait(kInProgress) — wait for peers to write to MY done_sem slots
  3. batchReset(kIdle) — reset MY done_sem slots

Step 3 resets doneSemAddr(my_rank, r) (my own memory). However, the batchReset uses CU_STREAM_WRITE_VALUE_DEFAULT. Between step 3 on rank A and step 1 of the NEXT replay on rank B, there is no formal guarantee that the reset write at A is visible before rank B's next signal arrives. Because CU_STREAM_WAIT_VALUE_EQ is used (not GEQ), correctness relies on the reset reaching the memory before the next iteration's batchSignal on the peer, which the doneBarrier's completion semantics guarantee (peers can only start the next replay after their own barrier completes). This is correct by the protocol ordering, but the dependency is non-obvious and would benefit from a short comment explaining why reset on one rank cannot race with the next signal from the peer.

Comment on lines 976 to 979
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");

auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());

const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing validation for recv_ptrs_gpu — no size or device check

The old call-site accepted const std::vector<void*>& recv_ptrs and explicitly verified:

NVF_CHECK(
    (int64_t)recv_ptrs.size() == metadata.world_size,
    "recv_ptrs size must match world size.");

It also coerced the pointer table to the send device via .to(send.device()).

The new at::Tensor recv_ptrs_gpu has neither check: if it has fewer than world_size entries the kernel silently reads garbage pointers; if it lives on the wrong device the launch will fault. remotePointersTensor() always produces a [world_size] tensor on the right device by construction, but the API contract is now implicit and fragile for any future caller. Consider adding:

Suggested change
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");
auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());
const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
recv_ptrs_gpu.is_cuda() && recv_ptrs_gpu.device() == send.device(),
"recv_ptrs_gpu must be a CUDA tensor on the same device as send.");
NVF_CHECK(
recv_ptrs_gpu.dim() == 1 &&
recv_ptrs_gpu.size(0) == metadata.world_size,
"recv_ptrs_gpu must have shape [world_size].");

Comment on lines +91 to +112

auto counts_matrix = at::empty({W, W}, gpu_opts);
for (int64_t r = 0; r < W; r++) {
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
counts_matrix[r].data_ptr<int64_t>(),
reinterpret_cast<void*>(a2av.syncRemotePtr(r)),
W * sizeof(int64_t),
cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream)));
}

a2av.resetCountsSem(stream);

auto recv_counts = counts_matrix.select(1, my_rank).contiguous();

auto send_offsets = at::zeros({W}, gpu_opts);
if (W > 1) {
send_offsets.narrow(0, 1, W - 1)
.copy_(send_counts.cumsum(0).narrow(0, 0, W - 1));
}

at::Tensor recv_offsets = my_rank > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Counts read from peers before waitCountsReady ensures NVLink visibility

The order in prepareAlltoallvMetadataGpu is correct (waitCountsReady precedes the cudaMemcpyAsync loop), so the architectural concern here is memory-model ordering rather than statement ordering.

signalCountsReady uses CU_STREAM_WRITE_VALUE_DEFAULT, which does not provide release/flush semantics for preceding device-memory writes visible over NVLink. Per CUDA documentation, CU_STREAM_WRITE_VALUE_FLUSH is required to guarantee that the cudaMemcpyAsync that copied send_counts into sync_buf_ (step 1) is visible to a remote peer that later observes the semaphore value via CU_STREAM_WAIT_VALUE_EQ.

Without the flush flag, the counts-ready signal from peer B could be seen by rank A before B's sync_buf_ write is observable through NVLink. In practice NVLink ordering has made this work, but it violates the documented memory model. The same issue applies to doneBarrier's batchSignal call (ipc_handle.cpp, the signalCountsReady and batchSignal helpers).

// Change in batchSignal:
ops[idx].writeValue.flags = CU_STREAM_WRITE_VALUE_FLUSH;  // release semantics

@x41lakazam
Copy link
Copy Markdown
Collaborator

Generally model have a capacity factor, built in their architecture, that caps how many tokens any expert can receive. Maybe it would be helpful to pass it as a parameter of doMoeDispatch so that max_recv_tokens becomes ceil(T * capacity_factor).

@samnordmann
Copy link
Copy Markdown
Collaborator Author

Generally model have a capacity factor, built in their architecture, that caps how many tokens any expert can receive. Maybe it would be helpful to pass it as a parameter of doMoeDispatch so that max_recv_tokens becomes ceil(T * capacity_factor).

definitely agree! Let me mark that as a TODO for next PR

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

@samnordmann
Copy link
Copy Markdown
Collaborator Author

The last commit addresses the many linter issues that appeared because of the CI update

@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

Copy link
Copy Markdown
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks -- it's important to cuda-graph capture.

// Scatter by original token index to restore local order.
auto combined_x = at::empty({total_recv, hidden}, x.options());
combined_x.index_copy_(0, recv_src_idx, recv_x);
auto combined_x = at::zeros({num_tokens, hidden}, x.options());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, Communications take a pre-allocated output:

auto* allocate =
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
. So the output buffer size (or an argument that's used to compute the output buffer size) should be somehow passed to Allocate instead of MoeCombine or MoeDispatch. What am I missing?

Copy link
Copy Markdown
Collaborator Author

@samnordmann samnordmann Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The allocation here takes place at runtime inside the dispatch/combine (i.e. these ops don't support pre-allocated buffers), to account for the purely dynamical flavor of MoE (which is non cuda graph capturable) where the allocation size is determined by routing.

Here, in the cuda graph capturable version, you are right that the allocation is static and therefore could be scooped out by making the ops support pre-allocated buffer. We should probably add that (now or in future PR), but this will probably require a different interface for each of the two flavors.

Wdyt ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The allocation here takes place at runtime inside the dispatch/combine

That's where I'm confused. IIUC, even for dispatch, the allocation of the output buffer should be pre-allocated to the max (i.e. min(n_experts_per_gpu,k)*s) for SOL. The dispatch op itself or the underlying kernel shouldn't allocate.

(I ask this because I would like to reuse MoeDispatch and MoeCombine for FusionExecutorCache and would like to understand where the code is going. When the end state is clear, I'm often flexible with intermediate states not being perfect.)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's where I'm confused. IIUC, even for dispatch, the allocation of the output buffer should be pre-allocated to the max (i.e. min(n_experts_per_gpu,k)*s) for SOL.

There exists many different implementations and flavors. I agree that for speed a static allocation is much better, however the memory footprint will be worse. That's a trade-off. Since speed + cuda-graph compatibility are my priorities I'm ok to make this assumption, i.e., make the allocation purely static and make the ops accept pre-allocated buffers.

I'm actually in favor of that solution. Or we can keep both with two API variants. Lmk what you think.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

however the memory footprint will be worse

Yes, that's the struggle. IIUC, when it comes to CUDA graph (the motivation of this PR), SOTA is to use static-size buffers with some other techniques to alleviate memory pressure, e.g. https://arxiv.org/html/2603.07685v2 section 4.3.7.

make the allocation purely static and make the ops accept pre-allocated buffers. I'm actually in favor of that solution.

Yes, I'm in favor of that as well. Do you plan to do so in this PR or separately?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the struggle. IIUC, when it comes to CUDA graph (the motivation of this PR), SOTA is to use static-size buffers with some other techniques to alleviate memory pressure, e.g. https://arxiv.org/html/2603.07685v2 section 4.3.7.

that's interesting. EP load balancing (Along with a capacity factor for dispatch/combine) is another technique.

Yes, I'm in favor of that as well. Do you plan to do so in this PR or separately?

I'd rather do it in the next PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Priya2698 I stopped reviewing the details because I expect the implementation to change a lot with pre-allocation. I personally don't mind merging this to enable CUDA graph and iterating on the implementation later. In the worst case, we can reuse the kernels but package them differently.

@wujingyue wujingyue requested a review from Priya2698 March 23, 2026 22:03
@samnordmann
Copy link
Copy Markdown
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants