From b9be302b847828c3ceaeb193a8541921f52dc6c5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 13 Mar 2026 00:18:54 +0000 Subject: [PATCH 01/13] update --- .../graph_store/remote_dist_dataset.py | 294 +++++++++--- gigl/distributed/utils/__init__.py | 2 + gigl/distributed/utils/neighborloader.py | 178 +++++++ .../graph_store_integration_test.py | 55 ++- .../graph_store/remote_dist_dataset_test.py | 453 ++++++++++++++++++ .../distributed/utils/neighborloader_test.py | 243 ++++++++++ 6 files changed, 1135 insertions(+), 90 deletions(-) diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 927552d9b..84dbb1eac 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -11,6 +11,10 @@ from gigl.common.logger import Logger from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.utils.neighborloader import ( + ShardStrategy, + compute_server_assignments, +) from gigl.distributed.utils.networking import get_free_ports from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import EdgeType, NodeType @@ -227,12 +231,73 @@ def _fetch_node_ids( node_ids = torch.futures.wait_all(futures) return {server_rank: node_ids for server_rank, node_ids in enumerate(node_ids)} + def _fetch_node_ids_by_server( + self, + rank: int, + world_size: int, + node_type: Optional[NodeType] = None, + split: Optional[Literal["train", "val", "test"]] = None, + ) -> dict[int, torch.Tensor]: + """Fetches node ids using contiguous server assignment. + + Each compute node is assigned a contiguous range of servers. Only + assigned servers are RPCed; unassigned servers get empty tensors. + Boundary servers are sliced fractionally when servers don't divide + evenly across compute nodes. + + Args: + rank: The rank of the compute node requesting node ids. + world_size: The total number of compute nodes. + node_type: The type of nodes to get. Must be provided for heterogeneous datasets. + split: The split of the dataset to get node ids from. + + Returns: + A dict mapping every server rank to a tensor of node ids. + """ + node_type = self._infer_node_type_if_homogeneous_with_label_edges(node_type) + + assignments = compute_server_assignments( + num_servers=self.cluster_info.num_storage_nodes, + num_compute_nodes=world_size, + compute_rank=rank, + ) + + logger.info( + f"Getting node ids via CONTIGUOUS strategy for rank {rank} / {world_size} " + f"with node type {node_type} and split {split}. " + f"Assigned servers: {list(assignments.keys())}" + ) + + # RPC only assigned servers (fetch ALL nodes, no server-side sharding) + futures: dict[int, torch.futures.Future[torch.Tensor]] = {} + for server_rank in assignments: + futures[server_rank] = async_request_server( + server_rank, + DistServer.get_node_ids, + rank=None, + world_size=None, + split=split, + node_type=node_type, + ) + + # Build result: slice assigned servers, empty tensors for unassigned + result: dict[int, torch.Tensor] = {} + for server_rank in range(self.cluster_info.num_storage_nodes): + if server_rank in futures: + all_nodes = futures[server_rank].wait() + result[server_rank] = assignments[server_rank].slice_tensor(all_nodes) + else: + result[server_rank] = torch.empty(0, dtype=torch.long) + + return result + def fetch_node_ids( self, rank: Optional[int] = None, world_size: Optional[int] = None, split: Optional[Literal["train", "val", "test"]] = None, node_type: Optional[NodeType] = None, + shard_strategy: ShardStrategy = ShardStrategy.ROUND_ROBIN, ) -> dict[int, torch.Tensor]: """ Fetches node ids from the storage nodes for the current compute node (machine). @@ -248,50 +313,21 @@ def fetch_node_ids( If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. node_type (Optional[NodeType]): The type of nodes to get. Must be provided for heterogeneous datasets. + shard_strategy (ShardStrategy): Strategy for sharding node IDs across compute nodes. + ``ROUND_ROBIN`` (default) shards each server's nodes across all compute nodes. + ``CONTIGUOUS`` assigns entire servers to compute nodes, producing empty tensors + for unassigned servers. ``CONTIGUOUS`` requires both ``rank`` and ``world_size``. + + Raises: + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. Returns: dict[int, torch.Tensor]: A dict mapping storage rank to node ids. Examples: - Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. - Nodes are partitioned across storage nodes, with splits defined as: - - Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] - train=[0, 1, 2, 3], val=[4, 5], test=[6, 7] - Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] - train=[8, 9, 10, 11], val=[12, 13], test=[14, 15] - - Get all nodes (no split filtering, no sharding): - - >>> dataset.fetch_node_ids() - { - 0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All 8 nodes from storage rank 0 - 1: tensor([8, 9, 10, 11, 12, 13, 14, 15]) # All 8 nodes from storage rank 1 - } - - Shard all nodes across 2 compute nodes (compute rank 0 gets first half from each storage): - - >>> dataset.fetch_node_ids(rank=0, world_size=2) - { - 0: tensor([0, 1, 2, 3]), # First 4 of all 8 nodes from storage rank 0 - 1: tensor([8, 9, 10, 11]) # First 4 of all 8 nodes from storage rank 1 - } - - Get only training nodes (no sharding): - - >>> dataset.fetch_node_ids(split="train") - { - 0: tensor([0, 1, 2, 3]), # 4 training nodes from storage rank 0 - 1: tensor([8, 9, 10, 11]) # 4 training nodes from storage rank 1 - } - - Combine split and sharding (training nodes, sharded for compute rank 0): - - >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train") - { - 0: tensor([0, 1]), # First 2 of 4 training nodes from storage rank 0 - 1: tensor([8, 9]) # First 2 of 4 training nodes from storage rank 1 - } + See :class:`~gigl.distributed.utils.neighborloader.ShardStrategy` for + concrete examples of how each strategy distributes node IDs across + compute nodes. Note: When `split=None`, all nodes are queryable. This means nodes from any split @@ -304,6 +340,21 @@ def fetch_node_ids( `mp_sharing_dict` to the `RemoteDistDataset` constructor. """ + if shard_strategy == ShardStrategy.CONTIGUOUS: + if rank is None or world_size is None: + raise ValueError( + "Both rank and world_size must be provided when using " + f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" + ) + + def _do_fetch() -> dict[int, torch.Tensor]: + if shard_strategy == ShardStrategy.CONTIGUOUS: + assert rank is not None and world_size is not None + return self._fetch_node_ids_by_server( + rank, world_size, node_type, split + ) + return self._fetch_node_ids(rank, world_size, node_type, split) + def server_key(server_rank: int) -> str: return f"node_ids_from_server_{server_rank}" @@ -314,7 +365,7 @@ def server_key(server_rank: int) -> str: logger.info( f"Compute rank {torch.distributed.get_rank()} is getting node ids from storage nodes" ) - node_ids = self._fetch_node_ids(rank, world_size, node_type, split) + node_ids = _do_fetch() for server_rank, node_id in node_ids.items(): node_id.share_memory_() self._mp_sharing_dict[server_key(server_rank)] = node_id @@ -335,7 +386,7 @@ def server_key(server_rank: int) -> str: gc.collect() return node_ids else: - return self._fetch_node_ids(rank, world_size, node_type, split) + return _do_fetch() def fetch_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: """ @@ -412,6 +463,92 @@ def _fetch_ablp_input( for server_rank, ablp_input in enumerate(ablp_inputs) } + def _fetch_ablp_input_by_server( + self, + split: Literal["train", "val", "test"], + rank: int, + world_size: int, + node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: + """Fetches ABLP input using contiguous server assignment. + + Each compute node is assigned a contiguous range of servers. Only + assigned servers are RPCed; unassigned servers get empty tensors. + Boundary servers are sliced fractionally when servers don't divide + evenly across compute nodes. + + Args: + split: The split of the dataset to get ABLP input from. + rank: The rank of the compute node requesting ABLP input. + world_size: The total number of compute nodes. + node_type: The type of anchor nodes to retrieve. + supervision_edge_type: The edge type for supervision. + + Returns: + A dict mapping every server rank to a tuple of + (anchors, positive_labels, negative_labels). + """ + assignments = compute_server_assignments( + num_servers=self.cluster_info.num_storage_nodes, + num_compute_nodes=world_size, + compute_rank=rank, + ) + + logger.info( + f"Getting ABLP input via CONTIGUOUS strategy for rank {rank} / {world_size} " + f"with node type {node_type}, split {split}, and " + f"supervision edge type {supervision_edge_type}. " + f"Assigned servers: {list(assignments.keys())}" + ) + + # RPC only assigned servers (fetch ALL ABLP data, no server-side sharding) + futures: dict[ + int, + torch.futures.Future[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ] = {} + for server_rank in assignments: + futures[server_rank] = async_request_server( + server_rank, + DistServer.get_ablp_input, + split=split, + rank=None, + world_size=None, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + ) + + # Build result: slice assigned servers, empty tensors for unassigned + result: dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = {} + for server_rank in range(self.cluster_info.num_storage_nodes): + if server_rank in futures: + anchors, positive_labels, negative_labels = futures[server_rank].wait() + server_slice = assignments[server_rank] + sliced_anchors = server_slice.slice_tensor(anchors) + sliced_positive = server_slice.slice_tensor(positive_labels) + sliced_negative = ( + server_slice.slice_tensor(negative_labels) + if negative_labels is not None + else None + ) + result[server_rank] = ( + sliced_anchors, + sliced_positive, + sliced_negative, + ) + else: + result[server_rank] = ( + torch.empty(0, dtype=torch.long), + torch.empty(0, dtype=torch.long), + None, + ) + + return result + # TODO(#488) - support multiple supervision edge types def fetch_ablp_input( self, @@ -420,9 +557,9 @@ def fetch_ablp_input( world_size: Optional[int] = None, anchor_node_type: Optional[NodeType] = None, supervision_edge_type: Optional[EdgeType] = None, + shard_strategy: ShardStrategy = ShardStrategy.ROUND_ROBIN, ) -> dict[int, ABLPInputNodes]: - """ - Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. + """Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. The returned dict maps storage rank to an :class:`ABLPInputNodes` dataclass for that storage node. If (rank, world_size) is provided, the input will be @@ -447,6 +584,11 @@ def fetch_ablp_input( Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. Defaults to None. + shard_strategy (ShardStrategy): Strategy for sharding ABLP input across compute + nodes. ``ROUND_ROBIN`` (default) shards each server's data across all compute + nodes. ``CONTIGUOUS`` assigns entire servers to compute nodes, producing empty + tensors for unassigned servers. ``CONTIGUOUS`` requires both ``rank`` and + ``world_size``. Returns: dict[int, ABLPInputNodes]: @@ -456,24 +598,13 @@ def fetch_ablp_input( - positive_labels: Dict mapping positive label EdgeType to a 2D tensor [N, M]. - negative_labels: Optional dict mapping negative label EdgeType to a 2D tensor [N, M]. - Examples: - Suppose we have 1 storage node with users [0, 1, 2, 3, 4] where: - train=[0, 1, 2], val=[3], test=[4] - And positive/negative labels defined for link prediction. - - Get training ABLP input (heterogeneous): - - >>> dataset.fetch_ablp_input(split="train", node_type=USER, supervision_edge_type=USER_TO_ITEM) - { - 0: ABLPInputNodes( - anchor_nodes=tensor([0, 1, 2]), - positive_labels={("user", "to_positive", "item"): tensor([[0, 1], [1, 2], [2, 3]])}, - anchor_node_type="user", - negative_labels={("user", "to_negative", "item"): tensor([[2], [3], [4]])}, - ) - } + Raises: + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. - For labeled homogeneous graphs, anchor_node_type will be DEFAULT_HOMOGENEOUS_NODE_TYPE. + Examples: + See :class:`~gigl.distributed.utils.neighborloader.ShardStrategy` for + concrete examples of how each strategy distributes data across + compute nodes. Note: The GLT sampling engine expects all processes on a given compute machine to have @@ -482,6 +613,13 @@ def fetch_ablp_input( `mp_sharing_dict` to the `RemoteDistDataset` constructor. """ + if shard_strategy == ShardStrategy.CONTIGUOUS: + if rank is None or world_size is None: + raise ValueError( + "Both rank and world_size must be provided when using " + f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" + ) + if (anchor_node_type is None) != (supervision_edge_type is None): raise ValueError( f"anchor_node_type and supervision_edge_type must both be provided or both be None, received: " @@ -521,6 +659,26 @@ def wrap_ablp_input( }, ) + def _do_fetch_ablp() -> ( + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] + ): + if shard_strategy == ShardStrategy.CONTIGUOUS: + assert rank is not None and world_size is not None + return self._fetch_ablp_input_by_server( + split=split, + rank=rank, + world_size=world_size, + node_type=evaluated_anchor_node_type, + supervision_edge_type=evaluated_supervision_edge_type, + ) + return self._fetch_ablp_input( + split=split, + rank=rank, + world_size=world_size, + node_type=evaluated_anchor_node_type, + supervision_edge_type=evaluated_supervision_edge_type, + ) + if self._mp_sharing_dict is not None: assert self._mp_barrier is not None if self._local_rank == 0: @@ -528,13 +686,7 @@ def wrap_ablp_input( logger.info( f"Compute rank {torch.distributed.get_rank()} is getting ABLP input from storage nodes" ) - raw_ablp_inputs = self._fetch_ablp_input( - split=split, - rank=rank, - world_size=world_size, - node_type=evaluated_anchor_node_type, - supervision_edge_type=evaluated_supervision_edge_type, - ) + raw_ablp_inputs = _do_fetch_ablp() for server_rank, ( anchors, positive_labels, @@ -587,13 +739,7 @@ def wrap_ablp_input( gc.collect() return returned_ablp_inputs else: - raw_inputs = self._fetch_ablp_input( - split=split, - rank=rank, - world_size=world_size, - node_type=evaluated_anchor_node_type, - supervision_edge_type=evaluated_supervision_edge_type, - ) + raw_inputs = _do_fetch_ablp() return { server_rank: wrap_ablp_input( anchor_node_type=evaluated_anchor_node_type, diff --git a/gigl/distributed/utils/__init__.py b/gigl/distributed/utils/__init__.py index 04a986658..63aeb8f50 100644 --- a/gigl/distributed/utils/__init__.py +++ b/gigl/distributed/utils/__init__.py @@ -4,6 +4,7 @@ __all__ = [ "GraphStoreInfo", + "ShardStrategy", "get_available_device", "get_free_port", "get_free_ports", @@ -24,6 +25,7 @@ get_process_group_name, init_neighbor_loader_worker, ) +from .neighborloader import ShardStrategy from .networking import ( GraphStoreInfo, get_free_port, diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index fdac550bc..7f76447fc 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -26,6 +26,184 @@ class SamplingClusterSetup(Enum): GRAPH_STORE = "graph_store" +class ShardStrategy(Enum): + """Strategy for sharding node IDs across compute nodes. + + Controls how data from storage servers is distributed to compute nodes. + Both strategies produce the same total coverage (every node appears on + exactly one compute node), but differ in which servers each compute node + communicates with. + + Attributes: + ROUND_ROBIN: Each compute node gets a slice of nodes from every server. + Server-side sharding via rank/world_size. This is the current default. + CONTIGUOUS: Assign entire servers to compute nodes. Each compute node + only gets nodes from its assigned servers, with empty tensors for + the rest. Boundary servers are split fractionally when servers + don't divide evenly across compute nodes. + + Examples: + **2 storage nodes, 2 compute nodes** (even split): + + Suppose each server holds 10 node IDs:: + + Server 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + Server 1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + + ``ROUND_ROBIN`` — every compute node gets a slice from *every* server:: + + Compute 0 (rank=0, world_size=2): + {0: [0,1,2,3,4], 1: [10,11,12,13,14]} + Compute 1 (rank=1, world_size=2): + {0: [5,6,7,8,9], 1: [15,16,17,18,19]} + + ``CONTIGUOUS`` — each compute node gets *entire* servers:: + + Compute 0 (rank=0, world_size=2): + {0: [0,1,2,3,4,5,6,7,8,9], 1: []} # all of server 0 + Compute 1 (rank=1, world_size=2): + {0: [], 1: [10,11,12,13,14,15,16,17,18,19]} # all of server 1 + + **3 storage nodes, 2 compute nodes** (fractional boundary): + + Server 1 is split at the boundary — compute 0 gets the first half, + compute 1 gets the second half:: + + Server 0: [0..9], Server 1: [10..19], Server 2: [20..29] + + Compute 0 (rank=0): {0: [0..9], 1: [10..14], 2: []} + Compute 1 (rank=1): {0: [], 1: [15..19], 2: [20..29]} + + See Also: + :func:`compute_server_assignments` for the assignment algorithm. + """ + + ROUND_ROBIN = "round_robin" + CONTIGUOUS = "contiguous" + + +@dataclass(frozen=True) +class ServerSlice: + """A compute node's ownership of a single server's nodes. + + Fractions are represented as exact rationals (numerator, denominator) + to avoid floating-point boundary errors. For a server with N nodes, + the slice is ``tensor[N * start_num // start_den : N * end_num // end_den]``. + + Args: + server_rank: The rank of the storage server. + start_num: Numerator of the start fraction. + start_den: Denominator of the start fraction. + end_num: Numerator of the end fraction. + end_den: Denominator of the end fraction. + """ + + server_rank: int + start_num: int + start_den: int + end_num: int + end_den: int + + def slice_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """Slice a 1D tensor according to this assignment's rational bounds. + + Uses integer division (N * num // den) for exact, deterministic + index computation. Returns a ``.clone()`` for partial slices to avoid + retaining full backing storage when used with ``share_memory_()``. + + Args: + tensor: A 1D tensor of node IDs from the server. + + Returns: + The sliced portion of the tensor. + """ + total = len(tensor) + start_idx = total * self.start_num // self.start_den + end_idx = total * self.end_num // self.end_den + if start_idx == 0 and end_idx == total: + return tensor + return tensor[start_idx:end_idx].clone() + + +def compute_server_assignments( + num_servers: int, + num_compute_nodes: int, + compute_rank: int, +) -> dict[int, ServerSlice]: + """Compute which servers (and what fraction) a compute node owns. + + Uses integer arithmetic throughout. Compute rank R owns the server + range ``[R * S / C, (R+1) * S / C)`` where boundaries are rational + numbers with denominator C. For each server s in ``[0, S)``, the overlap + with this range determines the ServerSlice fractions. + + Only servers with non-zero overlap are included in the returned dict. + + Args: + num_servers: Total number of storage servers (S). + num_compute_nodes: Total number of compute nodes (C). + compute_rank: Rank of the current compute node (R). + + Returns: + A dict mapping server rank to the ``ServerSlice`` describing the + fraction of that server owned by this compute node. + + Raises: + ValueError: If any argument is invalid (negative values, + rank >= num_compute_nodes, or zero servers/compute nodes). + + Examples: + >>> compute_server_assignments(num_servers=4, num_compute_nodes=2, compute_rank=0) + {0: ServerSlice(server_rank=0, ...), 1: ServerSlice(server_rank=1, ...)} + + >>> compute_server_assignments(num_servers=3, num_compute_nodes=2, compute_rank=1) + {1: ServerSlice(server_rank=1, ...), 2: ServerSlice(server_rank=2, ...)} + """ + if num_servers <= 0: + raise ValueError(f"num_servers must be positive, got {num_servers}") + if num_compute_nodes <= 0: + raise ValueError(f"num_compute_nodes must be positive, got {num_compute_nodes}") + if compute_rank < 0 or compute_rank >= num_compute_nodes: + raise ValueError( + f"compute_rank must be in [0, {num_compute_nodes}), got {compute_rank}" + ) + + S = num_servers + C = num_compute_nodes + R = compute_rank + + # Segment boundaries (as numerators with denominator C): + # start = R * S, end = (R + 1) * S + seg_start = R * S + seg_end = (R + 1) * S + + assignments: dict[int, ServerSlice] = {} + for s in range(S): + # Server s spans [s * C, (s + 1) * C) in numerator-space with denominator C + server_start = s * C + server_end = (s + 1) * C + + overlap_start = max(seg_start, server_start) + overlap_end = min(seg_end, server_end) + + if overlap_start >= overlap_end: + continue + + # Fraction of server s: [(overlap_start - s*C) / C, (overlap_end - s*C) / C) + start_num = overlap_start - server_start + end_num = overlap_end - server_start + + assignments[s] = ServerSlice( + server_rank=s, + start_num=start_num, + start_den=C, + end_num=end_num, + end_den=C, + ) + + return assignments + + @dataclass(frozen=True) class DatasetSchema: """ diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 61539f95c..ac8471330 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -28,7 +28,7 @@ build_storage_dataset, run_storage_server, ) -from gigl.distributed.utils.neighborloader import shard_nodes_by_process +from gigl.distributed.utils.neighborloader import ShardStrategy, shard_nodes_by_process from gigl.distributed.utils.networking import get_free_port, get_free_ports from gigl.distributed.utils.partition_book import build_partition_book, get_ids_on_rank from gigl.env.distributed import ( @@ -287,6 +287,44 @@ def _run_compute_train_tests( count_tensor.item() == expected_batches ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" + # --- CONTIGUOUS shard strategy tests --- + # With 2 servers and 2 compute nodes, rank R should get all of server R's + # nodes and an empty tensor for server (1-R). + contiguous_node_ids = remote_dist_dataset.fetch_node_ids( + split="train", + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + # Assert structure: each rank owns exactly one server in the 2S/2C case + rank = cluster_info.compute_node_rank + other_rank = 1 - rank + assert ( + contiguous_node_ids[rank].numel() > 0 + ), f"Rank {rank} should have non-empty tensor for its own server" + assert ( + contiguous_node_ids[other_rank].numel() == 0 + ), f"Rank {rank} should have empty tensor for server {other_rank}" + + # Assert total node parity: CONTIGUOUS and ROUND_ROBIN should cover the same nodes + local_contiguous_count = sum(t.numel() for t in contiguous_node_ids.values()) + local_round_robin_count = sum(t.numel() for t in random_negative_input.values()) + contiguous_total = torch.tensor(local_contiguous_count, dtype=torch.int64) + round_robin_total = torch.tensor(local_round_robin_count, dtype=torch.int64) + torch.distributed.all_reduce(contiguous_total, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(round_robin_total, op=torch.distributed.ReduceOp.SUM) + assert contiguous_total.item() == round_robin_total.item(), ( + f"CONTIGUOUS total ({contiguous_total.item()}) must equal " + f"ROUND_ROBIN total ({round_robin_total.item()})" + ) + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} CONTIGUOUS: " + f"{local_contiguous_count} nodes from assigned server" + ) + shutdown_compute_proccess() @@ -370,9 +408,6 @@ def _run_compute_multiple_loaders_test( prefetch_size=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} ablp_loader_1 producers: ({ablp_loader_1._producer_id_list})" - ) ablp_loader_2 = DistABLPLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -383,9 +418,6 @@ def _run_compute_multiple_loaders_test( prefetch_size=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} ablp_loader_2 producers: ({ablp_loader_2._producer_id_list})" - ) neighbor_loader_1 = DistNeighborLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -395,9 +427,6 @@ def _run_compute_multiple_loaders_test( worker_concurrency=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} neighbor_loader_1 producers: ({neighbor_loader_1._producer_id_list})" - ) neighbor_loader_2 = DistNeighborLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -407,12 +436,6 @@ def _run_compute_multiple_loaders_test( worker_concurrency=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} neighbor_loader_2 producers: ({neighbor_loader_2._producer_id_list})" - ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} phase 1: loading batches from 4 parallel loaders" - ) torch.distributed.barrier() for ablp_batch_1, ablp_batch_2, neg_batch_1, neg_batch_2 in zip( ablp_loader_1, ablp_loader_2, neighbor_loader_1, neighbor_loader_2 diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index 389332476..7b0af4d20 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -10,6 +10,7 @@ from gigl.common import LocalUri from gigl.distributed.graph_store.dist_server import DistServer, _call_func_on_server from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils.neighborloader import ShardStrategy from gigl.env.distributed import GraphStoreInfo from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, @@ -755,6 +756,458 @@ def test_fetch_ablp_input_mismatched_params_raises(self): ) +class TestRemoteDistDatasetContiguous(TestCase): + """Tests for fetch_node_ids and fetch_ablp_input with ShardStrategy.CONTIGUOUS.""" + + def tearDown(self) -> None: + global _test_server + _test_server = None + dist_server_module._dist_server = None + if dist.is_initialized(): + dist.destroy_process_group() + + def _make_rank_aware_async_mock( + self, server_data: dict[int, dict[str, torch.Tensor]] + ): + """Create an async mock that returns different node IDs per server rank. + + Args: + server_data: Maps server_rank to a dict of + ``{"all": tensor, "train": tensor, ...}`` where ``"all"`` + is the full node set and split keys are optional. + """ + + def _mock(server_rank, func, *args, **kwargs): + split = kwargs.get("split") + data = server_data[server_rank] + key = split if split is not None and split in data else "all" + future: torch.futures.Future = torch.futures.Future() + future.set_result(data[key]) + return future + + return _mock + + @staticmethod + def _mock_request_server_homogeneous(server_rank, func, *args, **kwargs): + """Mock request_server that returns None for node/edge types (homogeneous).""" + if func == DistServer.get_node_types: + return None + if func == DistServer.get_edge_types: + return None + return _mock_request_server(server_rank, func, *args, **kwargs) + + def test_even_split_2_servers_2_compute(self): + """2 servers, 2 compute nodes: each gets one server fully.""" + server_data = { + 0: {"all": torch.arange(10)}, + 1: {"all": torch.arange(10, 20)}, + } + mock_fn = self._make_rank_aware_async_mock(server_data) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=mock_fn, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=self._mock_request_server_homogeneous, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0 gets server 0 fully, server 1 empty + result_0 = remote_dataset.fetch_node_ids( + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality(result_0[0], torch.arange(10)) + self.assertEqual(len(result_0[1]), 0) + + # Rank 1 gets server 0 empty, server 1 fully + result_1 = remote_dataset.fetch_node_ids( + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assertEqual(len(result_1[0]), 0) + self.assert_tensor_equality(result_1[1], torch.arange(10, 20)) + + def test_fractional_split_3_servers_2_compute(self): + """3 servers, 2 compute nodes: server 1 is split at boundary.""" + server_data = { + 0: {"all": torch.arange(10)}, + 1: {"all": torch.arange(10, 20)}, + 2: {"all": torch.arange(20, 30)}, + } + mock_fn = self._make_rank_aware_async_mock(server_data) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=mock_fn, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=self._mock_request_server_homogeneous, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: server 0 fully, server 1 first half, server 2 empty + result_0 = remote_dataset.fetch_node_ids( + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality(result_0[0], torch.arange(10)) + # Server 1: 10 * 0 // 2 = 0, 10 * 1 // 2 = 5 → [10, 11, 12, 13, 14] + self.assert_tensor_equality(result_0[1], torch.arange(10, 15)) + self.assertEqual(len(result_0[2]), 0) + + # Rank 1: server 0 empty, server 1 second half, server 2 fully + result_1 = remote_dataset.fetch_node_ids( + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assertEqual(len(result_1[0]), 0) + # Server 1: 10 * 1 // 2 = 5, 10 * 2 // 2 = 10 → [15, 16, 17, 18, 19] + self.assert_tensor_equality(result_1[1], torch.arange(15, 20)) + self.assert_tensor_equality(result_1[2], torch.arange(20, 30)) + + def test_with_split_filtering(self): + """CONTIGUOUS strategy with split='train' filtering.""" + server_data = { + 0: {"all": torch.arange(10), "train": torch.tensor([0, 1, 2, 3])}, + 1: {"all": torch.arange(10, 20), "train": torch.tensor([10, 11, 12, 13])}, + } + mock_fn = self._make_rank_aware_async_mock(server_data) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=mock_fn, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=self._mock_request_server_homogeneous, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + result_0 = remote_dataset.fetch_node_ids( + rank=0, + world_size=2, + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality(result_0[0], torch.tensor([0, 1, 2, 3])) + self.assertEqual(len(result_0[1]), 0) + + def test_contiguous_requires_rank_and_world_size(self): + """CONTIGUOUS without rank/world_size raises ValueError.""" + cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + rank=0, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + def test_contiguous_labeled_homogeneous_auto_inference(self): + """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" + global _test_server + create_test_process_group() + + positive_labels = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], + } + negative_labels = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], + } + edge_indices = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + ) + } + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels=positive_labels, + negative_labels=negative_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=edge_indices, + src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + ) + _test_server = DistServer(dataset) + dist_server_module._dist_server = _test_server + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # No node_type: should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE + result = remote_dataset.fetch_node_ids( + rank=0, + world_size=1, + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality( + result[0], + torch.tensor([0, 1, 2]), + ) + + def _make_rank_aware_ablp_async_mock( + self, + server_data: dict[ + int, + dict[ + str, + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + ], + ], + ): + """Create an async mock that returns different ABLP data per server rank. + + Args: + server_data: Maps server_rank to a dict of + ``{"all": (anchors, pos, neg), "train": (anchors, pos, neg), ...}`` + where ``"all"`` is the full data and split keys are optional. + """ + + def _mock(server_rank, func, *args, **kwargs): + split = kwargs.get("split") + data = server_data[server_rank] + key = split if split is not None and split in data else "all" + future: torch.futures.Future = torch.futures.Future() + future.set_result(data[key]) + return future + + return _mock + + def test_ablp_even_split_2_servers_2_compute(self): + """ABLP CONTIGUOUS: 2 servers, 2 compute nodes — each gets one server fully.""" + neg_0: Optional[torch.Tensor] = torch.tensor([[4], [5], [6]]) + neg_1: Optional[torch.Tensor] = torch.tensor([[14], [15], [16]]) + server_data = { + 0: { + "train": ( + torch.tensor([0, 1, 2]), + torch.tensor([[0, 1], [1, 2], [2, 3]]), + neg_0, + ), + }, + 1: { + "train": ( + torch.tensor([10, 11, 12]), + torch.tensor([[10, 11], [11, 12], [12, 13]]), + neg_1, + ), + }, + } + mock_fn = self._make_rank_aware_ablp_async_mock(server_data) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=mock_fn, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=self._mock_request_server_homogeneous, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0 gets server 0 fully, server 1 empty + result_0 = remote_dataset.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0_s0 = result_0[0] + self.assert_tensor_equality(ablp_0_s0.anchor_nodes, torch.tensor([0, 1, 2])) + pos_0, neg_0 = ablp_0_s0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos_0, torch.tensor([[0, 1], [1, 2], [2, 3]])) + assert neg_0 is not None + self.assert_tensor_equality(neg_0, torch.tensor([[4], [5], [6]])) + # Server 1 should be empty for rank 0 + ablp_0_s1 = result_0[1] + self.assertEqual(len(ablp_0_s1.anchor_nodes), 0) + + # Rank 1 gets server 0 empty, server 1 fully + result_1 = remote_dataset.fetch_ablp_input( + split="train", + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_1_s0 = result_1[0] + self.assertEqual(len(ablp_1_s0.anchor_nodes), 0) + ablp_1_s1 = result_1[1] + self.assert_tensor_equality( + ablp_1_s1.anchor_nodes, torch.tensor([10, 11, 12]) + ) + pos_1, neg_1 = ablp_1_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos_1, torch.tensor([[10, 11], [11, 12], [12, 13]]) + ) + assert neg_1 is not None + self.assert_tensor_equality(neg_1, torch.tensor([[14], [15], [16]])) + + def test_ablp_fractional_split_3_servers_2_compute(self): + """ABLP CONTIGUOUS: 3 servers, 2 compute nodes — server 1 split at boundary.""" + # Each server has 4 anchors with 2D positive labels and 2D negative labels + neg_s0: Optional[torch.Tensor] = torch.tensor([[10], [11], [12], [13]]) + neg_s1: Optional[torch.Tensor] = torch.tensor([[20], [21], [22], [23]]) + neg_s2: Optional[torch.Tensor] = torch.tensor([[30], [31], [32], [33]]) + server_data = { + 0: { + "train": ( + torch.tensor([0, 1, 2, 3]), + torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), + neg_s0, + ), + }, + 1: { + "train": ( + torch.tensor([10, 11, 12, 13]), + torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), + neg_s1, + ), + }, + 2: { + "train": ( + torch.tensor([20, 21, 22, 23]), + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + neg_s2, + ), + }, + } + mock_fn = self._make_rank_aware_ablp_async_mock(server_data) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) + + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=mock_fn, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=self._mock_request_server_homogeneous, + ), + ): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: server 0 fully, server 1 first half (2 of 4), server 2 empty + result_0 = remote_dataset.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0_s0 = result_0[0] + self.assert_tensor_equality( + ablp_0_s0.anchor_nodes, torch.tensor([0, 1, 2, 3]) + ) + # Server 1: 4 * 0 // 2 = 0, 4 * 1 // 2 = 2 → first 2 + ablp_0_s1 = result_0[1] + self.assert_tensor_equality(ablp_0_s1.anchor_nodes, torch.tensor([10, 11])) + pos_0_s1, neg_0_s1 = ablp_0_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos_0_s1, torch.tensor([[10, 11], [11, 12]])) + assert neg_0_s1 is not None + self.assert_tensor_equality(neg_0_s1, torch.tensor([[20], [21]])) + ablp_0_s2 = result_0[2] + self.assertEqual(len(ablp_0_s2.anchor_nodes), 0) + + # Rank 1: server 0 empty, server 1 second half (2 of 4), server 2 fully + result_1 = remote_dataset.fetch_ablp_input( + split="train", + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_1_s0 = result_1[0] + self.assertEqual(len(ablp_1_s0.anchor_nodes), 0) + # Server 1: 4 * 1 // 2 = 2, 4 * 2 // 2 = 4 → last 2 + ablp_1_s1 = result_1[1] + self.assert_tensor_equality(ablp_1_s1.anchor_nodes, torch.tensor([12, 13])) + pos_1_s1, neg_1_s1 = ablp_1_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos_1_s1, torch.tensor([[12, 13], [13, 14]])) + assert neg_1_s1 is not None + self.assert_tensor_equality(neg_1_s1, torch.tensor([[22], [23]])) + ablp_1_s2 = result_1[2] + self.assert_tensor_equality( + ablp_1_s2.anchor_nodes, torch.tensor([20, 21, 22, 23]) + ) + + def test_ablp_contiguous_requires_rank_and_world_size(self): + """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" + cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + rank=0, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + def _test_fetch_free_ports_on_storage_cluster( rank: int, world_size: int, diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index 603b2dadb..5fa281f69 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -7,6 +7,9 @@ from torch_geometric.typing import EdgeType from gigl.distributed.utils.neighborloader import ( + ServerSlice, + ShardStrategy, + compute_server_assignments, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -490,5 +493,245 @@ def test_set_custom_features_heterogeneous(self): ) +class TestShardStrategy(TestCase): + """Tests for ShardStrategy enum values.""" + + def test_enum_values(self): + self.assertEqual(ShardStrategy.ROUND_ROBIN.value, "round_robin") + self.assertEqual(ShardStrategy.CONTIGUOUS.value, "contiguous") + + +class TestComputeServerAssignments(TestCase): + """Tests for compute_server_assignments and ServerSlice.""" + + def test_even_split_4_servers_2_compute(self): + """4 servers, 2 compute nodes: each gets 2 full servers.""" + # Rank 0: servers 0, 1 + assignments_0 = compute_server_assignments( + num_servers=4, num_compute_nodes=2, compute_rank=0 + ) + self.assertEqual(set(assignments_0.keys()), {0, 1}) + self.assertEqual( + assignments_0[0], + ServerSlice(server_rank=0, start_num=0, start_den=2, end_num=2, end_den=2), + ) + self.assertEqual( + assignments_0[1], + ServerSlice(server_rank=1, start_num=0, start_den=2, end_num=2, end_den=2), + ) + + # Rank 1: servers 2, 3 + assignments_1 = compute_server_assignments( + num_servers=4, num_compute_nodes=2, compute_rank=1 + ) + self.assertEqual(set(assignments_1.keys()), {2, 3}) + + def test_fractional_split_3_servers_2_compute(self): + """3 servers, 2 compute nodes: server 1 is split at boundary.""" + # Rank 0: [0, 1.5) → server 0 fully, server 1 first half + assignments_0 = compute_server_assignments( + num_servers=3, num_compute_nodes=2, compute_rank=0 + ) + self.assertEqual(set(assignments_0.keys()), {0, 1}) + # Server 0: full (start_num=0, end_num=2, den=2) + self.assertEqual(assignments_0[0].start_num, 0) + self.assertEqual(assignments_0[0].end_num, 2) + # Server 1: first half (start_num=0, end_num=1, den=2) + self.assertEqual(assignments_0[1].start_num, 0) + self.assertEqual(assignments_0[1].end_num, 1) + + # Rank 1: [1.5, 3) → server 1 second half, server 2 fully + assignments_1 = compute_server_assignments( + num_servers=3, num_compute_nodes=2, compute_rank=1 + ) + self.assertEqual(set(assignments_1.keys()), {1, 2}) + # Server 1: second half (start_num=1, end_num=2, den=2) + self.assertEqual(assignments_1[1].start_num, 1) + self.assertEqual(assignments_1[1].end_num, 2) + # Server 2: full + self.assertEqual(assignments_1[2].start_num, 0) + self.assertEqual(assignments_1[2].end_num, 2) + + def test_1_server_2_compute(self): + """1 server, 2 compute nodes: both share one server.""" + assignments_0 = compute_server_assignments( + num_servers=1, num_compute_nodes=2, compute_rank=0 + ) + self.assertEqual(set(assignments_0.keys()), {0}) + self.assertEqual(assignments_0[0].start_num, 0) + self.assertEqual(assignments_0[0].end_num, 1) + self.assertEqual(assignments_0[0].start_den, 2) + + assignments_1 = compute_server_assignments( + num_servers=1, num_compute_nodes=2, compute_rank=1 + ) + self.assertEqual(set(assignments_1.keys()), {0}) + self.assertEqual(assignments_1[0].start_num, 1) + self.assertEqual(assignments_1[0].end_num, 2) + + def test_more_compute_than_servers(self): + """2 servers, 5 compute nodes: some compute nodes share a server.""" + all_assignments: list[dict[int, ServerSlice]] = [] + for rank in range(5): + assignments = compute_server_assignments( + num_servers=2, num_compute_nodes=5, compute_rank=rank + ) + all_assignments.append(assignments) + # Each rank should have at most 2 servers + self.assertLessEqual(len(assignments), 2) + + # Verify recombination invariant for both servers + for server in range(2): + tensor = torch.arange(100) + slices: list[torch.Tensor] = [] + for rank_assignments in all_assignments: + if server in rank_assignments: + slices.append(rank_assignments[server].slice_tensor(tensor)) + combined = torch.cat(slices) + self.assert_tensor_equality(combined, tensor) + + def test_single_compute_gets_all_servers(self): + """1 compute node should get all servers fully.""" + assignments = compute_server_assignments( + num_servers=3, num_compute_nodes=1, compute_rank=0 + ) + self.assertEqual(set(assignments.keys()), {0, 1, 2}) + for s in range(3): + self.assertEqual(assignments[s].start_num, 0) + self.assertEqual(assignments[s].end_num, 1) + self.assertEqual(assignments[s].end_den, 1) + + def test_recombination_invariant_even(self): + """Concatenating all ranks' slices for a server reproduces the original tensor.""" + tensor = torch.arange(20) + for server in range(4): + slices: list[torch.Tensor] = [] + for rank in range(2): + assignments = compute_server_assignments( + num_servers=4, num_compute_nodes=2, compute_rank=rank + ) + if server in assignments: + slices.append(assignments[server].slice_tensor(tensor)) + combined = torch.cat(slices) if slices else torch.empty(0, dtype=torch.long) + # Each server is fully owned by exactly one rank in the even case + if slices: + self.assert_tensor_equality(combined, tensor) + + def test_recombination_invariant_fractional(self): + """Fractional split: concatenating all ranks' slices reproduces the original tensor.""" + tensor = torch.arange(10) + for server in range(3): + slices: list[torch.Tensor] = [] + for rank in range(2): + assignments = compute_server_assignments( + num_servers=3, num_compute_nodes=2, compute_rank=rank + ) + if server in assignments: + slices.append(assignments[server].slice_tensor(tensor)) + combined = torch.cat(slices) + self.assert_tensor_equality(combined, tensor) + + def test_validation_negative_servers(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=-1, num_compute_nodes=2, compute_rank=0 + ) + + def test_validation_zero_servers(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=0, num_compute_nodes=2, compute_rank=0 + ) + + def test_validation_negative_compute_nodes(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=2, num_compute_nodes=-1, compute_rank=0 + ) + + def test_validation_zero_compute_nodes(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=2, num_compute_nodes=0, compute_rank=0 + ) + + def test_validation_rank_too_large(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=2, num_compute_nodes=2, compute_rank=2 + ) + + def test_validation_negative_rank(self): + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=2, num_compute_nodes=2, compute_rank=-1 + ) + + +class TestServerSlice(TestCase): + """Tests for ServerSlice.slice_tensor.""" + + def test_full_tensor_no_clone(self): + """Full tensor (start=0, end=total) returns the same object, no clone.""" + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=1, end_num=1, end_den=1 + ) + result = server_slice.slice_tensor(tensor) + self.assertTrue(result.data_ptr() == tensor.data_ptr()) + + def test_partial_slice_clones(self): + """Partial slice returns a clone (different data_ptr).""" + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 + ) + result = server_slice.slice_tensor(tensor) + self.assert_tensor_equality(result, torch.arange(5)) + self.assertNotEqual(result.data_ptr(), tensor.data_ptr()) + + def test_second_half_slice(self): + """Second half slice works correctly.""" + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=1, start_den=2, end_num=2, end_den=2 + ) + result = server_slice.slice_tensor(tensor) + self.assert_tensor_equality(result, torch.arange(5, 10)) + + def test_empty_tensor(self): + """Slicing an empty tensor returns an empty tensor.""" + tensor = torch.empty(0, dtype=torch.long) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 + ) + result = server_slice.slice_tensor(tensor) + self.assertEqual(len(result), 0) + + def test_odd_sized_tensor_fractional(self): + """Odd-sized tensor with fractional split uses integer division correctly.""" + tensor = torch.arange(7) # 7 elements + # First half: 7 * 0 // 2 = 0, 7 * 1 // 2 = 3 + first_half = ServerSlice( + server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 + ) + # Second half: 7 * 1 // 2 = 3, 7 * 2 // 2 = 7 + second_half = ServerSlice( + server_rank=0, start_num=1, start_den=2, end_num=2, end_den=2 + ) + self.assert_tensor_equality(first_half.slice_tensor(tensor), torch.arange(3)) + self.assert_tensor_equality( + second_half.slice_tensor(tensor), torch.arange(3, 7) + ) + # Recombination + combined = torch.cat( + [ + first_half.slice_tensor(tensor), + second_half.slice_tensor(tensor), + ] + ) + self.assert_tensor_equality(combined, tensor) + + if __name__ == "__main__": absltest.main() From b437700620cf43a99e6f2daf7c05e1508523aec1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 23 Mar 2026 19:26:40 +0000 Subject: [PATCH 02/13] test --- .../distributed/utils/neighborloader_test.py | 322 ++++++------------ 1 file changed, 113 insertions(+), 209 deletions(-) diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index dc7d8f4b0..787ec68f8 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -12,7 +12,6 @@ ) from gigl.distributed.utils.neighborloader import ( ServerSlice, - ShardStrategy, compute_server_assignments, extract_edge_type_metadata, extract_metadata, @@ -499,195 +498,140 @@ def test_set_custom_features_heterogeneous(self): ) -class TestShardStrategy(TestCase): - """Tests for ShardStrategy enum values.""" - - def test_enum_values(self): - self.assertEqual(ShardStrategy.ROUND_ROBIN.value, "round_robin") - self.assertEqual(ShardStrategy.CONTIGUOUS.value, "contiguous") - - class TestComputeServerAssignments(TestCase): - """Tests for compute_server_assignments and ServerSlice.""" - - def test_even_split_4_servers_2_compute(self): - """4 servers, 2 compute nodes: each gets 2 full servers.""" - # Rank 0: servers 0, 1 - assignments_0 = compute_server_assignments( - num_servers=4, num_compute_nodes=2, compute_rank=0 - ) - self.assertEqual(set(assignments_0.keys()), {0, 1}) - self.assertEqual( - assignments_0[0], - ServerSlice(server_rank=0, start_num=0, start_den=2, end_num=2, end_den=2), - ) - self.assertEqual( - assignments_0[1], - ServerSlice(server_rank=1, start_num=0, start_den=2, end_num=2, end_den=2), - ) - - # Rank 1: servers 2, 3 - assignments_1 = compute_server_assignments( - num_servers=4, num_compute_nodes=2, compute_rank=1 - ) - self.assertEqual(set(assignments_1.keys()), {2, 3}) - - def test_fractional_split_3_servers_2_compute(self): - """3 servers, 2 compute nodes: server 1 is split at boundary.""" - # Rank 0: [0, 1.5) → server 0 fully, server 1 first half - assignments_0 = compute_server_assignments( - num_servers=3, num_compute_nodes=2, compute_rank=0 - ) - self.assertEqual(set(assignments_0.keys()), {0, 1}) - # Server 0: full (start_num=0, end_num=2, den=2) - self.assertEqual(assignments_0[0].start_num, 0) - self.assertEqual(assignments_0[0].end_num, 2) - # Server 1: first half (start_num=0, end_num=1, den=2) - self.assertEqual(assignments_0[1].start_num, 0) - self.assertEqual(assignments_0[1].end_num, 1) - - # Rank 1: [1.5, 3) → server 1 second half, server 2 fully - assignments_1 = compute_server_assignments( - num_servers=3, num_compute_nodes=2, compute_rank=1 - ) - self.assertEqual(set(assignments_1.keys()), {1, 2}) - # Server 1: second half (start_num=1, end_num=2, den=2) - self.assertEqual(assignments_1[1].start_num, 1) - self.assertEqual(assignments_1[1].end_num, 2) - # Server 2: full - self.assertEqual(assignments_1[2].start_num, 0) - self.assertEqual(assignments_1[2].end_num, 2) - - def test_1_server_2_compute(self): - """1 server, 2 compute nodes: both share one server.""" - assignments_0 = compute_server_assignments( - num_servers=1, num_compute_nodes=2, compute_rank=0 - ) - self.assertEqual(set(assignments_0.keys()), {0}) - self.assertEqual(assignments_0[0].start_num, 0) - self.assertEqual(assignments_0[0].end_num, 1) - self.assertEqual(assignments_0[0].start_den, 2) - - assignments_1 = compute_server_assignments( - num_servers=1, num_compute_nodes=2, compute_rank=1 - ) - self.assertEqual(set(assignments_1.keys()), {0}) - self.assertEqual(assignments_1[0].start_num, 1) - self.assertEqual(assignments_1[0].end_num, 2) - - def test_more_compute_than_servers(self): - """2 servers, 5 compute nodes: some compute nodes share a server.""" - all_assignments: list[dict[int, ServerSlice]] = [] - for rank in range(5): - assignments = compute_server_assignments( - num_servers=2, num_compute_nodes=5, compute_rank=rank - ) - all_assignments.append(assignments) - # Each rank should have at most 2 servers - self.assertLessEqual(len(assignments), 2) - - # Verify recombination invariant for both servers - for server in range(2): - tensor = torch.arange(100) - slices: list[torch.Tensor] = [] - for rank_assignments in all_assignments: - if server in rank_assignments: - slices.append(rank_assignments[server].slice_tensor(tensor)) - combined = torch.cat(slices) - self.assert_tensor_equality(combined, tensor) - - def test_single_compute_gets_all_servers(self): - """1 compute node should get all servers fully.""" + @parameterized.expand( + [ + param( + "rank_0", + compute_rank=0, + expected_assignments={ + 0: ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + 1: ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + }, + ), + param( + "rank_1", + compute_rank=1, + expected_assignments={ + 1: ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + 2: ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + }, + ), + ] + ) + def test_fractional_boundary_assignment( + self, _, compute_rank: int, expected_assignments: dict[int, ServerSlice] + ) -> None: assignments = compute_server_assignments( - num_servers=3, num_compute_nodes=1, compute_rank=0 - ) - self.assertEqual(set(assignments.keys()), {0, 1, 2}) - for s in range(3): - self.assertEqual(assignments[s].start_num, 0) - self.assertEqual(assignments[s].end_num, 1) - self.assertEqual(assignments[s].end_den, 1) - - def test_recombination_invariant_even(self): - """Concatenating all ranks' slices for a server reproduces the original tensor.""" - tensor = torch.arange(20) - for server in range(4): - slices: list[torch.Tensor] = [] - for rank in range(2): - assignments = compute_server_assignments( - num_servers=4, num_compute_nodes=2, compute_rank=rank - ) - if server in assignments: - slices.append(assignments[server].slice_tensor(tensor)) - combined = torch.cat(slices) if slices else torch.empty(0, dtype=torch.long) - # Each server is fully owned by exactly one rank in the even case - if slices: - self.assert_tensor_equality(combined, tensor) - - def test_recombination_invariant_fractional(self): - """Fractional split: concatenating all ranks' slices reproduces the original tensor.""" - tensor = torch.arange(10) - for server in range(3): - slices: list[torch.Tensor] = [] - for rank in range(2): - assignments = compute_server_assignments( - num_servers=3, num_compute_nodes=2, compute_rank=rank - ) - if server in assignments: - slices.append(assignments[server].slice_tensor(tensor)) - combined = torch.cat(slices) - self.assert_tensor_equality(combined, tensor) - - def test_validation_negative_servers(self): - with self.assertRaises(ValueError): - compute_server_assignments( - num_servers=-1, num_compute_nodes=2, compute_rank=0 - ) + num_servers=3, num_compute_nodes=2, compute_rank=compute_rank + ) - def test_validation_zero_servers(self): - with self.assertRaises(ValueError): - compute_server_assignments( - num_servers=0, num_compute_nodes=2, compute_rank=0 - ) + self.assertEqual(assignments, expected_assignments) - def test_validation_negative_compute_nodes(self): - with self.assertRaises(ValueError): - compute_server_assignments( - num_servers=2, num_compute_nodes=-1, compute_rank=0 - ) - - def test_validation_zero_compute_nodes(self): - with self.assertRaises(ValueError): + def test_assignments_recombine_server_data(self) -> None: + tensor = torch.arange(7) + all_assignments = [ compute_server_assignments( - num_servers=2, num_compute_nodes=0, compute_rank=0 + num_servers=2, num_compute_nodes=5, compute_rank=rank ) + for rank in range(5) + ] - def test_validation_rank_too_large(self): - with self.assertRaises(ValueError): - compute_server_assignments( - num_servers=2, num_compute_nodes=2, compute_rank=2 + for server_rank in range(2): + combined = torch.cat( + [ + assignments[server_rank].slice_tensor(tensor) + for assignments in all_assignments + if server_rank in assignments + ] ) + self.assert_tensor_equality(combined, tensor) - def test_validation_negative_rank(self): + @parameterized.expand( + [ + param( + "negative_servers", + num_servers=-1, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "zero_servers", + num_servers=0, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "negative_compute_nodes", + num_servers=2, + num_compute_nodes=-1, + compute_rank=0, + ), + param( + "zero_compute_nodes", + num_servers=2, + num_compute_nodes=0, + compute_rank=0, + ), + param( + "rank_too_large", + num_servers=2, + num_compute_nodes=2, + compute_rank=2, + ), + param( + "negative_rank", + num_servers=2, + num_compute_nodes=2, + compute_rank=-1, + ), + ] + ) + def test_validates_arguments( + self, _, num_servers: int, num_compute_nodes: int, compute_rank: int + ) -> None: with self.assertRaises(ValueError): compute_server_assignments( - num_servers=2, num_compute_nodes=2, compute_rank=-1 + num_servers=num_servers, + num_compute_nodes=num_compute_nodes, + compute_rank=compute_rank, ) class TestServerSlice(TestCase): - """Tests for ServerSlice.slice_tensor.""" - - def test_full_tensor_no_clone(self): - """Full tensor (start=0, end=total) returns the same object, no clone.""" + def test_full_tensor_returns_same_object(self) -> None: tensor = torch.arange(10) server_slice = ServerSlice( server_rank=0, start_num=0, start_den=1, end_num=1, end_den=1 ) result = server_slice.slice_tensor(tensor) - self.assertTrue(result.data_ptr() == tensor.data_ptr()) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) - def test_partial_slice_clones(self): - """Partial slice returns a clone (different data_ptr).""" + def test_partial_slice_clones_requested_range(self) -> None: tensor = torch.arange(10) server_slice = ServerSlice( server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 @@ -696,47 +640,7 @@ def test_partial_slice_clones(self): self.assert_tensor_equality(result, torch.arange(5)) self.assertNotEqual(result.data_ptr(), tensor.data_ptr()) - def test_second_half_slice(self): - """Second half slice works correctly.""" - tensor = torch.arange(10) - server_slice = ServerSlice( - server_rank=0, start_num=1, start_den=2, end_num=2, end_den=2 - ) - result = server_slice.slice_tensor(tensor) - self.assert_tensor_equality(result, torch.arange(5, 10)) - - def test_empty_tensor(self): - """Slicing an empty tensor returns an empty tensor.""" - tensor = torch.empty(0, dtype=torch.long) - server_slice = ServerSlice( - server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 - ) - result = server_slice.slice_tensor(tensor) - self.assertEqual(len(result), 0) - def test_odd_sized_tensor_fractional(self): - """Odd-sized tensor with fractional split uses integer division correctly.""" - tensor = torch.arange(7) # 7 elements - # First half: 7 * 0 // 2 = 0, 7 * 1 // 2 = 3 - first_half = ServerSlice( - server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 - ) - # Second half: 7 * 1 // 2 = 3, 7 * 2 // 2 = 7 - second_half = ServerSlice( - server_rank=0, start_num=1, start_den=2, end_num=2, end_den=2 - ) - self.assert_tensor_equality(first_half.slice_tensor(tensor), torch.arange(3)) - self.assert_tensor_equality( - second_half.slice_tensor(tensor), torch.arange(3, 7) - ) - # Recombination - combined = torch.cat( - [ - first_half.slice_tensor(tensor), - second_half.slice_tensor(tensor), - ] - ) - self.assert_tensor_equality(combined, tensor) class ExtractMetadataTest(TestCase): def setUp(self): self._device = torch.device("cpu") From 9dc44a4b8b76a53cb299dbb80565a13aa7ec4074 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 23 Mar 2026 21:11:14 +0000 Subject: [PATCH 03/13] update --- .../graph_store/remote_dist_dataset_test.py | 468 +++++++----------- 1 file changed, 171 insertions(+), 297 deletions(-) diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index 7b0af4d20..d6eda651c 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -1,4 +1,5 @@ -from typing import Optional +from contextlib import contextmanager +from typing import Final, Optional from unittest.mock import patch import torch @@ -12,6 +13,7 @@ from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ShardStrategy from gigl.env.distributed import GraphStoreInfo +from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, @@ -39,6 +41,25 @@ # Module-level test server instance used by mock functions _test_server: Optional[DistServer] = None +# Shared ABLP label data used by multiple test classes +_DEFAULT_POSITIVE_LABELS: Final[dict[int, list[int]]] = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], +} +_DEFAULT_NEGATIVE_LABELS: Final[dict[int, list[int]]] = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], +} +_DEFAULT_TRAIN_IDS: Final[list[int]] = [0, 1, 2] +_DEFAULT_VAL_IDS: Final[list[int]] = [3] +_DEFAULT_TEST_IDS: Final[list[int]] = [4] + def _mock_request_server(server_rank, func, *args, **kwargs): """Mock request_server that routes through _call_func_on_server.""" @@ -76,7 +97,76 @@ def _create_mock_graph_store_info( return MockGraphStoreInfo(real_info, compute_node_rank) -class TestRemoteDistDataset(TestCase): +@contextmanager +def _patch_remote_requests(async_side_effect, sync_side_effect): + """Context manager that patches both async_request_server and request_server.""" + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=async_side_effect, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=sync_side_effect, + ), + ): + yield + + +def _create_server_with_splits( + edge_indices: Optional[dict] = None, + src_node_type: Optional[NodeType] = None, + dst_node_type: Optional[NodeType] = None, + supervision_edge_type: Optional[EdgeType] = None, +) -> None: + """Create a DistServer with a dataset that has train/val/test splits. + + Args: + edge_indices: Edge indices to use. Defaults to DEFAULT_HETEROGENEOUS_EDGE_INDICES. + src_node_type: Source node type for labeled homogeneous datasets. + dst_node_type: Destination node type for labeled homogeneous datasets. + supervision_edge_type: Supervision edge type for labeled homogeneous datasets. + """ + global _test_server + create_test_process_group() + + kwargs: dict = {} + if src_node_type is not None: + kwargs.update( + src_node_type=src_node_type, + dst_node_type=dst_node_type, + supervision_edge_type=supervision_edge_type, + ) + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels=_DEFAULT_POSITIVE_LABELS, + negative_labels=_DEFAULT_NEGATIVE_LABELS, + train_node_ids=_DEFAULT_TRAIN_IDS, + val_node_ids=_DEFAULT_VAL_IDS, + test_node_ids=_DEFAULT_TEST_IDS, + edge_indices=edge_indices or DEFAULT_HETEROGENEOUS_EDGE_INDICES, + **kwargs, + ) + _test_server = DistServer(dataset) + dist_server_module._dist_server = _test_server + + +class RemoteDistDatasetTestBase(TestCase): + """Shared tearDown for all RemoteDistDataset test classes.""" + + def tearDown(self) -> None: + global _test_server + _test_server = None + dist_server_module._dist_server = None + if dist.is_initialized(): + dist.destroy_process_group() + + +@patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, +) +class TestRemoteDistDataset(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server # 10 nodes in DEFAULT_HOMOGENEOUS_EDGE_INDEX ring graph @@ -88,17 +178,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_graph_metadata_getters_homogeneous(self, mock_request): """Test fetch_node_feature_info, fetch_edge_feature_info, fetch_edge_dir, fetch_edge_types, fetch_node_types for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -113,7 +192,7 @@ def test_graph_metadata_getters_homogeneous(self, mock_request): self.assertIsNone(remote_dataset.fetch_edge_types()) self.assertIsNone(remote_dataset.fetch_node_types()) - def test_init_rejects_non_dict_proxy_for_mp_sharing_dict(self): + def test_init_rejects_non_dict_proxy_for_mp_sharing_dict(self, mock_request): cluster_info = _create_mock_graph_store_info() with self.assertRaises(ValueError): @@ -123,7 +202,7 @@ def test_init_rejects_non_dict_proxy_for_mp_sharing_dict(self): mp_sharing_dict=dict(), # Regular dict should fail ) - def test_init_rejects_non_barrier_for_mp_barrier(self): + def test_init_rejects_non_barrier_for_mp_barrier(self, mock_request): cluster_info = _create_mock_graph_store_info() with self.assertRaises(ValueError): @@ -133,7 +212,7 @@ def test_init_rejects_non_barrier_for_mp_barrier(self): mp_sharing_dict=mp.Manager().dict(), ) - def test_cluster_info_property(self): + def test_cluster_info_property(self, mock_request): cluster_info = _create_mock_graph_store_info( num_storage_nodes=3, num_compute_nodes=2 ) @@ -146,10 +225,6 @@ def test_cluster_info_property(self): "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_ids(self, mock_request, mock_async_request): """Test fetch_node_ids returns node ids, with optional sharding via rank/world_size.""" cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) @@ -168,10 +243,6 @@ def test_fetch_node_ids(self, mock_request, mock_async_request): result = remote_dataset.fetch_node_ids(rank=1, world_size=2) self.assert_tensor_equality(result[0], torch.arange(5, 10)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_homogeneous(self, mock_request): """Test fetch_node_partition_book returns the tensor partition book for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -183,10 +254,6 @@ def test_fetch_node_partition_book_homogeneous(self, mock_request): self.assertEqual(result.shape[0], 10) self.assert_tensor_equality(result, torch.zeros(10, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_homogeneous(self, mock_request): """Test fetch_edge_partition_book returns the tensor partition book for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -198,10 +265,6 @@ def test_fetch_edge_partition_book_homogeneous(self, mock_request): self.assertEqual(result.shape[0], 10) self.assert_tensor_equality(result, torch.zeros(10, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_homogeneous_rejects_node_type( self, mock_request ): @@ -213,7 +276,11 @@ def test_fetch_node_partition_book_homogeneous_rejects_node_type( remote_dataset.fetch_node_partition_book(node_type=USER) -class TestRemoteDistDatasetHeterogeneous(TestCase): +@patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, +) +class TestRemoteDistDatasetHeterogeneous(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server # 5 users, 5 stories in DEFAULT_HETEROGENEOUS_EDGE_INDICES @@ -228,17 +295,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_graph_metadata_getters_heterogeneous(self, mock_request): """Test fetch_node_feature_info, fetch_edge_dir, fetch_edge_types, fetch_node_types for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -264,7 +320,7 @@ def test_graph_metadata_getters_heterogeneous(self, mock_request): "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) - def test_fetch_node_ids_with_node_type(self, mock_async_request): + def test_fetch_node_ids_with_node_type(self, mock_request, mock_async_request): """Test fetch_node_ids with node_type for heterogeneous graphs, with optional sharding.""" cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -285,10 +341,6 @@ def test_fetch_node_ids_with_node_type(self, mock_async_request): result = remote_dataset.fetch_node_ids(rank=1, world_size=2, node_type=USER) self.assert_tensor_equality(result[0], torch.arange(2, 5)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_heterogeneous(self, mock_request): """Test fetch_node_partition_book returns per-type partition books for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -306,10 +358,6 @@ def test_fetch_node_partition_book_heterogeneous(self, mock_request): self.assertEqual(story_pb.shape[0], 5) self.assert_tensor_equality(story_pb, torch.zeros(5, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_heterogeneous(self, mock_request): """Test fetch_edge_partition_book returns per-type partition books for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -328,10 +376,6 @@ def test_fetch_edge_partition_book_heterogeneous(self, mock_request): ), ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_heterogeneous_requires_node_type( self, mock_request ): @@ -342,10 +386,6 @@ def test_fetch_node_partition_book_heterogeneous_requires_node_type( with self.assertRaises(ValueError): remote_dataset.fetch_node_partition_book() - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_heterogeneous_requires_edge_type( self, mock_request ): @@ -357,54 +397,16 @@ def test_fetch_edge_partition_book_heterogeneous_requires_edge_type( remote_dataset.fetch_edge_partition_book() -class TestRemoteDistDatasetWithSplits(TestCase): +class TestRemoteDistDatasetWithSplits(RemoteDistDatasetTestBase): """Tests for fetch_node_ids with train/val/test splits.""" - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - def _create_server_with_splits(self) -> None: - """Create a DistServer with a dataset that has train/val/test splits.""" - global _test_server - create_test_process_group() - - positive_labels = { - 0: [0, 1], - 1: [1, 2], - 2: [2, 3], - 3: [3, 4], - 4: [4, 0], - } - negative_labels = { - 0: [2], - 1: [3], - 2: [4], - 3: [0], - 4: [1], - } - - dataset = create_heterogeneous_dataset_for_ablp( - positive_labels=positive_labels, - negative_labels=negative_labels, - train_node_ids=[0, 1, 2], - val_node_ids=[3], - test_node_ids=[4], - edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, - ) - _test_server = DistServer(dataset) - dist_server_module._dist_server = _test_server - @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) def test_fetch_node_ids_with_splits(self, mock_async_request): """Test fetch_node_ids with train/val/test splits and optional sharding.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -449,7 +451,7 @@ def test_fetch_node_ids_with_splits(self, mock_async_request): ) def test_fetch_ablp_input(self, mock_async_request): """Test fetch_ablp_input with train/val/test splits.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -517,7 +519,7 @@ def test_fetch_ablp_input(self, mock_async_request): ) def test_fetch_ablp_input_with_sharding(self, mock_async_request): """Test fetch_ablp_input with sharding across compute nodes.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -571,7 +573,7 @@ def test_fetch_ablp_input_with_sharding(self, mock_async_request): ) -class TestRemoteDistDatasetLabeledHomogeneous(TestCase): +class TestRemoteDistDatasetLabeledHomogeneous(RemoteDistDatasetTestBase): """Tests for datasets using DEFAULT_HOMOGENEOUS_NODE_TYPE / DEFAULT_HOMOGENEOUS_EDGE_TYPE. A 'labeled homogeneous' dataset is stored internally as heterogeneous @@ -580,50 +582,19 @@ class TestRemoteDistDatasetLabeledHomogeneous(TestCase): callers do not need to supply them explicitly. """ - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - def _create_server_with_labeled_homogeneous_splits(self) -> None: - global _test_server - create_test_process_group() - - positive_labels = { - 0: [0, 1], - 1: [1, 2], - 2: [2, 3], - 3: [3, 4], - 4: [4, 0], - } - negative_labels = { - 0: [2], - 1: [3], - 2: [4], - 3: [0], - 4: [1], - } - edge_indices = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] - ) - } + _LABELED_HOMOGENEOUS_EDGE_INDICES: Final = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + ) + } - dataset = create_heterogeneous_dataset_for_ablp( - positive_labels=positive_labels, - negative_labels=negative_labels, - train_node_ids=[0, 1, 2], - val_node_ids=[3], - test_node_ids=[4], - edge_indices=edge_indices, + def _create_labeled_homogeneous_server(self) -> None: + _create_server_with_splits( + edge_indices=self._LABELED_HOMOGENEOUS_EDGE_INDICES, src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, ) - _test_server = DistServer(dataset) - dist_server_module._dist_server = _test_server @patch( "gigl.distributed.graph_store.remote_dist_dataset.request_server", @@ -631,7 +602,7 @@ def _create_server_with_labeled_homogeneous_splits(self) -> None: ) def test_fetch_node_types_labeled_homogeneous(self, mock_request): """Test fetch_node_types returns DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -639,35 +610,29 @@ def test_fetch_node_types_labeled_homogeneous(self, mock_request): self.assertIsNotNone(node_types) self.assertIn(DEFAULT_HOMOGENEOUS_NODE_TYPE, node_types) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=_mock_async_request_server, - ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) - def test_fetch_node_ids_auto_detects_default_node_type( - self, mock_request, mock_async_request - ): + def test_fetch_node_ids_auto_detects_default_node_type(self): """Test fetch_node_ids without node_type auto-detects DEFAULT_HOMOGENEOUS_NODE_TYPE.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - # No node_type provided: _fetch_node_ids should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="train")[0], - torch.tensor([0, 1, 2]), - ) - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="val")[0], - torch.tensor([3]), - ) - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="test")[0], - torch.tensor([4]), - ) + with _patch_remote_requests(_mock_async_request_server, _mock_request_server): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) + + # No node_type provided: _fetch_node_ids should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="train")[0], + torch.tensor([0, 1, 2]), + ) + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="val")[0], + torch.tensor([3]), + ) + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="test")[0], + torch.tensor([4]), + ) @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", @@ -675,7 +640,7 @@ def test_fetch_node_ids_auto_detects_default_node_type( ) def test_fetch_ablp_input_defaults_to_homogeneous_types(self, mock_async_request): """Test fetch_ablp_input without anchor_node_type/supervision_edge_type uses homogeneous defaults.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -706,7 +671,7 @@ def test_fetch_node_partition_book_auto_infers_default_node_type( self, mock_request ): """Test fetch_node_partition_book auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE when None.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -725,7 +690,7 @@ def test_fetch_edge_partition_book_auto_infers_default_edge_type( self, mock_request ): """Test fetch_edge_partition_book auto-infers DEFAULT_HOMOGENEOUS_EDGE_TYPE when None.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -756,16 +721,9 @@ def test_fetch_ablp_input_mismatched_params_raises(self): ) -class TestRemoteDistDatasetContiguous(TestCase): +class TestRemoteDistDatasetContiguous(RemoteDistDatasetTestBase): """Tests for fetch_node_ids and fetch_ablp_input with ShardStrategy.CONTIGUOUS.""" - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - def _make_rank_aware_async_mock( self, server_data: dict[int, dict[str, torch.Tensor]] ): @@ -803,20 +761,12 @@ def test_even_split_2_servers_2_compute(self): 1: {"all": torch.arange(10, 20)}, } mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=mock_fn, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=self._mock_request_server_homogeneous, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) # Rank 0 gets server 0 fully, server 1 empty result_0 = remote_dataset.fetch_node_ids( @@ -844,20 +794,12 @@ def test_fractional_split_3_servers_2_compute(self): 2: {"all": torch.arange(20, 30)}, } mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=mock_fn, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=self._mock_request_server_homogeneous, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) # Rank 0: server 0 fully, server 1 first half, server 2 empty result_0 = remote_dataset.fetch_node_ids( @@ -888,20 +830,12 @@ def test_with_split_filtering(self): 1: {"all": torch.arange(10, 20), "train": torch.tensor([10, 11, 12, 13])}, } mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=mock_fn, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=self._mock_request_server_homogeneous, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) result_0 = remote_dataset.fetch_node_ids( rank=0, @@ -934,55 +868,23 @@ def test_contiguous_requires_rank_and_world_size(self): def test_contiguous_labeled_homogeneous_auto_inference(self): """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" - global _test_server - create_test_process_group() - - positive_labels = { - 0: [0, 1], - 1: [1, 2], - 2: [2, 3], - 3: [3, 4], - 4: [4, 0], - } - negative_labels = { - 0: [2], - 1: [3], - 2: [4], - 3: [0], - 4: [1], - } - edge_indices = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] - ) - } - dataset = create_heterogeneous_dataset_for_ablp( - positive_labels=positive_labels, - negative_labels=negative_labels, - train_node_ids=[0, 1, 2], - val_node_ids=[3], - test_node_ids=[4], - edge_indices=edge_indices, + _create_server_with_splits( + edge_indices={ + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + ) + }, src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, ) - _test_server = DistServer(dataset) - dist_server_module._dist_server = _test_server cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=_mock_async_request_server, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(_mock_async_request_server, _mock_request_server): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) # No node_type: should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE result = remote_dataset.fetch_node_ids( @@ -1045,20 +947,12 @@ def test_ablp_even_split_2_servers_2_compute(self): }, } mock_fn = self._make_rank_aware_ablp_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=mock_fn, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=self._mock_request_server_homogeneous, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) # Rank 0 gets server 0 fully, server 1 empty result_0 = remote_dataset.fetch_ablp_input( @@ -1127,20 +1021,12 @@ def test_ablp_fractional_split_3_servers_2_compute(self): }, } mock_fn = self._make_rank_aware_ablp_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) - with ( - patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=mock_fn, - ), - patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=self._mock_request_server_homogeneous, - ), - ): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0 + ) # Rank 0: server 0 fully, server 1 first half (2 of 4), server 2 empty result_0 = remote_dataset.fetch_ablp_input( @@ -1250,20 +1136,13 @@ def _test_fetch_free_ports_on_storage_cluster( dist.destroy_process_group() -class TestGetFreePortsOnStorageCluster(TestCase): +class TestGetFreePortsOnStorageCluster(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server dataset = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - def test_fetch_free_ports_on_storage_cluster_distributed(self): """Test that free ports are correctly broadcast across all ranks.""" init_method = get_process_group_init_method() @@ -1286,7 +1165,7 @@ def test_fetch_free_ports_fails_without_process_group(self): remote_dataset.fetch_free_ports_on_storage_cluster(num_ports=1) -class TestCallFuncOnServer(TestCase): +class TestCallFuncOnServer(RemoteDistDatasetTestBase): """Tests for the _call_func_on_server dispatch logic.""" def setUp(self) -> None: @@ -1299,11 +1178,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - def test_dispatches_server_method(self): """Test that _call_func_on_server correctly dispatches an unbound DistServer method.""" result = _call_func_on_server(DistServer.get_edge_dir) From 1de868a2ee7c1c8dbb24a5cf47482498e2786a3d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 23 Mar 2026 23:58:29 +0000 Subject: [PATCH 04/13] maybe cleanup --- gigl/distributed/dist_ablp_neighborloader.py | 6 +- .../graph_store/remote_dist_dataset.py | 330 ++++++------ gigl/distributed/graph_store/sharding.py | 72 +++ gigl/distributed/utils/__init__.py | 2 - gigl/distributed/utils/neighborloader.py | 178 ------- .../graph_store_integration_test.py | 3 +- .../dist_ablp_neighborloader_test.py | 55 ++ .../graph_store/remote_dist_dataset_test.py | 485 ++++++++++-------- .../distributed/graph_store/sharding_test.py | 150 ++++++ .../distributed/utils/neighborloader_test.py | 145 ------ 10 files changed, 717 insertions(+), 709 deletions(-) create mode 100644 gigl/distributed/graph_store/sharding.py create mode 100644 tests/unit/distributed/graph_store/sharding_test.py diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 498087794..945194c9f 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -699,7 +699,11 @@ def _setup_for_graph_store( # Extract supervision edge types and derive label edge types from the # ABLPInputNodes.labels dict (keyed by supervision edge type). self._supervision_edge_types = list(first_input.labels.keys()) - has_negatives = any(neg is not None for _, neg in first_input.labels.values()) + has_negatives = any( + negative_labels is not None + for ablp_input in input_nodes.values() + for _, negative_labels in ablp_input.labels.values() + ) self._positive_label_edge_types = [ message_passing_to_positive_label(et) for et in self._supervision_edge_types diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 88899d464..1c37f91ce 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -6,7 +6,8 @@ from gigl.common.logger import Logger from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer -from gigl.distributed.utils.neighborloader import ( +from gigl.distributed.graph_store.sharding import ( + ServerSlice, ShardStrategy, compute_server_assignments, ) @@ -163,65 +164,78 @@ def _infer_edge_type_if_homogeneous_with_label_edges( ) return edge_type + def _validate_contiguous_args( + self, + rank: Optional[int], + world_size: Optional[int], + shard_strategy: ShardStrategy, + ) -> tuple[Optional[int], Optional[int]]: + """Validate contiguous sharding inputs and preserve round-robin inputs unchanged.""" + if shard_strategy != ShardStrategy.CONTIGUOUS: + return rank, world_size + + if rank is None or world_size is None: + raise ValueError( + "Both rank and world_size must be provided when using " + f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" + ) + if world_size != self.cluster_info.num_compute_nodes: + raise ValueError( + "ShardStrategy.CONTIGUOUS expects world_size to equal " + "cluster_info.num_compute_nodes. " + f"Got world_size={world_size}, " + f"cluster_info.num_compute_nodes={self.cluster_info.num_compute_nodes}" + ) + return rank, world_size + + def _compute_assignments_if_needed( + self, + rank: Optional[int], + world_size: Optional[int], + shard_strategy: ShardStrategy, + ) -> Optional[dict[int, ServerSlice]]: + """Compute contiguous server assignments when that shard strategy is requested.""" + if shard_strategy != ShardStrategy.CONTIGUOUS: + return None + + assert rank is not None and world_size is not None + return compute_server_assignments( + num_servers=self.cluster_info.num_storage_nodes, + num_compute_nodes=world_size, + compute_rank=rank, + ) + def _fetch_node_ids( self, rank: Optional[int] = None, world_size: Optional[int] = None, node_type: Optional[NodeType] = None, split: Optional[Literal["train", "val", "test"]] = None, + assignments: Optional[dict[int, ServerSlice]] = None, ) -> dict[int, torch.Tensor]: """Fetches node ids from the storage nodes for the current compute node (machine).""" - futures: list[torch.futures.Future[torch.Tensor]] = [] node_type = self._infer_node_type_if_homogeneous_with_label_edges(node_type) - logger.info( - f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" - ) - - for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_node_ids, - rank=rank, - world_size=world_size, - split=split, - node_type=node_type, - ) + if assignments is None: + logger.info( + f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" ) + futures: list[torch.futures.Future[torch.Tensor]] = [] + for server_rank in range(self.cluster_info.num_storage_nodes): + futures.append( + async_request_server( + server_rank, + DistServer.get_node_ids, + rank=rank, + world_size=world_size, + split=split, + node_type=node_type, + ) + ) node_ids = torch.futures.wait_all(futures) - return {server_rank: node_ids for server_rank, node_ids in enumerate(node_ids)} - - def _fetch_node_ids_by_server( - self, - rank: int, - world_size: int, - node_type: Optional[NodeType] = None, - split: Optional[Literal["train", "val", "test"]] = None, - ) -> dict[int, torch.Tensor]: - """Fetches node ids using contiguous server assignment. - - Each compute node is assigned a contiguous range of servers. Only - assigned servers are RPCed; unassigned servers get empty tensors. - Boundary servers are sliced fractionally when servers don't divide - evenly across compute nodes. - - Args: - rank: The rank of the compute node requesting node ids. - world_size: The total number of compute nodes. - node_type: The type of nodes to get. Must be provided for heterogeneous datasets. - split: The split of the dataset to get node ids from. - - Returns: - A dict mapping every server rank to a tensor of node ids. - """ - node_type = self._infer_node_type_if_homogeneous_with_label_edges(node_type) - - assignments = compute_server_assignments( - num_servers=self.cluster_info.num_storage_nodes, - num_compute_nodes=world_size, - compute_rank=rank, - ) + return { + server_rank: node_ids for server_rank, node_ids in enumerate(node_ids) + } logger.info( f"Getting node ids via CONTIGUOUS strategy for rank {rank} / {world_size} " @@ -229,10 +243,9 @@ def _fetch_node_ids_by_server( f"Assigned servers: {list(assignments.keys())}" ) - # RPC only assigned servers (fetch ALL nodes, no server-side sharding) - futures: dict[int, torch.futures.Future[torch.Tensor]] = {} + assigned_futures: dict[int, torch.futures.Future[torch.Tensor]] = {} for server_rank in assignments: - futures[server_rank] = async_request_server( + assigned_futures[server_rank] = async_request_server( server_rank, DistServer.get_node_ids, rank=None, @@ -241,15 +254,13 @@ def _fetch_node_ids_by_server( node_type=node_type, ) - # Build result: slice assigned servers, empty tensors for unassigned result: dict[int, torch.Tensor] = {} for server_rank in range(self.cluster_info.num_storage_nodes): - if server_rank in futures: - all_nodes = futures[server_rank].wait() + if server_rank in assigned_futures: + all_nodes = assigned_futures[server_rank].wait() result[server_rank] = assignments[server_rank].slice_tensor(all_nodes) else: result[server_rank] = torch.empty(0, dtype=torch.long) - return result def fetch_node_ids( @@ -267,26 +278,32 @@ def fetch_node_ids( filtered and sharded according to the provided arguments. Args: - rank (Optional[int]): The rank of the process requesting node ids. Must be provided if world_size is provided. - world_size (Optional[int]): The total number of processes in the distributed setup. Must be provided if rank is provided. + rank (Optional[int]): The requested shard rank. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node rank. + world_size (Optional[int]): The requested shard world size. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + requires ``world_size == cluster_info.num_compute_nodes``. split (Optional[Literal["train", "val", "test"]]): The split of the dataset to get node ids from. If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. node_type (Optional[NodeType]): The type of nodes to get. Must be provided for heterogeneous datasets. shard_strategy (ShardStrategy): Strategy for sharding node IDs across compute nodes. - ``ROUND_ROBIN`` (default) shards each server's nodes across all compute nodes. - ``CONTIGUOUS`` assigns entire servers to compute nodes, producing empty tensors - for unassigned servers. ``CONTIGUOUS`` requires both ``rank`` and ``world_size``. + ``ROUND_ROBIN`` (default) shards each server's nodes across the + requested rank/world_size on the storage server. ``CONTIGUOUS`` + assigns storage servers to compute nodes, returning empty tensors + for unassigned servers. Raises: - ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None, + or if ``world_size`` does not match ``cluster_info.num_compute_nodes``. Returns: dict[int, torch.Tensor]: A dict mapping storage rank to node ids. Examples: - See :class:`~gigl.distributed.utils.neighborloader.ShardStrategy` for + See :class:`~gigl.distributed.graph_store.sharding.ShardStrategy` for concrete examples of how each strategy distributes node IDs across compute nodes. @@ -295,14 +312,23 @@ def fetch_node_ids( (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split. """ - if shard_strategy == ShardStrategy.CONTIGUOUS: - if rank is None or world_size is None: - raise ValueError( - "Both rank and world_size must be provided when using " - f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" - ) - return self._fetch_node_ids_by_server(rank, world_size, node_type, split) - return self._fetch_node_ids(rank, world_size, node_type, split) + rank, world_size = self._validate_contiguous_args( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + assignments = self._compute_assignments_if_needed( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + return self._fetch_node_ids( + rank=rank, + world_size=world_size, + node_type=node_type, + split=split, + assignments=assignments, + ) def fetch_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: """ @@ -349,67 +375,37 @@ def _fetch_ablp_input( world_size: Optional[int] = None, node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + assignments: Optional[dict[int, ServerSlice]] = None, ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: """Fetches ABLP input from the storage nodes for the current compute node (machine).""" - futures: list[ - torch.futures.Future[ - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] - ] - ] = [] - logger.info( - f"Getting ABLP input for rank {rank} / {world_size} with node type {node_type}, " - f"split {split}, and supervision edge type {supervision_edge_type}" - ) + if assignments is None: + futures: list[ + torch.futures.Future[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] + ] = [] + logger.info( + f"Getting ABLP input for rank {rank} / {world_size} with node type {node_type}, " + f"split {split}, and supervision edge type {supervision_edge_type}" + ) - for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_ablp_input, - split=split, - rank=rank, - world_size=world_size, - node_type=node_type, - supervision_edge_type=supervision_edge_type, + for server_rank in range(self.cluster_info.num_storage_nodes): + futures.append( + async_request_server( + server_rank, + DistServer.get_ablp_input, + split=split, + rank=rank, + world_size=world_size, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + ) ) - ) ablp_inputs = torch.futures.wait_all(futures) - return { - server_rank: ablp_input - for server_rank, ablp_input in enumerate(ablp_inputs) - } - - def _fetch_ablp_input_by_server( - self, - split: Literal["train", "val", "test"], - rank: int, - world_size: int, - node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, - supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, - ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: - """Fetches ABLP input using contiguous server assignment. - - Each compute node is assigned a contiguous range of servers. Only - assigned servers are RPCed; unassigned servers get empty tensors. - Boundary servers are sliced fractionally when servers don't divide - evenly across compute nodes. - - Args: - split: The split of the dataset to get ABLP input from. - rank: The rank of the compute node requesting ABLP input. - world_size: The total number of compute nodes. - node_type: The type of anchor nodes to retrieve. - supervision_edge_type: The edge type for supervision. - - Returns: - A dict mapping every server rank to a tuple of - (anchors, positive_labels, negative_labels). - """ - assignments = compute_server_assignments( - num_servers=self.cluster_info.num_storage_nodes, - num_compute_nodes=world_size, - compute_rank=rank, - ) + return { + server_rank: ablp_input + for server_rank, ablp_input in enumerate(ablp_inputs) + } logger.info( f"Getting ABLP input via CONTIGUOUS strategy for rank {rank} / {world_size} " @@ -418,15 +414,14 @@ def _fetch_ablp_input_by_server( f"Assigned servers: {list(assignments.keys())}" ) - # RPC only assigned servers (fetch ALL ABLP data, no server-side sharding) - futures: dict[ + assigned_futures: dict[ int, torch.futures.Future[ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] ], ] = {} for server_rank in assignments: - futures[server_rank] = async_request_server( + assigned_futures[server_rank] = async_request_server( server_rank, DistServer.get_ablp_input, split=split, @@ -436,13 +431,14 @@ def _fetch_ablp_input_by_server( supervision_edge_type=supervision_edge_type, ) - # Build result: slice assigned servers, empty tensors for unassigned result: dict[ int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] ] = {} for server_rank in range(self.cluster_info.num_storage_nodes): - if server_rank in futures: - anchors, positive_labels, negative_labels = futures[server_rank].wait() + if server_rank in assigned_futures: + anchors, positive_labels, negative_labels = assigned_futures[ + server_rank + ].wait() server_slice = assignments[server_rank] sliced_anchors = server_slice.slice_tensor(anchors) sliced_positive = server_slice.slice_tensor(positive_labels) @@ -459,10 +455,9 @@ def _fetch_ablp_input_by_server( else: result[server_rank] = ( torch.empty(0, dtype=torch.long), - torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), None, ) - return result # TODO(#488) - support multiple supervision edge types @@ -488,10 +483,12 @@ def fetch_ablp_input( Args: split (Literal["train", "val", "test"]): The split to get the input for. - rank (Optional[int]): The rank of the process requesting the input. - Must be provided if world_size is provided. - world_size (Optional[int]): The total number of processes in the distributed setup. - Must be provided if rank is provided. + rank (Optional[int]): The requested shard rank. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node rank. + world_size (Optional[int]): The requested shard world size. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + requires ``world_size == cluster_info.num_compute_nodes``. anchor_node_type (Optional[NodeType]): The type of the anchor nodes to retrieve. Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. @@ -501,35 +498,34 @@ def fetch_ablp_input( Must be None for labeled homogeneous graphs. Defaults to None. shard_strategy (ShardStrategy): Strategy for sharding ABLP input across compute - nodes. ``ROUND_ROBIN`` (default) shards each server's data across all compute - nodes. ``CONTIGUOUS`` assigns entire servers to compute nodes, producing empty - tensors for unassigned servers. ``CONTIGUOUS`` requires both ``rank`` and - ``world_size``. + nodes. ``ROUND_ROBIN`` (default) shards each server's data across the + requested rank/world_size on the storage server. ``CONTIGUOUS`` assigns + storage servers to compute nodes, producing empty tensors for unassigned + servers. Returns: dict[int, ABLPInputNodes]: A dict mapping storage rank to an ABLPInputNodes containing: - - anchor_node_type: The node type of the anchor nodes, or None for labeled homogeneous. + - anchor_node_type: The node type of the anchor nodes, or ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` for labeled homogeneous. - anchor_nodes: 1D tensor of anchor node IDs for the split. - positive_labels: Dict mapping positive label EdgeType to a 2D tensor [N, M]. - negative_labels: Optional dict mapping negative label EdgeType to a 2D tensor [N, M]. Raises: - ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None, + or if ``world_size`` does not match ``cluster_info.num_compute_nodes``. Examples: - See :class:`~gigl.distributed.utils.neighborloader.ShardStrategy` for + See :class:`~gigl.distributed.graph_store.sharding.ShardStrategy` for concrete examples of how each strategy distributes data across compute nodes. """ - - if shard_strategy == ShardStrategy.CONTIGUOUS: - if rank is None or world_size is None: - raise ValueError( - "Both rank and world_size must be provided when using " - f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" - ) + rank, world_size = self._validate_contiguous_args( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) if (anchor_node_type is None) != (supervision_edge_type is None): raise ValueError( @@ -546,23 +542,19 @@ def fetch_ablp_input( evaluated_supervision_edge_type = supervision_edge_type del anchor_node_type, supervision_edge_type - if shard_strategy == ShardStrategy.CONTIGUOUS: - assert rank is not None and world_size is not None - raw_inputs = self._fetch_ablp_input_by_server( - split=split, - rank=rank, - world_size=world_size, - node_type=evaluated_anchor_node_type, - supervision_edge_type=evaluated_supervision_edge_type, - ) - else: - raw_inputs = self._fetch_ablp_input( - split=split, - rank=rank, - world_size=world_size, - node_type=evaluated_anchor_node_type, - supervision_edge_type=evaluated_supervision_edge_type, - ) + assignments = self._compute_assignments_if_needed( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + raw_inputs = self._fetch_ablp_input( + split=split, + rank=rank, + world_size=world_size, + node_type=evaluated_anchor_node_type, + supervision_edge_type=evaluated_supervision_edge_type, + assignments=assignments, + ) return { server_rank: ABLPInputNodes( anchor_node_type=evaluated_anchor_node_type, diff --git a/gigl/distributed/graph_store/sharding.py b/gigl/distributed/graph_store/sharding.py new file mode 100644 index 000000000..5d0fd5913 --- /dev/null +++ b/gigl/distributed/graph_store/sharding.py @@ -0,0 +1,72 @@ +"""Graph-store-specific sharding helpers.""" + +from dataclasses import dataclass +from enum import Enum + +import torch + + +class ShardStrategy(Enum): + """Strategies for splitting remote graph-store inputs across compute nodes.""" + + ROUND_ROBIN = "round_robin" + CONTIGUOUS = "contiguous" + + +@dataclass(frozen=True) +class ServerSlice: + """The fraction of a storage server owned by one compute node.""" + + server_rank: int + start_num: int + start_den: int + end_num: int + end_den: int + + def slice_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """Slice a tensor according to this server assignment.""" + total = len(tensor) + start_idx = total * self.start_num // self.start_den + end_idx = total * self.end_num // self.end_den + if start_idx == 0 and end_idx == total: + return tensor + return tensor[start_idx:end_idx].clone() + + +def compute_server_assignments( + num_servers: int, + num_compute_nodes: int, + compute_rank: int, +) -> dict[int, ServerSlice]: + """Compute which servers, and which fractions of them, belong to one compute node.""" + if num_servers <= 0: + raise ValueError(f"num_servers must be positive, got {num_servers}") + if num_compute_nodes <= 0: + raise ValueError(f"num_compute_nodes must be positive, got {num_compute_nodes}") + if compute_rank < 0 or compute_rank >= num_compute_nodes: + raise ValueError( + f"compute_rank must be in [0, {num_compute_nodes}), got {compute_rank}" + ) + + seg_start = compute_rank * num_servers + seg_end = (compute_rank + 1) * num_servers + + assignments: dict[int, ServerSlice] = {} + for server_rank in range(num_servers): + server_start = server_rank * num_compute_nodes + server_end = (server_rank + 1) * num_compute_nodes + + overlap_start = max(seg_start, server_start) + overlap_end = min(seg_end, server_end) + if overlap_start >= overlap_end: + continue + + assignments[server_rank] = ServerSlice( + server_rank=server_rank, + start_num=overlap_start - server_start, + start_den=num_compute_nodes, + end_num=overlap_end - server_start, + end_den=num_compute_nodes, + ) + + return assignments diff --git a/gigl/distributed/utils/__init__.py b/gigl/distributed/utils/__init__.py index 63aeb8f50..04a986658 100644 --- a/gigl/distributed/utils/__init__.py +++ b/gigl/distributed/utils/__init__.py @@ -4,7 +4,6 @@ __all__ = [ "GraphStoreInfo", - "ShardStrategy", "get_available_device", "get_free_port", "get_free_ports", @@ -25,7 +24,6 @@ get_process_group_name, init_neighbor_loader_worker, ) -from .neighborloader import ShardStrategy from .networking import ( GraphStoreInfo, get_free_port, diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index cc589d46b..1e6198499 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -28,184 +28,6 @@ class SamplingClusterSetup(Enum): GRAPH_STORE = "graph_store" -class ShardStrategy(Enum): - """Strategy for sharding node IDs across compute nodes. - - Controls how data from storage servers is distributed to compute nodes. - Both strategies produce the same total coverage (every node appears on - exactly one compute node), but differ in which servers each compute node - communicates with. - - Attributes: - ROUND_ROBIN: Each compute node gets a slice of nodes from every server. - Server-side sharding via rank/world_size. This is the current default. - CONTIGUOUS: Assign entire servers to compute nodes. Each compute node - only gets nodes from its assigned servers, with empty tensors for - the rest. Boundary servers are split fractionally when servers - don't divide evenly across compute nodes. - - Examples: - **2 storage nodes, 2 compute nodes** (even split): - - Suppose each server holds 10 node IDs:: - - Server 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - Server 1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - - ``ROUND_ROBIN`` — every compute node gets a slice from *every* server:: - - Compute 0 (rank=0, world_size=2): - {0: [0,1,2,3,4], 1: [10,11,12,13,14]} - Compute 1 (rank=1, world_size=2): - {0: [5,6,7,8,9], 1: [15,16,17,18,19]} - - ``CONTIGUOUS`` — each compute node gets *entire* servers:: - - Compute 0 (rank=0, world_size=2): - {0: [0,1,2,3,4,5,6,7,8,9], 1: []} # all of server 0 - Compute 1 (rank=1, world_size=2): - {0: [], 1: [10,11,12,13,14,15,16,17,18,19]} # all of server 1 - - **3 storage nodes, 2 compute nodes** (fractional boundary): - - Server 1 is split at the boundary — compute 0 gets the first half, - compute 1 gets the second half:: - - Server 0: [0..9], Server 1: [10..19], Server 2: [20..29] - - Compute 0 (rank=0): {0: [0..9], 1: [10..14], 2: []} - Compute 1 (rank=1): {0: [], 1: [15..19], 2: [20..29]} - - See Also: - :func:`compute_server_assignments` for the assignment algorithm. - """ - - ROUND_ROBIN = "round_robin" - CONTIGUOUS = "contiguous" - - -@dataclass(frozen=True) -class ServerSlice: - """A compute node's ownership of a single server's nodes. - - Fractions are represented as exact rationals (numerator, denominator) - to avoid floating-point boundary errors. For a server with N nodes, - the slice is ``tensor[N * start_num // start_den : N * end_num // end_den]``. - - Args: - server_rank: The rank of the storage server. - start_num: Numerator of the start fraction. - start_den: Denominator of the start fraction. - end_num: Numerator of the end fraction. - end_den: Denominator of the end fraction. - """ - - server_rank: int - start_num: int - start_den: int - end_num: int - end_den: int - - def slice_tensor(self, tensor: torch.Tensor) -> torch.Tensor: - """Slice a 1D tensor according to this assignment's rational bounds. - - Uses integer division (N * num // den) for exact, deterministic - index computation. Returns a ``.clone()`` for partial slices to avoid - retaining full backing storage when used with ``share_memory_()``. - - Args: - tensor: A 1D tensor of node IDs from the server. - - Returns: - The sliced portion of the tensor. - """ - total = len(tensor) - start_idx = total * self.start_num // self.start_den - end_idx = total * self.end_num // self.end_den - if start_idx == 0 and end_idx == total: - return tensor - return tensor[start_idx:end_idx].clone() - - -def compute_server_assignments( - num_servers: int, - num_compute_nodes: int, - compute_rank: int, -) -> dict[int, ServerSlice]: - """Compute which servers (and what fraction) a compute node owns. - - Uses integer arithmetic throughout. Compute rank R owns the server - range ``[R * S / C, (R+1) * S / C)`` where boundaries are rational - numbers with denominator C. For each server s in ``[0, S)``, the overlap - with this range determines the ServerSlice fractions. - - Only servers with non-zero overlap are included in the returned dict. - - Args: - num_servers: Total number of storage servers (S). - num_compute_nodes: Total number of compute nodes (C). - compute_rank: Rank of the current compute node (R). - - Returns: - A dict mapping server rank to the ``ServerSlice`` describing the - fraction of that server owned by this compute node. - - Raises: - ValueError: If any argument is invalid (negative values, - rank >= num_compute_nodes, or zero servers/compute nodes). - - Examples: - >>> compute_server_assignments(num_servers=4, num_compute_nodes=2, compute_rank=0) - {0: ServerSlice(server_rank=0, ...), 1: ServerSlice(server_rank=1, ...)} - - >>> compute_server_assignments(num_servers=3, num_compute_nodes=2, compute_rank=1) - {1: ServerSlice(server_rank=1, ...), 2: ServerSlice(server_rank=2, ...)} - """ - if num_servers <= 0: - raise ValueError(f"num_servers must be positive, got {num_servers}") - if num_compute_nodes <= 0: - raise ValueError(f"num_compute_nodes must be positive, got {num_compute_nodes}") - if compute_rank < 0 or compute_rank >= num_compute_nodes: - raise ValueError( - f"compute_rank must be in [0, {num_compute_nodes}), got {compute_rank}" - ) - - S = num_servers - C = num_compute_nodes - R = compute_rank - - # Segment boundaries (as numerators with denominator C): - # start = R * S, end = (R + 1) * S - seg_start = R * S - seg_end = (R + 1) * S - - assignments: dict[int, ServerSlice] = {} - for s in range(S): - # Server s spans [s * C, (s + 1) * C) in numerator-space with denominator C - server_start = s * C - server_end = (s + 1) * C - - overlap_start = max(seg_start, server_start) - overlap_end = min(seg_end, server_end) - - if overlap_start >= overlap_end: - continue - - # Fraction of server s: [(overlap_start - s*C) / C, (overlap_end - s*C) / C) - start_num = overlap_start - server_start - end_num = overlap_end - server_start - - assignments[s] = ServerSlice( - server_rank=s, - start_num=start_num, - start_den=C, - end_num=end_num, - end_den=C, - ) - - return assignments - - @dataclass(frozen=True) class DatasetSchema: """ diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 85c9de2e6..9deef0155 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -23,11 +23,12 @@ shutdown_compute_proccess, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.graph_store.sharding import ShardStrategy from gigl.distributed.graph_store.storage_utils import ( build_storage_dataset, run_storage_server, ) -from gigl.distributed.utils.neighborloader import ShardStrategy, shard_nodes_by_process +from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.distributed.utils.networking import get_free_port, get_free_ports from gigl.distributed.utils.partition_book import build_partition_book, get_ids_on_rank from gigl.env.distributed import ( diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index fd14aadd5..c73465733 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -1,6 +1,7 @@ import unittest from collections import defaultdict from typing import Literal, Optional, Union +from unittest.mock import Mock, patch import torch import torch.multiprocessing as mp @@ -28,6 +29,7 @@ ) from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, GraphPartitionData, PartitionOutput, message_passing_to_negative_label, @@ -36,6 +38,7 @@ to_homogeneous, ) from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter +from gigl.utils.sampling import ABLPInputNodes from tests.test_assets.distributed.utils import ( assert_tensor_equality, create_test_process_group, @@ -424,6 +427,58 @@ def tearDown(self): torch.distributed.destroy_process_group() super().tearDown() + def test_graph_store_setup_detects_negatives_across_all_servers(self) -> None: + loader = DistABLPLoader.__new__(DistABLPLoader) + loader._instance_count = 0 + loader._shutdowned = True + + dataset = Mock() + dataset.cluster_info.num_storage_nodes = 2 + dataset.cluster_info.compute_cluster_world_size = 2 + dataset.cluster_info.storage_cluster_master_ip = "127.0.0.1" + dataset.fetch_node_feature_info.return_value = None + dataset.fetch_edge_feature_info.return_value = None + dataset.fetch_edge_types.return_value = None + dataset.fetch_free_ports_on_storage_cluster.return_value = [12345, 12346] + + input_nodes = { + 0: ABLPInputNodes( + anchor_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + anchor_nodes=torch.empty(0, dtype=torch.long), + labels={ + DEFAULT_HOMOGENEOUS_EDGE_TYPE: ( + torch.empty((0, 0), dtype=torch.long), + None, + ) + }, + ), + 1: ABLPInputNodes( + anchor_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + anchor_nodes=torch.tensor([1]), + labels={ + DEFAULT_HOMOGENEOUS_EDGE_TYPE: ( + torch.tensor([[2]]), + torch.tensor([[3]]), + ) + }, + ), + } + + with patch("torch.distributed.get_rank", return_value=0): + input_data, _, _ = loader._setup_for_graph_store( + input_nodes=input_nodes, + dataset=dataset, + num_workers=1, + ) + + self.assertEqual(loader._negative_label_edge_types, [_NEGATIVE_EDGE_TYPE]) + self.assertEqual(input_data[0].negative_label_by_edge_types, {}) + self.assertIn(_NEGATIVE_EDGE_TYPE, input_data[1].negative_label_by_edge_types) + self.assert_tensor_equality( + input_data[0].positive_label_by_edge_types[_POSITIVE_EDGE_TYPE], + torch.empty((0, 0), dtype=torch.long), + ) + @parameterized.expand( [ param( diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index b7ef0b173..7f56c89a8 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -1,17 +1,18 @@ from contextlib import contextmanager -from typing import Final, Optional +from typing import Final, Literal, Optional from unittest.mock import patch import torch import torch.distributed as dist import torch.multiprocessing as mp from absl.testing import absltest +from parameterized import param, parameterized import gigl.distributed.graph_store.dist_server as dist_server_module from gigl.common import LocalUri from gigl.distributed.graph_store.dist_server import DistServer, _call_func_on_server from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.utils.neighborloader import ShardStrategy +from gigl.distributed.graph_store.sharding import ShardStrategy from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import ( @@ -563,9 +564,7 @@ class TestRemoteDistDatasetLabeledHomogeneous(RemoteDistDatasetTestBase): """ _LABELED_HOMOGENEOUS_EDGE_INDICES: Final = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] - ) + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) } def _create_labeled_homogeneous_server(self) -> None: @@ -596,9 +595,7 @@ def test_fetch_node_ids_auto_detects_default_node_type(self): cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) with _patch_remote_requests(_mock_async_request_server, _mock_request_server): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) # No node_type provided: _fetch_node_ids should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE self.assert_tensor_equality( @@ -734,74 +731,87 @@ def _mock_request_server_homogeneous(server_rank, func, *args, **kwargs): return None return _mock_request_server(server_rank, func, *args, **kwargs) - def test_even_split_2_servers_2_compute(self): - """2 servers, 2 compute nodes: each gets one server fully.""" - server_data = { - 0: {"all": torch.arange(10)}, - 1: {"all": torch.arange(10, 20)}, - } - mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) - - with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) - - # Rank 0 gets server 0 fully, server 1 empty - result_0 = remote_dataset.fetch_node_ids( - rank=0, - world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) - self.assert_tensor_equality(result_0[0], torch.arange(10)) - self.assertEqual(len(result_0[1]), 0) - - # Rank 1 gets server 0 empty, server 1 fully - result_1 = remote_dataset.fetch_node_ids( - rank=1, - world_size=2, + def _assert_contiguous_node_ids( + self, + remote_dataset: RemoteDistDataset, + world_size: int, + expected_by_rank: dict[int, dict[int, torch.Tensor]], + split: Optional[Literal["train", "val", "test"]] = None, + ) -> None: + for rank, expected_by_server in expected_by_rank.items(): + result = remote_dataset.fetch_node_ids( + rank=rank, + world_size=world_size, + split=split, shard_strategy=ShardStrategy.CONTIGUOUS, ) - self.assertEqual(len(result_1[0]), 0) - self.assert_tensor_equality(result_1[1], torch.arange(10, 20)) - - def test_fractional_split_3_servers_2_compute(self): - """3 servers, 2 compute nodes: server 1 is split at boundary.""" - server_data = { - 0: {"all": torch.arange(10)}, - 1: {"all": torch.arange(10, 20)}, - 2: {"all": torch.arange(20, 30)}, - } + self.assertEqual(set(result.keys()), set(expected_by_server.keys())) + for server_rank, expected_tensor in expected_by_server.items(): + self.assert_tensor_equality(result[server_rank], expected_tensor) + + @parameterized.expand( + [ + param( + "even_split", + num_storage_nodes=2, + server_data={ + 0: {"all": torch.arange(10)}, + 1: {"all": torch.arange(10, 20)}, + }, + expected_by_rank={ + 0: { + 0: torch.arange(10), + 1: torch.empty(0, dtype=torch.long), + }, + 1: { + 0: torch.empty(0, dtype=torch.long), + 1: torch.arange(10, 20), + }, + }, + ), + param( + "fractional_split", + num_storage_nodes=3, + server_data={ + 0: {"all": torch.arange(10)}, + 1: {"all": torch.arange(10, 20)}, + 2: {"all": torch.arange(20, 30)}, + }, + expected_by_rank={ + 0: { + 0: torch.arange(10), + 1: torch.arange(10, 15), + 2: torch.empty(0, dtype=torch.long), + }, + 1: { + 0: torch.empty(0, dtype=torch.long), + 1: torch.arange(15, 20), + 2: torch.arange(20, 30), + }, + }, + ), + ] + ) + def test_fetch_node_ids_contiguous( + self, + _, + num_storage_nodes: int, + server_data: dict[int, dict[str, torch.Tensor]], + expected_by_rank: dict[int, dict[int, torch.Tensor]], + ) -> None: mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=num_storage_nodes, + num_compute_nodes=2, + ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) - - # Rank 0: server 0 fully, server 1 first half, server 2 empty - result_0 = remote_dataset.fetch_node_ids( - rank=0, - world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) - self.assert_tensor_equality(result_0[0], torch.arange(10)) - # Server 1: 10 * 0 // 2 = 0, 10 * 1 // 2 = 5 → [10, 11, 12, 13, 14] - self.assert_tensor_equality(result_0[1], torch.arange(10, 15)) - self.assertEqual(len(result_0[2]), 0) - - # Rank 1: server 0 empty, server 1 second half, server 2 fully - result_1 = remote_dataset.fetch_node_ids( - rank=1, + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + self._assert_contiguous_node_ids( + remote_dataset=remote_dataset, world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, + expected_by_rank=expected_by_rank, ) - self.assertEqual(len(result_1[0]), 0) - # Server 1: 10 * 1 // 2 = 5, 10 * 2 // 2 = 10 → [15, 16, 17, 18, 19] - self.assert_tensor_equality(result_1[1], torch.arange(15, 20)) - self.assert_tensor_equality(result_1[2], torch.arange(20, 30)) def test_with_split_filtering(self): """CONTIGUOUS strategy with split='train' filtering.""" @@ -810,12 +820,13 @@ def test_with_split_filtering(self): 1: {"all": torch.arange(10, 20), "train": torch.tensor([10, 11, 12, 13])}, } mock_fn = self._make_rank_aware_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, + num_compute_nodes=2, + ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) result_0 = remote_dataset.fetch_node_ids( rank=0, @@ -828,7 +839,10 @@ def test_with_split_filtering(self): def test_contiguous_requires_rank_and_world_size(self): """CONTIGUOUS without rank/world_size raises ValueError.""" - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, + num_compute_nodes=2, + ) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) with self.assertRaises(ValueError): @@ -845,6 +859,12 @@ def test_contiguous_requires_rank_and_world_size(self): world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS, ) + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + rank=0, + world_size=3, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) def test_contiguous_labeled_homogeneous_auto_inference(self): """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" @@ -862,9 +882,7 @@ def test_contiguous_labeled_homogeneous_auto_inference(self): cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) with _patch_remote_requests(_mock_async_request_server, _mock_request_server): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) # No node_type: should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE result = remote_dataset.fetch_node_ids( @@ -906,153 +924,187 @@ def _mock(server_rank, func, *args, **kwargs): return _mock - def test_ablp_even_split_2_servers_2_compute(self): - """ABLP CONTIGUOUS: 2 servers, 2 compute nodes — each gets one server fully.""" - neg_0: Optional[torch.Tensor] = torch.tensor([[4], [5], [6]]) - neg_1: Optional[torch.Tensor] = torch.tensor([[14], [15], [16]]) - server_data = { - 0: { - "train": ( - torch.tensor([0, 1, 2]), - torch.tensor([[0, 1], [1, 2], [2, 3]]), - neg_0, - ), - }, - 1: { - "train": ( - torch.tensor([10, 11, 12]), - torch.tensor([[10, 11], [11, 12], [12, 13]]), - neg_1, - ), - }, - } - mock_fn = self._make_rank_aware_ablp_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) - - with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) - - # Rank 0 gets server 0 fully, server 1 empty - result_0 = remote_dataset.fetch_ablp_input( - split="train", - rank=0, - world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) - ablp_0_s0 = result_0[0] - self.assert_tensor_equality(ablp_0_s0.anchor_nodes, torch.tensor([0, 1, 2])) - pos_0, neg_0 = ablp_0_s0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality(pos_0, torch.tensor([[0, 1], [1, 2], [2, 3]])) - assert neg_0 is not None - self.assert_tensor_equality(neg_0, torch.tensor([[4], [5], [6]])) - # Server 1 should be empty for rank 0 - ablp_0_s1 = result_0[1] - self.assertEqual(len(ablp_0_s1.anchor_nodes), 0) - - # Rank 1 gets server 0 empty, server 1 fully - result_1 = remote_dataset.fetch_ablp_input( + def _assert_contiguous_ablp_inputs( + self, + remote_dataset: RemoteDistDataset, + world_size: int, + expected_by_rank: dict[ + int, + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + ], + ) -> None: + for rank, expected_by_server in expected_by_rank.items(): + result = remote_dataset.fetch_ablp_input( split="train", - rank=1, - world_size=2, + rank=rank, + world_size=world_size, shard_strategy=ShardStrategy.CONTIGUOUS, ) - ablp_1_s0 = result_1[0] - self.assertEqual(len(ablp_1_s0.anchor_nodes), 0) - ablp_1_s1 = result_1[1] - self.assert_tensor_equality( - ablp_1_s1.anchor_nodes, torch.tensor([10, 11, 12]) - ) - pos_1, neg_1 = ablp_1_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality( - pos_1, torch.tensor([[10, 11], [11, 12], [12, 13]]) - ) - assert neg_1 is not None - self.assert_tensor_equality(neg_1, torch.tensor([[14], [15], [16]])) - - def test_ablp_fractional_split_3_servers_2_compute(self): - """ABLP CONTIGUOUS: 3 servers, 2 compute nodes — server 1 split at boundary.""" - # Each server has 4 anchors with 2D positive labels and 2D negative labels - neg_s0: Optional[torch.Tensor] = torch.tensor([[10], [11], [12], [13]]) - neg_s1: Optional[torch.Tensor] = torch.tensor([[20], [21], [22], [23]]) - neg_s2: Optional[torch.Tensor] = torch.tensor([[30], [31], [32], [33]]) - server_data = { - 0: { - "train": ( - torch.tensor([0, 1, 2, 3]), - torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), - neg_s0, - ), - }, - 1: { - "train": ( - torch.tensor([10, 11, 12, 13]), - torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), - neg_s1, - ), - }, - 2: { - "train": ( - torch.tensor([20, 21, 22, 23]), - torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), - neg_s2, - ), - }, - } + self.assertEqual(set(result.keys()), set(expected_by_server.keys())) + for server_rank, ( + expected_anchors, + expected_positive, + expected_negative, + ) in expected_by_server.items(): + ablp_input = result[server_rank] + self.assert_tensor_equality(ablp_input.anchor_nodes, expected_anchors) + positive, negative = ablp_input.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(positive, expected_positive) + if expected_negative is None: + self.assertIsNone(negative) + else: + assert negative is not None + self.assert_tensor_equality(negative, expected_negative) + + @parameterized.expand( + [ + param( + "even_split", + num_storage_nodes=2, + server_data={ + 0: { + "train": ( + torch.tensor([0, 1, 2]), + torch.tensor([[0, 1], [1, 2], [2, 3]]), + torch.tensor([[4], [5], [6]]), + ), + }, + 1: { + "train": ( + torch.tensor([10, 11, 12]), + torch.tensor([[10, 11], [11, 12], [12, 13]]), + torch.tensor([[14], [15], [16]]), + ), + }, + }, + expected_by_rank={ + 0: { + 0: ( + torch.tensor([0, 1, 2]), + torch.tensor([[0, 1], [1, 2], [2, 3]]), + torch.tensor([[4], [5], [6]]), + ), + 1: ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ), + }, + 1: { + 0: ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ), + 1: ( + torch.tensor([10, 11, 12]), + torch.tensor([[10, 11], [11, 12], [12, 13]]), + torch.tensor([[14], [15], [16]]), + ), + }, + }, + ), + param( + "fractional_split", + num_storage_nodes=3, + server_data={ + 0: { + "train": ( + torch.tensor([0, 1, 2, 3]), + torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), + torch.tensor([[10], [11], [12], [13]]), + ), + }, + 1: { + "train": ( + torch.tensor([10, 11, 12, 13]), + torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), + torch.tensor([[20], [21], [22], [23]]), + ), + }, + 2: { + "train": ( + torch.tensor([20, 21, 22, 23]), + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + torch.tensor([[30], [31], [32], [33]]), + ), + }, + }, + expected_by_rank={ + 0: { + 0: ( + torch.tensor([0, 1, 2, 3]), + torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), + torch.tensor([[10], [11], [12], [13]]), + ), + 1: ( + torch.tensor([10, 11]), + torch.tensor([[10, 11], [11, 12]]), + torch.tensor([[20], [21]]), + ), + 2: ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ), + }, + 1: { + 0: ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ), + 1: ( + torch.tensor([12, 13]), + torch.tensor([[12, 13], [13, 14]]), + torch.tensor([[22], [23]]), + ), + 2: ( + torch.tensor([20, 21, 22, 23]), + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + torch.tensor([[30], [31], [32], [33]]), + ), + }, + }, + ), + ] + ) + def test_fetch_ablp_input_contiguous( + self, + _, + num_storage_nodes: int, + server_data: dict[ + int, + dict[ + str, + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + ], + ], + expected_by_rank: dict[ + int, + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + ], + ) -> None: mock_fn = self._make_rank_aware_ablp_async_mock(server_data) - cluster_info = _create_mock_graph_store_info(num_storage_nodes=3) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=num_storage_nodes, + num_compute_nodes=2, + ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset( - cluster_info=cluster_info, local_rank=0 - ) - - # Rank 0: server 0 fully, server 1 first half (2 of 4), server 2 empty - result_0 = remote_dataset.fetch_ablp_input( - split="train", - rank=0, + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + self._assert_contiguous_ablp_inputs( + remote_dataset=remote_dataset, world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) - ablp_0_s0 = result_0[0] - self.assert_tensor_equality( - ablp_0_s0.anchor_nodes, torch.tensor([0, 1, 2, 3]) - ) - # Server 1: 4 * 0 // 2 = 0, 4 * 1 // 2 = 2 → first 2 - ablp_0_s1 = result_0[1] - self.assert_tensor_equality(ablp_0_s1.anchor_nodes, torch.tensor([10, 11])) - pos_0_s1, neg_0_s1 = ablp_0_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality(pos_0_s1, torch.tensor([[10, 11], [11, 12]])) - assert neg_0_s1 is not None - self.assert_tensor_equality(neg_0_s1, torch.tensor([[20], [21]])) - ablp_0_s2 = result_0[2] - self.assertEqual(len(ablp_0_s2.anchor_nodes), 0) - - # Rank 1: server 0 empty, server 1 second half (2 of 4), server 2 fully - result_1 = remote_dataset.fetch_ablp_input( - split="train", - rank=1, - world_size=2, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) - ablp_1_s0 = result_1[0] - self.assertEqual(len(ablp_1_s0.anchor_nodes), 0) - # Server 1: 4 * 1 // 2 = 2, 4 * 2 // 2 = 4 → last 2 - ablp_1_s1 = result_1[1] - self.assert_tensor_equality(ablp_1_s1.anchor_nodes, torch.tensor([12, 13])) - pos_1_s1, neg_1_s1 = ablp_1_s1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality(pos_1_s1, torch.tensor([[12, 13], [13, 14]])) - assert neg_1_s1 is not None - self.assert_tensor_equality(neg_1_s1, torch.tensor([[22], [23]])) - ablp_1_s2 = result_1[2] - self.assert_tensor_equality( - ablp_1_s2.anchor_nodes, torch.tensor([20, 21, 22, 23]) + expected_by_rank=expected_by_rank, ) def test_ablp_contiguous_requires_rank_and_world_size(self): """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" - cluster_info = _create_mock_graph_store_info(num_storage_nodes=2) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, + num_compute_nodes=2, + ) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) with self.assertRaises(ValueError): @@ -1072,6 +1124,13 @@ def test_ablp_contiguous_requires_rank_and_world_size(self): world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS, ) + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + rank=0, + world_size=3, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) def _test_fetch_free_ports_on_storage_cluster( diff --git a/tests/unit/distributed/graph_store/sharding_test.py b/tests/unit/distributed/graph_store/sharding_test.py new file mode 100644 index 000000000..7fcf69493 --- /dev/null +++ b/tests/unit/distributed/graph_store/sharding_test.py @@ -0,0 +1,150 @@ +import torch +from parameterized import param, parameterized + +from gigl.distributed.graph_store.sharding import ( + ServerSlice, + compute_server_assignments, +) +from tests.test_assets.test_case import TestCase + + +class TestComputeServerAssignments(TestCase): + @parameterized.expand( + [ + param( + "rank_0", + compute_rank=0, + expected_assignments={ + 0: ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + 1: ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + }, + ), + param( + "rank_1", + compute_rank=1, + expected_assignments={ + 1: ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + 2: ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + }, + ), + ] + ) + def test_fractional_boundary_assignment( + self, _, compute_rank: int, expected_assignments: dict[int, ServerSlice] + ) -> None: + assignments = compute_server_assignments( + num_servers=3, num_compute_nodes=2, compute_rank=compute_rank + ) + self.assertEqual(assignments, expected_assignments) + + def test_assignments_recombine_server_data(self) -> None: + tensor = torch.arange(7) + all_assignments = [ + compute_server_assignments( + num_servers=2, num_compute_nodes=5, compute_rank=rank + ) + for rank in range(5) + ] + + for server_rank in range(2): + combined = torch.cat( + [ + assignments[server_rank].slice_tensor(tensor) + for assignments in all_assignments + if server_rank in assignments + ] + ) + self.assert_tensor_equality(combined, tensor) + + @parameterized.expand( + [ + param( + "negative_servers", + num_servers=-1, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "zero_servers", + num_servers=0, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "negative_compute_nodes", + num_servers=2, + num_compute_nodes=-1, + compute_rank=0, + ), + param( + "zero_compute_nodes", + num_servers=2, + num_compute_nodes=0, + compute_rank=0, + ), + param( + "rank_too_large", + num_servers=2, + num_compute_nodes=2, + compute_rank=2, + ), + param( + "negative_rank", + num_servers=2, + num_compute_nodes=2, + compute_rank=-1, + ), + ] + ) + def test_validates_arguments( + self, _, num_servers: int, num_compute_nodes: int, compute_rank: int + ) -> None: + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=num_servers, + num_compute_nodes=num_compute_nodes, + compute_rank=compute_rank, + ) + + +class TestServerSlice(TestCase): + def test_full_tensor_returns_same_object(self) -> None: + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=1, end_num=1, end_den=1 + ) + result = server_slice.slice_tensor(tensor) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) + + def test_partial_slice_clones_requested_range(self) -> None: + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 + ) + result = server_slice.slice_tensor(tensor) + self.assert_tensor_equality(result, torch.arange(5)) + self.assertNotEqual(result.data_ptr(), tensor.data_ptr()) diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index 787ec68f8..ed20cb289 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -11,8 +11,6 @@ POSITIVE_LABEL_METADATA_KEY, ) from gigl.distributed.utils.neighborloader import ( - ServerSlice, - compute_server_assignments, extract_edge_type_metadata, extract_metadata, labeled_to_homogeneous, @@ -498,149 +496,6 @@ def test_set_custom_features_heterogeneous(self): ) -class TestComputeServerAssignments(TestCase): - @parameterized.expand( - [ - param( - "rank_0", - compute_rank=0, - expected_assignments={ - 0: ServerSlice( - server_rank=0, - start_num=0, - start_den=2, - end_num=2, - end_den=2, - ), - 1: ServerSlice( - server_rank=1, - start_num=0, - start_den=2, - end_num=1, - end_den=2, - ), - }, - ), - param( - "rank_1", - compute_rank=1, - expected_assignments={ - 1: ServerSlice( - server_rank=1, - start_num=1, - start_den=2, - end_num=2, - end_den=2, - ), - 2: ServerSlice( - server_rank=2, - start_num=0, - start_den=2, - end_num=2, - end_den=2, - ), - }, - ), - ] - ) - def test_fractional_boundary_assignment( - self, _, compute_rank: int, expected_assignments: dict[int, ServerSlice] - ) -> None: - assignments = compute_server_assignments( - num_servers=3, num_compute_nodes=2, compute_rank=compute_rank - ) - - self.assertEqual(assignments, expected_assignments) - - def test_assignments_recombine_server_data(self) -> None: - tensor = torch.arange(7) - all_assignments = [ - compute_server_assignments( - num_servers=2, num_compute_nodes=5, compute_rank=rank - ) - for rank in range(5) - ] - - for server_rank in range(2): - combined = torch.cat( - [ - assignments[server_rank].slice_tensor(tensor) - for assignments in all_assignments - if server_rank in assignments - ] - ) - self.assert_tensor_equality(combined, tensor) - - @parameterized.expand( - [ - param( - "negative_servers", - num_servers=-1, - num_compute_nodes=2, - compute_rank=0, - ), - param( - "zero_servers", - num_servers=0, - num_compute_nodes=2, - compute_rank=0, - ), - param( - "negative_compute_nodes", - num_servers=2, - num_compute_nodes=-1, - compute_rank=0, - ), - param( - "zero_compute_nodes", - num_servers=2, - num_compute_nodes=0, - compute_rank=0, - ), - param( - "rank_too_large", - num_servers=2, - num_compute_nodes=2, - compute_rank=2, - ), - param( - "negative_rank", - num_servers=2, - num_compute_nodes=2, - compute_rank=-1, - ), - ] - ) - def test_validates_arguments( - self, _, num_servers: int, num_compute_nodes: int, compute_rank: int - ) -> None: - with self.assertRaises(ValueError): - compute_server_assignments( - num_servers=num_servers, - num_compute_nodes=num_compute_nodes, - compute_rank=compute_rank, - ) - - -class TestServerSlice(TestCase): - def test_full_tensor_returns_same_object(self) -> None: - tensor = torch.arange(10) - server_slice = ServerSlice( - server_rank=0, start_num=0, start_den=1, end_num=1, end_den=1 - ) - result = server_slice.slice_tensor(tensor) - self.assertEqual(result.data_ptr(), tensor.data_ptr()) - - def test_partial_slice_clones_requested_range(self) -> None: - tensor = torch.arange(10) - server_slice = ServerSlice( - server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 - ) - result = server_slice.slice_tensor(tensor) - self.assert_tensor_equality(result, torch.arange(5)) - self.assertNotEqual(result.data_ptr(), tensor.data_ptr()) - - class ExtractMetadataTest(TestCase): def setUp(self): self._device = torch.device("cpu") From d44847e64fde4ac095ff2df447ad93b5f7a538c2 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:41:16 +0000 Subject: [PATCH 05/13] Add detailed docstring with examples to ShardStrategy enum Expand the one-line docstring to include concrete examples showing how ROUND_ROBIN and CONTIGUOUS strategies distribute node IDs across compute nodes, including split filtering and fractional server assignment. Co-Authored-By: Claude Opus 4.6 --- gigl/distributed/graph_store/sharding.py | 78 +++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/graph_store/sharding.py b/gigl/distributed/graph_store/sharding.py index 5d0fd5913..523d04b9f 100644 --- a/gigl/distributed/graph_store/sharding.py +++ b/gigl/distributed/graph_store/sharding.py @@ -7,7 +7,83 @@ class ShardStrategy(Enum): - """Strategies for splitting remote graph-store inputs across compute nodes.""" + """Strategies for splitting remote graph-store inputs across compute nodes. + + When fetching node IDs (or ABLP input) from storage servers, the shard + strategy controls how data is divided among compute nodes. + + Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. + Nodes are partitioned across storage nodes, with splits defined as:: + + Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] + train=[0, 1, 2, 3], val=[4, 5], test=[6, 7] + Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] + train=[8, 9, 10, 11], val=[12, 13], test=[14, 15] + + **ROUND_ROBIN** — Each storage server independently shards its own nodes + across the requested ``world_size``. Every compute node contacts every + storage server and receives an interleaved slice:: + + # All training nodes, sharded across 2 compute nodes (round-robin): + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train") + { + 0: tensor([0, 2]), # Even-indexed training nodes from storage 0 + 1: tensor([8, 10]) # Even-indexed training nodes from storage 1 + } + >>> dataset.fetch_node_ids(rank=1, world_size=2, split="train") + { + 0: tensor([1, 3]), # Odd-indexed training nodes from storage 0 + 1: tensor([9, 11]) # Odd-indexed training nodes from storage 1 + } + + Both strategies also support ``split=None``, which returns all nodes + (train + val + test) from each storage server:: + + # Round-robin, all nodes (no split), sharded across 2 compute nodes: + >>> dataset.fetch_node_ids(rank=0, world_size=2) + { + 0: tensor([0, 2, 4, 6]), # Even-indexed nodes from storage 0 + 1: tensor([8, 10, 12, 14]) # Even-indexed nodes from storage 1 + } + + # Contiguous, all nodes (no split), sharded across 2 compute nodes: + >>> dataset.fetch_node_ids(rank=0, world_size=2, + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All nodes from storage 0 + 1: tensor([]) # Nothing from storage 1 + } + + **CONTIGUOUS** — Storage servers are assigned to compute nodes in + contiguous blocks. Each compute node fetches *all* data from its + assigned server(s) and receives empty tensors for unassigned ones. + When servers outnumber compute nodes a server's data is fractionally + split; when compute nodes outnumber servers a compute node may own a + fraction of one server:: + + # All training nodes, sharded across 2 compute nodes (contiguous): + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3]), # All training nodes from storage 0 + 1: tensor([]) # Nothing from storage 1 + } + >>> dataset.fetch_node_ids(rank=1, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([]), # Nothing from storage 0 + 1: tensor([8, 9, 10, 11]) # All training nodes from storage 1 + } + + # With 3 storage nodes and 2 compute nodes, server 1 is fractionally split: + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3]), # All of storage 0 + 1: tensor([8, 9]), # First half of storage 1 + 2: tensor([]) # Nothing from storage 2 + } + """ ROUND_ROBIN = "round_robin" CONTIGUOUS = "contiguous" From 4023152d6de986637037068f40993691dce51e78 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:45:32 +0000 Subject: [PATCH 06/13] Extract _validate_contiguous_args to free function, remove world_size check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The world_size != num_compute_nodes validation was unnecessarily restrictive — callers may legitimately pass a different world_size. Also extract the validator to a module-level function since it no longer needs self. Co-Authored-By: Claude Opus 4.6 --- .../graph_store/remote_dist_dataset.py | 54 ++++++++----------- .../graph_store/remote_dist_dataset_test.py | 13 ----- 2 files changed, 22 insertions(+), 45 deletions(-) diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 1c37f91ce..4c1465c47 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -24,6 +24,22 @@ logger = Logger() +def _validate_contiguous_args( + rank: Optional[int], + world_size: Optional[int], + shard_strategy: ShardStrategy, +) -> None: + """Validate contiguous sharding inputs; no-op for round-robin.""" + if shard_strategy != ShardStrategy.CONTIGUOUS: + return + + if rank is None or world_size is None: + raise ValueError( + "Both rank and world_size must be provided when using " + f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" + ) + + class RemoteDistDataset: def __init__( self, @@ -164,30 +180,6 @@ def _infer_edge_type_if_homogeneous_with_label_edges( ) return edge_type - def _validate_contiguous_args( - self, - rank: Optional[int], - world_size: Optional[int], - shard_strategy: ShardStrategy, - ) -> tuple[Optional[int], Optional[int]]: - """Validate contiguous sharding inputs and preserve round-robin inputs unchanged.""" - if shard_strategy != ShardStrategy.CONTIGUOUS: - return rank, world_size - - if rank is None or world_size is None: - raise ValueError( - "Both rank and world_size must be provided when using " - f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" - ) - if world_size != self.cluster_info.num_compute_nodes: - raise ValueError( - "ShardStrategy.CONTIGUOUS expects world_size to equal " - "cluster_info.num_compute_nodes. " - f"Got world_size={world_size}, " - f"cluster_info.num_compute_nodes={self.cluster_info.num_compute_nodes}" - ) - return rank, world_size - def _compute_assignments_if_needed( self, rank: Optional[int], @@ -283,7 +275,7 @@ def fetch_node_ids( expects the compute-node rank. world_size (Optional[int]): The requested shard world size. ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` - requires ``world_size == cluster_info.num_compute_nodes``. + expects the compute-node world size. split (Optional[Literal["train", "val", "test"]]): The split of the dataset to get node ids from. If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. @@ -296,8 +288,7 @@ def fetch_node_ids( for unassigned servers. Raises: - ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None, - or if ``world_size`` does not match ``cluster_info.num_compute_nodes``. + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. Returns: dict[int, torch.Tensor]: A dict mapping storage rank to node ids. @@ -312,7 +303,7 @@ def fetch_node_ids( (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split. """ - rank, world_size = self._validate_contiguous_args( + _validate_contiguous_args( rank=rank, world_size=world_size, shard_strategy=shard_strategy, @@ -488,7 +479,7 @@ def fetch_ablp_input( expects the compute-node rank. world_size (Optional[int]): The requested shard world size. ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` - requires ``world_size == cluster_info.num_compute_nodes``. + expects the compute-node world size. anchor_node_type (Optional[NodeType]): The type of the anchor nodes to retrieve. Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. @@ -512,8 +503,7 @@ def fetch_ablp_input( - negative_labels: Optional dict mapping negative label EdgeType to a 2D tensor [N, M]. Raises: - ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None, - or if ``world_size`` does not match ``cluster_info.num_compute_nodes``. + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. Examples: See :class:`~gigl.distributed.graph_store.sharding.ShardStrategy` for @@ -521,7 +511,7 @@ def fetch_ablp_input( compute nodes. """ - rank, world_size = self._validate_contiguous_args( + _validate_contiguous_args( rank=rank, world_size=world_size, shard_strategy=shard_strategy, diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index 7f56c89a8..05bb7478d 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -859,12 +859,6 @@ def test_contiguous_requires_rank_and_world_size(self): world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS, ) - with self.assertRaises(ValueError): - remote_dataset.fetch_node_ids( - rank=0, - world_size=3, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) def test_contiguous_labeled_homogeneous_auto_inference(self): """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" @@ -1124,13 +1118,6 @@ def test_ablp_contiguous_requires_rank_and_world_size(self): world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS, ) - with self.assertRaises(ValueError): - remote_dataset.fetch_ablp_input( - split="train", - rank=0, - world_size=3, - shard_strategy=ShardStrategy.CONTIGUOUS, - ) def _test_fetch_free_ports_on_storage_cluster( From dfd8f944205dc386d071ee8b75995655733654ba Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:47:20 +0000 Subject: [PATCH 07/13] Remove unnecessary .clone() from ServerSlice.slice_tensor The sliced tensor holds a reference to the original, but in the contiguous flow the original is a local variable that goes out of scope, so the slice effectively owns the data. Removing clone() avoids an unnecessary copy. Co-Authored-By: Claude Opus 4.6 --- gigl/distributed/graph_store/sharding.py | 2 +- tests/unit/distributed/graph_store/sharding_test.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gigl/distributed/graph_store/sharding.py b/gigl/distributed/graph_store/sharding.py index 523d04b9f..6b515ad5d 100644 --- a/gigl/distributed/graph_store/sharding.py +++ b/gigl/distributed/graph_store/sharding.py @@ -106,7 +106,7 @@ def slice_tensor(self, tensor: torch.Tensor) -> torch.Tensor: end_idx = total * self.end_num // self.end_den if start_idx == 0 and end_idx == total: return tensor - return tensor[start_idx:end_idx].clone() + return tensor[start_idx:end_idx] def compute_server_assignments( diff --git a/tests/unit/distributed/graph_store/sharding_test.py b/tests/unit/distributed/graph_store/sharding_test.py index 7fcf69493..547524663 100644 --- a/tests/unit/distributed/graph_store/sharding_test.py +++ b/tests/unit/distributed/graph_store/sharding_test.py @@ -140,11 +140,10 @@ def test_full_tensor_returns_same_object(self) -> None: result = server_slice.slice_tensor(tensor) self.assertEqual(result.data_ptr(), tensor.data_ptr()) - def test_partial_slice_clones_requested_range(self) -> None: + def test_partial_slice_returns_requested_range(self) -> None: tensor = torch.arange(10) server_slice = ServerSlice( server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 ) result = server_slice.slice_tensor(tensor) self.assert_tensor_equality(result, torch.arange(5)) - self.assertNotEqual(result.data_ptr(), tensor.data_ptr()) From d336096699715bb72df95d7317362dd044b6faed Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:47:38 +0000 Subject: [PATCH 08/13] Upgrade integration test to compare actual node IDs, not just counts Replace the all_reduce count comparison with all_gather + sorted tensor comparison to catch cases where counts match but actual node IDs differ between CONTIGUOUS and ROUND_ROBIN strategies. Co-Authored-By: Claude Opus 4.6 --- .../graph_store_integration_test.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 9deef0155..9c0e8eec6 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -261,15 +261,27 @@ def _run_compute_train_tests( ), f"Rank {rank} should have empty tensor for server {other_rank}" # Assert total node parity: CONTIGUOUS and ROUND_ROBIN should cover the same nodes - local_contiguous_count = sum(t.numel() for t in contiguous_node_ids.values()) - local_round_robin_count = sum(t.numel() for t in random_negative_input.values()) - contiguous_total = torch.tensor(local_contiguous_count, dtype=torch.int64) - round_robin_total = torch.tensor(local_round_robin_count, dtype=torch.int64) - torch.distributed.all_reduce(contiguous_total, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(round_robin_total, op=torch.distributed.ReduceOp.SUM) - assert contiguous_total.item() == round_robin_total.item(), ( - f"CONTIGUOUS total ({contiguous_total.item()}) must equal " - f"ROUND_ROBIN total ({round_robin_total.item()})" + local_contiguous_nodes = torch.cat(list(contiguous_node_ids.values())) + local_round_robin_nodes = torch.cat(list(random_negative_input.values())) + + # Gather all nodes from all ranks + contiguous_gathered: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.long) + for _ in range(torch.distributed.get_world_size()) + ] + round_robin_gathered: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.long) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather_object(contiguous_gathered, local_contiguous_nodes) + torch.distributed.all_gather_object(round_robin_gathered, local_round_robin_nodes) + + all_contiguous = torch.cat(contiguous_gathered).sort().values + all_round_robin = torch.cat(round_robin_gathered).sort().values + assert torch.equal(all_contiguous, all_round_robin), ( + f"CONTIGUOUS and ROUND_ROBIN must produce the same sorted node set. " + f"CONTIGUOUS: {all_contiguous[:10]}... ({all_contiguous.numel()} nodes), " + f"ROUND_ROBIN: {all_round_robin[:10]}... ({all_round_robin.numel()} nodes)" ) torch.distributed.barrier() From 836a6e7a7c6cb125b7a1c761e262081b2d9ff606 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:54:49 +0000 Subject: [PATCH 09/13] Simplify TestRemoteDistDatasetContiguous test class - Merge _make_rank_aware_async_mock and _make_rank_aware_ablp_async_mock into a single generic helper - Remove _assert_contiguous_node_ids and _assert_contiguous_ablp_inputs helpers, inline assertions directly in tests - Replace @parameterized.expand with separate named test methods for better readability - Fix stale variable reference in integration test log line Co-Authored-By: Claude Opus 4.6 --- .../graph_store_integration_test.py | 2 +- .../graph_store/remote_dist_dataset_test.py | 518 ++++++++---------- 2 files changed, 227 insertions(+), 293 deletions(-) diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 9c0e8eec6..a5ed5753b 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -287,7 +287,7 @@ def _run_compute_train_tests( torch.distributed.barrier() logger.info( f"Rank {torch.distributed.get_rank()} CONTIGUOUS: " - f"{local_contiguous_count} nodes from assigned server" + f"{local_contiguous_nodes.numel()} nodes from assigned server" ) shutdown_compute_proccess() diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index 05bb7478d..b9003cc64 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -1,12 +1,12 @@ +from collections.abc import Callable from contextlib import contextmanager -from typing import Final, Literal, Optional +from typing import Any, Final, Optional from unittest.mock import patch import torch import torch.distributed as dist import torch.multiprocessing as mp from absl.testing import absltest -from parameterized import param, parameterized import gigl.distributed.graph_store.dist_server as dist_server_module from gigl.common import LocalUri @@ -702,28 +702,28 @@ class TestRemoteDistDatasetContiguous(RemoteDistDatasetTestBase): """Tests for fetch_node_ids and fetch_ablp_input with ShardStrategy.CONTIGUOUS.""" def _make_rank_aware_async_mock( - self, server_data: dict[int, dict[str, torch.Tensor]] - ): - """Create an async mock that returns different node IDs per server rank. + self, server_data: dict[int, Any] + ) -> Callable[..., torch.futures.Future]: + """Create an async mock that returns pre-set data per server rank. Args: - server_data: Maps server_rank to a dict of - ``{"all": tensor, "train": tensor, ...}`` where ``"all"`` - is the full node set and split keys are optional. + server_data: Maps server_rank to the value that server should return. + Can be a tensor (for node ID tests) or a tuple (for ABLP tests). """ - def _mock(server_rank, func, *args, **kwargs): - split = kwargs.get("split") - data = server_data[server_rank] - key = split if split is not None and split in data else "all" + def _mock( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> torch.futures.Future: future: torch.futures.Future = torch.futures.Future() - future.set_result(data[key]) + future.set_result(server_data[server_rank]) return future return _mock @staticmethod - def _mock_request_server_homogeneous(server_rank, func, *args, **kwargs): + def _mock_request_server_homogeneous( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: """Mock request_server that returns None for node/edge types (homogeneous).""" if func == DistServer.get_node_types: return None @@ -731,117 +731,92 @@ def _mock_request_server_homogeneous(server_rank, func, *args, **kwargs): return None return _mock_request_server(server_rank, func, *args, **kwargs) - def _assert_contiguous_node_ids( - self, - remote_dataset: RemoteDistDataset, - world_size: int, - expected_by_rank: dict[int, dict[int, torch.Tensor]], - split: Optional[Literal["train", "val", "test"]] = None, - ) -> None: - for rank, expected_by_server in expected_by_rank.items(): - result = remote_dataset.fetch_node_ids( - rank=rank, - world_size=world_size, - split=split, - shard_strategy=ShardStrategy.CONTIGUOUS, + def test_fetch_node_ids_contiguous_even_split(self) -> None: + """CONTIGUOUS with 2 storage nodes and 2 compute nodes: each rank gets one server.""" + server_data: dict[int, torch.Tensor] = { + 0: torch.arange(10), + 1: torch.arange(10, 20), + } + mock_fn = self._make_rank_aware_async_mock(server_data) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: gets all of server 0, empty from server 1 + result = ds.fetch_node_ids( + rank=0, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS ) - self.assertEqual(set(result.keys()), set(expected_by_server.keys())) - for server_rank, expected_tensor in expected_by_server.items(): - self.assert_tensor_equality(result[server_rank], expected_tensor) - - @parameterized.expand( - [ - param( - "even_split", - num_storage_nodes=2, - server_data={ - 0: {"all": torch.arange(10)}, - 1: {"all": torch.arange(10, 20)}, - }, - expected_by_rank={ - 0: { - 0: torch.arange(10), - 1: torch.empty(0, dtype=torch.long), - }, - 1: { - 0: torch.empty(0, dtype=torch.long), - 1: torch.arange(10, 20), - }, - }, - ), - param( - "fractional_split", - num_storage_nodes=3, - server_data={ - 0: {"all": torch.arange(10)}, - 1: {"all": torch.arange(10, 20)}, - 2: {"all": torch.arange(20, 30)}, - }, - expected_by_rank={ - 0: { - 0: torch.arange(10), - 1: torch.arange(10, 15), - 2: torch.empty(0, dtype=torch.long), - }, - 1: { - 0: torch.empty(0, dtype=torch.long), - 1: torch.arange(15, 20), - 2: torch.arange(20, 30), - }, - }, - ), - ] - ) - def test_fetch_node_ids_contiguous( - self, - _, - num_storage_nodes: int, - server_data: dict[int, dict[str, torch.Tensor]], - expected_by_rank: dict[int, dict[int, torch.Tensor]], - ) -> None: + self.assert_tensor_equality(result[0], torch.arange(10)) + self.assertEqual(result[1].numel(), 0) + + # Rank 1: empty from server 0, gets all of server 1 + result = ds.fetch_node_ids( + rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assertEqual(result[0].numel(), 0) + self.assert_tensor_equality(result[1], torch.arange(10, 20)) + + def test_fetch_node_ids_contiguous_fractional_split(self) -> None: + """CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 is fractionally split.""" + server_data: dict[int, torch.Tensor] = { + 0: torch.arange(10), + 1: torch.arange(10, 20), + 2: torch.arange(20, 30), + } mock_fn = self._make_rank_aware_async_mock(server_data) cluster_info = _create_mock_graph_store_info( - num_storage_nodes=num_storage_nodes, - num_compute_nodes=2, + num_storage_nodes=3, num_compute_nodes=2 ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - self._assert_contiguous_node_ids( - remote_dataset=remote_dataset, - world_size=2, - expected_by_rank=expected_by_rank, + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: all of server 0, first half of server 1, nothing from server 2 + result = ds.fetch_node_ids( + rank=0, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS ) + self.assert_tensor_equality(result[0], torch.arange(10)) + self.assert_tensor_equality(result[1], torch.arange(10, 15)) + self.assertEqual(result[2].numel(), 0) - def test_with_split_filtering(self): + # Rank 1: nothing from server 0, second half of server 1, all of server 2 + result = ds.fetch_node_ids( + rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assertEqual(result[0].numel(), 0) + self.assert_tensor_equality(result[1], torch.arange(15, 20)) + self.assert_tensor_equality(result[2], torch.arange(20, 30)) + + def test_with_split_filtering(self) -> None: """CONTIGUOUS strategy with split='train' filtering.""" - server_data = { - 0: {"all": torch.arange(10), "train": torch.tensor([0, 1, 2, 3])}, - 1: {"all": torch.arange(10, 20), "train": torch.tensor([10, 11, 12, 13])}, + server_data: dict[int, torch.Tensor] = { + 0: torch.tensor([0, 1, 2, 3]), + 1: torch.tensor([10, 11, 12, 13]), } mock_fn = self._make_rank_aware_async_mock(server_data) cluster_info = _create_mock_graph_store_info( - num_storage_nodes=2, - num_compute_nodes=2, + num_storage_nodes=2, num_compute_nodes=2 ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - result_0 = remote_dataset.fetch_node_ids( + result = ds.fetch_node_ids( rank=0, world_size=2, split="train", shard_strategy=ShardStrategy.CONTIGUOUS, ) - self.assert_tensor_equality(result_0[0], torch.tensor([0, 1, 2, 3])) - self.assertEqual(len(result_0[1]), 0) + self.assert_tensor_equality(result[0], torch.tensor([0, 1, 2, 3])) + self.assertEqual(result[1].numel(), 0) - def test_contiguous_requires_rank_and_world_size(self): + def test_contiguous_requires_rank_and_world_size(self) -> None: """CONTIGUOUS without rank/world_size raises ValueError.""" cluster_info = _create_mock_graph_store_info( - num_storage_nodes=2, - num_compute_nodes=2, + num_storage_nodes=2, num_compute_nodes=2 ) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -860,7 +835,7 @@ def test_contiguous_requires_rank_and_world_size(self): shard_strategy=ShardStrategy.CONTIGUOUS, ) - def test_contiguous_labeled_homogeneous_auto_inference(self): + def test_contiguous_labeled_homogeneous_auto_inference(self) -> None: """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" _create_server_with_splits( edge_indices={ @@ -878,7 +853,6 @@ def test_contiguous_labeled_homogeneous_auto_inference(self): with _patch_remote_requests(_mock_async_request_server, _mock_request_server): remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - # No node_type: should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE result = remote_dataset.fetch_node_ids( rank=0, world_size=1, @@ -890,214 +864,174 @@ def test_contiguous_labeled_homogeneous_auto_inference(self): torch.tensor([0, 1, 2]), ) - def _make_rank_aware_ablp_async_mock( - self, + def test_fetch_ablp_input_contiguous_even_split(self) -> None: + """ABLP CONTIGUOUS with 2 storage nodes and 2 compute nodes.""" server_data: dict[ - int, - dict[ - str, - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], - ], - ], - ): - """Create an async mock that returns different ABLP data per server rank. + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = { + 0: ( + torch.tensor([0, 1, 2]), + torch.tensor([[0, 1], [1, 2], [2, 3]]), + torch.tensor([[4], [5], [6]]), + ), + 1: ( + torch.tensor([10, 11, 12]), + torch.tensor([[10, 11], [11, 12], [12, 13]]), + torch.tensor([[14], [15], [16]]), + ), + } + mock_fn = self._make_rank_aware_async_mock(server_data) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) - Args: - server_data: Maps server_rank to a dict of - ``{"all": (anchors, pos, neg), "train": (anchors, pos, neg), ...}`` - where ``"all"`` is the full data and split keys are optional. - """ + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - def _mock(server_rank, func, *args, **kwargs): - split = kwargs.get("split") - data = server_data[server_rank] - key = split if split is not None and split in data else "all" - future: torch.futures.Future = torch.futures.Future() - future.set_result(data[key]) - return future + # Rank 0: gets all of server 0, empty from server 1 + result = ds.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0 = result[0] + self.assert_tensor_equality(ablp_0.anchor_nodes, torch.tensor([0, 1, 2])) + pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos, torch.tensor([[0, 1], [1, 2], [2, 3]]) + ) + assert neg is not None + self.assert_tensor_equality(neg, torch.tensor([[4], [5], [6]])) - return _mock + ablp_1 = result[1] + self.assertEqual(ablp_1.anchor_nodes.numel(), 0) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assertEqual(pos_1.numel(), 0) + self.assertIsNone(neg_1) - def _assert_contiguous_ablp_inputs( - self, - remote_dataset: RemoteDistDataset, - world_size: int, - expected_by_rank: dict[ - int, - dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], - ], - ) -> None: - for rank, expected_by_server in expected_by_rank.items(): - result = remote_dataset.fetch_ablp_input( + # Rank 1: empty from server 0, gets all of server 1 + result = ds.fetch_ablp_input( split="train", - rank=rank, - world_size=world_size, + rank=1, + world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS, ) - self.assertEqual(set(result.keys()), set(expected_by_server.keys())) - for server_rank, ( - expected_anchors, - expected_positive, - expected_negative, - ) in expected_by_server.items(): - ablp_input = result[server_rank] - self.assert_tensor_equality(ablp_input.anchor_nodes, expected_anchors) - positive, negative = ablp_input.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality(positive, expected_positive) - if expected_negative is None: - self.assertIsNone(negative) - else: - assert negative is not None - self.assert_tensor_equality(negative, expected_negative) - - @parameterized.expand( - [ - param( - "even_split", - num_storage_nodes=2, - server_data={ - 0: { - "train": ( - torch.tensor([0, 1, 2]), - torch.tensor([[0, 1], [1, 2], [2, 3]]), - torch.tensor([[4], [5], [6]]), - ), - }, - 1: { - "train": ( - torch.tensor([10, 11, 12]), - torch.tensor([[10, 11], [11, 12], [12, 13]]), - torch.tensor([[14], [15], [16]]), - ), - }, - }, - expected_by_rank={ - 0: { - 0: ( - torch.tensor([0, 1, 2]), - torch.tensor([[0, 1], [1, 2], [2, 3]]), - torch.tensor([[4], [5], [6]]), - ), - 1: ( - torch.empty(0, dtype=torch.long), - torch.empty((0, 0), dtype=torch.long), - None, - ), - }, - 1: { - 0: ( - torch.empty(0, dtype=torch.long), - torch.empty((0, 0), dtype=torch.long), - None, - ), - 1: ( - torch.tensor([10, 11, 12]), - torch.tensor([[10, 11], [11, 12], [12, 13]]), - torch.tensor([[14], [15], [16]]), - ), - }, - }, + ablp_0 = result[0] + self.assertEqual(ablp_0.anchor_nodes.numel(), 0) + + ablp_1 = result[1] + self.assert_tensor_equality( + ablp_1.anchor_nodes, torch.tensor([10, 11, 12]) + ) + pos, neg = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos, torch.tensor([[10, 11], [11, 12], [12, 13]]) + ) + assert neg is not None + self.assert_tensor_equality(neg, torch.tensor([[14], [15], [16]])) + + def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: + """ABLP CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 fractionally split.""" + server_data: dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = { + 0: ( + torch.tensor([0, 1, 2, 3]), + torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), + torch.tensor([[10], [11], [12], [13]]), ), - param( - "fractional_split", - num_storage_nodes=3, - server_data={ - 0: { - "train": ( - torch.tensor([0, 1, 2, 3]), - torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), - torch.tensor([[10], [11], [12], [13]]), - ), - }, - 1: { - "train": ( - torch.tensor([10, 11, 12, 13]), - torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), - torch.tensor([[20], [21], [22], [23]]), - ), - }, - 2: { - "train": ( - torch.tensor([20, 21, 22, 23]), - torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), - torch.tensor([[30], [31], [32], [33]]), - ), - }, - }, - expected_by_rank={ - 0: { - 0: ( - torch.tensor([0, 1, 2, 3]), - torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), - torch.tensor([[10], [11], [12], [13]]), - ), - 1: ( - torch.tensor([10, 11]), - torch.tensor([[10, 11], [11, 12]]), - torch.tensor([[20], [21]]), - ), - 2: ( - torch.empty(0, dtype=torch.long), - torch.empty((0, 0), dtype=torch.long), - None, - ), - }, - 1: { - 0: ( - torch.empty(0, dtype=torch.long), - torch.empty((0, 0), dtype=torch.long), - None, - ), - 1: ( - torch.tensor([12, 13]), - torch.tensor([[12, 13], [13, 14]]), - torch.tensor([[22], [23]]), - ), - 2: ( - torch.tensor([20, 21, 22, 23]), - torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), - torch.tensor([[30], [31], [32], [33]]), - ), - }, - }, + 1: ( + torch.tensor([10, 11, 12, 13]), + torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), + torch.tensor([[20], [21], [22], [23]]), ), - ] - ) - def test_fetch_ablp_input_contiguous( - self, - _, - num_storage_nodes: int, - server_data: dict[ - int, - dict[ - str, - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], - ], - ], - expected_by_rank: dict[ - int, - dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], - ], - ) -> None: - mock_fn = self._make_rank_aware_ablp_async_mock(server_data) + 2: ( + torch.tensor([20, 21, 22, 23]), + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + torch.tensor([[30], [31], [32], [33]]), + ), + } + mock_fn = self._make_rank_aware_async_mock(server_data) cluster_info = _create_mock_graph_store_info( - num_storage_nodes=num_storage_nodes, - num_compute_nodes=2, + num_storage_nodes=3, num_compute_nodes=2 ) with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - self._assert_contiguous_ablp_inputs( - remote_dataset=remote_dataset, + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: all of server 0, first half of server 1, nothing from server 2 + result = ds.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0 = result[0] + self.assert_tensor_equality( + ablp_0.anchor_nodes, torch.tensor([0, 1, 2, 3]) + ) + pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos, torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]) + ) + assert neg is not None + self.assert_tensor_equality( + neg, torch.tensor([[10], [11], [12], [13]]) + ) + + ablp_1 = result[1] + self.assert_tensor_equality( + ablp_1.anchor_nodes, torch.tensor([10, 11]) + ) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos_1, torch.tensor([[10, 11], [11, 12]]) + ) + assert neg_1 is not None + self.assert_tensor_equality(neg_1, torch.tensor([[20], [21]])) + + ablp_2 = result[2] + self.assertEqual(ablp_2.anchor_nodes.numel(), 0) + + # Rank 1: nothing from server 0, second half of server 1, all of server 2 + result = ds.fetch_ablp_input( + split="train", + rank=1, world_size=2, - expected_by_rank=expected_by_rank, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assertEqual(result[0].anchor_nodes.numel(), 0) + + ablp_1 = result[1] + self.assert_tensor_equality( + ablp_1.anchor_nodes, torch.tensor([12, 13]) + ) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos_1, torch.tensor([[12, 13], [13, 14]]) + ) + assert neg_1 is not None + self.assert_tensor_equality(neg_1, torch.tensor([[22], [23]])) + + ablp_2 = result[2] + self.assert_tensor_equality( + ablp_2.anchor_nodes, torch.tensor([20, 21, 22, 23]) + ) + pos_2, neg_2 = ablp_2.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos_2, + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + ) + assert neg_2 is not None + self.assert_tensor_equality( + neg_2, torch.tensor([[30], [31], [32], [33]]) ) - def test_ablp_contiguous_requires_rank_and_world_size(self): + def test_ablp_contiguous_requires_rank_and_world_size(self) -> None: """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" cluster_info = _create_mock_graph_store_info( - num_storage_nodes=2, - num_compute_nodes=2, + num_storage_nodes=2, num_compute_nodes=2 ) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) From 60e6786629cd2feda53bf7c26bd0d64a1a07cae2 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 16:58:25 +0000 Subject: [PATCH 10/13] Add missing type annotations to test helper functions Annotate _mock_request_server, _mock_async_request_server, _patch_remote_requests, and _create_server_with_splits kwargs with proper type hints. Add Callable, Iterator, and Any imports. Co-Authored-By: Claude Opus 4.6 --- .../graph_store/remote_dist_dataset_test.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index b9003cc64..ece0b44a7 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Iterator from contextlib import contextmanager from typing import Any, Final, Optional from unittest.mock import patch @@ -62,12 +62,16 @@ _DEFAULT_TEST_IDS: Final[list[int]] = [4] -def _mock_request_server(server_rank, func, *args, **kwargs): +def _mock_request_server( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: """Mock request_server that routes through _call_func_on_server.""" return _call_func_on_server(func, *args, **kwargs) -def _mock_async_request_server(server_rank, func, *args, **kwargs): +def _mock_async_request_server( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any +) -> torch.futures.Future: """Mock async_request_server that routes through _call_func_on_server and returns a future.""" future: torch.futures.Future = torch.futures.Future() future.set_result(_call_func_on_server(func, *args, **kwargs)) @@ -99,7 +103,10 @@ def _create_mock_graph_store_info( @contextmanager -def _patch_remote_requests(async_side_effect, sync_side_effect): +def _patch_remote_requests( + async_side_effect: Callable[..., torch.futures.Future], + sync_side_effect: Callable[..., Any], +) -> Iterator[None]: """Context manager that patches both async_request_server and request_server.""" with ( patch( @@ -131,7 +138,7 @@ def _create_server_with_splits( global _test_server create_test_process_group() - kwargs: dict = {} + kwargs: dict[str, Any] = {} if src_node_type is not None: kwargs.update( src_node_type=src_node_type, From c20558270d368dd3572a89936fbe76be6b4020ef Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 17:16:41 +0000 Subject: [PATCH 11/13] update --- .../dist_ablp_neighborloader_test.py | 55 ------------------- .../graph_store/remote_dist_dataset_test.py | 36 +++--------- 2 files changed, 9 insertions(+), 82 deletions(-) diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index c73465733..fd14aadd5 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -1,7 +1,6 @@ import unittest from collections import defaultdict from typing import Literal, Optional, Union -from unittest.mock import Mock, patch import torch import torch.multiprocessing as mp @@ -29,7 +28,6 @@ ) from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, - DEFAULT_HOMOGENEOUS_NODE_TYPE, GraphPartitionData, PartitionOutput, message_passing_to_negative_label, @@ -38,7 +36,6 @@ to_homogeneous, ) from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter -from gigl.utils.sampling import ABLPInputNodes from tests.test_assets.distributed.utils import ( assert_tensor_equality, create_test_process_group, @@ -427,58 +424,6 @@ def tearDown(self): torch.distributed.destroy_process_group() super().tearDown() - def test_graph_store_setup_detects_negatives_across_all_servers(self) -> None: - loader = DistABLPLoader.__new__(DistABLPLoader) - loader._instance_count = 0 - loader._shutdowned = True - - dataset = Mock() - dataset.cluster_info.num_storage_nodes = 2 - dataset.cluster_info.compute_cluster_world_size = 2 - dataset.cluster_info.storage_cluster_master_ip = "127.0.0.1" - dataset.fetch_node_feature_info.return_value = None - dataset.fetch_edge_feature_info.return_value = None - dataset.fetch_edge_types.return_value = None - dataset.fetch_free_ports_on_storage_cluster.return_value = [12345, 12346] - - input_nodes = { - 0: ABLPInputNodes( - anchor_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, - anchor_nodes=torch.empty(0, dtype=torch.long), - labels={ - DEFAULT_HOMOGENEOUS_EDGE_TYPE: ( - torch.empty((0, 0), dtype=torch.long), - None, - ) - }, - ), - 1: ABLPInputNodes( - anchor_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, - anchor_nodes=torch.tensor([1]), - labels={ - DEFAULT_HOMOGENEOUS_EDGE_TYPE: ( - torch.tensor([[2]]), - torch.tensor([[3]]), - ) - }, - ), - } - - with patch("torch.distributed.get_rank", return_value=0): - input_data, _, _ = loader._setup_for_graph_store( - input_nodes=input_nodes, - dataset=dataset, - num_workers=1, - ) - - self.assertEqual(loader._negative_label_edge_types, [_NEGATIVE_EDGE_TYPE]) - self.assertEqual(input_data[0].negative_label_by_edge_types, {}) - self.assertIn(_NEGATIVE_EDGE_TYPE, input_data[1].negative_label_by_edge_types) - self.assert_tensor_equality( - input_data[0].positive_label_by_edge_types[_POSITIVE_EDGE_TYPE], - torch.empty((0, 0), dtype=torch.long), - ) - @parameterized.expand( [ param( diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index ece0b44a7..d05a91947 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -905,9 +905,7 @@ def test_fetch_ablp_input_contiguous_even_split(self) -> None: ablp_0 = result[0] self.assert_tensor_equality(ablp_0.anchor_nodes, torch.tensor([0, 1, 2])) pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality( - pos, torch.tensor([[0, 1], [1, 2], [2, 3]]) - ) + self.assert_tensor_equality(pos, torch.tensor([[0, 1], [1, 2], [2, 3]])) assert neg is not None self.assert_tensor_equality(neg, torch.tensor([[4], [5], [6]])) @@ -928,9 +926,7 @@ def test_fetch_ablp_input_contiguous_even_split(self) -> None: self.assertEqual(ablp_0.anchor_nodes.numel(), 0) ablp_1 = result[1] - self.assert_tensor_equality( - ablp_1.anchor_nodes, torch.tensor([10, 11, 12]) - ) + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([10, 11, 12])) pos, neg = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] self.assert_tensor_equality( pos, torch.tensor([[10, 11], [11, 12], [12, 13]]) @@ -975,26 +971,18 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: shard_strategy=ShardStrategy.CONTIGUOUS, ) ablp_0 = result[0] - self.assert_tensor_equality( - ablp_0.anchor_nodes, torch.tensor([0, 1, 2, 3]) - ) + self.assert_tensor_equality(ablp_0.anchor_nodes, torch.tensor([0, 1, 2, 3])) pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] self.assert_tensor_equality( pos, torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]) ) assert neg is not None - self.assert_tensor_equality( - neg, torch.tensor([[10], [11], [12], [13]]) - ) + self.assert_tensor_equality(neg, torch.tensor([[10], [11], [12], [13]])) ablp_1 = result[1] - self.assert_tensor_equality( - ablp_1.anchor_nodes, torch.tensor([10, 11]) - ) + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([10, 11])) pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality( - pos_1, torch.tensor([[10, 11], [11, 12]]) - ) + self.assert_tensor_equality(pos_1, torch.tensor([[10, 11], [11, 12]])) assert neg_1 is not None self.assert_tensor_equality(neg_1, torch.tensor([[20], [21]])) @@ -1011,13 +999,9 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: self.assertEqual(result[0].anchor_nodes.numel(), 0) ablp_1 = result[1] - self.assert_tensor_equality( - ablp_1.anchor_nodes, torch.tensor([12, 13]) - ) + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([12, 13])) pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] - self.assert_tensor_equality( - pos_1, torch.tensor([[12, 13], [13, 14]]) - ) + self.assert_tensor_equality(pos_1, torch.tensor([[12, 13], [13, 14]])) assert neg_1 is not None self.assert_tensor_equality(neg_1, torch.tensor([[22], [23]])) @@ -1031,9 +1015,7 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), ) assert neg_2 is not None - self.assert_tensor_equality( - neg_2, torch.tensor([[30], [31], [32], [33]]) - ) + self.assert_tensor_equality(neg_2, torch.tensor([[30], [31], [32], [33]])) def test_ablp_contiguous_requires_rank_and_world_size(self) -> None: """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" From be6ce6c6e1b2eb9aa1c894f52791265e1502691c Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 17:55:57 +0000 Subject: [PATCH 12/13] update --- gigl/distributed/graph_store/dist_server.py | 98 +++---- gigl/distributed/graph_store/messages.py | 54 ++++ .../graph_store/remote_dist_dataset.py | 175 +++++------- tests/unit/distributed/dist_server_test.py | 228 ++++++++++++--- .../distributed/graph_store/messages_test.py | 124 ++++++++ .../graph_store/remote_dist_dataset_test.py | 264 +++++++++++++++++- 6 files changed, 725 insertions(+), 218 deletions(-) create mode 100644 gigl/distributed/graph_store/messages.py create mode 100644 tests/unit/distributed/graph_store/messages_test.py diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index da8db83e2..633d1db9f 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -34,16 +34,12 @@ from gigl.common.logger import Logger from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest from gigl.distributed.sampler import ABLPNodeSamplerInput from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType -from gigl.types.graph import ( - DEFAULT_HOMOGENEOUS_EDGE_TYPE, - DEFAULT_HOMOGENEOUS_NODE_TYPE, - FeatureInfo, - select_label_edge_types, -) +from gigl.types.graph import FeatureInfo, select_label_edge_types from gigl.utils.data_splitters import get_labels_for_anchor_nodes SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0 @@ -277,19 +273,13 @@ def get_edge_dir(self) -> Literal["in", "out"]: def get_node_ids( self, - rank: Optional[int] = None, - world_size: Optional[int] = None, - split: Optional[Union[Literal["train", "val", "test"], str]] = None, - node_type: Optional[NodeType] = None, + request: FetchNodesRequest, ) -> torch.Tensor: """Get the node ids from the dataset. Args: - rank: The rank of the process requesting node ids. Must be provided if world_size is provided. - world_size: The total number of processes in the distributed setup. Must be provided if rank is provided. - split: The split of the dataset to get node ids from. If provided, the dataset must have - `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. - node_type: The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous. + request: The node-fetch request, including split, node type, + and either round-robin rank/world_size or a contiguous slice. Returns: The node ids. @@ -302,63 +292,40 @@ def get_node_ids( * If the node type is provided for a homogeneous dataset * If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided - Examples: - Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99]. - - Get all node ids (no split filtering): - - >>> server.get_node_ids() - tensor([0, 1, 2, ..., 99]) # All 100 nodes - - Get only training nodes: - - >>> server.get_node_ids(split="train") - tensor([0, 1, 2, ..., 59]) # 60 training nodes - - Shard all nodes across 4 processes (each gets ~25 nodes): - - >>> server.get_node_ids(rank=0, world_size=4) - tensor([0, 1, 2, ..., 24]) # First 25 of all 100 nodes - - Shard training nodes across 4 processes (each gets ~15 nodes): - - >>> server.get_node_ids(rank=0, world_size=4, split="train") - tensor([0, 1, 2, ..., 14]) # First 15 of the 60 training nodes - Note: When `split=None`, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split. """ - if (rank is None) ^ (world_size is None): - raise ValueError( - f"rank and world_size must be provided together. Received rank: {rank}, world_size: {world_size}" - ) - if split == "train": + request.validate() + if request.split == "train": nodes = self.dataset.train_node_ids - elif split == "val": + elif request.split == "val": nodes = self.dataset.val_node_ids - elif split == "test": + elif request.split == "test": nodes = self.dataset.test_node_ids - elif split is None: + elif request.split is None: nodes = self.dataset.node_ids else: raise ValueError( - f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None." + f"Invalid split: {request.split}. Must be one of 'train', 'val', 'test', or None." ) - if node_type is not None: + if request.node_type is not None: if not isinstance(nodes, abc.Mapping): raise ValueError( - f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}" + f"node_type was provided as {request.node_type}, so node ids must be a dict[NodeType, torch.Tensor] " + f"(e.g. a heterogeneous dataset), got {type(nodes)}" ) - nodes = nodes[node_type] + nodes = nodes[request.node_type] elif not isinstance(nodes, torch.Tensor): raise ValueError( f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." ) - if rank is not None and world_size is not None: - return shard_nodes_by_process(nodes, rank, world_size) + if request.server_slice is not None: + return request.server_slice.slice_tensor(nodes) + if request.rank is not None and request.world_size is not None: + return shard_nodes_by_process(nodes, request.rank, request.world_size) return nodes def get_edge_types(self) -> Optional[list[EdgeType]]: @@ -385,11 +352,7 @@ def get_node_types(self) -> Optional[list[NodeType]]: def get_ablp_input( self, - split: Union[Literal["train", "val", "test"], str], - rank: Optional[int] = None, - world_size: Optional[int] = None, - node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, - supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + request: FetchABLPRequest, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing. @@ -397,13 +360,9 @@ def get_ablp_input( e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2. Args: - split: The split to get the training input for. - rank: The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned. - Must be provided if world_size is provided. - world_size: The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned. - Must be provided if rank is provided. - node_type: The type of nodes to retrieve. Defaults to the default homogeneous node type. - supervision_edge_type: The edge type to use for the supervision. Defaults to the default homogeneous edge type. + request: The ABLP fetch request, including split, node type, + supervision edge type, and either round-robin rank/world_size + or a contiguous slice. Returns: A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. @@ -414,11 +373,18 @@ def get_ablp_input( Raises: ValueError: If the split is invalid. """ + request.validate() anchors = self.get_node_ids( - split=split, rank=rank, world_size=world_size, node_type=node_type + FetchNodesRequest( + split=request.split, + rank=request.rank, + world_size=request.world_size, + node_type=request.node_type, + server_slice=request.server_slice, + ) ) positive_label_edge_type, negative_label_edge_type = select_label_edge_types( - supervision_edge_type, self.dataset.get_edge_types() + request.supervision_edge_type, self.dataset.get_edge_types() ) positive_labels, negative_labels = get_labels_for_anchor_nodes( self.dataset, anchors, positive_label_edge_type, negative_label_edge_type diff --git a/gigl/distributed/graph_store/messages.py b/gigl/distributed/graph_store/messages.py new file mode 100644 index 000000000..07bce62e8 --- /dev/null +++ b/gigl/distributed/graph_store/messages.py @@ -0,0 +1,54 @@ +"""RPC request messages for graph-store fetch operations.""" + +from dataclasses import dataclass +from typing import Literal, Optional, Union + +from gigl.distributed.graph_store.sharding import ServerSlice +from gigl.src.common.types.graph_data import EdgeType, NodeType + + +@dataclass(frozen=True) +class FetchNodesRequest: + """Request for fetching node IDs from a storage server.""" + + rank: Optional[int] = None + world_size: Optional[int] = None + split: Optional[Union[Literal["train", "val", "test"], str]] = None + node_type: Optional[NodeType] = None + server_slice: Optional[ServerSlice] = None + + def validate(self) -> None: + """Validate that the request does not mix sharding modes.""" + if (self.rank is None) ^ (self.world_size is None): + raise ValueError( + "rank and world_size must be provided together. " + f"Received rank={self.rank}, world_size={self.world_size}" + ) + if self.server_slice is not None and ( + self.rank is not None or self.world_size is not None + ): + raise ValueError("server_slice cannot be combined with rank/world_size.") + + +@dataclass(frozen=True) +class FetchABLPRequest: + """Request for fetching ABLP input from a storage server.""" + + split: Union[Literal["train", "val", "test"], str] + node_type: NodeType + supervision_edge_type: EdgeType + rank: Optional[int] = None + world_size: Optional[int] = None + server_slice: Optional[ServerSlice] = None + + def validate(self) -> None: + """Validate that the request does not mix sharding modes.""" + if (self.rank is None) ^ (self.world_size is None): + raise ValueError( + "rank and world_size must be provided together. " + f"Received rank={self.rank}, world_size={self.world_size}" + ) + if self.server_slice is not None and ( + self.rank is not None or self.world_size is not None + ): + raise ValueError("server_slice cannot be combined with rank/world_size.") diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 4c1465c47..425d608cb 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -6,6 +6,7 @@ from gigl.common.logger import Logger from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest from gigl.distributed.graph_store.sharding import ( ServerSlice, ShardStrategy, @@ -208,52 +209,46 @@ def _fetch_node_ids( """Fetches node ids from the storage nodes for the current compute node (machine).""" node_type = self._infer_node_type_if_homogeneous_with_label_edges(node_type) + # Build per-server requests + requests: dict[int, FetchNodesRequest] = {} if assignments is None: - logger.info( - f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" - ) - futures: list[torch.futures.Future[torch.Tensor]] = [] for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_node_ids, - rank=rank, - world_size=world_size, - split=split, - node_type=node_type, - ) + requests[server_rank] = FetchNodesRequest( + rank=rank, + world_size=world_size, + split=split, + node_type=node_type, + ) + else: + for server_rank, server_slice in assignments.items(): + requests[server_rank] = FetchNodesRequest( + split=split, + node_type=node_type, + server_slice=server_slice, ) - node_ids = torch.futures.wait_all(futures) - return { - server_rank: node_ids for server_rank, node_ids in enumerate(node_ids) - } + strategy = "CONTIGUOUS" if assignments is not None else "ROUND_ROBIN" logger.info( - f"Getting node ids via CONTIGUOUS strategy for rank {rank} / {world_size} " + f"Fetching node ids via {strategy} for rank {rank} / {world_size} " f"with node type {node_type} and split {split}. " - f"Assigned servers: {list(assignments.keys())}" + f"Requesting from servers: {sorted(requests.keys())}" ) - assigned_futures: dict[int, torch.futures.Future[torch.Tensor]] = {} - for server_rank in assignments: - assigned_futures[server_rank] = async_request_server( - server_rank, - DistServer.get_node_ids, - rank=None, - world_size=None, - split=split, - node_type=node_type, + # Dispatch all futures + futures: dict[int, torch.futures.Future[torch.Tensor]] = { + server_rank: async_request_server( + server_rank, DistServer.get_node_ids, request ) + for server_rank, request in requests.items() + } - result: dict[int, torch.Tensor] = {} - for server_rank in range(self.cluster_info.num_storage_nodes): - if server_rank in assigned_futures: - all_nodes = assigned_futures[server_rank].wait() - result[server_rank] = assignments[server_rank].slice_tensor(all_nodes) - else: - result[server_rank] = torch.empty(0, dtype=torch.long) - return result + # Collect results, filling empty tensors for unrequested servers + return { + server_rank: futures[server_rank].wait() + if server_rank in futures + else torch.empty(0, dtype=torch.long) + for server_rank in range(self.cluster_info.num_storage_nodes) + } def fetch_node_ids( self, @@ -369,87 +364,63 @@ def _fetch_ablp_input( assignments: Optional[dict[int, ServerSlice]] = None, ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: """Fetches ABLP input from the storage nodes for the current compute node (machine).""" + # Build per-server requests + requests: dict[int, FetchABLPRequest] = {} if assignments is None: - futures: list[ - torch.futures.Future[ - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] - ] - ] = [] - logger.info( - f"Getting ABLP input for rank {rank} / {world_size} with node type {node_type}, " - f"split {split}, and supervision edge type {supervision_edge_type}" - ) - for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_ablp_input, - split=split, - rank=rank, - world_size=world_size, - node_type=node_type, - supervision_edge_type=supervision_edge_type, - ) + requests[server_rank] = FetchABLPRequest( + split=split, + rank=rank, + world_size=world_size, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + ) + else: + for server_rank, server_slice in assignments.items(): + requests[server_rank] = FetchABLPRequest( + split=split, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + server_slice=server_slice, ) - ablp_inputs = torch.futures.wait_all(futures) - return { - server_rank: ablp_input - for server_rank, ablp_input in enumerate(ablp_inputs) - } + strategy = "CONTIGUOUS" if assignments is not None else "ROUND_ROBIN" logger.info( - f"Getting ABLP input via CONTIGUOUS strategy for rank {rank} / {world_size} " + f"Fetching ABLP input via {strategy} for rank {rank} / {world_size} " f"with node type {node_type}, split {split}, and " f"supervision edge type {supervision_edge_type}. " - f"Assigned servers: {list(assignments.keys())}" + f"Requesting from servers: {sorted(requests.keys())}" ) - assigned_futures: dict[ + # Dispatch all futures + futures: dict[ int, torch.futures.Future[ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] ], - ] = {} - for server_rank in assignments: - assigned_futures[server_rank] = async_request_server( - server_rank, - DistServer.get_ablp_input, - split=split, - rank=None, - world_size=None, - node_type=node_type, - supervision_edge_type=supervision_edge_type, + ] = { + server_rank: async_request_server( + server_rank, DistServer.get_ablp_input, request ) + for server_rank, request in requests.items() + } - result: dict[ - int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] - ] = {} - for server_rank in range(self.cluster_info.num_storage_nodes): - if server_rank in assigned_futures: - anchors, positive_labels, negative_labels = assigned_futures[ - server_rank - ].wait() - server_slice = assignments[server_rank] - sliced_anchors = server_slice.slice_tensor(anchors) - sliced_positive = server_slice.slice_tensor(positive_labels) - sliced_negative = ( - server_slice.slice_tensor(negative_labels) - if negative_labels is not None - else None - ) - result[server_rank] = ( - sliced_anchors, - sliced_positive, - sliced_negative, - ) - else: - result[server_rank] = ( - torch.empty(0, dtype=torch.long), - torch.empty((0, 0), dtype=torch.long), - None, - ) - return result + def _empty_ablp_result() -> ( + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ): + return ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ) + + # Collect results, filling empty tuples for unrequested servers + return { + server_rank: futures[server_rank].wait() + if server_rank in futures + else _empty_ablp_result() + for server_rank in range(self.cluster_info.num_storage_nodes) + } # TODO(#488) - support multiple supervision edge types def fetch_ablp_input( diff --git a/tests/unit/distributed/dist_server_test.py b/tests/unit/distributed/dist_server_test.py index cf5ba20d0..cb10feb44 100644 --- a/tests/unit/distributed/dist_server_test.py +++ b/tests/unit/distributed/dist_server_test.py @@ -2,6 +2,8 @@ from absl.testing import absltest from gigl.distributed.graph_store import dist_server +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest +from gigl.distributed.graph_store.sharding import ServerSlice from gigl.src.common.types.graph_data import Relation from tests.test_assets.distributed.test_dataset import ( DEFAULT_HETEROGENEOUS_EDGE_INDICES, @@ -83,7 +85,7 @@ def test_get_node_ids_with_homogeneous_dataset(self) -> None: server = dist_server.DistServer(dataset) # Test with world_size=1, rank=0 (should get all nodes) - node_ids = server.get_node_ids(rank=0, world_size=1, node_type=None) + node_ids = server.get_node_ids(FetchNodesRequest(rank=0, world_size=1)) self.assertIsInstance(node_ids, torch.Tensor) self.assertEqual(node_ids.shape[0], 10) self.assert_tensor_equality(node_ids, torch.arange(10)) @@ -96,13 +98,17 @@ def test_get_node_ids_with_heterogeneous_dataset(self) -> None: server = dist_server.DistServer(dataset) # Test with USER node type - user_node_ids = server.get_node_ids(rank=0, world_size=1, node_type=USER) + user_node_ids = server.get_node_ids( + FetchNodesRequest(rank=0, world_size=1, node_type=USER) + ) self.assertIsInstance(user_node_ids, torch.Tensor) self.assertEqual(user_node_ids.shape[0], 5) self.assert_tensor_equality(user_node_ids, torch.arange(5)) # Test with STORY node type - story_node_ids = server.get_node_ids(rank=0, world_size=1, node_type=STORY) + story_node_ids = server.get_node_ids( + FetchNodesRequest(rank=0, world_size=1, node_type=STORY) + ) self.assertIsInstance(story_node_ids, torch.Tensor) self.assertEqual(story_node_ids.shape[0], 5) self.assert_tensor_equality(story_node_ids, torch.arange(5)) @@ -115,17 +121,17 @@ def test_get_node_ids_with_multiple_ranks(self) -> None: server = dist_server.DistServer(dataset) # Test with world_size=2 - rank_0_nodes = server.get_node_ids(rank=0, world_size=2, node_type=None) - rank_1_nodes = server.get_node_ids(rank=1, world_size=2, node_type=None) + rank_0_nodes = server.get_node_ids(FetchNodesRequest(rank=0, world_size=2)) + rank_1_nodes = server.get_node_ids(FetchNodesRequest(rank=1, world_size=2)) # Verify each rank gets different nodes self.assert_tensor_equality(rank_0_nodes, torch.arange(5)) self.assert_tensor_equality(rank_1_nodes, torch.arange(5, 10)) # Test with world_size=3 (uneven split) - rank_0_nodes = server.get_node_ids(rank=0, world_size=3, node_type=None) - rank_1_nodes = server.get_node_ids(rank=1, world_size=3, node_type=None) - rank_2_nodes = server.get_node_ids(rank=2, world_size=3, node_type=None) + rank_0_nodes = server.get_node_ids(FetchNodesRequest(rank=0, world_size=3)) + rank_1_nodes = server.get_node_ids(FetchNodesRequest(rank=1, world_size=3)) + rank_2_nodes = server.get_node_ids(FetchNodesRequest(rank=2, world_size=3)) self.assert_tensor_equality(rank_0_nodes, torch.arange(3)) self.assert_tensor_equality(rank_1_nodes, torch.arange(3, 6)) @@ -139,10 +145,10 @@ def test_get_node_ids_rank_world_size_must_be_provided_together(self) -> None: server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=None) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=None)) with self.assertRaises(ValueError): - server.get_node_ids(rank=None, world_size=1) + server.get_node_ids(FetchNodesRequest(rank=None, world_size=1)) def test_get_node_ids_with_homogeneous_dataset_and_node_type(self) -> None: """Test get_node_ids with a homogeneous dataset and a node type raises error.""" @@ -151,7 +157,7 @@ def test_get_node_ids_with_homogeneous_dataset_and_node_type(self) -> None: ) server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=1, node_type=USER) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=1, node_type=USER)) def test_get_node_ids_with_heterogeneous_dataset_and_no_node_type( self, @@ -162,7 +168,7 @@ def test_get_node_ids_with_heterogeneous_dataset_and_no_node_type( ) server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=1, node_type=None) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=1)) def test_get_node_ids_with_train_split(self) -> None: """Test get_node_ids returns only training nodes when split='train'.""" @@ -178,7 +184,9 @@ def test_get_node_ids_with_train_split(self) -> None: ) server = dist_server.DistServer(dataset) - train_nodes = server.get_node_ids(node_type=USER, split="train") + train_nodes = server.get_node_ids( + FetchNodesRequest(node_type=USER, split="train") + ) self.assert_tensor_equality(train_nodes, torch.tensor([0, 1, 2])) def test_get_node_ids_with_val_split(self) -> None: @@ -195,7 +203,7 @@ def test_get_node_ids_with_val_split(self) -> None: ) server = dist_server.DistServer(dataset) - val_nodes = server.get_node_ids(node_type=USER, split="val") + val_nodes = server.get_node_ids(FetchNodesRequest(node_type=USER, split="val")) self.assert_tensor_equality(val_nodes, torch.tensor([3])) def test_get_node_ids_with_test_split(self) -> None: @@ -212,7 +220,9 @@ def test_get_node_ids_with_test_split(self) -> None: ) server = dist_server.DistServer(dataset) - test_nodes = server.get_node_ids(node_type=USER, split="test") + test_nodes = server.get_node_ids( + FetchNodesRequest(node_type=USER, split="test") + ) self.assert_tensor_equality(test_nodes, torch.tensor([4])) def test_get_node_ids_with_split_and_sharding(self) -> None: @@ -231,15 +241,58 @@ def test_get_node_ids_with_split_and_sharding(self) -> None: # Train split has [0, 1, 2], shard across 2 ranks rank_0_nodes = server.get_node_ids( - rank=0, world_size=2, node_type=USER, split="train" + FetchNodesRequest(rank=0, world_size=2, node_type=USER, split="train") ) rank_1_nodes = server.get_node_ids( - rank=1, world_size=2, node_type=USER, split="train" + FetchNodesRequest(rank=1, world_size=2, node_type=USER, split="train") ) self.assert_tensor_equality(rank_0_nodes, torch.tensor([0])) self.assert_tensor_equality(rank_1_nodes, torch.tensor([1, 2])) + def test_get_node_ids_with_server_slice(self) -> None: + """Test get_node_ids supports contiguous server-side slicing.""" + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + server = dist_server.DistServer(dataset) + + sliced_nodes = server.get_node_ids( + FetchNodesRequest( + server_slice=ServerSlice( + server_rank=0, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ) + ) + ) + + self.assert_tensor_equality(sliced_nodes, torch.arange(5, 10)) + + def test_get_node_ids_rejects_mixed_sharding_modes(self) -> None: + """Test get_node_ids rejects rank/world_size combined with server_slice.""" + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + server = dist_server.DistServer(dataset) + + with self.assertRaises(ValueError): + server.get_node_ids( + FetchNodesRequest( + rank=0, + world_size=2, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ) + ) + def test_get_node_ids_invalid_split(self) -> None: """Test get_node_ids raises ValueError with invalid split.""" dataset = create_homogeneous_dataset( @@ -248,7 +301,7 @@ def test_get_node_ids_invalid_split(self) -> None: server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(split="invalid") + server.get_node_ids(FetchNodesRequest(split="invalid")) def test_get_edge_dir(self) -> None: """Test get_edge_dir with a dataset.""" @@ -337,11 +390,13 @@ def test_get_ablp_input(self) -> None: for split, expected_user_ids in split_to_user_ids.items(): with self.subTest(split=split): anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( - split=split, - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split=split, + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Verify anchor nodes match expected users @@ -394,20 +449,24 @@ def test_get_ablp_input_multiple_ranks(self) -> None: # Note that the rank and world size here are for the process group we're *fetching for*, not the process group we're *fetching from*. # e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2. anchor_nodes_0, pos_labels_0, neg_labels_0 = server.get_ablp_input( - split="train", - rank=0, - world_size=2, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Get training input for rank 1 of 2 anchor_nodes_1, pos_labels_1, neg_labels_1 = server.get_ablp_input( - split="train", - rank=1, - world_size=2, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=1, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Train nodes [0, 1, 2, 3] should be split across ranks @@ -452,11 +511,13 @@ def test_get_ablp_input_invalid_split(self) -> None: with self.assertRaises(ValueError): server.get_ablp_input( - split="invalid", - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="invalid", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) def test_get_ablp_input_without_negative_labels(self) -> None: @@ -483,11 +544,13 @@ def test_get_ablp_input_without_negative_labels(self) -> None: server = dist_server.DistServer(dataset) anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( - split="train", - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Verify train split returns the expected users @@ -500,6 +563,85 @@ def test_get_ablp_input_without_negative_labels(self) -> None: # Negative labels should be None self.assertIsNone(neg_labels) + def test_get_ablp_input_with_server_slice(self) -> None: + """Test get_ablp_input computes labels only for the server-sliced anchors.""" + create_test_process_group() + positive_labels = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], + } + negative_labels = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], + } + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels=positive_labels, + negative_labels=negative_labels, + train_node_ids=[0, 1, 2, 3], + val_node_ids=[4], + test_node_ids=[], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + server = dist_server.DistServer(dataset) + + anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( + FetchABLPRequest( + split="train", + node_type=USER, + supervision_edge_type=USER_TO_STORY, + server_slice=ServerSlice( + server_rank=0, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ) + ) + + self.assert_tensor_equality(anchor_nodes, torch.tensor([2, 3])) + self.assert_tensor_equality(pos_labels, torch.tensor([[2, 3], [3, 4]]), dim=1) + assert neg_labels is not None + self.assert_tensor_equality(neg_labels, torch.tensor([[4], [0]])) + + def test_get_ablp_input_rejects_mixed_sharding_modes(self) -> None: + """Test get_ablp_input rejects rank/world_size combined with server_slice.""" + create_test_process_group() + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels={0: [0], 1: [0], 2: [0]}, + negative_labels=None, + train_node_ids=[0], + val_node_ids=[1], + test_node_ids=[2], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + server = dist_server.DistServer(dataset) + + with self.assertRaises(ValueError): + server.get_ablp_input( + FetchABLPRequest( + split="train", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=1, + end_num=1, + end_den=1, + ), + ) + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/distributed/graph_store/messages_test.py b/tests/unit/distributed/graph_store/messages_test.py new file mode 100644 index 000000000..d10a02133 --- /dev/null +++ b/tests/unit/distributed/graph_store/messages_test.py @@ -0,0 +1,124 @@ +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest +from gigl.distributed.graph_store.sharding import ServerSlice +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from tests.test_assets.test_case import TestCase + + +class TestFetchNodesRequest(TestCase): + def test_validate_accepts_rank_world_size(self) -> None: + request = FetchNodesRequest(rank=0, world_size=2) + request.validate() + + def test_validate_accepts_server_slice(self) -> None: + request = FetchNodesRequest( + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ) + ) + request.validate() + + def test_validate_rejects_partial_rank_world_size(self) -> None: + with self.assertRaises(ValueError): + FetchNodesRequest(rank=0).validate() + + with self.assertRaises(ValueError): + FetchNodesRequest(world_size=2).validate() + + def test_validate_rejects_mixed_sharding_modes(self) -> None: + with self.assertRaises(ValueError): + FetchNodesRequest( + rank=0, + world_size=2, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ).validate() + + +class TestFetchABLPRequest(TestCase): + def test_validate_accepts_rank_world_size(self) -> None: + request = FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ) + request.validate() + + def test_validate_accepts_server_slice(self) -> None: + request = FetchABLPRequest( + split="train", + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ) + request.validate() + + def test_validate_rejects_partial_rank_world_size(self) -> None: + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + rank=0, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ).validate() + + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ).validate() + + def test_validate_rejects_mixed_sharding_modes(self) -> None: + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ).validate() diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index d05a91947..31c3060af 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -11,8 +11,9 @@ import gigl.distributed.graph_store.dist_server as dist_server_module from gigl.common import LocalUri from gigl.distributed.graph_store.dist_server import DistServer, _call_func_on_server +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.graph_store.sharding import ShardStrategy +from gigl.distributed.graph_store.sharding import ServerSlice, ShardStrategy from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import ( @@ -709,20 +710,47 @@ class TestRemoteDistDatasetContiguous(RemoteDistDatasetTestBase): """Tests for fetch_node_ids and fetch_ablp_input with ShardStrategy.CONTIGUOUS.""" def _make_rank_aware_async_mock( - self, server_data: dict[int, Any] + self, + server_data: dict[int, Any], + captured_requests: Optional[list[tuple[int, Any]]] = None, ) -> Callable[..., torch.futures.Future]: """Create an async mock that returns pre-set data per server rank. Args: server_data: Maps server_rank to the value that server should return. Can be a tensor (for node ID tests) or a tuple (for ABLP tests). + captured_requests: Optional list populated with ``(server_rank, request)`` + tuples for later assertions. """ def _mock( server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any ) -> torch.futures.Future: + assert not kwargs + assert len(args) == 1 + request = args[0] + if captured_requests is not None: + captured_requests.append((server_rank, request)) + future: torch.futures.Future = torch.futures.Future() - future.set_result(server_data[server_rank]) + response = server_data[server_rank] + if isinstance(request, FetchNodesRequest): + if request.server_slice is not None: + assert isinstance(response, torch.Tensor) + response = request.server_slice.slice_tensor(response) + elif isinstance(request, FetchABLPRequest): + if request.server_slice is not None: + anchors, positive_labels, negative_labels = response + response = ( + request.server_slice.slice_tensor(anchors), + request.server_slice.slice_tensor(positive_labels), + ( + request.server_slice.slice_tensor(negative_labels) + if negative_labels is not None + else None + ), + ) + future.set_result(response) return future return _mock @@ -744,7 +772,8 @@ def test_fetch_node_ids_contiguous_even_split(self) -> None: 0: torch.arange(10), 1: torch.arange(10, 20), } - mock_fn = self._make_rank_aware_async_mock(server_data) + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) cluster_info = _create_mock_graph_store_info( num_storage_nodes=2, num_compute_nodes=2 ) @@ -758,13 +787,52 @@ def test_fetch_node_ids_contiguous_even_split(self) -> None: ) self.assert_tensor_equality(result[0], torch.arange(10)) self.assertEqual(result[1].numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) # Rank 1: empty from server 0, gets all of server 1 + captured_requests.clear() result = ds.fetch_node_ids( rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS ) self.assertEqual(result[0].numel(), 0) self.assert_tensor_equality(result[1], torch.arange(10, 20)) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) def test_fetch_node_ids_contiguous_fractional_split(self) -> None: """CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 is fractionally split.""" @@ -773,7 +841,8 @@ def test_fetch_node_ids_contiguous_fractional_split(self) -> None: 1: torch.arange(10, 20), 2: torch.arange(20, 30), } - mock_fn = self._make_rank_aware_async_mock(server_data) + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) cluster_info = _create_mock_graph_store_info( num_storage_nodes=3, num_compute_nodes=2 ) @@ -788,14 +857,81 @@ def test_fetch_node_ids_contiguous_fractional_split(self) -> None: self.assert_tensor_equality(result[0], torch.arange(10)) self.assert_tensor_equality(result[1], torch.arange(10, 15)) self.assertEqual(result[2].numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ), + ), + ], + ) # Rank 1: nothing from server 0, second half of server 1, all of server 2 + captured_requests.clear() result = ds.fetch_node_ids( rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS ) self.assertEqual(result[0].numel(), 0) self.assert_tensor_equality(result[1], torch.arange(15, 20)) self.assert_tensor_equality(result[2], torch.arange(20, 30)) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 2, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ], + ) def test_with_split_filtering(self) -> None: """CONTIGUOUS strategy with split='train' filtering.""" @@ -887,7 +1023,8 @@ def test_fetch_ablp_input_contiguous_even_split(self) -> None: torch.tensor([[14], [15], [16]]), ), } - mock_fn = self._make_rank_aware_async_mock(server_data) + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) cluster_info = _create_mock_graph_store_info( num_storage_nodes=2, num_compute_nodes=2 ) @@ -914,8 +1051,29 @@ def test_fetch_ablp_input_contiguous_even_split(self) -> None: pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] self.assertEqual(pos_1.numel(), 0) self.assertIsNone(neg_1) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) # Rank 1: empty from server 0, gets all of server 1 + captured_requests.clear() result = ds.fetch_ablp_input( split="train", rank=1, @@ -933,6 +1091,26 @@ def test_fetch_ablp_input_contiguous_even_split(self) -> None: ) assert neg is not None self.assert_tensor_equality(neg, torch.tensor([[14], [15], [16]])) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: """ABLP CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 fractionally split.""" @@ -955,7 +1133,8 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: torch.tensor([[30], [31], [32], [33]]), ), } - mock_fn = self._make_rank_aware_async_mock(server_data) + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) cluster_info = _create_mock_graph_store_info( num_storage_nodes=3, num_compute_nodes=2 ) @@ -988,8 +1167,44 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: ablp_2 = result[2] self.assertEqual(ablp_2.anchor_nodes.numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ), + ), + ], + ) # Rank 1: nothing from server 0, second half of server 1, all of server 2 + captured_requests.clear() result = ds.fetch_ablp_input( split="train", rank=1, @@ -1016,6 +1231,41 @@ def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: ) assert neg_2 is not None self.assert_tensor_equality(neg_2, torch.tensor([[30], [31], [32], [33]])) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 2, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ], + ) def test_ablp_contiguous_requires_rank_and_world_size(self) -> None: """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" From 597e22b3c62e4565171b5ed3632c8a90e798f83d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Mar 2026 18:02:27 +0000 Subject: [PATCH 13/13] update --- .../graph_store/remote_dist_dataset.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 425d608cb..62e83a6bb 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -256,7 +256,7 @@ def fetch_node_ids( world_size: Optional[int] = None, split: Optional[Literal["train", "val", "test"]] = None, node_type: Optional[NodeType] = None, - shard_strategy: ShardStrategy = ShardStrategy.ROUND_ROBIN, + shard_strategy: ShardStrategy = ShardStrategy.CONTIGUOUS, ) -> dict[int, torch.Tensor]: """ Fetches node ids from the storage nodes for the current compute node (machine). @@ -276,11 +276,12 @@ def fetch_node_ids( If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. node_type (Optional[NodeType]): The type of nodes to get. Must be provided for heterogeneous datasets. + Must be None for labeled homogeneous graphs. shard_strategy (ShardStrategy): Strategy for sharding node IDs across compute nodes. - ``ROUND_ROBIN`` (default) shards each server's nodes across the - requested rank/world_size on the storage server. ``CONTIGUOUS`` - assigns storage servers to compute nodes, returning empty tensors + ``CONTIGUOUS`` (default) assigns storage servers to compute nodes, returning empty tensors for unassigned servers. + ``ROUND_ROBIN`` shards each server's nodes across the + requested rank/world_size on the storage server. Raises: ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. @@ -430,7 +431,7 @@ def fetch_ablp_input( world_size: Optional[int] = None, anchor_node_type: Optional[NodeType] = None, supervision_edge_type: Optional[EdgeType] = None, - shard_strategy: ShardStrategy = ShardStrategy.ROUND_ROBIN, + shard_strategy: ShardStrategy = ShardStrategy.CONTIGUOUS, ) -> dict[int, ABLPInputNodes]: """Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. @@ -459,11 +460,12 @@ def fetch_ablp_input( Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. Defaults to None. - shard_strategy (ShardStrategy): Strategy for sharding ABLP input across compute - nodes. ``ROUND_ROBIN`` (default) shards each server's data across the - requested rank/world_size on the storage server. ``CONTIGUOUS`` assigns - storage servers to compute nodes, producing empty tensors for unassigned - servers. + shard_strategy (ShardStrategy): + Strategy for sharding ABLP input across compute nodes. + ``CONTIGUOUS`` (default) assigns storage servers to compute nodes, + producing empty tensors for unassigned servers. + ``ROUND_ROBIN`` shards each server's data across the + requested rank/world_size on the storage server. Returns: dict[int, ABLPInputNodes]: