diff --git a/python/gigl/common/services/vertex_ai.py b/python/gigl/common/services/vertex_ai.py index 1e7b288e..b3ce330c 100644 --- a/python/gigl/common/services/vertex_ai.py +++ b/python/gigl/common/services/vertex_ai.py @@ -85,6 +85,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name DEFAULT_PIPELINE_TIMEOUT_S: Final[int] = 60 * 60 * 36 # 36 hours DEFAULT_CUSTOM_JOB_TIMEOUT_S: Final[int] = 60 * 60 * 24 # 24 hours +BOOT_DISK_PLACEHOLDER: Final[str] = "DISK_TYPE_UNSPECIFIED" @dataclass @@ -98,7 +99,7 @@ class VertexAiJobConfig: accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED" accelerator_count: int = 0 replica_count: int = 1 - boot_disk_type: str = "pd-ssd" # Persistent Disk SSD + boot_disk_type: str = BOOT_DISK_PLACEHOLDER # Persistent Disk SSD boot_disk_size_gb: int = 100 # Default disk size in GB labels: Optional[dict[str, str]] = None timeout_s: Optional[ @@ -107,6 +108,15 @@ class VertexAiJobConfig: enable_web_access: bool = True scheduling_strategy: Optional[aiplatform.gapic.Scheduling.Strategy] = None + def __post_init__(self): + if self.boot_disk_type is BOOT_DISK_PLACEHOLDER: + if self.machine_type.startswith("g4-"): + logger.info(f"No boot disk type set, and g4 machine detected, using hyperdisk-balanced") + self.boot_disk_type = "hyperdisk-balanced" # g4 machines require use of hyperdisk-balanced + else: + logger.info(f"No boot disk type set, using pd-ssd") + self.boot_disk_type = "pd-ssd" + class VertexAIService: """ diff --git a/python/tests/unit/common/services/__init__.py b/python/tests/unit/common/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/tests/unit/common/services/vertex_ai_test.py b/python/tests/unit/common/services/vertex_ai_test.py new file mode 100644 index 00000000..9e2bf6e3 --- /dev/null +++ b/python/tests/unit/common/services/vertex_ai_test.py @@ -0,0 +1,36 @@ +import unittest + +from parameterized import param, parameterized + +from gigl.common.services.vertex_ai import VertexAiJobConfig + + +class VertexAIServiceTest(unittest.TestCase): + @parameterized.expand( + [ + param( + "g4 machine ; should default to hyperdisk-balanced", + machine_type="g4-standard-8", + expected_boot_disk_type="hyperdisk-balanced", + ), + param( + "n1 machine ; should default to pd-ssd", + machine_type="n1-standard-4", + expected_boot_disk_type="pd-ssd", + ), + ] + ) + def test_default_boot_disk_for_machine( + self, _, machine_type, expected_boot_disk_type + ): + job_config = VertexAiJobConfig( + job_name="job_name", + container_uri="container_uri", + command=["command"], + machine_type=machine_type, + ) + self.assertEqual(job_config.boot_disk_type, expected_boot_disk_type) + + +if __name__ == "__main__": + unittest.main()