Skip to content
Merged
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ list(APPEND
"${CMAKE_CURRENT_SOURCE_DIR}/cmake"
)
find_package(CUDAToolkit REQUIRED)
find_package(MPI 3.0 REQUIRED COMPONENTS CXX)
find_package(Torch 2.6 REQUIRED CONFIG)

# Also, torch_python!
Expand All @@ -79,6 +78,7 @@ find_library(TORCH_PYTHON_LIBRARY
find_library(TORCH_PYTHON_LIBRARY torch_python REQUIRED)

if (DGRAPH_ENABLE_NVSHMEM)
find_package(MPI 3.0 REQUIRED COMPONENTS CXX)
find_package(NVSHMEM 2.5 REQUIRED MODULE)
endif ()

Expand Down
24 changes: 24 additions & 0 deletions DGraph/Communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from DGraph.distributed.nccl import NCCLBackendEngine

from DGraph.CommunicatorBase import CommunicatorBase
from typing import Tuple, Optional

SUPPORTED_BACKENDS = ["nccl", "mpi", "nvshmem"]

Expand Down Expand Up @@ -95,6 +96,13 @@ def get_local_tensor(

return masked_tensor

def alloc_buffer(
self, size: Tuple[int, ...], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""Allocate a buffer suitable for this backend's communication model.
Default: torch.empty. NVSHMEM overrides with symmetric allocation."""
return self.__backend_engine.allocate_buffer(size, dtype, device)

def scatter(self, *args, **kwargs) -> torch.Tensor:
self.__check_init()
return self.__backend_engine.scatter(*args, **kwargs)
Expand All @@ -103,6 +111,22 @@ def gather(self, *args, **kwargs) -> torch.Tensor:
self.__check_init()
return self.__backend_engine.gather(*args, **kwargs)

def put(
self,
send_buffer: torch.Tensor,
recv_buffer: torch.Tensor,
send_offsets: torch.Tensor,
recv_offsets: torch.Tensor,
remote_offsets: Optional[torch.Tensor] = None,
) -> None:
return self.__backend_engine.put(
send_buffer,
recv_buffer,
send_offsets,
recv_offsets,
remote_offsets=remote_offsets,
)

def barrier(self) -> None:
self.__check_init()
self.__backend_engine.barrier()
Expand Down
119 changes: 5 additions & 114 deletions DGraph/data/ogbn_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ogb.nodeproppred import NodePropPredDataset
from DGraph.data.graph import DistributedGraph
from DGraph.data.graph import get_round_robin_node_rank_map
import numpy as np
from DGraph.data.preprocess import process_homogenous_data
import os
import torch.distributed as dist


SUPPORTED_DATASETS = [
"ogbn-arxiv",
Expand All @@ -37,117 +37,6 @@
}


def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]:
"""The nodes are renumbered based on the rank mappings so the node features and
numbers are contiguous."""

contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement)
return renumbered_nodes, contiguous_rank_mapping


def edge_renumbering(
edge_indices, renumbered_nodes, vertex_mapping, edge_features=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
src_indices = edge_indices[0, :]
dst_indices = edge_indices[1, :]
src_indices = renumbered_nodes[src_indices]
dst_indices = renumbered_nodes[dst_indices]

edge_src_rank_mapping = vertex_mapping[src_indices]
edge_dest_rank_mapping = vertex_mapping[dst_indices]

sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping)
dst_indices = dst_indices[sorted_indices]
src_indices = src_indices[sorted_indices]

sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices]

if edge_features is not None:
# Sort the edge features based on the sorted indices
edge_features = edge_features[sorted_indices]

return (
torch.stack([src_indices, dst_indices], dim=0),
sorted_src_rank_mapping,
sorted_dest_rank_mapping,
edge_features,
)


def process_homogenous_data(
graph_data,
labels,
rank: int,
world_Size: int,
split_idx: dict,
node_rank_placement: torch.Tensor,
*args,
**kwargs,
) -> DistributedGraph:
"""For processing homogenous graph with node features, edge index and labels"""
assert "node_feat" in graph_data, "Node features not found"
assert "edge_index" in graph_data, "Edge index not found"
assert "num_nodes" in graph_data, "Number of nodes not found"
assert graph_data["edge_feat"] is None, "Edge features not supported"

node_features = torch.Tensor(graph_data["node_feat"]).float()
edge_index = torch.Tensor(graph_data["edge_index"]).long()
num_nodes = graph_data["num_nodes"]
labels = torch.Tensor(labels).long()
# For bidirectional graphs the number of edges are double counted
num_edges = edge_index.shape[1]

assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch"
assert "train" in split_idx, "Train mask not found"
assert "valid" in split_idx, "Validation mask not found"
assert "test" in split_idx, "Test mask not found"

train_nodes = torch.from_numpy(split_idx["train"])
valid_nodes = torch.from_numpy(split_idx["valid"])
test_nodes = torch.from_numpy(split_idx["test"])

# Renumber the nodes and edges to make them contiguous
renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement)
node_features = node_features[renumbered_nodes]

# Sanity check to make sure we placed the nodes in the correct spots

assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping)

# First renumber the edges
# Then we calculate the location of the source and destination vertex of each edge
# based on the rank mapping
# Then we sort the edges based on the source vertex rank mapping
# When determining the location of the edge, we use the rank of the source vertex
# as the location of the edge

edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering(
edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None
)

train_nodes = renumbered_nodes[train_nodes]
valid_nodes = renumbered_nodes[valid_nodes]
test_nodes = renumbered_nodes[test_nodes]

labels = labels[renumbered_nodes]

graph_obj = DistributedGraph(
node_features=node_features,
edge_index=edge_index,
num_nodes=num_nodes,
num_edges=num_edges,
node_loc=contiguous_rank_mapping.long(),
edge_loc=edge_rank_mapping.long(),
edge_dest_rank_mapping=edge_dest_rank_mapping.long(),
world_size=world_Size,
labels=labels,
train_mask=train_nodes,
val_mask=valid_nodes,
test_mask=test_nodes,
)
return graph_obj


class DistributedOGBWrapper(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -211,7 +100,9 @@ def __init__(
else:
if node_rank_placement is None:
if self._rank == 0:
print(f"Node rank placement not provided, generating a round robin placement")
print(
f"Node rank placement not provided, generating a round robin placement"
)
node_rank_placement = get_round_robin_node_rank_map(
graph_data["num_nodes"], self._world_size
)
Expand Down
114 changes: 114 additions & 0 deletions DGraph/data/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
from typing import Optional, Tuple
from DGraph.data.graph import DistributedGraph


def node_renumbering(node_rank_placement) -> Tuple[torch.Tensor, torch.Tensor]:
"""The nodes are renumbered based on the rank mappings so the node features and
numbers are contiguous."""

contiguous_rank_mapping, renumbered_nodes = torch.sort(node_rank_placement)
return renumbered_nodes, contiguous_rank_mapping


def edge_renumbering(
edge_indices, renumbered_nodes, vertex_mapping, edge_features=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
src_indices = edge_indices[0, :]
dst_indices = edge_indices[1, :]
src_indices = renumbered_nodes[src_indices]
dst_indices = renumbered_nodes[dst_indices]

edge_src_rank_mapping = vertex_mapping[src_indices]
edge_dest_rank_mapping = vertex_mapping[dst_indices]

sorted_src_rank_mapping, sorted_indices = torch.sort(edge_src_rank_mapping)
dst_indices = dst_indices[sorted_indices]
src_indices = src_indices[sorted_indices]

sorted_dest_rank_mapping = edge_dest_rank_mapping[sorted_indices]

if edge_features is not None:
# Sort the edge features based on the sorted indices
edge_features = edge_features[sorted_indices]

return (
torch.stack([src_indices, dst_indices], dim=0),
sorted_src_rank_mapping,
sorted_dest_rank_mapping,
edge_features,
)


def process_homogenous_data(
graph_data,
labels,
rank: int,
world_Size: int,
split_idx: dict,
node_rank_placement: torch.Tensor,
*args,
**kwargs,
) -> DistributedGraph:
"""For processing homogenous graph with node features, edge index and labels"""
assert "node_feat" in graph_data, "Node features not found"
assert "edge_index" in graph_data, "Edge index not found"
assert "num_nodes" in graph_data, "Number of nodes not found"
assert graph_data["edge_feat"] is None, "Edge features not supported"

node_features = torch.Tensor(graph_data["node_feat"]).float()
edge_index = torch.Tensor(graph_data["edge_index"]).long()
num_nodes = graph_data["num_nodes"]
labels = torch.Tensor(labels).long()
# For bidirectional graphs the number of edges are double counted
num_edges = edge_index.shape[1]

assert node_rank_placement.shape[0] == num_nodes, "Node mapping mismatch"
assert "train" in split_idx, "Train mask not found"
assert "valid" in split_idx, "Validation mask not found"
assert "test" in split_idx, "Test mask not found"

train_nodes = torch.from_numpy(split_idx["train"])
valid_nodes = torch.from_numpy(split_idx["valid"])
test_nodes = torch.from_numpy(split_idx["test"])

# Renumber the nodes and edges to make them contiguous
renumbered_nodes, contiguous_rank_mapping = node_renumbering(node_rank_placement)
node_features = node_features[renumbered_nodes]

# Sanity check to make sure we placed the nodes in the correct spots

assert torch.all(node_rank_placement[renumbered_nodes] == contiguous_rank_mapping)

# First renumber the edges
# Then we calculate the location of the source and destination vertex of each edge
# based on the rank mapping
# Then we sort the edges based on the source vertex rank mapping
# When determining the location of the edge, we use the rank of the source vertex
# as the location of the edge

edge_index, edge_rank_mapping, edge_dest_rank_mapping, _ = edge_renumbering(
edge_index, renumbered_nodes, contiguous_rank_mapping, edge_features=None
)

train_nodes = renumbered_nodes[train_nodes]
valid_nodes = renumbered_nodes[valid_nodes]
test_nodes = renumbered_nodes[test_nodes]

labels = labels[renumbered_nodes]

graph_obj = DistributedGraph(
node_features=node_features,
edge_index=edge_index,
num_nodes=num_nodes,
num_edges=num_edges,
node_loc=contiguous_rank_mapping.long(),
edge_loc=edge_rank_mapping.long(),
edge_dest_rank_mapping=edge_dest_rank_mapping.long(),
world_size=world_Size,
labels=labels,
train_mask=train_nodes,
val_mask=valid_nodes,
test_mask=test_nodes,
)
return graph_obj
37 changes: 36 additions & 1 deletion DGraph/distributed/Engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#
# SPDX-License-Identifier: (Apache-2.0)
import torch
from typing import Optional, Union
from typing import Optional, Union, Tuple


class BackendEngine(object):
Expand Down Expand Up @@ -64,6 +64,41 @@ def gather(
) -> torch.Tensor:
raise NotImplementedError

def put(
self,
send_buffer: torch.Tensor,
recv_buffer: torch.Tensor,
send_offsets: torch.Tensor,
recv_offsets: torch.Tensor,
remote_offsets: Optional[torch.Tensor] = None,
) -> None:
"""
Exchange data between all ranks.

Chunks send_buffer by send_offsets, delivers each chunk to the
corresponding rank's recv_buffer. Must be synchronous: when this
method returns, recv_buffer is fully populated and safe to read.

Two-sided backends ignore remote_offsets.
One-sided backends use remote_offsets[i] as the write position
into rank i's recv_buffer.
"""
raise NotImplementedError

def allocate_buffer(
self,
size: Tuple[int, ...],
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
Allocate a communication buffer.

Default: torch.empty. One-sided backends override this to
return symmetric / registered memory.
"""
return torch.empty(size, dtype=dtype, device=device)

def finalize(self) -> None:
raise NotImplementedError

Expand Down
14 changes: 14 additions & 0 deletions DGraph/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,18 @@
Modules exported by this package:
- `Engine`: The DGraph communication engine used by the Communicator.
- `BackendEngine`: The abstract DGraph communication engine used by the Communicator.
- `HaloExchange`: Halo exchange class for communicating remote vertices
- `CommunicationPattern`: Dataclass for holding communication pattern information
"""
from DGraph.distributed.haloExchange import HaloExchange, DGraphMessagePassing
from DGraph.distributed.commInfo import (
CommunicationPattern,
build_communication_pattern,
)

__all__ = [
"HaloExchange",
"DGraphMessagePassing",
"CommunicationPattern",
"build_communication_pattern",
]
Loading