Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions examples/link_prediction/graph_store/heterogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Uses `RemoteDistDataset` to connect to a remote graph store cluster
- Uses `init_compute_process` to initialize the compute node connection to storage
- Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo`
- Uses `mp_sharing_dict` for efficient tensor sharing between local processes
- Each process fetches its own shard of data from the storage cluster

Standard mode (`heterogeneous_inference.py`):
- Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition
Expand Down Expand Up @@ -84,9 +84,7 @@
import gc
import os
import sys
import threading
import time
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Union

Expand Down Expand Up @@ -145,8 +143,6 @@ class InferenceProcessArgs:
cluster_info (GraphStoreInfo): Cluster topology info for graph store mode, containing
information about storage and compute node ranks and addresses.
inference_node_type (NodeType): Node type that embeddings should be generated for.
mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor
sharing between local processes.
model_state_dict_uri (Uri): URI to load the trained model state dict from.
hid_dim (int): Hidden dimension of the model.
out_dim (int): Output dimension of the model.
Expand All @@ -173,8 +169,6 @@ class InferenceProcessArgs:

# Data
inference_node_type: NodeType
mp_sharing_dict: MutableMapping[str, torch.Tensor]
mp_barrier: threading.Barrier

# Model
model_state_dict_uri: Uri
Expand Down Expand Up @@ -230,15 +224,17 @@ def _inference_process(
dataset = RemoteDistDataset(
args.cluster_info,
local_rank,
mp_sharing_dict=args.mp_sharing_dict,
mp_barrier=args.mp_barrier,
)
logger.info(
f"Local rank {local_rank} in machine {args.machine_rank} has rank {rank}/{world_size} and using device {device} for inference"
)

# Get the node ids on the current machine for the current node type
input_nodes = dataset.fetch_node_ids(node_type=args.inference_node_type)
input_nodes = dataset.fetch_node_ids(
node_type=args.inference_node_type,
rank=torch.distributed.get_rank(),
world_size=torch.distributed.get_world_size(),
)
logger.info(
f"Rank {rank} got input nodes of shapes: {[f'{rank}: {node.shape}' for rank, node in input_nodes.items()]}"
)
Expand Down Expand Up @@ -533,15 +529,12 @@ def _run_example_inference(
log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50"))

# When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine.
manager = torch.multiprocessing.Manager()
inference_args = InferenceProcessArgs(
local_world_size=num_inference_processes_per_machine,
machine_rank=cluster_info.compute_node_rank,
machine_world_size=cluster_info.num_compute_nodes,
cluster_info=cluster_info,
inference_node_type=inference_node_type,
mp_sharing_dict=manager.dict(),
mp_barrier=manager.Barrier(num_inference_processes_per_machine), # type: ignore[attr-defined]
model_state_dict_uri=model_uri,
hid_dim=hid_dim,
out_dim=out_dim,
Expand Down
15 changes: 2 additions & 13 deletions examples/link_prediction/graph_store/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@
import os
import statistics
import sys
import threading
import time
from collections.abc import Iterator, MutableMapping
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -395,8 +394,6 @@ class TrainingProcessArgs:
# Distributed context
local_world_size: int
cluster_info: GraphStoreInfo
mp_sharing_dict: MutableMapping[str, torch.Tensor]
mp_barrier: threading.Barrier

# Data
supervision_edge_type: EdgeType
Expand Down Expand Up @@ -448,8 +445,6 @@ def _training_process(
dataset = RemoteDistDataset(
args.cluster_info,
local_rank,
mp_sharing_dict=args.mp_sharing_dict,
mp_barrier=args.mp_barrier,
)

rank = torch.distributed.get_rank()
Expand Down Expand Up @@ -945,20 +940,14 @@ def _run_example_training(
)
supervision_edge_type = supervision_edge_types[0]

# Step 4: Create shared dict and mp barrierfor inter-process tensor sharing
manager = mp.Manager()
mp_sharing_dict = manager.dict()
mp_barrier = manager.Barrier(local_world_size) # type: ignore[attr-defined]
# Step 5: Spawn training processes
# Step 4: Spawn training processes
print("--- Launching training processes ...\n")
flush()
start_time = time.time()

training_args = TrainingProcessArgs(
local_world_size=local_world_size,
cluster_info=cluster_info,
mp_sharing_dict=mp_sharing_dict,
mp_barrier=mp_barrier,
supervision_edge_type=supervision_edge_type,
model_uri=model_uri,
eval_metrics_uri=eval_metrics_uri,
Expand Down
22 changes: 3 additions & 19 deletions examples/link_prediction/graph_store/homogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Uses `RemoteDistDataset` to connect to a remote graph store cluster
- Uses `init_compute_process` to initialize the compute node connection to storage
- Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo`
- Uses `mp_sharing_dict` for efficient tensor sharing between local processes
- Each process fetches its own shard of data from the storage cluster

Standard mode (`homogeneous_inference.py`):
- Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition
Expand Down Expand Up @@ -85,9 +85,7 @@
import gc
import os
import sys
import threading
import time
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Union

Expand Down Expand Up @@ -152,8 +150,6 @@ class InferenceProcessArgs:
log_every_n_batch (int): Frequency to log batch information during inference.
inference_node_type (NodeType): Node type that embeddings should be generated for.
gbml_config_pb_wrapper (GbmlConfigPbWrapper): Wrapper containing GBML configuration.
mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor
sharing between local processes.
"""

# Distributed context
Expand All @@ -178,8 +174,6 @@ class InferenceProcessArgs:
log_every_n_batch: int
inference_node_type: NodeType
gbml_config_pb_wrapper: GbmlConfigPbWrapper
mp_sharing_dict: MutableMapping[str, torch.Tensor]
mp_barrier: threading.Barrier


@torch.no_grad()
Expand Down Expand Up @@ -216,21 +210,16 @@ def _inference_process(
dataset = RemoteDistDataset(
args.cluster_info,
local_rank,
mp_sharing_dict=args.mp_sharing_dict,
mp_barrier=args.mp_barrier,
)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
logger.info(
f"Local rank {local_rank} in machine {args.cluster_info.compute_node_rank} has rank {rank}/{world_size} and using device {device} for inference"
)

# We expect that each compute machine has the same input nodes.
# As such, we shard across the compute machine cluster.
# If this is not done, then all nodes will receive the same input nodes, which is not what we want.
input_nodes = dataset.fetch_node_ids(
rank=args.cluster_info.compute_node_rank,
world_size=args.cluster_info.num_compute_nodes,
rank=torch.distributed.get_rank(),
world_size=torch.distributed.get_world_size(),
)
logger.info(
f"Rank {rank} got input nodes of shapes: {[f'{rank}: {node.shape}' for rank, node in input_nodes.items()]}"
Expand Down Expand Up @@ -486,9 +475,6 @@ def _run_example_inference(
)
local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE

manager = mp.Manager()
mp_sharing_dict = manager.dict()
mp_barrier = manager.Barrier(local_world_size) # type: ignore[attr-defined]
if cluster_info.compute_node_rank == 0:
gcs_utils = GcsUtils()
num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(
Expand Down Expand Up @@ -547,8 +533,6 @@ def _run_example_inference(
log_every_n_batch=log_every_n_batch,
inference_node_type=graph_metadata.homogeneous_node_type,
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
mp_sharing_dict=mp_sharing_dict,
mp_barrier=mp_barrier,
)
mp.spawn(
fn=_inference_process,
Expand Down
29 changes: 8 additions & 21 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
| | ``master_ip_address`` extracted manually | ``get_graph_store_info()`` encapsulates all |
| | | topology |
+---------------------------+----------------------------------------------+----------------------------------------------+
| **Inter-process sharing** | N/A (each process loads own partition) | ``mp_sharing_dict`` for efficient tensor |
| | | sharing between local processes |
| **Inter-process sharing** | N/A (each process loads own partition) | Each process fetches its own shard from |
| | | the storage cluster |
+---------------------------+----------------------------------------------+----------------------------------------------+
| **Cleanup** | ``torch.distributed.destroy_process_group()`` | ``shutdown_compute_proccess()`` disconnects |
| | | from storage cluster |
Expand Down Expand Up @@ -122,9 +122,8 @@
import os
import statistics
import sys
import threading
import time
from collections.abc import Iterator, MutableMapping
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Literal, Optional

Expand Down Expand Up @@ -226,8 +225,8 @@ def _setup_dataloaders(
flush()
ablp_input = dataset.fetch_ablp_input(
split=split,
rank=cluster_info.compute_node_rank,
world_size=cluster_info.num_compute_nodes,
rank=torch.distributed.get_rank(),
world_size=torch.distributed.get_world_size(),
)

main_loader = DistABLPLoader(
Expand All @@ -252,8 +251,8 @@ def _setup_dataloaders(

# For the random negative loader, we get all node IDs from the storage cluster.
all_node_ids = dataset.fetch_node_ids(
rank=cluster_info.compute_node_rank,
world_size=cluster_info.num_compute_nodes,
rank=torch.distributed.get_rank(),
world_size=torch.distributed.get_world_size(),
)

random_negative_loader = DistNeighborLoader(
Expand Down Expand Up @@ -365,8 +364,6 @@ class TrainingProcessArgs:
Attributes:
local_world_size (int): Number of training processes spawned by each machine.
cluster_info (GraphStoreInfo): Cluster topology info for graph store mode.
mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor
sharing between local processes.
model_uri (Uri): URI to save/load the trained model state dict.
eval_metrics_uri (Optional[Uri]): Destination URI for writing evaluation metrics in
KFP-compatible JSON format. If None, metrics are not written.
Expand All @@ -392,8 +389,6 @@ class TrainingProcessArgs:
# Distributed context
local_world_size: int
cluster_info: GraphStoreInfo
mp_sharing_dict: MutableMapping[str, torch.Tensor]
mp_barrier: threading.Barrier

# Model
model_uri: Uri
Expand Down Expand Up @@ -442,8 +437,6 @@ def _training_process(
dataset = RemoteDistDataset(
args.cluster_info,
local_rank,
mp_sharing_dict=args.mp_sharing_dict,
mp_barrier=args.mp_barrier,
)

rank = torch.distributed.get_rank()
Expand Down Expand Up @@ -925,20 +918,14 @@ def _run_example_training(

should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training

# Step 4: Create shared dict for inter-process tensor sharing
manager = mp.Manager()
mp_sharing_dict = manager.dict()
mp_barrier = manager.Barrier(local_world_size) # type: ignore[attr-defined]
# Step 5: Spawn training processes
# Step 4: Spawn training processes
logger.info("--- Launching training processes ...\n")
flush()
start_time = time.time()

training_args = TrainingProcessArgs(
local_world_size=local_world_size,
cluster_info=cluster_info,
mp_sharing_dict=mp_sharing_dict,
mp_barrier=mp_barrier,
model_uri=model_uri,
eval_metrics_uri=eval_metrics_uri,
hid_dim=hid_dim,
Expand Down
12 changes: 7 additions & 5 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,17 +647,19 @@ def _setup_for_graph_store(
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
node_rank = dataset.cluster_info.compute_node_rank
compute_rank = torch.distributed.get_rank()

# Get sampling ports for compute-storage connections.
# One port per compute process (not per compute node) so that each
# process gets its own server-side sampling worker group.
sampling_ports = dataset.fetch_free_ports_on_storage_cluster(
num_ports=dataset.cluster_info.num_compute_nodes
num_ports=dataset.cluster_info.compute_cluster_world_size
)
sampling_port = sampling_ports[node_rank]
sampling_port = sampling_ports[compute_rank]
worker_key = (
f"compute_ablp_loader_rank_{node_rank}_worker_{self._instance_count}"
f"compute_ablp_loader_rank_{compute_rank}_worker_{self._instance_count}"
)
logger.info(f"rank: {torch.distributed.get_rank()}, worker_key: {worker_key}")
logger.info(f"rank: {compute_rank}, worker_key: {worker_key}")
worker_options = RemoteDistSamplingWorkerOptions(
server_rank=list(range(dataset.cluster_info.num_storage_nodes)),
num_workers=num_workers,
Expand Down
12 changes: 7 additions & 5 deletions gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,18 @@ def _setup_for_graph_store(
node_feature_info = dataset.fetch_node_feature_info()
edge_feature_info = dataset.fetch_edge_feature_info()
edge_types = dataset.fetch_edge_types()
node_rank = dataset.cluster_info.compute_node_rank
compute_rank = torch.distributed.get_rank()

# Get sampling ports for compute-storage connections.
# One port per compute process (not per compute node) so that each
# process gets its own server-side sampling worker group.
sampling_ports = dataset.fetch_free_ports_on_storage_cluster(
num_ports=dataset.cluster_info.num_compute_nodes
num_ports=dataset.cluster_info.compute_cluster_world_size
)
sampling_port = sampling_ports[node_rank]
sampling_port = sampling_ports[compute_rank]

worker_key = f"compute_rank_{node_rank}_worker_{self._instance_count}"
logger.info(f"Rank {torch.distributed.get_rank()} worker key: {worker_key}")
worker_key = f"compute_rank_{compute_rank}_worker_{self._instance_count}"
logger.info(f"Rank {compute_rank} worker key: {worker_key}")
worker_options = RemoteDistSamplingWorkerOptions(
server_rank=list(range(dataset.cluster_info.num_storage_nodes)),
num_workers=num_workers,
Expand Down
Loading