diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 498087794..945194c9f 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -699,7 +699,11 @@ def _setup_for_graph_store( # Extract supervision edge types and derive label edge types from the # ABLPInputNodes.labels dict (keyed by supervision edge type). self._supervision_edge_types = list(first_input.labels.keys()) - has_negatives = any(neg is not None for _, neg in first_input.labels.values()) + has_negatives = any( + negative_labels is not None + for ablp_input in input_nodes.values() + for _, negative_labels in ablp_input.labels.values() + ) self._positive_label_edge_types = [ message_passing_to_positive_label(et) for et in self._supervision_edge_types diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index da8db83e2..633d1db9f 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -34,16 +34,12 @@ from gigl.common.logger import Logger from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest from gigl.distributed.sampler import ABLPNodeSamplerInput from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType -from gigl.types.graph import ( - DEFAULT_HOMOGENEOUS_EDGE_TYPE, - DEFAULT_HOMOGENEOUS_NODE_TYPE, - FeatureInfo, - select_label_edge_types, -) +from gigl.types.graph import FeatureInfo, select_label_edge_types from gigl.utils.data_splitters import get_labels_for_anchor_nodes SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0 @@ -277,19 +273,13 @@ def get_edge_dir(self) -> Literal["in", "out"]: def get_node_ids( self, - rank: Optional[int] = None, - world_size: Optional[int] = None, - split: Optional[Union[Literal["train", "val", "test"], str]] = None, - node_type: Optional[NodeType] = None, + request: FetchNodesRequest, ) -> torch.Tensor: """Get the node ids from the dataset. Args: - rank: The rank of the process requesting node ids. Must be provided if world_size is provided. - world_size: The total number of processes in the distributed setup. Must be provided if rank is provided. - split: The split of the dataset to get node ids from. If provided, the dataset must have - `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. - node_type: The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous. + request: The node-fetch request, including split, node type, + and either round-robin rank/world_size or a contiguous slice. Returns: The node ids. @@ -302,63 +292,40 @@ def get_node_ids( * If the node type is provided for a homogeneous dataset * If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided - Examples: - Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99]. - - Get all node ids (no split filtering): - - >>> server.get_node_ids() - tensor([0, 1, 2, ..., 99]) # All 100 nodes - - Get only training nodes: - - >>> server.get_node_ids(split="train") - tensor([0, 1, 2, ..., 59]) # 60 training nodes - - Shard all nodes across 4 processes (each gets ~25 nodes): - - >>> server.get_node_ids(rank=0, world_size=4) - tensor([0, 1, 2, ..., 24]) # First 25 of all 100 nodes - - Shard training nodes across 4 processes (each gets ~15 nodes): - - >>> server.get_node_ids(rank=0, world_size=4, split="train") - tensor([0, 1, 2, ..., 14]) # First 15 of the 60 training nodes - Note: When `split=None`, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split. """ - if (rank is None) ^ (world_size is None): - raise ValueError( - f"rank and world_size must be provided together. Received rank: {rank}, world_size: {world_size}" - ) - if split == "train": + request.validate() + if request.split == "train": nodes = self.dataset.train_node_ids - elif split == "val": + elif request.split == "val": nodes = self.dataset.val_node_ids - elif split == "test": + elif request.split == "test": nodes = self.dataset.test_node_ids - elif split is None: + elif request.split is None: nodes = self.dataset.node_ids else: raise ValueError( - f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None." + f"Invalid split: {request.split}. Must be one of 'train', 'val', 'test', or None." ) - if node_type is not None: + if request.node_type is not None: if not isinstance(nodes, abc.Mapping): raise ValueError( - f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}" + f"node_type was provided as {request.node_type}, so node ids must be a dict[NodeType, torch.Tensor] " + f"(e.g. a heterogeneous dataset), got {type(nodes)}" ) - nodes = nodes[node_type] + nodes = nodes[request.node_type] elif not isinstance(nodes, torch.Tensor): raise ValueError( f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." ) - if rank is not None and world_size is not None: - return shard_nodes_by_process(nodes, rank, world_size) + if request.server_slice is not None: + return request.server_slice.slice_tensor(nodes) + if request.rank is not None and request.world_size is not None: + return shard_nodes_by_process(nodes, request.rank, request.world_size) return nodes def get_edge_types(self) -> Optional[list[EdgeType]]: @@ -385,11 +352,7 @@ def get_node_types(self) -> Optional[list[NodeType]]: def get_ablp_input( self, - split: Union[Literal["train", "val", "test"], str], - rank: Optional[int] = None, - world_size: Optional[int] = None, - node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, - supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + request: FetchABLPRequest, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Get the ABLP (Anchor Based Link Prediction) input for a specific rank in distributed processing. @@ -397,13 +360,9 @@ def get_ablp_input( e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2. Args: - split: The split to get the training input for. - rank: The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned. - Must be provided if world_size is provided. - world_size: The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned. - Must be provided if rank is provided. - node_type: The type of nodes to retrieve. Defaults to the default homogeneous node type. - supervision_edge_type: The edge type to use for the supervision. Defaults to the default homogeneous edge type. + request: The ABLP fetch request, including split, node type, + supervision edge type, and either round-robin rank/world_size + or a contiguous slice. Returns: A tuple containing the anchor nodes for the rank, the positive labels, and the negative labels. @@ -414,11 +373,18 @@ def get_ablp_input( Raises: ValueError: If the split is invalid. """ + request.validate() anchors = self.get_node_ids( - split=split, rank=rank, world_size=world_size, node_type=node_type + FetchNodesRequest( + split=request.split, + rank=request.rank, + world_size=request.world_size, + node_type=request.node_type, + server_slice=request.server_slice, + ) ) positive_label_edge_type, negative_label_edge_type = select_label_edge_types( - supervision_edge_type, self.dataset.get_edge_types() + request.supervision_edge_type, self.dataset.get_edge_types() ) positive_labels, negative_labels = get_labels_for_anchor_nodes( self.dataset, anchors, positive_label_edge_type, negative_label_edge_type diff --git a/gigl/distributed/graph_store/messages.py b/gigl/distributed/graph_store/messages.py new file mode 100644 index 000000000..07bce62e8 --- /dev/null +++ b/gigl/distributed/graph_store/messages.py @@ -0,0 +1,54 @@ +"""RPC request messages for graph-store fetch operations.""" + +from dataclasses import dataclass +from typing import Literal, Optional, Union + +from gigl.distributed.graph_store.sharding import ServerSlice +from gigl.src.common.types.graph_data import EdgeType, NodeType + + +@dataclass(frozen=True) +class FetchNodesRequest: + """Request for fetching node IDs from a storage server.""" + + rank: Optional[int] = None + world_size: Optional[int] = None + split: Optional[Union[Literal["train", "val", "test"], str]] = None + node_type: Optional[NodeType] = None + server_slice: Optional[ServerSlice] = None + + def validate(self) -> None: + """Validate that the request does not mix sharding modes.""" + if (self.rank is None) ^ (self.world_size is None): + raise ValueError( + "rank and world_size must be provided together. " + f"Received rank={self.rank}, world_size={self.world_size}" + ) + if self.server_slice is not None and ( + self.rank is not None or self.world_size is not None + ): + raise ValueError("server_slice cannot be combined with rank/world_size.") + + +@dataclass(frozen=True) +class FetchABLPRequest: + """Request for fetching ABLP input from a storage server.""" + + split: Union[Literal["train", "val", "test"], str] + node_type: NodeType + supervision_edge_type: EdgeType + rank: Optional[int] = None + world_size: Optional[int] = None + server_slice: Optional[ServerSlice] = None + + def validate(self) -> None: + """Validate that the request does not mix sharding modes.""" + if (self.rank is None) ^ (self.world_size is None): + raise ValueError( + "rank and world_size must be provided together. " + f"Received rank={self.rank}, world_size={self.world_size}" + ) + if self.server_slice is not None and ( + self.rank is not None or self.world_size is not None + ): + raise ValueError("server_slice cannot be combined with rank/world_size.") diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 8a5689f4c..62e83a6bb 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -6,6 +6,12 @@ from gigl.common.logger import Logger from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest +from gigl.distributed.graph_store.sharding import ( + ServerSlice, + ShardStrategy, + compute_server_assignments, +) from gigl.distributed.utils.networking import get_free_ports from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import EdgeType, NodeType @@ -19,6 +25,22 @@ logger = Logger() +def _validate_contiguous_args( + rank: Optional[int], + world_size: Optional[int], + shard_strategy: ShardStrategy, +) -> None: + """Validate contiguous sharding inputs; no-op for round-robin.""" + if shard_strategy != ShardStrategy.CONTIGUOUS: + return + + if rank is None or world_size is None: + raise ValueError( + "Both rank and world_size must be provided when using " + f"ShardStrategy.CONTIGUOUS. Got rank={rank}, world_size={world_size}" + ) + + class RemoteDistDataset: def __init__( self, @@ -159,34 +181,74 @@ def _infer_edge_type_if_homogeneous_with_label_edges( ) return edge_type + def _compute_assignments_if_needed( + self, + rank: Optional[int], + world_size: Optional[int], + shard_strategy: ShardStrategy, + ) -> Optional[dict[int, ServerSlice]]: + """Compute contiguous server assignments when that shard strategy is requested.""" + if shard_strategy != ShardStrategy.CONTIGUOUS: + return None + + assert rank is not None and world_size is not None + return compute_server_assignments( + num_servers=self.cluster_info.num_storage_nodes, + num_compute_nodes=world_size, + compute_rank=rank, + ) + def _fetch_node_ids( self, rank: Optional[int] = None, world_size: Optional[int] = None, node_type: Optional[NodeType] = None, split: Optional[Literal["train", "val", "test"]] = None, + assignments: Optional[dict[int, ServerSlice]] = None, ) -> dict[int, torch.Tensor]: """Fetches node ids from the storage nodes for the current compute node (machine).""" - futures: list[torch.futures.Future[torch.Tensor]] = [] node_type = self._infer_node_type_if_homogeneous_with_label_edges(node_type) - logger.info( - f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" - ) - - for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_node_ids, + # Build per-server requests + requests: dict[int, FetchNodesRequest] = {} + if assignments is None: + for server_rank in range(self.cluster_info.num_storage_nodes): + requests[server_rank] = FetchNodesRequest( rank=rank, world_size=world_size, split=split, node_type=node_type, ) + else: + for server_rank, server_slice in assignments.items(): + requests[server_rank] = FetchNodesRequest( + split=split, + node_type=node_type, + server_slice=server_slice, + ) + + strategy = "CONTIGUOUS" if assignments is not None else "ROUND_ROBIN" + logger.info( + f"Fetching node ids via {strategy} for rank {rank} / {world_size} " + f"with node type {node_type} and split {split}. " + f"Requesting from servers: {sorted(requests.keys())}" + ) + + # Dispatch all futures + futures: dict[int, torch.futures.Future[torch.Tensor]] = { + server_rank: async_request_server( + server_rank, DistServer.get_node_ids, request ) - node_ids = torch.futures.wait_all(futures) - return {server_rank: node_ids for server_rank, node_ids in enumerate(node_ids)} + for server_rank, request in requests.items() + } + + # Collect results, filling empty tensors for unrequested servers + return { + server_rank: futures[server_rank].wait() + if server_rank in futures + else torch.empty(0, dtype=torch.long) + for server_rank in range(self.cluster_info.num_storage_nodes) + } def fetch_node_ids( self, @@ -194,6 +256,7 @@ def fetch_node_ids( world_size: Optional[int] = None, split: Optional[Literal["train", "val", "test"]] = None, node_type: Optional[NodeType] = None, + shard_strategy: ShardStrategy = ShardStrategy.CONTIGUOUS, ) -> dict[int, torch.Tensor]: """ Fetches node ids from the storage nodes for the current compute node (machine). @@ -202,64 +265,57 @@ def fetch_node_ids( filtered and sharded according to the provided arguments. Args: - rank (Optional[int]): The rank of the process requesting node ids. Must be provided if world_size is provided. - world_size (Optional[int]): The total number of processes in the distributed setup. Must be provided if rank is provided. + rank (Optional[int]): The requested shard rank. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node rank. + world_size (Optional[int]): The requested shard world size. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node world size. split (Optional[Literal["train", "val", "test"]]): The split of the dataset to get node ids from. If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. node_type (Optional[NodeType]): The type of nodes to get. Must be provided for heterogeneous datasets. + Must be None for labeled homogeneous graphs. + shard_strategy (ShardStrategy): Strategy for sharding node IDs across compute nodes. + ``CONTIGUOUS`` (default) assigns storage servers to compute nodes, returning empty tensors + for unassigned servers. + ``ROUND_ROBIN`` shards each server's nodes across the + requested rank/world_size on the storage server. + + Raises: + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. Returns: dict[int, torch.Tensor]: A dict mapping storage rank to node ids. Examples: - Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. - Nodes are partitioned across storage nodes, with splits defined as: - - Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] - train=[0, 1, 2, 3], val=[4, 5], test=[6, 7] - Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] - train=[8, 9, 10, 11], val=[12, 13], test=[14, 15] - - Get all nodes (no split filtering, no sharding): - - >>> dataset.fetch_node_ids() - { - 0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All 8 nodes from storage rank 0 - 1: tensor([8, 9, 10, 11, 12, 13, 14, 15]) # All 8 nodes from storage rank 1 - } - - Shard all nodes across 2 compute nodes (compute rank 0 gets first half from each storage): - - >>> dataset.fetch_node_ids(rank=0, world_size=2) - { - 0: tensor([0, 1, 2, 3]), # First 4 of all 8 nodes from storage rank 0 - 1: tensor([8, 9, 10, 11]) # First 4 of all 8 nodes from storage rank 1 - } - - Get only training nodes (no sharding): - - >>> dataset.fetch_node_ids(split="train") - { - 0: tensor([0, 1, 2, 3]), # 4 training nodes from storage rank 0 - 1: tensor([8, 9, 10, 11]) # 4 training nodes from storage rank 1 - } - - Combine split and sharding (training nodes, sharded for compute rank 0): - - >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train") - { - 0: tensor([0, 1]), # First 2 of 4 training nodes from storage rank 0 - 1: tensor([8, 9]) # First 2 of 4 training nodes from storage rank 1 - } + See :class:`~gigl.distributed.graph_store.sharding.ShardStrategy` for + concrete examples of how each strategy distributes node IDs across + compute nodes. Note: When `split=None`, all nodes are queryable. This means nodes from any split (train, val, or test) may be returned. This is useful when you need to sample neighbors during inference, as neighbor nodes may belong to any split. """ - return self._fetch_node_ids(rank, world_size, node_type, split) + _validate_contiguous_args( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + assignments = self._compute_assignments_if_needed( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) + return self._fetch_node_ids( + rank=rank, + world_size=world_size, + node_type=node_type, + split=split, + assignments=assignments, + ) def fetch_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: """ @@ -306,34 +362,65 @@ def _fetch_ablp_input( world_size: Optional[int] = None, node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + assignments: Optional[dict[int, ServerSlice]] = None, ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: """Fetches ABLP input from the storage nodes for the current compute node (machine).""" - futures: list[ - torch.futures.Future[ - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] - ] - ] = [] - logger.info( - f"Getting ABLP input for rank {rank} / {world_size} with node type {node_type}, " - f"split {split}, and supervision edge type {supervision_edge_type}" - ) - - for server_rank in range(self.cluster_info.num_storage_nodes): - futures.append( - async_request_server( - server_rank, - DistServer.get_ablp_input, + # Build per-server requests + requests: dict[int, FetchABLPRequest] = {} + if assignments is None: + for server_rank in range(self.cluster_info.num_storage_nodes): + requests[server_rank] = FetchABLPRequest( split=split, rank=rank, world_size=world_size, node_type=node_type, supervision_edge_type=supervision_edge_type, ) + else: + for server_rank, server_slice in assignments.items(): + requests[server_rank] = FetchABLPRequest( + split=split, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + server_slice=server_slice, + ) + + strategy = "CONTIGUOUS" if assignments is not None else "ROUND_ROBIN" + logger.info( + f"Fetching ABLP input via {strategy} for rank {rank} / {world_size} " + f"with node type {node_type}, split {split}, and " + f"supervision edge type {supervision_edge_type}. " + f"Requesting from servers: {sorted(requests.keys())}" + ) + + # Dispatch all futures + futures: dict[ + int, + torch.futures.Future[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ] = { + server_rank: async_request_server( + server_rank, DistServer.get_ablp_input, request ) - ablp_inputs = torch.futures.wait_all(futures) + for server_rank, request in requests.items() + } + + def _empty_ablp_result() -> ( + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ): + return ( + torch.empty(0, dtype=torch.long), + torch.empty((0, 0), dtype=torch.long), + None, + ) + + # Collect results, filling empty tuples for unrequested servers return { - server_rank: ablp_input - for server_rank, ablp_input in enumerate(ablp_inputs) + server_rank: futures[server_rank].wait() + if server_rank in futures + else _empty_ablp_result() + for server_rank in range(self.cluster_info.num_storage_nodes) } # TODO(#488) - support multiple supervision edge types @@ -344,9 +431,9 @@ def fetch_ablp_input( world_size: Optional[int] = None, anchor_node_type: Optional[NodeType] = None, supervision_edge_type: Optional[EdgeType] = None, + shard_strategy: ShardStrategy = ShardStrategy.CONTIGUOUS, ) -> dict[int, ABLPInputNodes]: - """ - Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. + """Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. The returned dict maps storage rank to an :class:`ABLPInputNodes` dataclass for that storage node. If (rank, world_size) is provided, the input will be @@ -359,10 +446,12 @@ def fetch_ablp_input( Args: split (Literal["train", "val", "test"]): The split to get the input for. - rank (Optional[int]): The rank of the process requesting the input. - Must be provided if world_size is provided. - world_size (Optional[int]): The total number of processes in the distributed setup. - Must be provided if rank is provided. + rank (Optional[int]): The requested shard rank. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node rank. + world_size (Optional[int]): The requested shard world size. + ``ROUND_ROBIN`` forwards this to the storage server. ``CONTIGUOUS`` + expects the compute-node world size. anchor_node_type (Optional[NodeType]): The type of the anchor nodes to retrieve. Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. @@ -371,35 +460,35 @@ def fetch_ablp_input( Must be provided for heterogeneous graphs. Must be None for labeled homogeneous graphs. Defaults to None. + shard_strategy (ShardStrategy): + Strategy for sharding ABLP input across compute nodes. + ``CONTIGUOUS`` (default) assigns storage servers to compute nodes, + producing empty tensors for unassigned servers. + ``ROUND_ROBIN`` shards each server's data across the + requested rank/world_size on the storage server. Returns: dict[int, ABLPInputNodes]: A dict mapping storage rank to an ABLPInputNodes containing: - - anchor_node_type: The node type of the anchor nodes, or None for labeled homogeneous. + - anchor_node_type: The node type of the anchor nodes, or ``DEFAULT_HOMOGENEOUS_NODE_TYPE`` for labeled homogeneous. - anchor_nodes: 1D tensor of anchor node IDs for the split. - positive_labels: Dict mapping positive label EdgeType to a 2D tensor [N, M]. - negative_labels: Optional dict mapping negative label EdgeType to a 2D tensor [N, M]. - Examples: - Suppose we have 1 storage node with users [0, 1, 2, 3, 4] where: - train=[0, 1, 2], val=[3], test=[4] - And positive/negative labels defined for link prediction. - - Get training ABLP input (heterogeneous): - - >>> dataset.fetch_ablp_input(split="train", node_type=USER, supervision_edge_type=USER_TO_ITEM) - { - 0: ABLPInputNodes( - anchor_nodes=tensor([0, 1, 2]), - positive_labels={("user", "to_positive", "item"): tensor([[0, 1], [1, 2], [2, 3]])}, - anchor_node_type="user", - negative_labels={("user", "to_negative", "item"): tensor([[2], [3], [4]])}, - ) - } + Raises: + ValueError: If ``shard_strategy`` is ``CONTIGUOUS`` but ``rank`` or ``world_size`` is None. - For labeled homogeneous graphs, anchor_node_type will be DEFAULT_HOMOGENEOUS_NODE_TYPE. + Examples: + See :class:`~gigl.distributed.graph_store.sharding.ShardStrategy` for + concrete examples of how each strategy distributes data across + compute nodes. """ + _validate_contiguous_args( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) if (anchor_node_type is None) != (supervision_edge_type is None): raise ValueError( @@ -416,12 +505,18 @@ def fetch_ablp_input( evaluated_supervision_edge_type = supervision_edge_type del anchor_node_type, supervision_edge_type + assignments = self._compute_assignments_if_needed( + rank=rank, + world_size=world_size, + shard_strategy=shard_strategy, + ) raw_inputs = self._fetch_ablp_input( split=split, rank=rank, world_size=world_size, node_type=evaluated_anchor_node_type, supervision_edge_type=evaluated_supervision_edge_type, + assignments=assignments, ) return { server_rank: ABLPInputNodes( diff --git a/gigl/distributed/graph_store/sharding.py b/gigl/distributed/graph_store/sharding.py new file mode 100644 index 000000000..6b515ad5d --- /dev/null +++ b/gigl/distributed/graph_store/sharding.py @@ -0,0 +1,148 @@ +"""Graph-store-specific sharding helpers.""" + +from dataclasses import dataclass +from enum import Enum + +import torch + + +class ShardStrategy(Enum): + """Strategies for splitting remote graph-store inputs across compute nodes. + + When fetching node IDs (or ABLP input) from storage servers, the shard + strategy controls how data is divided among compute nodes. + + Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. + Nodes are partitioned across storage nodes, with splits defined as:: + + Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] + train=[0, 1, 2, 3], val=[4, 5], test=[6, 7] + Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] + train=[8, 9, 10, 11], val=[12, 13], test=[14, 15] + + **ROUND_ROBIN** — Each storage server independently shards its own nodes + across the requested ``world_size``. Every compute node contacts every + storage server and receives an interleaved slice:: + + # All training nodes, sharded across 2 compute nodes (round-robin): + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train") + { + 0: tensor([0, 2]), # Even-indexed training nodes from storage 0 + 1: tensor([8, 10]) # Even-indexed training nodes from storage 1 + } + >>> dataset.fetch_node_ids(rank=1, world_size=2, split="train") + { + 0: tensor([1, 3]), # Odd-indexed training nodes from storage 0 + 1: tensor([9, 11]) # Odd-indexed training nodes from storage 1 + } + + Both strategies also support ``split=None``, which returns all nodes + (train + val + test) from each storage server:: + + # Round-robin, all nodes (no split), sharded across 2 compute nodes: + >>> dataset.fetch_node_ids(rank=0, world_size=2) + { + 0: tensor([0, 2, 4, 6]), # Even-indexed nodes from storage 0 + 1: tensor([8, 10, 12, 14]) # Even-indexed nodes from storage 1 + } + + # Contiguous, all nodes (no split), sharded across 2 compute nodes: + >>> dataset.fetch_node_ids(rank=0, world_size=2, + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All nodes from storage 0 + 1: tensor([]) # Nothing from storage 1 + } + + **CONTIGUOUS** — Storage servers are assigned to compute nodes in + contiguous blocks. Each compute node fetches *all* data from its + assigned server(s) and receives empty tensors for unassigned ones. + When servers outnumber compute nodes a server's data is fractionally + split; when compute nodes outnumber servers a compute node may own a + fraction of one server:: + + # All training nodes, sharded across 2 compute nodes (contiguous): + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3]), # All training nodes from storage 0 + 1: tensor([]) # Nothing from storage 1 + } + >>> dataset.fetch_node_ids(rank=1, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([]), # Nothing from storage 0 + 1: tensor([8, 9, 10, 11]) # All training nodes from storage 1 + } + + # With 3 storage nodes and 2 compute nodes, server 1 is fractionally split: + >>> dataset.fetch_node_ids(rank=0, world_size=2, split="train", + ... shard_strategy=ShardStrategy.CONTIGUOUS) + { + 0: tensor([0, 1, 2, 3]), # All of storage 0 + 1: tensor([8, 9]), # First half of storage 1 + 2: tensor([]) # Nothing from storage 2 + } + """ + + ROUND_ROBIN = "round_robin" + CONTIGUOUS = "contiguous" + + +@dataclass(frozen=True) +class ServerSlice: + """The fraction of a storage server owned by one compute node.""" + + server_rank: int + start_num: int + start_den: int + end_num: int + end_den: int + + def slice_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """Slice a tensor according to this server assignment.""" + total = len(tensor) + start_idx = total * self.start_num // self.start_den + end_idx = total * self.end_num // self.end_den + if start_idx == 0 and end_idx == total: + return tensor + return tensor[start_idx:end_idx] + + +def compute_server_assignments( + num_servers: int, + num_compute_nodes: int, + compute_rank: int, +) -> dict[int, ServerSlice]: + """Compute which servers, and which fractions of them, belong to one compute node.""" + if num_servers <= 0: + raise ValueError(f"num_servers must be positive, got {num_servers}") + if num_compute_nodes <= 0: + raise ValueError(f"num_compute_nodes must be positive, got {num_compute_nodes}") + if compute_rank < 0 or compute_rank >= num_compute_nodes: + raise ValueError( + f"compute_rank must be in [0, {num_compute_nodes}), got {compute_rank}" + ) + + seg_start = compute_rank * num_servers + seg_end = (compute_rank + 1) * num_servers + + assignments: dict[int, ServerSlice] = {} + for server_rank in range(num_servers): + server_start = server_rank * num_compute_nodes + server_end = (server_rank + 1) * num_compute_nodes + + overlap_start = max(seg_start, server_start) + overlap_end = min(seg_end, server_end) + if overlap_start >= overlap_end: + continue + + assignments[server_rank] = ServerSlice( + server_rank=server_rank, + start_num=overlap_start - server_start, + start_den=num_compute_nodes, + end_num=overlap_end - server_start, + end_den=num_compute_nodes, + ) + + return assignments diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 29ce30477..a5ed5753b 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -23,6 +23,7 @@ shutdown_compute_proccess, ) from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.graph_store.sharding import ShardStrategy from gigl.distributed.graph_store.storage_utils import ( build_storage_dataset, run_storage_server, @@ -239,6 +240,56 @@ def _run_compute_train_tests( count_tensor.item() == expected_batches ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" + # --- CONTIGUOUS shard strategy tests --- + # With 2 servers and 2 compute nodes, rank R should get all of server R's + # nodes and an empty tensor for server (1-R). + contiguous_node_ids = remote_dist_dataset.fetch_node_ids( + split="train", + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + # Assert structure: each rank owns exactly one server in the 2S/2C case + rank = cluster_info.compute_node_rank + other_rank = 1 - rank + assert ( + contiguous_node_ids[rank].numel() > 0 + ), f"Rank {rank} should have non-empty tensor for its own server" + assert ( + contiguous_node_ids[other_rank].numel() == 0 + ), f"Rank {rank} should have empty tensor for server {other_rank}" + + # Assert total node parity: CONTIGUOUS and ROUND_ROBIN should cover the same nodes + local_contiguous_nodes = torch.cat(list(contiguous_node_ids.values())) + local_round_robin_nodes = torch.cat(list(random_negative_input.values())) + + # Gather all nodes from all ranks + contiguous_gathered: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.long) + for _ in range(torch.distributed.get_world_size()) + ] + round_robin_gathered: list[torch.Tensor] = [ + torch.empty(0, dtype=torch.long) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather_object(contiguous_gathered, local_contiguous_nodes) + torch.distributed.all_gather_object(round_robin_gathered, local_round_robin_nodes) + + all_contiguous = torch.cat(contiguous_gathered).sort().values + all_round_robin = torch.cat(round_robin_gathered).sort().values + assert torch.equal(all_contiguous, all_round_robin), ( + f"CONTIGUOUS and ROUND_ROBIN must produce the same sorted node set. " + f"CONTIGUOUS: {all_contiguous[:10]}... ({all_contiguous.numel()} nodes), " + f"ROUND_ROBIN: {all_round_robin[:10]}... ({all_round_robin.numel()} nodes)" + ) + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} CONTIGUOUS: " + f"{local_contiguous_nodes.numel()} nodes from assigned server" + ) + shutdown_compute_proccess() @@ -316,9 +367,6 @@ def _run_compute_multiple_loaders_test( prefetch_size=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} ablp_loader_1 producers: ({ablp_loader_1._producer_id_list})" - ) ablp_loader_2 = DistABLPLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -329,9 +377,6 @@ def _run_compute_multiple_loaders_test( prefetch_size=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} ablp_loader_2 producers: ({ablp_loader_2._producer_id_list})" - ) neighbor_loader_1 = DistNeighborLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -341,9 +386,6 @@ def _run_compute_multiple_loaders_test( worker_concurrency=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} neighbor_loader_1 producers: ({neighbor_loader_1._producer_id_list})" - ) neighbor_loader_2 = DistNeighborLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], @@ -353,12 +395,6 @@ def _run_compute_multiple_loaders_test( worker_concurrency=2, batch_size=batch_size, ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} neighbor_loader_2 producers: ({neighbor_loader_2._producer_id_list})" - ) - logger.info( - f"Rank {torch.distributed.get_rank()} / {torch.distributed.get_world_size()} phase 1: loading batches from 4 parallel loaders" - ) torch.distributed.barrier() for ablp_batch_1, ablp_batch_2, neg_batch_1, neg_batch_2 in zip( ablp_loader_1, ablp_loader_2, neighbor_loader_1, neighbor_loader_2 diff --git a/tests/unit/distributed/dist_server_test.py b/tests/unit/distributed/dist_server_test.py index cf5ba20d0..cb10feb44 100644 --- a/tests/unit/distributed/dist_server_test.py +++ b/tests/unit/distributed/dist_server_test.py @@ -2,6 +2,8 @@ from absl.testing import absltest from gigl.distributed.graph_store import dist_server +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest +from gigl.distributed.graph_store.sharding import ServerSlice from gigl.src.common.types.graph_data import Relation from tests.test_assets.distributed.test_dataset import ( DEFAULT_HETEROGENEOUS_EDGE_INDICES, @@ -83,7 +85,7 @@ def test_get_node_ids_with_homogeneous_dataset(self) -> None: server = dist_server.DistServer(dataset) # Test with world_size=1, rank=0 (should get all nodes) - node_ids = server.get_node_ids(rank=0, world_size=1, node_type=None) + node_ids = server.get_node_ids(FetchNodesRequest(rank=0, world_size=1)) self.assertIsInstance(node_ids, torch.Tensor) self.assertEqual(node_ids.shape[0], 10) self.assert_tensor_equality(node_ids, torch.arange(10)) @@ -96,13 +98,17 @@ def test_get_node_ids_with_heterogeneous_dataset(self) -> None: server = dist_server.DistServer(dataset) # Test with USER node type - user_node_ids = server.get_node_ids(rank=0, world_size=1, node_type=USER) + user_node_ids = server.get_node_ids( + FetchNodesRequest(rank=0, world_size=1, node_type=USER) + ) self.assertIsInstance(user_node_ids, torch.Tensor) self.assertEqual(user_node_ids.shape[0], 5) self.assert_tensor_equality(user_node_ids, torch.arange(5)) # Test with STORY node type - story_node_ids = server.get_node_ids(rank=0, world_size=1, node_type=STORY) + story_node_ids = server.get_node_ids( + FetchNodesRequest(rank=0, world_size=1, node_type=STORY) + ) self.assertIsInstance(story_node_ids, torch.Tensor) self.assertEqual(story_node_ids.shape[0], 5) self.assert_tensor_equality(story_node_ids, torch.arange(5)) @@ -115,17 +121,17 @@ def test_get_node_ids_with_multiple_ranks(self) -> None: server = dist_server.DistServer(dataset) # Test with world_size=2 - rank_0_nodes = server.get_node_ids(rank=0, world_size=2, node_type=None) - rank_1_nodes = server.get_node_ids(rank=1, world_size=2, node_type=None) + rank_0_nodes = server.get_node_ids(FetchNodesRequest(rank=0, world_size=2)) + rank_1_nodes = server.get_node_ids(FetchNodesRequest(rank=1, world_size=2)) # Verify each rank gets different nodes self.assert_tensor_equality(rank_0_nodes, torch.arange(5)) self.assert_tensor_equality(rank_1_nodes, torch.arange(5, 10)) # Test with world_size=3 (uneven split) - rank_0_nodes = server.get_node_ids(rank=0, world_size=3, node_type=None) - rank_1_nodes = server.get_node_ids(rank=1, world_size=3, node_type=None) - rank_2_nodes = server.get_node_ids(rank=2, world_size=3, node_type=None) + rank_0_nodes = server.get_node_ids(FetchNodesRequest(rank=0, world_size=3)) + rank_1_nodes = server.get_node_ids(FetchNodesRequest(rank=1, world_size=3)) + rank_2_nodes = server.get_node_ids(FetchNodesRequest(rank=2, world_size=3)) self.assert_tensor_equality(rank_0_nodes, torch.arange(3)) self.assert_tensor_equality(rank_1_nodes, torch.arange(3, 6)) @@ -139,10 +145,10 @@ def test_get_node_ids_rank_world_size_must_be_provided_together(self) -> None: server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=None) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=None)) with self.assertRaises(ValueError): - server.get_node_ids(rank=None, world_size=1) + server.get_node_ids(FetchNodesRequest(rank=None, world_size=1)) def test_get_node_ids_with_homogeneous_dataset_and_node_type(self) -> None: """Test get_node_ids with a homogeneous dataset and a node type raises error.""" @@ -151,7 +157,7 @@ def test_get_node_ids_with_homogeneous_dataset_and_node_type(self) -> None: ) server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=1, node_type=USER) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=1, node_type=USER)) def test_get_node_ids_with_heterogeneous_dataset_and_no_node_type( self, @@ -162,7 +168,7 @@ def test_get_node_ids_with_heterogeneous_dataset_and_no_node_type( ) server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(rank=0, world_size=1, node_type=None) + server.get_node_ids(FetchNodesRequest(rank=0, world_size=1)) def test_get_node_ids_with_train_split(self) -> None: """Test get_node_ids returns only training nodes when split='train'.""" @@ -178,7 +184,9 @@ def test_get_node_ids_with_train_split(self) -> None: ) server = dist_server.DistServer(dataset) - train_nodes = server.get_node_ids(node_type=USER, split="train") + train_nodes = server.get_node_ids( + FetchNodesRequest(node_type=USER, split="train") + ) self.assert_tensor_equality(train_nodes, torch.tensor([0, 1, 2])) def test_get_node_ids_with_val_split(self) -> None: @@ -195,7 +203,7 @@ def test_get_node_ids_with_val_split(self) -> None: ) server = dist_server.DistServer(dataset) - val_nodes = server.get_node_ids(node_type=USER, split="val") + val_nodes = server.get_node_ids(FetchNodesRequest(node_type=USER, split="val")) self.assert_tensor_equality(val_nodes, torch.tensor([3])) def test_get_node_ids_with_test_split(self) -> None: @@ -212,7 +220,9 @@ def test_get_node_ids_with_test_split(self) -> None: ) server = dist_server.DistServer(dataset) - test_nodes = server.get_node_ids(node_type=USER, split="test") + test_nodes = server.get_node_ids( + FetchNodesRequest(node_type=USER, split="test") + ) self.assert_tensor_equality(test_nodes, torch.tensor([4])) def test_get_node_ids_with_split_and_sharding(self) -> None: @@ -231,15 +241,58 @@ def test_get_node_ids_with_split_and_sharding(self) -> None: # Train split has [0, 1, 2], shard across 2 ranks rank_0_nodes = server.get_node_ids( - rank=0, world_size=2, node_type=USER, split="train" + FetchNodesRequest(rank=0, world_size=2, node_type=USER, split="train") ) rank_1_nodes = server.get_node_ids( - rank=1, world_size=2, node_type=USER, split="train" + FetchNodesRequest(rank=1, world_size=2, node_type=USER, split="train") ) self.assert_tensor_equality(rank_0_nodes, torch.tensor([0])) self.assert_tensor_equality(rank_1_nodes, torch.tensor([1, 2])) + def test_get_node_ids_with_server_slice(self) -> None: + """Test get_node_ids supports contiguous server-side slicing.""" + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + server = dist_server.DistServer(dataset) + + sliced_nodes = server.get_node_ids( + FetchNodesRequest( + server_slice=ServerSlice( + server_rank=0, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ) + ) + ) + + self.assert_tensor_equality(sliced_nodes, torch.arange(5, 10)) + + def test_get_node_ids_rejects_mixed_sharding_modes(self) -> None: + """Test get_node_ids rejects rank/world_size combined with server_slice.""" + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + server = dist_server.DistServer(dataset) + + with self.assertRaises(ValueError): + server.get_node_ids( + FetchNodesRequest( + rank=0, + world_size=2, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ) + ) + def test_get_node_ids_invalid_split(self) -> None: """Test get_node_ids raises ValueError with invalid split.""" dataset = create_homogeneous_dataset( @@ -248,7 +301,7 @@ def test_get_node_ids_invalid_split(self) -> None: server = dist_server.DistServer(dataset) with self.assertRaises(ValueError): - server.get_node_ids(split="invalid") + server.get_node_ids(FetchNodesRequest(split="invalid")) def test_get_edge_dir(self) -> None: """Test get_edge_dir with a dataset.""" @@ -337,11 +390,13 @@ def test_get_ablp_input(self) -> None: for split, expected_user_ids in split_to_user_ids.items(): with self.subTest(split=split): anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( - split=split, - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split=split, + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Verify anchor nodes match expected users @@ -394,20 +449,24 @@ def test_get_ablp_input_multiple_ranks(self) -> None: # Note that the rank and world size here are for the process group we're *fetching for*, not the process group we're *fetching from*. # e.g. if our compute cluster is of world size 4, and we have 2 storage nodes, then the world size this gets called with is 4, not 2. anchor_nodes_0, pos_labels_0, neg_labels_0 = server.get_ablp_input( - split="train", - rank=0, - world_size=2, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Get training input for rank 1 of 2 anchor_nodes_1, pos_labels_1, neg_labels_1 = server.get_ablp_input( - split="train", - rank=1, - world_size=2, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=1, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Train nodes [0, 1, 2, 3] should be split across ranks @@ -452,11 +511,13 @@ def test_get_ablp_input_invalid_split(self) -> None: with self.assertRaises(ValueError): server.get_ablp_input( - split="invalid", - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="invalid", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) def test_get_ablp_input_without_negative_labels(self) -> None: @@ -483,11 +544,13 @@ def test_get_ablp_input_without_negative_labels(self) -> None: server = dist_server.DistServer(dataset) anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( - split="train", - rank=0, - world_size=1, - node_type=USER, - supervision_edge_type=USER_TO_STORY, + FetchABLPRequest( + split="train", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) ) # Verify train split returns the expected users @@ -500,6 +563,85 @@ def test_get_ablp_input_without_negative_labels(self) -> None: # Negative labels should be None self.assertIsNone(neg_labels) + def test_get_ablp_input_with_server_slice(self) -> None: + """Test get_ablp_input computes labels only for the server-sliced anchors.""" + create_test_process_group() + positive_labels = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], + } + negative_labels = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], + } + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels=positive_labels, + negative_labels=negative_labels, + train_node_ids=[0, 1, 2, 3], + val_node_ids=[4], + test_node_ids=[], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + server = dist_server.DistServer(dataset) + + anchor_nodes, pos_labels, neg_labels = server.get_ablp_input( + FetchABLPRequest( + split="train", + node_type=USER, + supervision_edge_type=USER_TO_STORY, + server_slice=ServerSlice( + server_rank=0, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ) + ) + + self.assert_tensor_equality(anchor_nodes, torch.tensor([2, 3])) + self.assert_tensor_equality(pos_labels, torch.tensor([[2, 3], [3, 4]]), dim=1) + assert neg_labels is not None + self.assert_tensor_equality(neg_labels, torch.tensor([[4], [0]])) + + def test_get_ablp_input_rejects_mixed_sharding_modes(self) -> None: + """Test get_ablp_input rejects rank/world_size combined with server_slice.""" + create_test_process_group() + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels={0: [0], 1: [0], 2: [0]}, + negative_labels=None, + train_node_ids=[0], + val_node_ids=[1], + test_node_ids=[2], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + server = dist_server.DistServer(dataset) + + with self.assertRaises(ValueError): + server.get_ablp_input( + FetchABLPRequest( + split="train", + rank=0, + world_size=1, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=1, + end_num=1, + end_den=1, + ), + ) + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/distributed/graph_store/messages_test.py b/tests/unit/distributed/graph_store/messages_test.py new file mode 100644 index 000000000..d10a02133 --- /dev/null +++ b/tests/unit/distributed/graph_store/messages_test.py @@ -0,0 +1,124 @@ +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest +from gigl.distributed.graph_store.sharding import ServerSlice +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from tests.test_assets.test_case import TestCase + + +class TestFetchNodesRequest(TestCase): + def test_validate_accepts_rank_world_size(self) -> None: + request = FetchNodesRequest(rank=0, world_size=2) + request.validate() + + def test_validate_accepts_server_slice(self) -> None: + request = FetchNodesRequest( + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ) + ) + request.validate() + + def test_validate_rejects_partial_rank_world_size(self) -> None: + with self.assertRaises(ValueError): + FetchNodesRequest(rank=0).validate() + + with self.assertRaises(ValueError): + FetchNodesRequest(world_size=2).validate() + + def test_validate_rejects_mixed_sharding_modes(self) -> None: + with self.assertRaises(ValueError): + FetchNodesRequest( + rank=0, + world_size=2, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ).validate() + + +class TestFetchABLPRequest(TestCase): + def test_validate_accepts_rank_world_size(self) -> None: + request = FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ) + request.validate() + + def test_validate_accepts_server_slice(self) -> None: + request = FetchABLPRequest( + split="train", + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ) + request.validate() + + def test_validate_rejects_partial_rank_world_size(self) -> None: + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + rank=0, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ).validate() + + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + ).validate() + + def test_validate_rejects_mixed_sharding_modes(self) -> None: + with self.assertRaises(ValueError): + FetchABLPRequest( + split="train", + rank=0, + world_size=2, + node_type=NodeType("user"), + supervision_edge_type=EdgeType( + src_node_type=NodeType("user"), + relation=Relation("to"), + dst_node_type=NodeType("story"), + ), + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ).validate() diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index d7ef907c3..31c3060af 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -1,4 +1,6 @@ -from typing import Optional +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from typing import Any, Final, Optional from unittest.mock import patch import torch @@ -9,8 +11,11 @@ import gigl.distributed.graph_store.dist_server as dist_server_module from gigl.common import LocalUri from gigl.distributed.graph_store.dist_server import DistServer, _call_func_on_server +from gigl.distributed.graph_store.messages import FetchABLPRequest, FetchNodesRequest from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.graph_store.sharding import ServerSlice, ShardStrategy from gigl.env.distributed import GraphStoreInfo +from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, @@ -38,13 +43,36 @@ # Module-level test server instance used by mock functions _test_server: Optional[DistServer] = None - -def _mock_request_server(server_rank, func, *args, **kwargs): +# Shared ABLP label data used by multiple test classes +_DEFAULT_POSITIVE_LABELS: Final[dict[int, list[int]]] = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], +} +_DEFAULT_NEGATIVE_LABELS: Final[dict[int, list[int]]] = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], +} +_DEFAULT_TRAIN_IDS: Final[list[int]] = [0, 1, 2] +_DEFAULT_VAL_IDS: Final[list[int]] = [3] +_DEFAULT_TEST_IDS: Final[list[int]] = [4] + + +def _mock_request_server( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: """Mock request_server that routes through _call_func_on_server.""" return _call_func_on_server(func, *args, **kwargs) -def _mock_async_request_server(server_rank, func, *args, **kwargs): +def _mock_async_request_server( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any +) -> torch.futures.Future: """Mock async_request_server that routes through _call_func_on_server and returns a future.""" future: torch.futures.Future = torch.futures.Future() future.set_result(_call_func_on_server(func, *args, **kwargs)) @@ -75,7 +103,79 @@ def _create_mock_graph_store_info( return MockGraphStoreInfo(real_info, compute_node_rank) -class TestRemoteDistDataset(TestCase): +@contextmanager +def _patch_remote_requests( + async_side_effect: Callable[..., torch.futures.Future], + sync_side_effect: Callable[..., Any], +) -> Iterator[None]: + """Context manager that patches both async_request_server and request_server.""" + with ( + patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=async_side_effect, + ), + patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=sync_side_effect, + ), + ): + yield + + +def _create_server_with_splits( + edge_indices: Optional[dict] = None, + src_node_type: Optional[NodeType] = None, + dst_node_type: Optional[NodeType] = None, + supervision_edge_type: Optional[EdgeType] = None, +) -> None: + """Create a DistServer with a dataset that has train/val/test splits. + + Args: + edge_indices: Edge indices to use. Defaults to DEFAULT_HETEROGENEOUS_EDGE_INDICES. + src_node_type: Source node type for labeled homogeneous datasets. + dst_node_type: Destination node type for labeled homogeneous datasets. + supervision_edge_type: Supervision edge type for labeled homogeneous datasets. + """ + global _test_server + create_test_process_group() + + kwargs: dict[str, Any] = {} + if src_node_type is not None: + kwargs.update( + src_node_type=src_node_type, + dst_node_type=dst_node_type, + supervision_edge_type=supervision_edge_type, + ) + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels=_DEFAULT_POSITIVE_LABELS, + negative_labels=_DEFAULT_NEGATIVE_LABELS, + train_node_ids=_DEFAULT_TRAIN_IDS, + val_node_ids=_DEFAULT_VAL_IDS, + test_node_ids=_DEFAULT_TEST_IDS, + edge_indices=edge_indices or DEFAULT_HETEROGENEOUS_EDGE_INDICES, + **kwargs, + ) + _test_server = DistServer(dataset) + dist_server_module._dist_server = _test_server + + +class RemoteDistDatasetTestBase(TestCase): + """Shared tearDown for all RemoteDistDataset test classes.""" + + def tearDown(self) -> None: + global _test_server + _test_server = None + dist_server_module._dist_server = None + if dist.is_initialized(): + dist.destroy_process_group() + + +@patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, +) +class TestRemoteDistDataset(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server # 10 nodes in DEFAULT_HOMOGENEOUS_EDGE_INDEX ring graph @@ -87,17 +187,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_graph_metadata_getters_homogeneous(self, mock_request): """Test fetch_node_feature_info, fetch_edge_feature_info, fetch_edge_dir, fetch_edge_types, fetch_node_types for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -112,7 +201,7 @@ def test_graph_metadata_getters_homogeneous(self, mock_request): self.assertIsNone(remote_dataset.fetch_edge_types()) self.assertIsNone(remote_dataset.fetch_node_types()) - def test_cluster_info_property(self): + def test_cluster_info_property(self, mock_request): cluster_info = _create_mock_graph_store_info( num_storage_nodes=3, num_compute_nodes=2 ) @@ -125,10 +214,6 @@ def test_cluster_info_property(self): "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_ids(self, mock_request, mock_async_request): """Test fetch_node_ids returns node ids, with optional sharding via rank/world_size.""" cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) @@ -147,10 +232,6 @@ def test_fetch_node_ids(self, mock_request, mock_async_request): result = remote_dataset.fetch_node_ids(rank=1, world_size=2) self.assert_tensor_equality(result[0], torch.arange(5, 10)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_homogeneous(self, mock_request): """Test fetch_node_partition_book returns the tensor partition book for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -162,10 +243,6 @@ def test_fetch_node_partition_book_homogeneous(self, mock_request): self.assertEqual(result.shape[0], 10) self.assert_tensor_equality(result, torch.zeros(10, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_homogeneous(self, mock_request): """Test fetch_edge_partition_book returns the tensor partition book for homogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -177,10 +254,6 @@ def test_fetch_edge_partition_book_homogeneous(self, mock_request): self.assertEqual(result.shape[0], 10) self.assert_tensor_equality(result, torch.zeros(10, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_homogeneous_rejects_node_type( self, mock_request ): @@ -192,7 +265,11 @@ def test_fetch_node_partition_book_homogeneous_rejects_node_type( remote_dataset.fetch_node_partition_book(node_type=USER) -class TestRemoteDistDatasetHeterogeneous(TestCase): +@patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, +) +class TestRemoteDistDatasetHeterogeneous(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server # 5 users, 5 stories in DEFAULT_HETEROGENEOUS_EDGE_INDICES @@ -207,17 +284,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_graph_metadata_getters_heterogeneous(self, mock_request): """Test fetch_node_feature_info, fetch_edge_dir, fetch_edge_types, fetch_node_types for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -243,7 +309,7 @@ def test_graph_metadata_getters_heterogeneous(self, mock_request): "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) - def test_fetch_node_ids_with_node_type(self, mock_async_request): + def test_fetch_node_ids_with_node_type(self, mock_request, mock_async_request): """Test fetch_node_ids with node_type for heterogeneous graphs, with optional sharding.""" cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -264,10 +330,6 @@ def test_fetch_node_ids_with_node_type(self, mock_async_request): result = remote_dataset.fetch_node_ids(rank=1, world_size=2, node_type=USER) self.assert_tensor_equality(result[0], torch.arange(2, 5)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_heterogeneous(self, mock_request): """Test fetch_node_partition_book returns per-type partition books for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -285,10 +347,6 @@ def test_fetch_node_partition_book_heterogeneous(self, mock_request): self.assertEqual(story_pb.shape[0], 5) self.assert_tensor_equality(story_pb, torch.zeros(5, dtype=torch.int64)) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_heterogeneous(self, mock_request): """Test fetch_edge_partition_book returns per-type partition books for heterogeneous graphs.""" cluster_info = _create_mock_graph_store_info() @@ -307,10 +365,6 @@ def test_fetch_edge_partition_book_heterogeneous(self, mock_request): ), ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_node_partition_book_heterogeneous_requires_node_type( self, mock_request ): @@ -321,10 +375,6 @@ def test_fetch_node_partition_book_heterogeneous_requires_node_type( with self.assertRaises(ValueError): remote_dataset.fetch_node_partition_book() - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) def test_fetch_edge_partition_book_heterogeneous_requires_edge_type( self, mock_request ): @@ -336,54 +386,16 @@ def test_fetch_edge_partition_book_heterogeneous_requires_edge_type( remote_dataset.fetch_edge_partition_book() -class TestRemoteDistDatasetWithSplits(TestCase): +class TestRemoteDistDatasetWithSplits(RemoteDistDatasetTestBase): """Tests for fetch_node_ids with train/val/test splits.""" - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - def _create_server_with_splits(self) -> None: - """Create a DistServer with a dataset that has train/val/test splits.""" - global _test_server - create_test_process_group() - - positive_labels = { - 0: [0, 1], - 1: [1, 2], - 2: [2, 3], - 3: [3, 4], - 4: [4, 0], - } - negative_labels = { - 0: [2], - 1: [3], - 2: [4], - 3: [0], - 4: [1], - } - - dataset = create_heterogeneous_dataset_for_ablp( - positive_labels=positive_labels, - negative_labels=negative_labels, - train_node_ids=[0, 1, 2], - val_node_ids=[3], - test_node_ids=[4], - edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, - ) - _test_server = DistServer(dataset) - dist_server_module._dist_server = _test_server - @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, ) def test_fetch_node_ids_with_splits(self, mock_async_request): """Test fetch_node_ids with train/val/test splits and optional sharding.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -428,7 +440,7 @@ def test_fetch_node_ids_with_splits(self, mock_async_request): ) def test_fetch_ablp_input(self, mock_async_request): """Test fetch_ablp_input with train/val/test splits.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -496,7 +508,7 @@ def test_fetch_ablp_input(self, mock_async_request): ) def test_fetch_ablp_input_with_sharding(self, mock_async_request): """Test fetch_ablp_input with sharding across compute nodes.""" - self._create_server_with_splits() + _create_server_with_splits() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -550,7 +562,7 @@ def test_fetch_ablp_input_with_sharding(self, mock_async_request): ) -class TestRemoteDistDatasetLabeledHomogeneous(TestCase): +class TestRemoteDistDatasetLabeledHomogeneous(RemoteDistDatasetTestBase): """Tests for datasets using DEFAULT_HOMOGENEOUS_NODE_TYPE / DEFAULT_HOMOGENEOUS_EDGE_TYPE. A 'labeled homogeneous' dataset is stored internally as heterogeneous @@ -559,50 +571,17 @@ class TestRemoteDistDatasetLabeledHomogeneous(TestCase): callers do not need to supply them explicitly. """ - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - - def _create_server_with_labeled_homogeneous_splits(self) -> None: - global _test_server - create_test_process_group() - - positive_labels = { - 0: [0, 1], - 1: [1, 2], - 2: [2, 3], - 3: [3, 4], - 4: [4, 0], - } - negative_labels = { - 0: [2], - 1: [3], - 2: [4], - 3: [0], - 4: [1], - } - edge_indices = { - DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( - [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] - ) - } + _LABELED_HOMOGENEOUS_EDGE_INDICES: Final = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) + } - dataset = create_heterogeneous_dataset_for_ablp( - positive_labels=positive_labels, - negative_labels=negative_labels, - train_node_ids=[0, 1, 2], - val_node_ids=[3], - test_node_ids=[4], - edge_indices=edge_indices, + def _create_labeled_homogeneous_server(self) -> None: + _create_server_with_splits( + edge_indices=self._LABELED_HOMOGENEOUS_EDGE_INDICES, src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, ) - _test_server = DistServer(dataset) - dist_server_module._dist_server = _test_server @patch( "gigl.distributed.graph_store.remote_dist_dataset.request_server", @@ -610,7 +589,7 @@ def _create_server_with_labeled_homogeneous_splits(self) -> None: ) def test_fetch_node_types_labeled_homogeneous(self, mock_request): """Test fetch_node_types returns DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -618,35 +597,27 @@ def test_fetch_node_types_labeled_homogeneous(self, mock_request): self.assertIsNotNone(node_types) self.assertIn(DEFAULT_HOMOGENEOUS_NODE_TYPE, node_types) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", - side_effect=_mock_async_request_server, - ) - @patch( - "gigl.distributed.graph_store.remote_dist_dataset.request_server", - side_effect=_mock_request_server, - ) - def test_fetch_node_ids_auto_detects_default_node_type( - self, mock_request, mock_async_request - ): + def test_fetch_node_ids_auto_detects_default_node_type(self): """Test fetch_node_ids without node_type auto-detects DEFAULT_HOMOGENEOUS_NODE_TYPE.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) - remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) - # No node_type provided: _fetch_node_ids should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="train")[0], - torch.tensor([0, 1, 2]), - ) - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="val")[0], - torch.tensor([3]), - ) - self.assert_tensor_equality( - remote_dataset.fetch_node_ids(split="test")[0], - torch.tensor([4]), - ) + with _patch_remote_requests(_mock_async_request_server, _mock_request_server): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # No node_type provided: _fetch_node_ids should auto-detect DEFAULT_HOMOGENEOUS_NODE_TYPE + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="train")[0], + torch.tensor([0, 1, 2]), + ) + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="val")[0], + torch.tensor([3]), + ) + self.assert_tensor_equality( + remote_dataset.fetch_node_ids(split="test")[0], + torch.tensor([4]), + ) @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", @@ -654,7 +625,7 @@ def test_fetch_node_ids_auto_detects_default_node_type( ) def test_fetch_ablp_input_defaults_to_homogeneous_types(self, mock_async_request): """Test fetch_ablp_input without anchor_node_type/supervision_edge_type uses homogeneous defaults.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -685,7 +656,7 @@ def test_fetch_node_partition_book_auto_infers_default_node_type( self, mock_request ): """Test fetch_node_partition_book auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE when None.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -704,7 +675,7 @@ def test_fetch_edge_partition_book_auto_infers_default_edge_type( self, mock_request ): """Test fetch_edge_partition_book auto-infers DEFAULT_HOMOGENEOUS_EDGE_TYPE when None.""" - self._create_server_with_labeled_homogeneous_splits() + self._create_labeled_homogeneous_server() cluster_info = _create_mock_graph_store_info() remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) @@ -735,6 +706,593 @@ def test_fetch_ablp_input_mismatched_params_raises(self): ) +class TestRemoteDistDatasetContiguous(RemoteDistDatasetTestBase): + """Tests for fetch_node_ids and fetch_ablp_input with ShardStrategy.CONTIGUOUS.""" + + def _make_rank_aware_async_mock( + self, + server_data: dict[int, Any], + captured_requests: Optional[list[tuple[int, Any]]] = None, + ) -> Callable[..., torch.futures.Future]: + """Create an async mock that returns pre-set data per server rank. + + Args: + server_data: Maps server_rank to the value that server should return. + Can be a tensor (for node ID tests) or a tuple (for ABLP tests). + captured_requests: Optional list populated with ``(server_rank, request)`` + tuples for later assertions. + """ + + def _mock( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> torch.futures.Future: + assert not kwargs + assert len(args) == 1 + request = args[0] + if captured_requests is not None: + captured_requests.append((server_rank, request)) + + future: torch.futures.Future = torch.futures.Future() + response = server_data[server_rank] + if isinstance(request, FetchNodesRequest): + if request.server_slice is not None: + assert isinstance(response, torch.Tensor) + response = request.server_slice.slice_tensor(response) + elif isinstance(request, FetchABLPRequest): + if request.server_slice is not None: + anchors, positive_labels, negative_labels = response + response = ( + request.server_slice.slice_tensor(anchors), + request.server_slice.slice_tensor(positive_labels), + ( + request.server_slice.slice_tensor(negative_labels) + if negative_labels is not None + else None + ), + ) + future.set_result(response) + return future + + return _mock + + @staticmethod + def _mock_request_server_homogeneous( + server_rank: int, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: + """Mock request_server that returns None for node/edge types (homogeneous).""" + if func == DistServer.get_node_types: + return None + if func == DistServer.get_edge_types: + return None + return _mock_request_server(server_rank, func, *args, **kwargs) + + def test_fetch_node_ids_contiguous_even_split(self) -> None: + """CONTIGUOUS with 2 storage nodes and 2 compute nodes: each rank gets one server.""" + server_data: dict[int, torch.Tensor] = { + 0: torch.arange(10), + 1: torch.arange(10, 20), + } + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: gets all of server 0, empty from server 1 + result = ds.fetch_node_ids( + rank=0, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assert_tensor_equality(result[0], torch.arange(10)) + self.assertEqual(result[1].numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) + + # Rank 1: empty from server 0, gets all of server 1 + captured_requests.clear() + result = ds.fetch_node_ids( + rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assertEqual(result[0].numel(), 0) + self.assert_tensor_equality(result[1], torch.arange(10, 20)) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) + + def test_fetch_node_ids_contiguous_fractional_split(self) -> None: + """CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 is fractionally split.""" + server_data: dict[int, torch.Tensor] = { + 0: torch.arange(10), + 1: torch.arange(10, 20), + 2: torch.arange(20, 30), + } + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=3, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: all of server 0, first half of server 1, nothing from server 2 + result = ds.fetch_node_ids( + rank=0, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assert_tensor_equality(result[0], torch.arange(10)) + self.assert_tensor_equality(result[1], torch.arange(10, 15)) + self.assertEqual(result[2].numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ), + ), + ], + ) + + # Rank 1: nothing from server 0, second half of server 1, all of server 2 + captured_requests.clear() + result = ds.fetch_node_ids( + rank=1, world_size=2, shard_strategy=ShardStrategy.CONTIGUOUS + ) + self.assertEqual(result[0].numel(), 0) + self.assert_tensor_equality(result[1], torch.arange(15, 20)) + self.assert_tensor_equality(result[2], torch.arange(20, 30)) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 2, + FetchNodesRequest( + split=None, + node_type=None, + server_slice=ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ], + ) + + def test_with_split_filtering(self) -> None: + """CONTIGUOUS strategy with split='train' filtering.""" + server_data: dict[int, torch.Tensor] = { + 0: torch.tensor([0, 1, 2, 3]), + 1: torch.tensor([10, 11, 12, 13]), + } + mock_fn = self._make_rank_aware_async_mock(server_data) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + result = ds.fetch_node_ids( + rank=0, + world_size=2, + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality(result[0], torch.tensor([0, 1, 2, 3])) + self.assertEqual(result[1].numel(), 0) + + def test_contiguous_requires_rank_and_world_size(self) -> None: + """CONTIGUOUS without rank/world_size raises ValueError.""" + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + rank=0, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_node_ids( + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + def test_contiguous_labeled_homogeneous_auto_inference(self) -> None: + """CONTIGUOUS strategy auto-infers DEFAULT_HOMOGENEOUS_NODE_TYPE for labeled homogeneous datasets.""" + _create_server_with_splits( + edge_indices={ + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + ) + }, + src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + ) + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + + with _patch_remote_requests(_mock_async_request_server, _mock_request_server): + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + result = remote_dataset.fetch_node_ids( + rank=0, + world_size=1, + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assert_tensor_equality( + result[0], + torch.tensor([0, 1, 2]), + ) + + def test_fetch_ablp_input_contiguous_even_split(self) -> None: + """ABLP CONTIGUOUS with 2 storage nodes and 2 compute nodes.""" + server_data: dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = { + 0: ( + torch.tensor([0, 1, 2]), + torch.tensor([[0, 1], [1, 2], [2, 3]]), + torch.tensor([[4], [5], [6]]), + ), + 1: ( + torch.tensor([10, 11, 12]), + torch.tensor([[10, 11], [11, 12], [12, 13]]), + torch.tensor([[14], [15], [16]]), + ), + } + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: gets all of server 0, empty from server 1 + result = ds.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0 = result[0] + self.assert_tensor_equality(ablp_0.anchor_nodes, torch.tensor([0, 1, 2])) + pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos, torch.tensor([[0, 1], [1, 2], [2, 3]])) + assert neg is not None + self.assert_tensor_equality(neg, torch.tensor([[4], [5], [6]])) + + ablp_1 = result[1] + self.assertEqual(ablp_1.anchor_nodes.numel(), 0) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assertEqual(pos_1.numel(), 0) + self.assertIsNone(neg_1) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) + + # Rank 1: empty from server 0, gets all of server 1 + captured_requests.clear() + result = ds.fetch_ablp_input( + split="train", + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0 = result[0] + self.assertEqual(ablp_0.anchor_nodes.numel(), 0) + + ablp_1 = result[1] + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([10, 11, 12])) + pos, neg = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos, torch.tensor([[10, 11], [11, 12], [12, 13]]) + ) + assert neg is not None + self.assert_tensor_equality(neg, torch.tensor([[14], [15], [16]])) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ) + ], + ) + + def test_fetch_ablp_input_contiguous_fractional_split(self) -> None: + """ABLP CONTIGUOUS with 3 storage nodes and 2 compute nodes: server 1 fractionally split.""" + server_data: dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = { + 0: ( + torch.tensor([0, 1, 2, 3]), + torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]), + torch.tensor([[10], [11], [12], [13]]), + ), + 1: ( + torch.tensor([10, 11, 12, 13]), + torch.tensor([[10, 11], [11, 12], [12, 13], [13, 14]]), + torch.tensor([[20], [21], [22], [23]]), + ), + 2: ( + torch.tensor([20, 21, 22, 23]), + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + torch.tensor([[30], [31], [32], [33]]), + ), + } + captured_requests: list[tuple[int, Any]] = [] + mock_fn = self._make_rank_aware_async_mock(server_data, captured_requests) + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=3, num_compute_nodes=2 + ) + + with _patch_remote_requests(mock_fn, self._mock_request_server_homogeneous): + ds = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Rank 0: all of server 0, first half of server 1, nothing from server 2 + result = ds.fetch_ablp_input( + split="train", + rank=0, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + ablp_0 = result[0] + self.assert_tensor_equality(ablp_0.anchor_nodes, torch.tensor([0, 1, 2, 3])) + pos, neg = ablp_0.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos, torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]]) + ) + assert neg is not None + self.assert_tensor_equality(neg, torch.tensor([[10], [11], [12], [13]])) + + ablp_1 = result[1] + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([10, 11])) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos_1, torch.tensor([[10, 11], [11, 12]])) + assert neg_1 is not None + self.assert_tensor_equality(neg_1, torch.tensor([[20], [21]])) + + ablp_2 = result[2] + self.assertEqual(ablp_2.anchor_nodes.numel(), 0) + self.assertEqual( + captured_requests, + [ + ( + 0, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + ), + ), + ], + ) + + # Rank 1: nothing from server 0, second half of server 1, all of server 2 + captured_requests.clear() + result = ds.fetch_ablp_input( + split="train", + rank=1, + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + self.assertEqual(result[0].anchor_nodes.numel(), 0) + + ablp_1 = result[1] + self.assert_tensor_equality(ablp_1.anchor_nodes, torch.tensor([12, 13])) + pos_1, neg_1 = ablp_1.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality(pos_1, torch.tensor([[12, 13], [13, 14]])) + assert neg_1 is not None + self.assert_tensor_equality(neg_1, torch.tensor([[22], [23]])) + + ablp_2 = result[2] + self.assert_tensor_equality( + ablp_2.anchor_nodes, torch.tensor([20, 21, 22, 23]) + ) + pos_2, neg_2 = ablp_2.labels[DEFAULT_HOMOGENEOUS_EDGE_TYPE] + self.assert_tensor_equality( + pos_2, + torch.tensor([[20, 21], [21, 22], [22, 23], [23, 24]]), + ) + assert neg_2 is not None + self.assert_tensor_equality(neg_2, torch.tensor([[30], [31], [32], [33]])) + self.assertEqual( + captured_requests, + [ + ( + 1, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ( + 2, + FetchABLPRequest( + split="train", + node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type=DEFAULT_HOMOGENEOUS_EDGE_TYPE, + server_slice=ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + ), + ), + ], + ) + + def test_ablp_contiguous_requires_rank_and_world_size(self) -> None: + """ABLP CONTIGUOUS without rank/world_size raises ValueError.""" + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=2, num_compute_nodes=2 + ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + rank=0, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + with self.assertRaises(ValueError): + remote_dataset.fetch_ablp_input( + split="train", + world_size=2, + shard_strategy=ShardStrategy.CONTIGUOUS, + ) + + def _test_fetch_free_ports_on_storage_cluster( rank: int, world_size: int, @@ -777,20 +1335,13 @@ def _test_fetch_free_ports_on_storage_cluster( dist.destroy_process_group() -class TestGetFreePortsOnStorageCluster(TestCase): +class TestGetFreePortsOnStorageCluster(RemoteDistDatasetTestBase): def setUp(self) -> None: global _test_server dataset = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - if dist.is_initialized(): - dist.destroy_process_group() - def test_fetch_free_ports_on_storage_cluster_distributed(self): """Test that free ports are correctly broadcast across all ranks.""" init_method = get_process_group_init_method() @@ -813,7 +1364,7 @@ def test_fetch_free_ports_fails_without_process_group(self): remote_dataset.fetch_free_ports_on_storage_cluster(num_ports=1) -class TestCallFuncOnServer(TestCase): +class TestCallFuncOnServer(RemoteDistDatasetTestBase): """Tests for the _call_func_on_server dispatch logic.""" def setUp(self) -> None: @@ -826,11 +1377,6 @@ def setUp(self) -> None: _test_server = DistServer(dataset) dist_server_module._dist_server = _test_server - def tearDown(self) -> None: - global _test_server - _test_server = None - dist_server_module._dist_server = None - def test_dispatches_server_method(self): """Test that _call_func_on_server correctly dispatches an unbound DistServer method.""" result = _call_func_on_server(DistServer.get_edge_dir) diff --git a/tests/unit/distributed/graph_store/sharding_test.py b/tests/unit/distributed/graph_store/sharding_test.py new file mode 100644 index 000000000..547524663 --- /dev/null +++ b/tests/unit/distributed/graph_store/sharding_test.py @@ -0,0 +1,149 @@ +import torch +from parameterized import param, parameterized + +from gigl.distributed.graph_store.sharding import ( + ServerSlice, + compute_server_assignments, +) +from tests.test_assets.test_case import TestCase + + +class TestComputeServerAssignments(TestCase): + @parameterized.expand( + [ + param( + "rank_0", + compute_rank=0, + expected_assignments={ + 0: ServerSlice( + server_rank=0, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + 1: ServerSlice( + server_rank=1, + start_num=0, + start_den=2, + end_num=1, + end_den=2, + ), + }, + ), + param( + "rank_1", + compute_rank=1, + expected_assignments={ + 1: ServerSlice( + server_rank=1, + start_num=1, + start_den=2, + end_num=2, + end_den=2, + ), + 2: ServerSlice( + server_rank=2, + start_num=0, + start_den=2, + end_num=2, + end_den=2, + ), + }, + ), + ] + ) + def test_fractional_boundary_assignment( + self, _, compute_rank: int, expected_assignments: dict[int, ServerSlice] + ) -> None: + assignments = compute_server_assignments( + num_servers=3, num_compute_nodes=2, compute_rank=compute_rank + ) + self.assertEqual(assignments, expected_assignments) + + def test_assignments_recombine_server_data(self) -> None: + tensor = torch.arange(7) + all_assignments = [ + compute_server_assignments( + num_servers=2, num_compute_nodes=5, compute_rank=rank + ) + for rank in range(5) + ] + + for server_rank in range(2): + combined = torch.cat( + [ + assignments[server_rank].slice_tensor(tensor) + for assignments in all_assignments + if server_rank in assignments + ] + ) + self.assert_tensor_equality(combined, tensor) + + @parameterized.expand( + [ + param( + "negative_servers", + num_servers=-1, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "zero_servers", + num_servers=0, + num_compute_nodes=2, + compute_rank=0, + ), + param( + "negative_compute_nodes", + num_servers=2, + num_compute_nodes=-1, + compute_rank=0, + ), + param( + "zero_compute_nodes", + num_servers=2, + num_compute_nodes=0, + compute_rank=0, + ), + param( + "rank_too_large", + num_servers=2, + num_compute_nodes=2, + compute_rank=2, + ), + param( + "negative_rank", + num_servers=2, + num_compute_nodes=2, + compute_rank=-1, + ), + ] + ) + def test_validates_arguments( + self, _, num_servers: int, num_compute_nodes: int, compute_rank: int + ) -> None: + with self.assertRaises(ValueError): + compute_server_assignments( + num_servers=num_servers, + num_compute_nodes=num_compute_nodes, + compute_rank=compute_rank, + ) + + +class TestServerSlice(TestCase): + def test_full_tensor_returns_same_object(self) -> None: + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=1, end_num=1, end_den=1 + ) + result = server_slice.slice_tensor(tensor) + self.assertEqual(result.data_ptr(), tensor.data_ptr()) + + def test_partial_slice_returns_requested_range(self) -> None: + tensor = torch.arange(10) + server_slice = ServerSlice( + server_rank=0, start_num=0, start_den=2, end_num=1, end_den=2 + ) + result = server_slice.slice_tensor(tensor) + self.assert_tensor_equality(result, torch.arange(5))