Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
789045c
Add VertexAiMultiPoolConfig to support multiple worker pools
kmonte Sep 24, 2025
13e30c4
Merge branch 'main' into kmonte/add-multipool-vai
kmontemayor2-sc Sep 25, 2025
76f94ce
Merge branch 'main' into kmonte/add-multipool-vai
kmonte Sep 26, 2025
d810e45
typo
kmonte Sep 26, 2025
5ff13b2
to more explicit configs
kmonte Sep 26, 2025
b4d35ef
wip
kmonte Sep 30, 2025
5a27057
wip
kmonte Sep 30, 2025
f8c4ab7
works
kmonte Oct 1, 2025
3a6768f
Merge branch 'main' into kmonte/launch-multipool-vai
kmonte Oct 1, 2025
6ea5b3f
tests
kmonte Oct 1, 2025
90d651c
remove
kmonte Oct 1, 2025
94037af
fix typecheck
kmonte Oct 1, 2025
298d19d
comments
kmontemayor2-sc Oct 3, 2025
7ba700b
Merge branch 'main' into kmonte/launch-multipool-vai
kmonte Oct 6, 2025
1752d91
Add get_graph_store_info to setup graph store clusters
kmonte Oct 6, 2025
6429617
add intergration tests for get_graph_store_info
kmonte Oct 6, 2025
603ca6a
[AUTOMATED] Update dep.vars, and other relevant files with new image …
github-actions[bot] Oct 6, 2025
d78e938
bleg
kmonte Oct 6, 2025
de00fd2
Merge branch 'kmonte/multipool-utils' of https://github.com/Snapchat/…
kmonte Oct 6, 2025
fc9d0d0
wip
kmonte Oct 6, 2025
704cbfd
Merge branch 'main' into kmonte/multipool-utils
kmonte Oct 7, 2025
fb91f1a
bleh
kmonte Oct 7, 2025
c2b5607
fix
kmonte Oct 7, 2025
694d72b
Nightly
kmonte Oct 7, 2025
74d8df1
Add utils to parse VAI CLUSTER_SPEC
kmonte Oct 7, 2025
de1de6a
comments
kmonte Oct 7, 2025
fc0dca4
rename
kmonte Oct 7, 2025
d3319d6
fixes
kmonte Oct 7, 2025
0905664
fixes
kmonte Oct 7, 2025
112d0ad
fix
kmonte Oct 7, 2025
f621bc7
address comments
kmonte Oct 8, 2025
fee17c1
Merge branch 'main' into kmonte/parse-cluster-spec
kmonte Oct 8, 2025
9b99706
reword
kmonte Oct 8, 2025
b9a766c
Merge branch 'main' into kmonte/multipool-utils
kmonte Oct 8, 2025
c4ec660
Merge branch 'kmonte/parse-cluster-spec' into kmonte/multipool-utils
kmonte Oct 8, 2025
86eb8eb
merges
kmonte Oct 8, 2025
026301e
merge
kmonte Oct 9, 2025
af12a00
fix
kmonte Oct 9, 2025
2c13526
test fix
kmonte Oct 9, 2025
a3dea31
fix test
kmonte Oct 9, 2025
aaceeee
fixes
kmonte Oct 10, 2025
1805a8d
[AUTOMATED] Bumped version to v0.0.10
github-actions[bot] Oct 10, 2025
11390ea
fix
kmonte Oct 10, 2025
a30e0b0
Merge branch 'release/v0.0.10' into kmonte/multipool-utils
kmonte Oct 10, 2025
e897ec2
Merge branch 'main' into kmonte/multipool-utils
kmonte Nov 3, 2025
94ada90
cleanup
kmonte Nov 3, 2025
f74c604
address comments
kmonte Nov 4, 2025
8b5b120
ble
kmonte Nov 4, 2025
ac76340
Properly set VAI env vars
kmonte Nov 4, 2025
40c755c
add test
kmonte Nov 4, 2025
21c8c03
Revert "ble"
kmonte Nov 4, 2025
0ec4a91
some vai cluster checks
kmonte Nov 5, 2025
527e040
Merge branch 'main' into kmonte/multipool-utils
kmontemayor2-sc Nov 5, 2025
e8df2c5
remove jon
kmonte Nov 5, 2025
8e4a43b
fix
kmonte Nov 5, 2025
7ec3f05
fix
kmonte Nov 5, 2025
c6e4746
bump to 0.0.11
kmonte Nov 5, 2025
41cd07f
more docs
kmonte Nov 5, 2025
009e001
Merge branch 'main' into kmonte/server-client-scratch-v2
kmonte Nov 5, 2025
8150d8d
Merge branch 'kmonte/multipool-utils' into kmonte/server-client-scrat…
kmonte Nov 5, 2025
4df5fa1
Revert "Merge branch 'kmonte/multipool-utils' into kmonte/server-clie…
kmonte Nov 5, 2025
c64a8ca
Merge branch 'kmonte/vai-env-vars' into kmonte/server-client-scratch-v2
kmonte Nov 5, 2025
8a7be1d
wip
kmonte Nov 5, 2025
9f2350a
slop
kmonte Nov 5, 2025
89d7b19
blegh
kmonte Nov 5, 2025
5ae5b09
merge
kmonte Nov 6, 2025
9c629e2
Merge branch 'main' into kmonte/server-client-scratch-v2
kmonte Nov 6, 2025
21f20a9
wip
kmonte Nov 6, 2025
ce14e00
bleh
kmonte Nov 6, 2025
4960441
Merge branch 'main' into kmonte/server-client-scratch-v2
kmonte Nov 7, 2025
8788c2e
wip
kmonte Nov 7, 2025
2818a9f
wip2
kmonte Nov 8, 2025
e37d5f8
wip
kmonte Nov 10, 2025
217877d
integration test works
kmonte Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ assert_yaml_configs_parse:
# Ex. `make unit_test_py PY_TEST_FILES="eval_metrics_test.py"`
# By default, runs all tests under python/tests/unit.
# See the help text for "--test_file_pattern" in python/tests/test_args.py for more details.
unit_test_py: clean_build_files_py type_check
unit_test_py: clean_build_files_py #type_check
( cd python ; \
python -m tests.unit.main \
--env=test \
Expand Down
3 changes: 3 additions & 0 deletions proto/snapchat/research/gbml/gigl_resource_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ message VertexAiResourceConfig {
message VertexAiGraphStoreConfig {
VertexAiResourceConfig graph_store_pool = 1;
VertexAiResourceConfig compute_pool = 2;

int32 num_processes_per_storage_machine = 3;
int32 num_processes_per_compute_machine = 4;
}
// (deprecated)
// Configuration for distributed training resources
Expand Down
1 change: 1 addition & 0 deletions python/gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def build_dataset_from_task_config_uri(
)

# Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path
logger.info(f"Reading GbmlConfig from URI: {task_config_uri}")
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=UriFactory.create_uri(task_config_uri)
)
Expand Down
185 changes: 118 additions & 67 deletions python/gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions
from graphlearn_torch.distributed import (
DistLoader,
DistServer,
MpDistSamplingWorkerOptions,
RemoteDistSamplingWorkerOptions,
request_server,
)
from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType
Expand All @@ -20,6 +26,7 @@
shard_nodes_by_process,
strip_label_edges,
)
from gigl.env.distributed import GraphStoreInfo
from gigl.src.common.types.graph_data import (
NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
Expand All @@ -37,10 +44,10 @@
class DistNeighborLoader(DistLoader):
def __init__(
self,
dataset: DistDataset,
dataset: Optional[DistDataset],
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
input_nodes: Optional[
Union[torch.Tensor, Tuple[NodeType, torch.Tensor]]
Union[torch.Tensor, Tuple[NodeType, torch.Tensor], list[torch.Tensor]]
] = None,
num_workers: int = 1,
batch_size: int = 1,
Expand All @@ -54,6 +61,7 @@ def __init__(
num_cpu_threads: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
graph_store_info: Optional[GraphStoreInfo] = None,
):
"""
Note: We try to adhere to pyg dataloader api as much as possible.
Expand Down Expand Up @@ -193,6 +201,10 @@ def __init__(
)

if input_nodes is None:
if dataset is None:
raise ValueError(
"Dataset must be provided if input_nodes are not provided."
)
if dataset.node_ids is None:
raise ValueError(
"Dataset must have node ids if input_nodes are not provided."
Expand All @@ -205,44 +217,106 @@ def __init__(

# Determines if the node ids passed in are heterogeneous or homogeneous.
self._is_labeled_heterogeneous = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes

# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset.node_ids, abc.Mapping):
if (
len(dataset.node_ids) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
self._is_labeled_heterogeneous = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
if dataset is None:
if graph_store_info is None:
raise ValueError(
"graph_store_info must be provided if dataset is not provided."
)
num_partitions, partition_idx, ntypes, etypes = request_server(
server_rank=0,
func=DistServer.get_dataset_meta,
)
if not isinstance(input_nodes, list):
raise ValueError(
"input_nodes must be a list if dataset is not provided."
)
if (
len(input_nodes)
!= graph_store_info.num_storage_nodes
* graph_store_info.num_processes_per_storage
):
raise ValueError(
f"input_nodes must be a list of length {graph_store_info.num_storage_nodes * graph_store_info.num_processes_per_storage}, got {len(input_nodes)}. E.g. one entry per process in the storage cluster."
)
worker_options = RemoteDistSamplingWorkerOptions(
server_rank=[
server_rank
for server_rank in range(
graph_store_info.num_storage_nodes
* graph_store_info.num_processes_per_storage
)
else:
node_type = None
],
num_workers=num_workers,
worker_devices=[torch.device("cpu") for i in range(num_workers)],
master_addr=graph_store_info.cluster_master_ip,
master_port=graph_store_info.cluster_master_port,
)
else:
node_type, node_ids = input_nodes
assert isinstance(
dataset.node_ids, abc.Mapping
), "Dataset must be heterogeneous if provided input nodes are a tuple."

num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes

# If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting,
# if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE.
if isinstance(dataset.node_ids, abc.Mapping):
if (
len(dataset.node_ids) == 1
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
self._is_labeled_heterogeneous = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
)
else:
node_type = None
elif isinstance(input_nodes, tuple):
node_type, node_ids = input_nodes
assert isinstance(
dataset.node_ids, abc.Mapping
), "Dataset must be heterogeneous if provided input nodes are a tuple."
else:
raise ValueError(
f"input_nodes must be a torch.Tensor or a tuple of (node_type, node_ids), got {type(input_nodes)}"
)
etypes = dataset.get_edge_types()

curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)

self._node_feature_info = dataset.node_feature_info
self._edge_feature_info = dataset.edge_feature_info
self._node_feature_info = dataset.node_feature_info
self._edge_feature_info = dataset.edge_feature_info

input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type)
input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type)
dist_sampling_ports = (
gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
)
dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank]

worker_options = MpDistSamplingWorkerOptions(
num_workers=num_workers,
worker_devices=[torch.device("cpu") for _ in range(num_workers)],
worker_concurrency=worker_concurrency,
# Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group
# need to be connected. Thus, we need master ip address and master port to
# initate the connection.
# Note that different groups of workers are independent, and thus
# the sampling processes in different groups should be independent, and should
# use different master ports.
master_addr=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
# Load testing show that when num_rpc_threads exceed 16, the performance
# will degrade.
num_rpc_threads=min(dataset.num_partitions, 16),
rpc_timeout=600,
channel_size=channel_size,
pin_memory=device.type == "cuda",
)

# Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize
# the memory overhead and CPU contention.
Expand Down Expand Up @@ -280,31 +354,14 @@ def __init__(
)

# Sets up worker options for the dataloader
dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank]

worker_options = MpDistSamplingWorkerOptions(
num_workers=num_workers,
worker_devices=[torch.device("cpu") for _ in range(num_workers)],
worker_concurrency=worker_concurrency,
# Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group
# need to be connected. Thus, we need master ip address and master port to
# initate the connection.
# Note that different groups of workers are independent, and thus
# the sampling processes in different groups should be independent, and should
# use different master ports.
master_addr=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
# Load testing show that when num_rpc_threads exceed 16, the performance
# will degrade.
num_rpc_threads=min(dataset.num_partitions, 16),
rpc_timeout=600,
channel_size=channel_size,
pin_memory=device.type == "cuda",
)

if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()

num_neighbors = patch_fanout_for_sampling(etypes, num_neighbors)
sampling_config = SamplingConfig(
sampling_type=SamplingType.NODE,
num_neighbors=num_neighbors,
Expand All @@ -315,16 +372,10 @@ def __init__(
collect_features=True,
with_neg=False,
with_weight=False,
edge_dir=dataset.edge_dir,
edge_dir=dataset.edge_dir if dataset is not None else "out",
seed=None, # it's actually optional - None means random.
)

if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()

super().__init__(dataset, input_data, sampling_config, device, worker_options)

def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
Expand Down
Empty file.
59 changes: 59 additions & 0 deletions python/gigl/distributed/server_client/remote_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional, Union

import torch

from gigl.common.logger import Logger
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, FeatureInfo

logger = Logger()

_dataset: Optional[DistDataset] = None


def register_dataset(dataset: DistDataset) -> None:
global _dataset
if _dataset is not None:
raise ValueError("Dataset already registered! Cannot register a new dataset.")
_dataset = dataset


def get_node_feature_info() -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]:
if _dataset is None:
raise ValueError(
"Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`"
)
return _dataset.node_feature_info


def get_edge_feature_info() -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]:
if _dataset is None:
raise ValueError(
"Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`"
)
return _dataset.edge_feature_info


def get_node_ids_for_rank(
rank: int, world_size: int, node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE
) -> torch.Tensor:
logger.info(
f"Getting node ids for rank {rank} / {world_size} with node type {node_type}"
)
if _dataset is None:
raise ValueError(
"Dataset not registered! Register the dataset first with `gigl.distributed.server_client.register_dataset`"
)
if isinstance(_dataset.node_ids, torch.Tensor):
nodes = _dataset.node_ids
elif isinstance(_dataset.node_ids, dict):
nodes = _dataset.node_ids[node_type]
else:
raise ValueError(
f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}"
)
logger.info(f"Sharding nodes {nodes.shape} for rank {rank} / {world_size}")
logger.info(f"Nodes: {nodes}")
return shard_nodes_by_process(nodes, rank, world_size)
Loading