Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
14fd212
Initial implementation of symmetric memory backend for PyTorch
saivishal1999 Feb 27, 2026
5646c03
Initital changes to add pytorch symmetric memory backend
saivishal1999 Feb 27, 2026
14816aa
Initial pytorch symmetric memory backend changes
saivishal1999 Mar 2, 2026
6996d05
Merge branch 'main' into symmetric-memory-pytorch-backends
nsarka Mar 3, 2026
49d669c
Initial review comments
saivishal1999 Mar 9, 2026
8962475
Alloc, rendezvous passing
saivishal1999 Mar 16, 2026
62c6945
Merge branch 'main' into symmetric-memory-pytorch-backends
saivishal1999 Mar 16, 2026
67181c8
multicast pending
saivishal1999 Mar 17, 2026
eea57d8
all backends passing
saivishal1999 Mar 20, 2026
a9ddffd
delete build file
saivishal1999 Mar 20, 2026
f9cac71
Merge branch 'main' into symmetric-memory-pytorch-backends
saivishal1999 Mar 20, 2026
8e62ccc
Lint errors and review comments
saivishal1999 Mar 24, 2026
1be0134
fix 3 lint errors
saivishal1999 Mar 24, 2026
3596301
Fix clang-tidy errors
saivishal1999 Mar 25, 2026
9b05915
Fixing outdated lint errors
saivishal1999 Mar 25, 2026
6147139
Add torch distributed gaurd
saivishal1999 Mar 25, 2026
b5a2418
Merge branch 'main' into symmetric-memory-pytorch-backends
saivishal1999 Mar 25, 2026
af128e4
Address pending review cmnts
saivishal1999 Apr 3, 2026
828573d
Add mocks for c10d
saivishal1999 Apr 10, 2026
294a867
Merge branch 'main' into symmetric-memory-pytorch-backends
saivishal1999 Apr 10, 2026
6a5d3c3
Fix missing guard for process_groups
saivishal1999 Apr 10, 2026
2908e70
include mock header for non distributed build
saivishal1999 Apr 10, 2026
aa4f9c1
Remove guard comments
saivishal1999 Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion csrc/multidevice/c10d_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@

#pragma once

#include <ATen/ATen.h>
#include <ATen/core/TensorBody.h>
#include <ATen/core/ivalue.h>
#include <c10/core/Device.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/intrusive_ptr.h>

namespace c10d {
Expand All @@ -34,7 +38,7 @@ class Work : public torch::CustomClassHolder {
};

struct ReduceOp : torch::CustomClassHolder {
enum RedOpType {
enum RedOpType : std::uint8_t {
SUM,
AVG,
PRODUCT,
Expand Down Expand Up @@ -211,4 +215,58 @@ class TCPStore : public torch::CustomClassHolder {
}
};

class ProcessGroup : public torch::CustomClassHolder {
public:
};

inline c10::intrusive_ptr<ProcessGroup> resolve_process_group(
const std::string& group_name) {
return c10::make_intrusive<ProcessGroup>();
}

inline void register_process_group(
const std::string& group_name,
const c10::intrusive_ptr<ProcessGroup>& group) {}

inline void unregister_process_group(const std::string& group_name) {}

} // namespace c10d

namespace c10d::symmetric_memory {

class SymmetricMemory : public torch::CustomClassHolder {
public:
~SymmetricMemory() override = default;
virtual bool has_multicast_support() {
return false;
}
virtual void* get_multicast_ptr() {
return nullptr;
}
at::Tensor get_remote_tensor(
int peer,
c10::IntArrayRef sizes,
c10::ScalarType dtype) {
return at::empty(sizes, at::TensorOptions().dtype(dtype));
}
};

inline void set_backend(const std::string&) {}

inline at::Tensor empty_strided_p2p(
c10::IntArrayRef size,
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::optional<std::string>& group_name,
std::optional<uint64_t> alloc_id) {
return at::empty(size, at::TensorOptions().dtype(dtype));
}

inline c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor,
const std::optional<std::string>& group_name = std::nullopt) {
return c10::make_intrusive<SymmetricMemory>();
}

} // namespace c10d::symmetric_memory
55 changes: 45 additions & 10 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <numeric>

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/exception.h>
#ifdef USE_C10D_NCCL
Expand Down Expand Up @@ -121,7 +122,8 @@ bool parseEnv(
}

// retrieves master port
if ((env = std::getenv("NVFUSER_MASTER_PORT")) != nullptr) {
env = std::getenv("NVFUSER_MASTER_PORT");
if (env != nullptr) {
master_port = std::atoi(env);
} else {
LOG(INFO) << "The environment variable NVFUSER_MASTER_PORT has not been "
Expand Down Expand Up @@ -248,10 +250,10 @@ void waitForDebuggerAtRanks(
std::cerr << "Process " << pid
<< " is waiting for the debugger. To continue debugging, "
<< "start gdb, `attach " << pid
<< "`, `set var waiting=false`, and `fini`." << std::endl;
<< "`, `set var waiting=false`, and `fini`.\n";
while (waiting) { // Please change `waiting` in the debugger.
}
std::cerr << "Process " << getpid() << " finished waiting." << std::endl;
std::cerr << "Process " << getpid() << " finished waiting.\n";
}

if (communicator->is_available()) {
Expand Down Expand Up @@ -331,6 +333,13 @@ Communicator& Communicator::getInstance() {
return *communicator;
}

void Communicator::registerProcessGroup(
const std::string& name,
const c10::intrusive_ptr<c10d::ProcessGroup>& pg) {
c10d::register_process_group(name, pg);
process_groups_[name] = pg;
}

void Communicator::cleanup() {
static bool cleaned_up = false;
NVF_CHECK(
Expand All @@ -349,19 +358,25 @@ void Communicator::cleanup() {

store_ = nullptr;

#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
#if defined(NVFUSER_DISTRIBUTED)
#if defined(USE_C10D_NCCL)
// Sort backends to work around a NCCL bug (nvbugs/4889623). Closing backends
// in different orders between ranks have been causing a hang.
std::vector<std::pair<std::string, c10::intrusive_ptr<c10d::Backend>>>
keyed_backends(backends_.begin(), backends_.end());
std::sort(keyed_backends.begin(), keyed_backends.end());
std::ranges::sort(keyed_backends.begin(), keyed_backends.end());
for (auto& [key, backend] : keyed_backends) {
// Call shutdown before destructing a ProcessGroupNCCL as instructed by
// https://github.com/pytorch/pytorch/blob/e62073d7997c9e63896cb5289ffd0874a8cc1838/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1164-L1170.
if (auto* pg_nccl = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get())) {
pg_nccl->shutdown();
}
}
#endif
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
Comment on lines +376 to +379
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 process_groups_ cleanup guard mismatch — compile error when NVFUSER_DISTRIBUTED is set without USE_DISTRIBUTED

process_groups_ is declared in communicator.h under #if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED), but the cleanup loop here lives under the broader #if defined(NVFUSER_DISTRIBUTED) (without the USE_DISTRIBUTED guard). When a build defines NVFUSER_DISTRIBUTED but not USE_DISTRIBUTED, process_groups_ does not exist as a member, yet this code tries to iterate over it — a hard compile error.

c10d::unregister_process_group (from GroupRegistry.hpp) is already included under #ifdef NVFUSER_DISTRIBUTED, so fixing just the guard on these lines is sufficient:

Suggested change
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
#if defined(USE_DISTRIBUTED)
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
#endif

(The surrounding #if defined(NVFUSER_DISTRIBUTED) / #endif already provides the outer distributed guard.)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is relevant -- but the real fix is to replace USE_DISTRIBUTED by NVFUSER_DISTRIBUTED wherever it appears (we don't use USE_DISTRIBUTED in the codebase)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. You're right—USE_DISTRIBUTED should be replaced with NVFUSER_DISTRIBUTED everywhere. The codebase currently has process_groups_ guarded by both macros, but the cleanup code only checks one, which causes the mismatch.

USE_DISTRIBUTED appears in 4 files:

  1. csrc/multidevice/communicator.h (line 14, line 162) — the header where process_groups_ is declared
  2. csrc/multidevice/communicator.cpp (line 414) — where process_groups_ is populated
  3. csrc/multidevice/symmetric_tensor.cpp (multiple locations) — where symmetric memory uses USE_DISTRIBUTED
  4. csrc/multidevice/symmetric_tensor.h (line 88) — where torch_symm_handle_ is declared

All of these #if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED) guards should become #if defined(NVFUSER_DISTRIBUTED).

The Python files (python/utils.py, python/tools/gen_nvfuser_version.py) and CMakeLists.txt also reference USE_DISTRIBUTED for build configuration, so those would need updating too depending on your build system changes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a build "build_image_ARM_igpu_20_TNVF" which was failing - it was setting NVFUSER_DISTRIBUTED to ON but USE_DISTRIBUTED to False during cmake configuration. So I had to explicitly add USE_DISTRIBUTED checks. I'll add the missing guard here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I didnt know
cc @wujingyue to make sure you are aware of this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tagging me.

the real fix is to replace USE_DISTRIBUTED by NVFUSER_DISTRIBUTED wherever it appears (we don't use USE_DISTRIBUTED in the codebase)

That sounds right.

I'm not sure why NVFUSER_DISTRIBUTED was created in the first place -- it was there before I joined the project. We could use USE_DISTRIBUTED throughout the codebase. I guess NVFUSER_DISTRIBUTED gives an extra control of turning nvFuser distributed off even when pytorch distributed is on. But I'm not sure that use case is relevant.

it was setting NVFUSER_DISTRIBUTED to ON but USE_DISTRIBUTED to False

Are you sure about this given

cmake_dependent_option(NVFUSER_DISTRIBUTED "" ON "USE_DISTRIBUTED" OFF)
?

Also note

// nvFuser is sometimes built on a pytorch without c10d. When that
// happens, c10d isn't linked, NVFUSER_DISTRIBUTED is undefined and the
// multi-GPU component of nvFuser is expected to be disabled.
//
// Instead of adding `#ifdef NVFUSER_DISTRIBUTED` in too many places, this file
// provides a buildable mock implementation of c10d to keep nvFuser code less
// divergent. This implementation won't run because tests and user code are
// guarded by Communicator::is_available.
. I think we can avoid the #if by defining a mock.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wujingyue I saw that USE_DISTRIBUTED was false and NVFUSER_DISTRIBUTE was on in logs here https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/287606349/raw.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saivishal1999 The other guard still looks missing here? Did you attempt the mock approach -- that should ideally work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding mocks and removing these guards in future commits, builds/tests are passing.

#endif
backends_.clear();
}
Expand All @@ -382,16 +397,16 @@ c10d::Backend* Communicator::getBackendForTeam(
// generate a string key which is unique to the team
// create the team and cache it
std::string team_key = prefix + getTeamKey(team, b);
// check that the caller's rank belongs to the requested team
auto rank_it = std::ranges::find(team.begin(), team.end(), deviceId());
if (rank_it == team.end()) {
return nullptr;
}
// check if backend associated with the team is present in the cache
if (backends_.find(team_key) ==
backends_.end()) { // create the backend and cache it
#ifdef NVFUSER_DISTRIBUTED
backends_[team_key] = [&]() -> c10::intrusive_ptr<c10d::Backend> {
// check that the caller's rank belongs to the requested team
auto rank_it = std::find(team.begin(), team.end(), deviceId());
if (rank_it == team.end()) {
return nullptr;
}
// retrieve the caller's rank index/position in the team
RankType team_rank = std::distance(team.begin(), rank_it);
return createBackend(
Expand All @@ -404,6 +419,26 @@ c10d::Backend* Communicator::getBackendForTeam(
backends_[team_key] = nullptr;
#endif
}
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED)
if (process_groups_.find(team_key) == process_groups_.end()) {
if (b == CommunicatorBackend::kNccl) {
RankType team_rank = std::distance(team.begin(), rank_it);

auto pg = c10::make_intrusive<c10d::ProcessGroup>(
c10::make_intrusive<c10d::PrefixStore>(team_key, store_),
team_rank,
static_cast<int>(team.size()));
pg->setBackend(
c10::DeviceType::CUDA,
c10d::ProcessGroup::BackendType::NCCL,
backends_[team_key]);
pg->setDefaultBackend(c10d::ProcessGroup::BackendType::NCCL);
pg->setGroupName(team_key);

registerProcessGroup(team_key, pg);
}
Comment thread
saivishal1999 marked this conversation as resolved.
Comment on lines +422 to +439
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why we need this change? I am not sure to understand the logic and motivation. It seems like an old artifact -- process_groups_ doesn't seem to be read anywhere. Please clarify

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to keep track of process groups registered by fuser's symmem so that they can be unregistered during cleanup and also to keep track if the group is already registered or not. in the next commit you'll see that i'll use this variable's keys(to read) and during cleanup

Comment thread
saivishal1999 marked this conversation as resolved.
}
#endif
return backends_.at(team_key).get();
}

Expand Down
11 changes: 10 additions & 1 deletion csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
#include <ATen/core/ivalue.h>
#include <c10/util/intrusive_ptr.h>

#ifdef NVFUSER_DISTRIBUTED
#if defined(NVFUSER_DISTRIBUTED)
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this header should always be present, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this header when I added the process_groups_ variable, so the same guard is used. It wasn't needed before my changes

#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#else
Expand Down Expand Up @@ -110,6 +111,10 @@ class NVF_API Communicator {
c10d::Backend* getWorld(
std::optional<CommunicatorBackend> backend = std::nullopt);

void registerProcessGroup(
const std::string& name,
const c10::intrusive_ptr<c10d::ProcessGroup>& pg);

// returns if a backend is available for creation
bool isBackendAvailable(CommunicatorBackend backend) const {
if (backend == CommunicatorBackend::kUcc) {
Expand Down Expand Up @@ -153,6 +158,10 @@ class NVF_API Communicator {
c10::intrusive_ptr<c10d::TCPStore> store_;
// cache for the created backends. The keys are strings generated from Teams
std::unordered_map<std::string, c10::intrusive_ptr<c10d::Backend>> backends_;
// c10d process-group wrappers registered for symmetric-memory rendezvous.
// Keeps track of the process groups created for the rendezvous.
std::unordered_map<std::string, c10::intrusive_ptr<c10d::ProcessGroup>>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make sure c10d_mock.h is up to date to avoid compilation issue in the non-distributed mode

Also, can you explain (and add a comment in the code) why we need ProcessGroup here?

process_groups_;
Comment thread
saivishal1999 marked this conversation as resolved.
};

} // namespace nvfuser
46 changes: 34 additions & 12 deletions csrc/multidevice/ipc_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int createIpcSocket(const std::string& path) {
int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
NVF_CHECK(sockfd >= 0, "Failed to create socket: ", strerror(errno));

struct sockaddr_un addr;
struct sockaddr_un addr{};
setupSockAddr(addr, path);

// For abstract namespace, len is usually calculated specifically, but for
Expand Down Expand Up @@ -69,31 +69,34 @@ void sendFd(
int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
NVF_CHECK(sockfd >= 0, "Failed to create socket: ", strerror(errno));

struct sockaddr_un addr;
struct sockaddr_un addr{};
setupSockAddr(addr, path);
socklen_t addrlen = sizeof(addr.sun_family) + path.length();

// Simple retry loop for connection
int ret = -1;
for (int i = 0; i < 100; ++i) {
ret = connect(sockfd, (struct sockaddr*)&addr, addrlen);
if (ret == 0)
if (ret == 0) {
break;
}
usleep(10000); // 10ms
}
if (ret < 0) {
close(sockfd);
NVF_CHECK(false, "Failed to connect to ", path, ": ", strerror(errno));
}

struct msghdr msg = {0};
struct cmsghdr* cmsg;
struct msghdr msg{};
struct cmsghdr* cmsg = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
char buf[CMSG_SPACE(sizeof(int))];

// If no header data, send at least one byte
char dummy = '.';
struct iovec iov;
struct iovec iov{};
if (header_data && header_len > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
iov.iov_base = const_cast<void*>(header_data);
iov.iov_len = header_len;
} else {
Expand Down Expand Up @@ -121,21 +124,22 @@ void sendFd(
}

int recvFd(int socket_fd, void* header_data, size_t header_len) {
struct sockaddr_un client_addr;
struct sockaddr_un client_addr{};
socklen_t client_len = sizeof(client_addr);
int client_fd =
accept(socket_fd, (struct sockaddr*)&client_addr, &client_len);
NVF_CHECK(client_fd >= 0, "Failed to accept connection: ", strerror(errno));

struct msghdr msg = {0};
struct cmsghdr* cmsg;
struct msghdr msg{};
struct cmsghdr* cmsg = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
char buf[CMSG_SPACE(sizeof(int))];

// If header_len > 0, we expect that much data.
// Note: recvmsg might return fewer bytes if strict requirements aren't met,
// but for local unix sockets with small payloads, it usually delivers all.
char dummy;
struct iovec iov;
char dummy = '.';
struct iovec iov{};
if (header_data && header_len > 0) {
iov.iov_base = header_data;
iov.iov_len = header_len;
Expand Down Expand Up @@ -168,7 +172,7 @@ int recvFd(int socket_fd, void* header_data, size_t header_len) {

int recv_fd = -1;
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg != NULL && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg != nullptr && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
memcpy(&recv_fd, CMSG_DATA(cmsg), sizeof(int));
}
Expand All @@ -191,4 +195,22 @@ MulticastProtocol getMulticastProtocol() {
return MulticastProtocol::BatchMemcpy;
}

SymmetricMemoryBackend getSymmetricMemoryBackend() {
if (isOptionEnabled(EnableOption::SymmetricMemoryBackend)) {
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nccl")) {
return SymmetricMemoryBackend::PyTorchNccl;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nvshmem")) {
return SymmetricMemoryBackend::PyTorchNvshmem;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_cuda")) {
return SymmetricMemoryBackend::PyTorchCuda;
}
}
return SymmetricMemoryBackend::Native;
}

} // namespace nvfuser
16 changes: 15 additions & 1 deletion csrc/multidevice/ipc_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,24 @@ const T& fromBytes(const std::vector<uint8_t>& bytes) {

// IPC Utils for sharing file descriptors

enum class MulticastProtocol { Memcpy, Multimem, BatchMemcpy };
enum class MulticastProtocol : uint8_t { Memcpy, Multimem, BatchMemcpy };

MulticastProtocol getMulticastProtocol();

// Backend for symmetric memory allocation and rendezvous.
// Native: Fuser's own CUDA VMM + IPC implementation (default, maintained).
// PyTorch*: Use PyTorch's symmetric memory
// (torch.distributed._symmetric_memory) with the given transport backend (Nccl,
// Nvshmem, or Cuda).
enum class SymmetricMemoryBackend : uint8_t {
Native,
PyTorchNccl,
PyTorchNvshmem,
PyTorchCuda,
};

SymmetricMemoryBackend getSymmetricMemoryBackend();

// Creates a listening Unix domain socket bound to path.
// If path starts with '@', it uses the abstract namespace (replaced with \0).
// Returns the socket file descriptor.
Expand Down
Loading
Loading