diff --git a/python/gigl/distributed/dist_ablp_neighborloader.py b/python/gigl/distributed/dist_ablp_neighborloader.py index 5ae260c39..078615e0c 100644 --- a/python/gigl/distributed/dist_ablp_neighborloader.py +++ b/python/gigl/distributed/dist_ablp_neighborloader.py @@ -33,7 +33,6 @@ from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, - reverse_edge_type, select_label_edge_types, ) from gigl.utils.data_splitters import get_labels_for_anchor_nodes @@ -243,15 +242,18 @@ def __init__( ) self._is_input_heterogeneous = True anchor_node_type, anchor_node_ids = input_nodes - # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if - # this assumption is no longer valid and/or is too opinionated - assert ( - supervision_edge_type[0] == anchor_node_type - ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ - got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" - supervision_node_type = supervision_edge_type[2] if dataset.edge_dir == "in": - supervision_edge_type = reverse_edge_type(supervision_edge_type) + supervision_node_type = supervision_edge_type[0] + if supervision_edge_type[2] != anchor_node_type: + raise ValueError( + f"Found anchor node type {anchor_node_type} but expected {supervision_edge_type[2]}" + ) + else: + supervision_node_type = supervision_edge_type[2] + if supervision_edge_type[0] != anchor_node_type: + raise ValueError( + f"Found anchor node type {anchor_node_type} but expected {supervision_edge_type[0]}" + ) elif isinstance(input_nodes, torch.Tensor): if supervision_edge_type is not None: diff --git a/python/gigl/distributed/dist_partitioner.py b/python/gigl/distributed/dist_partitioner.py index d1eb4b7ac..8bff6c33c 100644 --- a/python/gigl/distributed/dist_partitioner.py +++ b/python/gigl/distributed/dist_partitioner.py @@ -1053,12 +1053,19 @@ def _partition_label_edge_index( """ start_time = time.time() - if edge_type.src_node_type not in node_partition_book: + if self._should_assign_edges_by_src_node: + target_node_type = edge_type.src_node_type + target_edge_src_dst_index = 0 + else: + target_node_type = edge_type.dst_node_type + target_edge_src_dst_index = 1 + + if target_node_type not in node_partition_book: raise ValueError( - f"Edge type {edge_type} source node type {edge_type.src_node_type} not found in the node partition book node keys: {node_partition_book.keys()}" + f"Edge type {edge_type} source node type {target_node_type} not found in the node partition book node keys: {node_partition_book.keys()}" ) - target_node_partition_book = node_partition_book[edge_type.src_node_type] + target_node_partition_book = node_partition_book[target_node_type] if is_positive: assert ( self._positive_label_edge_index is not None @@ -1084,9 +1091,9 @@ def _label_pfn(source_node_ids, _): ), # 'partition_fn' takes 'val_indices' as input, uses it as keys for partition, # and returns the partition index. - rank_indices=label_edge_index[0], + rank_indices=label_edge_index[target_edge_src_dst_index], partition_function=_label_pfn, - total_val_size=label_edge_index[0].size(0), + total_val_size=label_edge_index[target_edge_src_dst_index].size(0), generate_pb=False, ) diff --git a/python/gigl/types/graph.py b/python/gigl/types/graph.py index 26a504675..7527618fb 100644 --- a/python/gigl/types/graph.py +++ b/python/gigl/types/graph.py @@ -109,30 +109,6 @@ class FeatureInfo: dtype: torch.dtype -def _get_label_edges( - labeled_edge_index: torch.Tensor, - edge_dir: Literal["in", "out"], - labeled_edge_type: EdgeType, -) -> tuple[EdgeType, torch.Tensor]: - """ - If edge direction is `out`, return the provided edge type and edge index. Otherwise, reverse the edge type and flip the edge index rows - so that the labeled edge index may be the same direction as the rest of the edges. - Args: - labeled_edge_index (torch.Tensor): Edge index containing positive or negative labels for supervision - edge_dir (Literal["in", "out"]): Direction of edges in the graph - labeled_edge_type (EdgeType): Edge type used for the positive or negative labeled edges - Returns: - EdgeType: Labeled edge type, which has been reversed if edge_dir = "in" - torch.Tensor: Labeled edge index, which has its rows flipped if edge_dir = "in" - """ - if edge_dir == "in": - rev_edge_type = reverse_edge_type(labeled_edge_type) - rev_labeled_edge_index = labeled_edge_index.flip(0) - return rev_edge_type, rev_labeled_edge_index - else: - return labeled_edge_type, labeled_edge_index - - # This dataclass should not be frozen, as we are expected to delete its members once they have been registered inside of the partitioner # in order to save memory. @dataclass @@ -153,8 +129,6 @@ class LoadedGraphTensors: def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: """ Convert positive and negative labels to edges. Converts this object in-place to a "heterogeneous" representation. - If the edge direction is "in", we must reverse the supervision edge type. This is because we assume that provided labels are directed - outwards in form (`anchor_node_type`, `relation`, `supervision_node_type`), and all edges in the edge index must be in the same direction. This function requires the following conditions and will throw if they are not met: 1. The positive_label is not None @@ -184,12 +158,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: "Detected multiple edge types in provided edge_index, but no edge types specified for provided positive label." ) positive_label_edge_type = message_passing_to_positive_label(main_edge_type) - labeled_edge_type, edge_index = _get_label_edges( - labeled_edge_index=self.positive_label, - edge_dir=edge_dir, - labeled_edge_type=positive_label_edge_type, - ) - edge_index_with_labels[labeled_edge_type] = edge_index + edge_index_with_labels[positive_label_edge_type] = self.positive_label logger.info( f"Treating homogeneous positive labels as edge type {positive_label_edge_type}." ) @@ -202,12 +171,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: positive_label_edge_type = message_passing_to_positive_label( positive_label_type ) - labeled_edge_type, edge_index = _get_label_edges( - labeled_edge_index=positive_label_tensor, - edge_dir=edge_dir, - labeled_edge_type=positive_label_edge_type, - ) - edge_index_with_labels[labeled_edge_type] = edge_index + edge_index_with_labels[positive_label_edge_type] = positive_label_tensor logger.info( f"Treating heterogeneous positive labels {positive_label_type} as edge type {positive_label_edge_type}." ) @@ -218,12 +182,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: "Detected multiple edge types in provided edge_index, but no edge types specified for provided negative label." ) negative_label_edge_type = message_passing_to_negative_label(main_edge_type) - labeled_edge_type, edge_index = _get_label_edges( - labeled_edge_index=self.negative_label, - edge_dir=edge_dir, - labeled_edge_type=negative_label_edge_type, - ) - edge_index_with_labels[labeled_edge_type] = edge_index + edge_index_with_labels[negative_label_edge_type] = self.negative_label logger.info( f"Treating homogeneous negative labels as edge type {negative_label_edge_type}." ) @@ -235,12 +194,7 @@ def treat_labels_as_edges(self, edge_dir: Literal["in", "out"]) -> None: negative_label_edge_type = message_passing_to_negative_label( negative_label_type ) - labeled_edge_type, edge_index = _get_label_edges( - labeled_edge_index=negative_label_tensor, - edge_dir=edge_dir, - labeled_edge_type=negative_label_edge_type, - ) - edge_index_with_labels[labeled_edge_type] = edge_index + edge_index_with_labels[negative_label_edge_type] = negative_label_tensor logger.info( f"Treating heterogeneous negative labels {negative_label_type} as edge type {negative_label_edge_type}." ) @@ -490,17 +444,3 @@ def to_homogeneous( n = next(iter(x.values())) return n return x - - -def reverse_edge_type(edge_type: _EdgeType) -> _EdgeType: - """ - Reverses the source and destination node types of the provided edge type - Args: - edge_type (EdgeType): The target edge to have its source and destinated node types reversed - Returns: - EdgeType: The reversed edge type - """ - if isinstance(edge_type, EdgeType): - return EdgeType(edge_type[2], edge_type[1], edge_type[0]) - else: - return (edge_type[2], edge_type[1], edge_type[0]) diff --git a/python/gigl/utils/data_splitters.py b/python/gigl/utils/data_splitters.py index a7272cd4d..3910f8503 100644 --- a/python/gigl/utils/data_splitters.py +++ b/python/gigl/utils/data_splitters.py @@ -25,7 +25,6 @@ DEFAULT_HOMOGENEOUS_NODE_TYPE, message_passing_to_negative_label, message_passing_to_positive_label, - reverse_edge_type, ) logger = Logger() @@ -242,15 +241,7 @@ def __init__( message_passing_to_negative_label(supervision_edge_type) for supervision_edge_type in supervision_edge_types ] - # If the edge direction is "in", we must reverse the labeled edge type, since separately provided labels are expected to be initially outgoing, and all edges - # in the graph must have the same edge direction. - if sampling_direction == "in": - self._labeled_edge_types = [ - reverse_edge_type(labeled_edge_type) - for labeled_edge_type in labeled_edge_types - ] - else: - self._labeled_edge_types = labeled_edge_types + self._labeled_edge_types = labeled_edge_types else: self._labeled_edge_types = supervision_edge_types diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index eed655705..160dd4c52 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -1,6 +1,6 @@ import unittest from collections.abc import Mapping -from typing import Optional, Union +from typing import Literal, Optional, Union import torch import torch.multiprocessing as mp @@ -334,8 +334,13 @@ def _run_dblp_supervised( len(supervision_edge_types) == 1 ), "TODO (mkolodner-sc): Support multiple supervision edge types in dataloading" supervision_edge_type = supervision_edge_types[0] - anchor_node_type = supervision_edge_type.src_node_type - supervision_node_type = supervision_edge_type.dst_node_type + sampling_edge_direction = dataset.edge_dir + if sampling_edge_direction == "in": + anchor_node_type = supervision_edge_type.dst_node_type + supervision_node_type = supervision_edge_type.src_node_type + else: + anchor_node_type = supervision_edge_type.src_node_type + supervision_node_type = supervision_edge_type.dst_node_type assert isinstance(dataset.train_node_ids, dict) assert isinstance(dataset.graph, dict) fanout = [2, 2] @@ -374,20 +379,28 @@ def _run_toy_heterogeneous_ablp( supervision_edge_types: list[EdgeType], fanout: Union[list[int], dict[EdgeType, list[int]]], ): - anchor_node_type = NodeType("user") - supervision_node_type = NodeType("story") assert ( len(supervision_edge_types) == 1 ), "TODO (mkolodner-sc): Support multiple supervision edge types in dataloading" supervision_edge_type = supervision_edge_types[0] + labeled_edge_type = message_passing_to_positive_label(supervision_edge_type) + sampling_edge_direction = dataset.edge_dir + assert isinstance(dataset.train_node_ids, dict) assert isinstance(dataset.graph, dict) - labeled_edge_type = EdgeType( - supervision_node_type, Relation("to_gigl_positive"), anchor_node_type - ) - all_positive_supervision_nodes, all_anchor_nodes, _, _ = dataset.graph[ - labeled_edge_type - ].topo.to_coo() + if sampling_edge_direction == "in": + anchor_node_type = supervision_edge_type.dst_node_type + supervision_node_type = supervision_edge_type.src_node_type + all_positive_supervision_nodes, all_anchor_nodes, _, _ = dataset.graph[ + labeled_edge_type + ].topo.to_coo() + else: + anchor_node_type = supervision_edge_type.src_node_type + supervision_node_type = supervision_edge_type.dst_node_type + all_anchor_nodes, all_positive_supervision_nodes, _, _ = dataset.graph[ + labeled_edge_type + ].topo.to_coo() + loader = DistABLPLoader( dataset=dataset, num_neighbors=fanout, @@ -818,9 +831,21 @@ def test_multiple_neighbor_loader(self): args=(dataset, self._context, expected_data_count), ) + @parameterized.expand( + [ + param( + "Inward edge direction", + sampling_edge_direction="in", + ), + param( + "Outward edge direction", + sampling_edge_direction="out", + ), + ] + ) # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build @unittest.skip("Failing on Google Cloud Build - skiping for now") - def test_dblp_supervised(self): + def test_dblp_supervised(self, _, sampling_edge_direction: Literal["in", "out"]): dblp_supervised_info = get_mocked_dataset_artifact_metadata()[ DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name ] @@ -842,7 +867,7 @@ def test_dblp_supervised(self): ) splitter = HashedNodeAnchorLinkSplitter( - sampling_direction="in", + sampling_direction=sampling_edge_direction, supervision_edge_types=supervision_edge_types, should_convert_labels_to_edges=True, ) @@ -850,7 +875,7 @@ def test_dblp_supervised(self): dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, distributed_context=self._context, - sample_edge_direction="in", + sample_edge_direction=sampling_edge_direction, _ssl_positive_label_percentage=0.1, splitter=splitter, ) @@ -863,17 +888,19 @@ def test_dblp_supervised(self): @parameterized.expand( [ param( - "Tensor-based partitioning, list fanout", + "Tensor-based partitioning, list fanout, inward edge direction", partitioner_class=DistPartitioner, fanout=[2, 2], + sampling_edge_direction="in", ), param( - "Range-based partitioning, list fanout", + "Range-based partitioning, list fanout, inward edge direction", partitioner_class=DistRangePartitioner, fanout=[2, 2], + sampling_edge_direction="in", ), param( - "Range-based partitioning, dict fanout", + "Range-based partitioning, dict fanout, inward edge direction", partitioner_class=DistRangePartitioner, fanout={ EdgeType(NodeType("user"), Relation("to"), NodeType("story")): [ @@ -885,6 +912,13 @@ def test_dblp_supervised(self): 2, ], }, + sampling_edge_direction="in", + ), + param( + "Range-based partitioning, list fanout, outward edge direction", + partitioner_class=DistRangePartitioner, + fanout=[2, 2], + sampling_edge_direction="out", ), ] ) @@ -893,6 +927,7 @@ def test_toy_heterogeneous_ablp( _, partitioner_class: type[DistPartitioner], fanout: Union[list[int], dict[EdgeType, list[int]]], + sampling_edge_direction: Literal["in", "out"], ): toy_heterogeneous_supervised_info = get_mocked_dataset_artifact_metadata()[ HETEROGENEOUS_TOY_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name @@ -915,7 +950,7 @@ def test_toy_heterogeneous_ablp( ) splitter = HashedNodeAnchorLinkSplitter( - sampling_direction="in", + sampling_direction=sampling_edge_direction, supervision_edge_types=supervision_edge_types, should_convert_labels_to_edges=True, ) @@ -923,7 +958,7 @@ def test_toy_heterogeneous_ablp( dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, distributed_context=self._context, - sample_edge_direction="in", + sample_edge_direction=sampling_edge_direction, _ssl_positive_label_percentage=0.1, splitter=splitter, partitioner_class=partitioner_class, diff --git a/python/tests/unit/distributed/distributed_partitioner_test.py b/python/tests/unit/distributed/distributed_partitioner_test.py index 18358259a..dd14fe6c0 100644 --- a/python/tests/unit/distributed/distributed_partitioner_test.py +++ b/python/tests/unit/distributed/distributed_partitioner_test.py @@ -788,7 +788,7 @@ def test_partitioning_correctness( rank=rank, is_heterogeneous=is_heterogeneous, output_node_partition_book=partition_output.node_partition_book, - should_assign_edges_by_src_node=True, + should_assign_edges_by_src_node=should_assign_edges_by_src_node, output_labeled_edge_index=partition_output.partitioned_positive_labels, expected_edge_types=MOCKED_HETEROGENEOUS_EDGE_TYPES, expected_pb_dtype=expected_pb_dtype, @@ -811,7 +811,7 @@ def test_partitioning_correctness( rank=rank, is_heterogeneous=is_heterogeneous, output_node_partition_book=partition_output.node_partition_book, - should_assign_edges_by_src_node=True, + should_assign_edges_by_src_node=should_assign_edges_by_src_node, output_labeled_edge_index=partition_output.partitioned_negative_labels, expected_edge_types=MOCKED_HETEROGENEOUS_EDGE_TYPES, expected_pb_dtype=expected_pb_dtype,