diff --git a/.github/cloud_builder/run_command_on_active_checkout.yaml b/.github/cloud_builder/run_command_on_active_checkout.yaml index eb905fb99..5e898d388 100644 --- a/.github/cloud_builder/run_command_on_active_checkout.yaml +++ b/.github/cloud_builder/run_command_on_active_checkout.yaml @@ -3,7 +3,7 @@ substitutions: options: logging: CLOUD_LOGGING_ONLY steps: - - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 + - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 entrypoint: /bin/bash args: - -c diff --git a/dep_vars.env b/dep_vars.env index 0b6ca38ef..c5aa972b5 100644 --- a/dep_vars.env +++ b/dep_vars.env @@ -1,13 +1,13 @@ # Note this file only supports static key value pairs so it can be loaded by make, bash, python, and sbt without any additional parsing. -DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 -DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 -DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:96d2b7ce368e8af7bc7a52eac7b6de4789f06815.41.1 +DOCKER_LATEST_BASE_CUDA_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cuda-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 +DOCKER_LATEST_BASE_CPU_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-cpu-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 +DOCKER_LATEST_BASE_DATAFLOW_IMAGE_NAME_WITH_TAG=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dataflow-base:64296177d7a8214cc5077dc9fddd9696adfdaaf2.42.1 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.9 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.9 -DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.9 -DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.9 -DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.9.yaml +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cuda:0.0.10 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu:0.0.10 +DEFAULT_GIGL_RELEASE_SRC_IMAGE_DATAFLOW_CPU=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/src-cpu-dataflow:0.0.10 +DEFAULT_GIGL_RELEASE_DEV_WORKBENCH_IMAGE=us-central1-docker.pkg.dev/external-snap-ci-github-gigl/public-gigl/gigl-dev-workbench:0.0.10 +DEFAULT_GIGL_RELEASE_KFP_PIPELINE_PATH=gs://public-gigl/releases/pipelines/gigl-pipeline-0.0.10.yaml SPARK_31_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark-custom-tfrecord_2.12-0.5.0.jar SPARK_35_TFRECORD_JAR_GCS_PATH=gs://public-gigl/tools/scala/spark_packages/spark_3.5.0-custom-tfrecord_2.12-0.6.1.jar diff --git a/python/gigl/__init__.py b/python/gigl/__init__.py index 00ec2dcdb..9b36b86cf 100644 --- a/python/gigl/__init__.py +++ b/python/gigl/__init__.py @@ -1 +1 @@ -__version__ = "0.0.9" +__version__ = "0.0.10" diff --git a/python/gigl/common/utils/vertex_ai_context.py b/python/gigl/common/utils/vertex_ai_context.py index 35d1d90ad..bc4a93d02 100644 --- a/python/gigl/common/utils/vertex_ai_context.py +++ b/python/gigl/common/utils/vertex_ai_context.py @@ -8,7 +8,6 @@ from typing import Optional import omegaconf -from google.cloud.aiplatform_v1.types import CustomJobSpec from gigl.common import GcsUri from gigl.common.logger import Logger @@ -183,29 +182,35 @@ class ClusterSpec: cluster: dict[str, list[str]] # Worker pool names mapped to their replica lists environment: str # The environment string (e.g., "cloud") task: TaskInfo # Information about the current task - # The CustomJobSpec for the current job - # See the docs for more info: - # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec - job: Optional[CustomJobSpec] = None - # We use a custom method for parsing, because CustomJobSpec is a protobuf message. + # DESPITE what the docs say, this is *not* a CustomJobSpec. + # It's *sort of* like a PythonPackageSpec, but it's not. + # It has `jobArgs` instead of `args`. + # See an example: + # {"python_module":"","package_uris":[],"job_args":[]} + job: Optional[dict] = None + + # We use a custom method for parsing, the "job" is actually a serialized json string. @classmethod def from_json(cls, json_str: str) -> "ClusterSpec": """Instantiates ClusterSpec from a JSON string.""" cluster_spec_json = json.loads(json_str) if "job" in cluster_spec_json and cluster_spec_json["job"] is not None: - job_spec = CustomJobSpec(**cluster_spec_json.pop("job")) + logger.info(f"Job spec: {cluster_spec_json['job']}") + job_spec = json.loads(cluster_spec_json.pop("job")) else: job_spec = None conf = omegaconf.OmegaConf.create(cluster_spec_json) if isinstance(conf, omegaconf.ListConfig): raise ValueError("ListConfig is not supported") - return cls( + cluster_spec = cls( cluster=conf.cluster, environment=conf.environment, task=conf.task, job=job_spec, ) + logger.info(f"Cluster spec: {cluster_spec}") + return cluster_spec def get_cluster_spec() -> ClusterSpec: diff --git a/python/gigl/distributed/utils/__init__.py b/python/gigl/distributed/utils/__init__.py index b73d00b8b..363fb1470 100644 --- a/python/gigl/distributed/utils/__init__.py +++ b/python/gigl/distributed/utils/__init__.py @@ -3,10 +3,12 @@ """ __all__ = [ + "GraphStoreInfo", "get_available_device", + "get_free_port", "get_free_ports_from_master_node", "get_free_ports_from_node", - "get_free_port", + "get_graph_store_info", "get_internal_ip_from_all_ranks", "get_internal_ip_from_master_node", "get_internal_ip_from_node", @@ -20,9 +22,11 @@ init_neighbor_loader_worker, ) from .networking import ( + GraphStoreInfo, get_free_port, get_free_ports_from_master_node, get_free_ports_from_node, + get_graph_store_info, get_internal_ip_from_all_ranks, get_internal_ip_from_master_node, get_internal_ip_from_node, diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 0ef68dbeb..0f1d8e77a 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -4,6 +4,11 @@ import torch from gigl.common.logger import Logger +from gigl.common.utils.vertex_ai_context import ( + get_cluster_spec, + is_currently_running_in_vertex_ai_job, +) +from gigl.env.distributed import GraphStoreInfo logger = Logger() @@ -179,3 +184,55 @@ def get_internal_ip_from_all_ranks() -> list[str]: assert all(ip for ip in ip_list), "Could not retrieve all ranks' internal IPs" return ip_list + + +def get_graph_store_info() -> GraphStoreInfo: + """ + Get the information about the graph store cluster. + + Returns: + GraphStoreInfo: The information about the graph store cluster. + + Raises: + ValueError: If a torch distributed environment is not initialized. + ValueError: If not running running in a supported environment. + """ + if not torch.distributed.is_initialized(): + raise ValueError("Distributed environment must be initialized") + if is_currently_running_in_vertex_ai_job(): + cluster_spec = get_cluster_spec() + # We setup the VAI cluster such that the compute nodes come first, followed by the storage nodes. + if "workerpool1" in cluster_spec.cluster: + num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + len( + cluster_spec.cluster["workerpool1"] + ) + else: + num_compute_nodes = len(cluster_spec.cluster["workerpool0"]) + num_storage_nodes = len(cluster_spec.cluster["workerpool2"]) + else: + raise ValueError( + "Must be running on a vertex AI job to get graph store cluster info!" + ) + + cluster_master_ip = get_internal_ip_from_master_node() + # We assume that the compute cluster nodes come first, followed by the storage nodes. + compute_cluster_master_ip = get_internal_ip_from_node(node_rank=0) + storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) + + cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + compute_cluster_master_port = get_free_ports_from_node(num_ports=1, node_rank=0)[0] + storage_cluster_master_port = get_free_ports_from_node( + num_ports=1, node_rank=num_compute_nodes + )[0] + + return GraphStoreInfo( + num_cluster_nodes=num_storage_nodes + num_compute_nodes, + num_storage_nodes=num_storage_nodes, + num_compute_nodes=num_compute_nodes, + cluster_master_ip=cluster_master_ip, + storage_cluster_master_ip=storage_cluster_master_ip, + compute_cluster_master_ip=compute_cluster_master_ip, + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + ) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index 84466dde4..e8999be67 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -19,3 +19,29 @@ class DistributedContext: # Total number of machines global_world_size: int + + +@dataclass(frozen=True) +class GraphStoreInfo: + """Information about a graph store cluster.""" + + # Number of nodes in the whole cluster + num_cluster_nodes: int + # Number of nodes in the storage cluster + num_storage_nodes: int + # Number of nodes in the compute cluster + num_compute_nodes: int + + # IP address of the master node for the whole cluster + cluster_master_ip: str + # IP address of the master node for the storage cluster + storage_cluster_master_ip: str + # IP address of the master node for the compute cluster + compute_cluster_master_ip: str + + # Port of the master node for the whole cluster + cluster_master_port: int + # Port of the master node for the storage cluster + storage_cluster_master_port: int + # Port of the master node for the compute cluster + compute_cluster_master_port: int diff --git a/python/pyproject.toml b/python/pyproject.toml index 5cda3a3cf..e1fde652e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" name = "gigl" description = "GIgantic Graph Learning Library" readme = "README.md" -version = "0.0.9" +version = "0.0.10" requires-python = ">=3.9,<3.10" # Currently we only support python 3.9 as per deps setup below classifiers = [ "Programming Language :: Python", diff --git a/python/tests/integration/distributed/utils/__init__.py b/python/tests/integration/distributed/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/integration/distributed/utils/networking_test.py b/python/tests/integration/distributed/utils/networking_test.py new file mode 100644 index 000000000..bbbf10283 --- /dev/null +++ b/python/tests/integration/distributed/utils/networking_test.py @@ -0,0 +1,83 @@ +import unittest +import uuid +from textwrap import dedent + +from parameterized import param, parameterized + +from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService +from gigl.env.pipelines_config import get_resource_config +from gigl.common.constants import GIGL_RELEASE_SRC_IMAGE_CPU + + +class NetworkingUtlsIntegrationTest(unittest.TestCase): + def setUp(self): + self._resource_config = get_resource_config() + self._project = self._resource_config.project + self._location = self._resource_config.region + self._service_account = self._resource_config.service_account_email + self._staging_bucket = ( + self._resource_config.temp_assets_regional_bucket_path.uri + ) + self._vertex_ai_service = VertexAIService( + project=self._project, + location=self._location, + service_account=self._service_account, + staging_bucket=self._staging_bucket, + ) + super().setUp() + + @parameterized.expand( + [ + param( + "Test with 1 compute node and 1 storage node", + compute_nodes=1, + storage_nodes=1, + ), + param( + "Test with 2 compute nodes and 2 storage nodes", + compute_nodes=2, + storage_nodes=2, + ), + ] + ) + def test_get_graph_store_info(self, _, storage_nodes, compute_nodes): + job_name = f"GiGL-Integration-Test-Graph-Store-{uuid.uuid4()}" + command = [ + "python", + "-c", + dedent( + f""" + import torch + from gigl.distributed.utils import get_graph_store_info + torch.distributed.init_process_group(backend="gloo") + info = get_graph_store_info() + assert info.num_storage_nodes == {storage_nodes}, f"Expected {storage_nodes} storage nodes, but got {{ info.num_storage_nodes }}" + assert info.num_compute_nodes == {compute_nodes}, f"Expected {compute_nodes} compute nodes, but got {{ info.num_compute_nodes }}" + assert info.num_cluster_nodes == {storage_nodes + compute_nodes}, f"Expected {storage_nodes + compute_nodes} cluster nodes, but got {{ info.num_cluster_nodes }}" + assert info.cluster_master_ip is not None, f"Cluster master IP is None" + assert info.storage_cluster_master_ip is not None, f"Storage cluster master IP is None" + assert info.compute_cluster_master_ip is not None, f"Compute cluster master IP is None" + assert info.cluster_master_port is not None, f"Cluster master port is None" + assert info.storage_cluster_master_port is not None, f"Storage cluster master port is None" + assert info.compute_cluster_master_port is not None, f"Compute cluster master port is None" + """ + ), + ] + compute_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, + replica_count=compute_nodes, + command=command, + machine_type="n2-standard-8", + ) + storage_cluster_config = VertexAiJobConfig( + job_name=job_name, + container_uri=GIGL_RELEASE_SRC_IMAGE_CPU, + replica_count=storage_nodes, + machine_type="n1-standard-4", + command=command, + ) + + self._vertex_ai_service.launch_graph_store_job( + compute_cluster_config, storage_cluster_config + ) diff --git a/python/tests/unit/common/utils/vertex_ai_context_test.py b/python/tests/unit/common/utils/vertex_ai_context_test.py index db6ad7cc7..aa7dba6f0 100644 --- a/python/tests/unit/common/utils/vertex_ai_context_test.py +++ b/python/tests/unit/common/utils/vertex_ai_context_test.py @@ -3,8 +3,6 @@ import unittest from unittest.mock import call, patch -from google.cloud.aiplatform_v1.types import CustomJobSpec - from gigl.common import GcsUri from gigl.common.services.vertex_ai import LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY from gigl.common.utils.vertex_ai_context import ( @@ -129,11 +127,7 @@ def test_parse_cluster_spec_success(self): }, "task": {"type": "workerpool0", "index": 1, "trial": "trial-123"}, "environment": "cloud", - "job": { - "worker_pool_specs": [ - {"machine_spec": {"machine_type": "n1-standard-4"}} - ] - }, + "job": '{ "worker_pool_specs": [ {"machine_spec": {"machine_type": "n1-standard-4"}}]}', } ) @@ -150,11 +144,11 @@ def test_parse_cluster_spec_success(self): }, environment="cloud", task=TaskInfo(type="workerpool0", index=1, trial="trial-123"), - job=CustomJobSpec( - worker_pool_specs=[ + job={ + "worker_pool_specs": [ {"machine_spec": {"machine_type": "n1-standard-4"}} ] - ), + }, ) self.assertEqual(cluster_spec, expected_cluster_spec) diff --git a/python/tests/unit/distributed/utils/networking_test.py b/python/tests/unit/distributed/utils/networking_test.py index 0a2fcf6be..b96319e09 100644 --- a/python/tests/unit/distributed/utils/networking_test.py +++ b/python/tests/unit/distributed/utils/networking_test.py @@ -1,5 +1,8 @@ +import json +import os import subprocess import unittest +from typing import Optional from unittest.mock import patch import torch @@ -8,8 +11,10 @@ from parameterized import param, parameterized from gigl.distributed.utils import ( + GraphStoreInfo, get_free_ports_from_master_node, get_free_ports_from_node, + get_graph_store_info, get_internal_ip_from_master_node, get_internal_ip_from_node, ) @@ -190,7 +195,7 @@ def tearDown(self): ), ] ) - def test_get_free_ports_from_master_node_two_ranks( + def _test_get_free_ports_from_master_node_two_ranks( self, _name, num_ports, world_size ): init_process_group_init_method = get_process_group_init_method() @@ -218,7 +223,7 @@ def test_get_free_ports_from_master_node_two_ranks( ), ] ) - def test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( + def _test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( self, _name, num_ports, world_size, master_node_rank, ports ): init_process_group_init_method = get_process_group_init_method() @@ -234,14 +239,14 @@ def test_get_free_ports_from_master_node_two_ranks_custom_master_node_rank( nprocs=world_size, ) - def test_get_free_ports_from_master_fails_if_process_group_not_initialized(self): + def _test_get_free_ports_from_master_fails_if_process_group_not_initialized(self): with self.assertRaises( AssertionError, msg="An error should be raised since the `dist.init_process_group` is not initialized", ): get_free_ports_from_master_node(num_ports=1) - def test_get_internal_ip_from_master_node(self): + def _test_get_internal_ip_from_master_node(self): init_process_group_init_method = get_process_group_init_method() expected_host_ip = subprocess.check_output(["hostname", "-i"]).decode().strip() world_size = 2 @@ -265,7 +270,7 @@ def test_get_internal_ip_from_master_node(self): ), ] ) - def test_get_internal_ip_from_master_node_with_master_node_rank( + def _test_get_internal_ip_from_master_node_with_master_node_rank( self, _, world_size, master_node_rank ): init_process_group_init_method = get_process_group_init_method() @@ -281,7 +286,7 @@ def test_get_internal_ip_from_master_node_with_master_node_rank( nprocs=world_size, ) - def test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized( + def _test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized( self, ): with self.assertRaises( @@ -289,3 +294,221 @@ def test_get_internal_ip_from_master_node_fails_if_process_group_not_initialized msg="An error should be raised since the `dist.init_process_group` is not initialized", ): get_internal_ip_from_master_node() + + +def _test_get_graph_store_info_in_dist_context( + rank: int, + world_size: int, + init_process_group_init_method: str, + storage_nodes: int, + compute_nodes: int, +): + """Test get_graph_store_info in a real distributed context.""" + # Initialize distributed process group + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=world_size, + rank=rank, + ) + try: + # Call get_graph_store_info + graph_store_info = get_graph_store_info() + + # Verify the result is a GraphStoreInfo instance + assert isinstance( + graph_store_info, GraphStoreInfo + ), "Result should be a GraphStoreInfo instance" + # Verify cluster sizes + assert ( + graph_store_info.num_storage_nodes == storage_nodes + ), f"Expected {storage_nodes} storage nodes" + assert ( + graph_store_info.num_compute_nodes == compute_nodes + ), f"Expected {compute_nodes} compute nodes" + assert ( + graph_store_info.num_cluster_nodes == storage_nodes + compute_nodes + ), "Total nodes should be sum of storage and compute nodes" + + # Verify IP addresses are strings and not empty + assert isinstance( + graph_store_info.cluster_master_ip, str + ), "Cluster master IP should be a string" + assert ( + len(graph_store_info.cluster_master_ip) > 0 + ), "Cluster master IP should not be empty" + assert isinstance( + graph_store_info.storage_cluster_master_ip, str + ), "Storage cluster master IP should be a string" + assert ( + len(graph_store_info.storage_cluster_master_ip) > 0 + ), "Storage cluster master IP should not be empty" + assert isinstance( + graph_store_info.compute_cluster_master_ip, str + ), "Compute cluster master IP should be a string" + assert ( + len(graph_store_info.compute_cluster_master_ip) > 0 + ), "Compute cluster master IP should not be empty" + + # Verify ports are positive integers + assert isinstance( + graph_store_info.cluster_master_port, int + ), "Cluster master port should be an integer" + assert ( + graph_store_info.cluster_master_port > 0 + ), "Cluster master port should be positive" + assert isinstance( + graph_store_info.storage_cluster_master_port, int + ), "Storage cluster master port should be an integer" + assert ( + graph_store_info.storage_cluster_master_port > 0 + ), "Storage cluster master port should be positive" + assert isinstance( + graph_store_info.compute_cluster_master_port, int + ), "Compute cluster master port should be an integer" + assert ( + graph_store_info.compute_cluster_master_port > 0 + ), "Compute cluster master port should be positive" + + # Verify all ranks get the same result (since they should all get the same broadcasted values) + gathered_info: list[Optional[GraphStoreInfo]] = [None] * world_size + dist.all_gather_object(gathered_info, graph_store_info) + + # All ranks should have the same GraphStoreInfo + for i, info in enumerate(gathered_info): + assert info is not None + assert ( + info.num_cluster_nodes == graph_store_info.num_cluster_nodes + ), f"Rank {i} should have same cluster nodes" + assert ( + info.num_storage_nodes == graph_store_info.num_storage_nodes + ), f"Rank {i} should have same storage nodes" + assert ( + info.num_compute_nodes == graph_store_info.num_compute_nodes + ), f"Rank {i} should have same compute nodes" + assert ( + info.cluster_master_ip == graph_store_info.cluster_master_ip + ), f"Rank {i} should have same cluster master IP" + assert ( + info.storage_cluster_master_ip + == graph_store_info.storage_cluster_master_ip + ), f"Rank {i} should have same storage master IP" + assert ( + info.compute_cluster_master_ip + == graph_store_info.compute_cluster_master_ip + ), f"Rank {i} should have same compute master IP" + assert ( + info.cluster_master_port == graph_store_info.cluster_master_port + ), f"Rank {i} should have same cluster master port" + assert ( + info.storage_cluster_master_port + == graph_store_info.storage_cluster_master_port + ), f"Rank {i} should have same storage master port" + assert ( + info.compute_cluster_master_port + == graph_store_info.compute_cluster_master_port + ), f"Rank {i} should have same compute master port" + + finally: + dist.destroy_process_group() + + +def _get_cluster_spec_for_test(worker_pool_sizes: list[int]) -> dict: + cluster_spec: dict = { + "environment": "cloud", + "task": { + "type": "workerpool0", + "index": 0, + }, + "cluster": {}, + } + for i, worker_pool_size in enumerate(worker_pool_sizes): + cluster_spec["cluster"][f"workerpool{i}"] = [ + f"workerpool{i}-{j}:2222" for j in range(worker_pool_size) + ] + return cluster_spec + + +class TestGetGraphStoreInfo(unittest.TestCase): + """Test suite for get_graph_store_info function.""" + + def tearDown(self): + """Clean up after each test.""" + if dist.is_initialized(): + dist.destroy_process_group() + + def _test_get_graph_store_info_fails_when_distributed_not_initialized(self): + """Test that get_graph_store_info fails when distributed environment is not initialized.""" + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + self.assertIn( + "Distributed environment must be initialized", str(context.exception) + ) + + def _test_get_graph_store_info_fails_when_not_running_in_vertex_ai_job(self): + """Test that get_graph_store_info fails when not running in a Vertex AI job.""" + init_process_group_init_method = get_process_group_init_method() + torch.distributed.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=1, + rank=0, + ) + with self.assertRaises(ValueError) as context: + get_graph_store_info() + + self.assertIn( + "Must be running on a vertex AI job to get graph store cluster info!", + str(context.exception), + ) + + @parameterized.expand( + [ + param( + "Test with 1 storage node and 1 compute node", + storage_nodes=1, + compute_nodes=1, + ), + param( + "Test with 2 storage nodes and 1 compute nodes", + storage_nodes=2, + compute_nodes=1, + ), + param( + "Test with 3 storage nodes and 2 compute nodes", + storage_nodes=3, + compute_nodes=2, + ), + ] + ) + def test_get_graph_store_info_success_in_distributed_context( + self, _name, storage_nodes, compute_nodes + ): + """Test successful execution of get_graph_store_info in a real distributed context.""" + init_process_group_init_method = get_process_group_init_method() + world_size = storage_nodes + compute_nodes + if compute_nodes == 1: + worker_pool_sizes = [1, 0, storage_nodes] + else: + worker_pool_sizes = [1, compute_nodes - 1, storage_nodes] + with patch.dict( + os.environ, + { + "CLUSTER_SPEC": json.dumps( + _get_cluster_spec_for_test(worker_pool_sizes) + ), + "CLOUD_ML_JOB_ID": "test_job_id", + }, + clear=False, + ): + mp.spawn( + fn=_test_get_graph_store_info_in_dist_context, + args=( + world_size, + init_process_group_init_method, + storage_nodes, + compute_nodes, + ), + nprocs=world_size, + )