Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,11 @@ def _setup_for_graph_store(
# Extract supervision edge types and derive label edge types from the
# ABLPInputNodes.labels dict (keyed by supervision edge type).
self._supervision_edge_types = list(first_input.labels.keys())
has_negatives = any(neg is not None for _, neg in first_input.labels.values())
has_negatives = any(
negative_labels is not None
for ablp_input in input_nodes.values()
for _, negative_labels in ablp_input.labels.values()
)

self._positive_label_edge_types = [
message_passing_to_positive_label(et) for et in self._supervision_edge_types
Expand Down
98 changes: 32 additions & 66 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,12 @@
from gigl.common.logger import Logger
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import SamplerOptions
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import (
DEFAULT_HOMOGENEOUS_EDGE_TYPE,
DEFAULT_HOMOGENEOUS_NODE_TYPE,
FeatureInfo,
select_label_edge_types,
)
from gigl.types.graph import FeatureInfo, select_label_edge_types
from gigl.utils.data_splitters import get_labels_for_anchor_nodes

SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0
Expand Down Expand Up @@ -277,19 +273,13 @@ def get_edge_dir(self) -> Literal["in", "out"]:

def get_node_ids(
self,
rank: Optional[int] = None,
world_size: Optional[int] = None,
split: Optional[Union[Literal["train", "val", "test"], str]] = None,
node_type: Optional[NodeType] = None,
request: FetchNodesRequest,
) -> torch.Tensor:
"""Get the node ids from the dataset.

Args:
rank: The rank of the process requesting node ids. Must be provided if world_size is provided.
world_size: The total number of processes in the distributed setup. Must be provided if rank is provided.
split: The split of the dataset to get node ids from. If provided, the dataset must have
`train_node_ids`, `val_node_ids`, and `test_node_ids` properties.
node_type: The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous.
request: The node-fetch request, including split, node type,
and either round-robin rank/world_size or a contiguous slice.

Returns:
The node ids.
Expand All @@ -302,63 +292,40 @@ def get_node_ids(
* If the node type is provided for a homogeneous dataset
* If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided

Examples:
Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99].

Get all node ids (no split filtering):

>>> server.get_node_ids()
tensor([0, 1, 2, ..., 99]) # All 100 nodes

Get only training nodes:

>>> server.get_node_ids(split="train")
tensor([0, 1, 2, ..., 59]) # 60 training nodes

Shard all nodes across 4 processes (each gets ~25 nodes):

>>> server.get_node_ids(rank=0, world_size=4)
tensor([0, 1, 2, ..., 24]) # First 25 of all 100 nodes

Shard training nodes across 4 processes (each gets ~15 nodes):

>>> server.get_node_ids(rank=0, world_size=4, split="train")
tensor([0, 1, 2, ..., 14]) # First 15 of the 60 training nodes

Note: When `split=None`, all nodes are queryable. This means nodes from any
split (train, val, or test) may be returned. This is useful when you need
to sample neighbors during inference, as neighbor nodes may belong to any split.
"""
if (rank is None) ^ (world_size is None):
raise ValueError(
f"rank and world_size must be provided together. Received rank: {rank}, world_size: {world_size}"
)
if split == "train":
request.validate()
if request.split == "train":
nodes = self.dataset.train_node_ids
elif split == "val":
elif request.split == "val":
nodes = self.dataset.val_node_ids
elif split == "test":
elif request.split == "test":
nodes = self.dataset.test_node_ids
elif split is None:
elif request.split is None:
nodes = self.dataset.node_ids
else:
raise ValueError(
f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None."
f"Invalid split: {request.split}. Must be one of 'train', 'val', 'test', or None."
)

if node_type is not None:
if request.node_type is not None:
if not isinstance(nodes, abc.Mapping):
raise ValueError(
f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}"
f"node_type was provided as {request.node_type}, so node ids must be a dict[NodeType, torch.Tensor] "
f"(e.g. a heterogeneous dataset), got {type(nodes)}"
)
nodes = nodes[node_type]
nodes = nodes[request.node_type]
elif not isinstance(nodes, torch.Tensor):
raise ValueError(
f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}."
)

if rank is not None and world_size is not None:
return shard_nodes_by_process(nodes, rank, world_size)
if request.server_slice is not None:
return request.server_slice.slice_tensor(nodes)
if request.rank is not None and request.world_size is not None:
return shard_nodes_by_process(nodes, request.rank, request.world_size)
return nodes

def get_edge_types(self) -> Optional[list[EdgeType]]:
Expand All @@ -385,25 +352,17 @@ def get_node_types(self) -> Optional[list[NodeType]]:

def get_ablp_input(
self,
split: Union[Literal["train", "val", "test"], str],
rank: Optional[int] = None,
world_size: Optional[int] = None,
node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE,
supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE,
request: FetchABLPRequest,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing.

Note: rank and world_size here are for the process group we're *fetching for*, not the process group we're *fetching from*.
e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2.

Args:
split: The split to get the training input for.
rank: The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned.
Must be provided if world_size is provided.
world_size: The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned.
Must be provided if rank is provided.
node_type: The type of nodes to retrieve. Defaults to the default homogeneous node type.
supervision_edge_type: The edge type to use for the supervision. Defaults to the default homogeneous edge type.
request: The ABLP fetch request, including split, node type,
supervision edge type, and either round-robin rank/world_size
or a contiguous slice.

Returns:
A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels.
Expand All @@ -414,11 +373,18 @@ def get_ablp_input(
Raises:
ValueError: If the split is invalid.
"""
request.validate()
anchors = self.get_node_ids(
split=split, rank=rank, world_size=world_size, node_type=node_type
FetchNodesRequest(
split=request.split,
rank=request.rank,
world_size=request.world_size,
node_type=request.node_type,
server_slice=request.server_slice,
)
)
positive_label_edge_type, negative_label_edge_type = select_label_edge_types(
supervision_edge_type, self.dataset.get_edge_types()
request.supervision_edge_type, self.dataset.get_edge_types()
)
positive_labels, negative_labels = get_labels_for_anchor_nodes(
self.dataset, anchors, positive_label_edge_type, negative_label_edge_type
Expand Down
54 changes: 54 additions & 0 deletions gigl/distributed/graph_store/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""RPC request messages for graph-store fetch operations."""

from dataclasses import dataclass
from typing import Literal, Optional, Union

from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.src.common.types.graph_data import EdgeType, NodeType


@dataclass(frozen=True)
class FetchNodesRequest:
"""Request for fetching node IDs from a storage server."""

rank: Optional[int] = None
world_size: Optional[int] = None
split: Optional[Union[Literal["train", "val", "test"], str]] = None
node_type: Optional[NodeType] = None
server_slice: Optional[ServerSlice] = None

def validate(self) -> None:
"""Validate that the request does not mix sharding modes."""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")


@dataclass(frozen=True)
class FetchABLPRequest:
"""Request for fetching ABLP input from a storage server."""

split: Union[Literal["train", "val", "test"], str]
node_type: NodeType
supervision_edge_type: EdgeType
rank: Optional[int] = None
world_size: Optional[int] = None
server_slice: Optional[ServerSlice] = None

def validate(self) -> None:
"""Validate that the request does not mix sharding modes."""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")
Loading