From bb1844bc33595287191cceb51f2b242183ec1d78 Mon Sep 17 00:00:00 2001 From: swong3 Date: Thu, 16 Oct 2025 19:04:51 +0000 Subject: [PATCH 01/16] Added DMP tests (single process) and simple fix to deal with awaitable object --- python/gigl/module/models.py | 413 ++++++++++++++++++++++++++++ python/tests/unit/nn/models_test.py | 1 + 2 files changed, 414 insertions(+) diff --git a/python/gigl/module/models.py b/python/gigl/module/models.py index afc55387a..6b65186b6 100644 --- a/python/gigl/module/models.py +++ b/python/gigl/module/models.py @@ -1,11 +1,424 @@ from gigl.common.logger import Logger from gigl.nn.models import LightGCN, LinkPredictionGNN +<<<<<<< HEAD __all__ = ["LinkPredictionGNN", "LightGCN"] +======= +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from torch_geometric.data import Data, HeteroData +from torch_geometric.nn.conv import LGConv +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.distributed.types import Awaitable +from typing_extensions import Self +>>>>>>> 0b19adb (Added DMP tests (single process) and simple fix to deal with awaitable object) logger = Logger() +<<<<<<< HEAD logger.warning( "gigl.module.models is deprecated and will be removed in a future release. " "Please use the `gigl.nn.models` module instead." ) +======= + +class LinkPredictionGNN(nn.Module): + """ + Link Prediction GNN model for both homogeneous and heterogeneous use cases + Args: + encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings + decoder (nn.Module): Decoder for transforming embeddings into scores. + Recommended to use `gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder` + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + ) -> None: + super().__init__() + self._encoder = encoder + self._decoder = decoder + + def forward( + self, + data: Union[Data, HeteroData], + device: torch.device, + output_node_types: Optional[list[NodeType]] = None, + ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: + if isinstance(data, HeteroData): + if output_node_types is None: + raise ValueError( + "Output node types must be specified in forward() pass for heterogeneous model" + ) + return self._encoder( + data=data, output_node_types=output_node_types, device=device + ) + else: + return self._encoder(data=data, device=device) + + def decode( + self, + query_embeddings: torch.Tensor, + candidate_embeddings: torch.Tensor, + ) -> torch.Tensor: + return self._decoder( + query_embeddings=query_embeddings, + candidate_embeddings=candidate_embeddings, + ) + + @property + def encoder(self) -> nn.Module: + return self._encoder + + @property + def decoder(self) -> nn.Module: + return self._decoder + + def to_ddp( + self, + device: Optional[torch.device], + find_unused_encoder_parameters: bool = False, + ) -> Self: + """ + Converts the model to DistributedDataParallel (DDP) mode. + + We do this because DDP does *not* expect the forward method of the modules it wraps to be called directly. + See how DistributedDataParallel.forward calls _pre_forward: + https://github.com/pytorch/pytorch/blob/26807dcf277feb2d99ab88d7b6da526488baea93/torch/nn/parallel/distributed.py#L1657 + If we do not do this, then calling forward() on the individual modules may not work correctly. + + Calling this function makes it safe to do: `LinkPredictionGNN.decoder(data, device)` + + Args: + device (Optional[torch.device]): The device to which the model should be moved. + If None, will default to CPU. + find_unused_encoder_parameters (bool): Whether to find unused parameters in the model. + This should be set to True if the model has parameters that are not used in the forward pass. + Returns: + LinkPredictionGNN: A new instance of LinkPredictionGNN for use with DDP. + """ + + if device is None: + device = torch.device("cpu") + ddp_encoder = DistributedDataParallel( + self._encoder.to(device), + device_ids=[device] if device.type != "cpu" else None, + find_unused_parameters=find_unused_encoder_parameters, + ) + # Do this "backwards" so the we can define "ddp_decoder" as a nn.Module first... + if not any(p.requires_grad for p in self._decoder.parameters()): + # If the decoder has no trainable parameters, we can just use it as is + ddp_decoder = self._decoder.to(device) + else: + # Only wrap the decoder in DDP if it has parameters that require gradients + # Otherwise DDP will complain about no parameters to train. + ddp_decoder = DistributedDataParallel( + self._decoder.to(device), + device_ids=[device] if device.type != "cpu" else None, + ) + self._encoder = ddp_encoder + self._decoder = ddp_decoder + return self + + def unwrap_from_ddp(self) -> "LinkPredictionGNN": + """ + Unwraps the model from DistributedDataParallel if it is wrapped. + + Returns: + LinkPredictionGNN: A new instance of LinkPredictionGNN with the original encoder and decoder. + """ + if isinstance(self._encoder, DistributedDataParallel): + encoder = self._encoder.module + else: + encoder = self._encoder + + if isinstance(self._decoder, DistributedDataParallel): + decoder = self._decoder.module + else: + decoder = self._decoder + + return LinkPredictionGNN(encoder=encoder, decoder=decoder) + + +# TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. +# TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific +# TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) +class LightGCN(nn.Module): + """ + LightGCN model with TorchRec integration for distributed ID embeddings. + + Reference: https://arxiv.org/pdf/2002.02126 + + This class extends the basic LightGCN implementation to use TorchRec's + distributed embedding tables for handling large-scale ID embeddings. + + Args: + node_type_to_num_nodes (Union[int, Dict[NodeType, int]]): Map from node types + to node counts. Can also pass a single int for homogeneous graphs. + embedding_dim (int): Dimension of node embeddings D. Default: 64. + num_layers (int): Number of LightGCN propagation layers K. Default: 2. + device (torch.device): Device to run the computation on. Default: CPU. + layer_weights (Optional[List[float]]): Weights for [e^(0), e^(1), ..., e^(K)]. + Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. + """ + + def __init__( + self, + node_type_to_num_nodes: Union[int, dict[NodeType, int]], + embedding_dim: int = 64, + num_layers: int = 2, + device: torch.device = torch.device("cpu"), + layer_weights: Optional[list[float]] = None, + ): + super().__init__() + + self._node_type_to_num_nodes = to_heterogeneous_node(node_type_to_num_nodes) + self._embedding_dim = embedding_dim + self._num_layers = num_layers + self._device = device + + # Construct LightGCN α weights: include e^(0) + K propagated layers ==> K+1 weights + if layer_weights is None: + layer_weights = [1.0 / (num_layers + 1)] * (num_layers + 1) + else: + if len(layer_weights) != (num_layers + 1): + raise ValueError( + f"layer_weights must have length K+1={num_layers+1}, got {len(layer_weights)}" + ) + + # Register layer weights as a buffer so it moves with the model to different devices + self.register_buffer( + "_layer_weights", + torch.tensor(layer_weights, dtype=torch.float32), + ) + + # Build TorchRec EBC (one table per node type) + # feature key naming convention: f"{node_type}_id" + self._feature_keys: list[str] = [ + f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys() + ] + tables: list[EmbeddingBagConfig] = [] + for node_type, num_nodes in self._node_type_to_num_nodes.items(): + tables.append( + EmbeddingBagConfig( + name=f"node_embedding_{node_type}", + embedding_dim=embedding_dim, + num_embeddings=num_nodes, + feature_names=[f"{node_type}_id"], + ) + ) + + self._embedding_bag_collection = EmbeddingBagCollection( + tables=tables, device=self._device + ) + + # Construct LightGCN propagation layers (LGConv = Ā X) + self._convs = nn.ModuleList( + [LGConv() for _ in range(self._num_layers)] + ) # K layers + + def forward( + self, + data: Union[Data, HeteroData], + device: torch.device, + output_node_types: Optional[list[NodeType]] = None, + anchor_node_ids: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: + """ + Forward pass of the LightGCN model. + + Args: + data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). + device (torch.device): Device to run the computation on. + output_node_types (Optional[List[NodeType]]): List of node types to return + embeddings for. Required for heterogeneous graphs. Default: None. + anchor_node_ids (Optional[torch.Tensor]): Local node indices to return + embeddings for. If None, returns embeddings for all nodes. Default: None. + + Returns: + Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. + For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. + For heterogeneous graphs, returns dict mapping node types to embeddings. + """ + if isinstance(data, HeteroData): + raise NotImplementedError("HeteroData is not yet supported for LightGCN") + output_node_types = output_node_types or list(data.node_types) + return self._forward_heterogeneous( + data, device, output_node_types, anchor_node_ids + ) + else: + return self._forward_homogeneous(data, device, anchor_node_ids) + + def _forward_homogeneous( + self, + data: Data, + device: torch.device, + anchor_node_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for homogeneous graphs using LightGCN propagation. + + Notation follows the LightGCN paper (https://arxiv.org/pdf/2002.02126): + - e^(0): Initial embeddings (no propagation) + - e^(k): Embeddings after k layers of graph convolution + - z: Final embedding = weighted sum of [e^(0), e^(1), ..., e^(K)] + + Variable naming: + - embeddings_0: Initial embeddings e^(0) for subgraph nodes + - embeddings_k: Current layer embeddings during propagation + - all_layer_embeddings: List containing [e^(0), e^(1), ..., e^(K)] + - final_embeddings: Final node embeddings (weighted sum) + + Args: + data (Data): PyG Data object containing edge_index and node IDs. + device (torch.device): Device to run computation on. + anchor_node_ids (Optional[torch.Tensor]): Local node indices to return + embeddings for. If None, returns embeddings for all nodes. Default: None. + + Returns: + torch.Tensor: Tensor of shape [num_nodes, embedding_dim] containing + final LightGCN embeddings. + """ + # Check if model is setup to be homogeneous + assert len(self._feature_keys) == 1, ( + f"Homogeneous path expects exactly one node type; got " + f"{len(self._feature_keys)} types: {self._feature_keys}" + ) + key = self._feature_keys[0] + edge_index = data.edge_index.to( + device + ) # shape [2, E], where E is the number of edges + + assert hasattr( + data, "node" + ), "Subgraph must include .node to map local→global IDs." + global_ids = data.node.to( + device + ).long() # shape [N_sub], maps local 0..N_sub-1 → global ids + + embeddings_0 = self._lookup_embeddings_for_single_node_type( + key, global_ids + ) # shape [N_sub, D], where N_sub is number of nodes in subgraph and D is embedding_dim + + # When using DMP, EmbeddingBagCollection returns Awaitable that needs to be resolved + if isinstance(embeddings_0, Awaitable): + embeddings_0 = embeddings_0.wait() + + all_layer_embeddings: list[torch.Tensor] = [embeddings_0] + embeddings_k = embeddings_0 + + for conv in self._convs: + embeddings_k = conv( + embeddings_k, edge_index + ) # shape [N_sub, D], normalized neighbor averaging over *subgraph* edges + all_layer_embeddings.append(embeddings_k) + + final_embeddings = self._weighted_layer_sum( + all_layer_embeddings + ) # shape [N_sub, D], weighted sum of all layer embeddings + + # If anchor node ids are provided, return the embeddings for the anchor nodes only + if anchor_node_ids is not None: + anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] + return final_embeddings[ + anchors_local + ] # shape [num_anchors, D], embeddings for anchor nodes only + + # Otherwise, return the embeddings for all nodes in the subgraph + return ( + final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph + ) + + def _lookup_embeddings_for_single_node_type( + self, node_type: str, ids: torch.Tensor + ) -> torch.Tensor: + """ + Fetch per-ID embeddings for a single node type using EmbeddingBagCollection. + + This method constructs a KeyedJaggedTensor (KJT) that includes all EBC feature + keys to ensure consistent forward pass behavior. For the requested node type, + we create B bags of length 1 (one per ID). For all other node types, we create + B bags of length 0. With SUM pooling, non-requested node types contribute zeros + and the requested node type acts as identity lookup. + + Args: + node_type (str): Feature key for the node type (e.g., "user_id", "item_id"). + ids (torch.Tensor): Node IDs to look up, shape [batch_size]. + + Returns: + torch.Tensor: Embeddings for the requested node type, shape [batch_size, embedding_dim]. + """ + if node_type not in self._feature_keys: + raise KeyError( + f"Unknown feature key '{node_type}'. Valid keys: {self._feature_keys}" + ) + + # Number of examples (one ID per "bag") + batch_size = int(ids.numel()) # B is the number of node IDs to lookup + device = ids.device + + # Build lengths in key-major order: for each key, we give B lengths. + # - requested key: ones (each example has 1 id) + # - other keys: zeros (each example has 0 ids) + lengths_per_key: list[torch.Tensor] = [] + for nt in self._feature_keys: + if nt == node_type: + lengths_per_key.append( + torch.ones(batch_size, dtype=torch.long, device=device) + ) # shape [B], all ones for requested key + else: + lengths_per_key.append( + torch.zeros(batch_size, dtype=torch.long, device=device) + ) # shape [B], all zeros for other keys + + lengths = torch.cat( + lengths_per_key, dim=0 + ) # shape [batch_size * num_keys], concatenated lengths for all keys + + # Values only contain the requested key's ids (sum of other lengths is 0) + kjt = KeyedJaggedTensor( + keys=self._feature_keys, # include ALL keys known by EBC + values=ids.long(), # shape [batch_size], only batch_size values for the requested key + lengths=lengths, # shape [batch_size * num_keys], batch_size lengths per key, concatenated key-major + ) + + out = self._embedding_bag_collection( + kjt + ) # KeyedTensor (dict-like): out[key] -> [batch_size, D] + return out[node_type] # shape [batch_size, D], embeddings for the requested key + + def _weighted_layer_sum( + self, all_layer_embeddings: list[torch.Tensor] + ) -> torch.Tensor: + """ + Computes weighted sum: w_0 * e^(0) + w_1 * e^(1) + ... + w_K * e^(K). + + This implements the final aggregation step in LightGCN where embeddings from + all layers (including the initial e^(0)) are combined using learned weights. + + Args: + all_layer_embeddings (List[torch.Tensor]): List [e^(0), e^(1), ..., e^(K)] + where each tensor has shape [N, D]. + + Returns: + torch.Tensor: Weighted sum of all layer embeddings, shape [N, D]. + """ + if len(all_layer_embeddings) != len(self._layer_weights): + raise ValueError( + f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." + ) + + # Stack all layer embeddings and compute weighted sum + # _layer_weights is already a tensor buffer registered in __init__ + stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D] + w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device + out = (stacked * w.view(-1, 1, 1)).sum( + dim=0 + ) # shape [N, D], w_0*X_0 + w_1*X_1 + ... + + return out +>>>>>>> 0b19adb (Added DMP tests (single process) and simple fix to deal with awaitable object) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index b201876d9..202413754 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -5,6 +5,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +import torch.distributed as dist from torch_geometric.data import Data, HeteroData from torch_geometric.nn.models import LightGCN as PyGLightGCN from torchrec.distributed.model_parallel import ( From ed6a2cd7a9ba8821c200f8428663448855f8f3c6 Mon Sep 17 00:00:00 2001 From: swong3 Date: Fri, 17 Oct 2025 23:29:28 +0000 Subject: [PATCH 02/16] Minor changes to DMP testing --- python/gigl/module/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/gigl/module/models.py b/python/gigl/module/models.py index 6b65186b6..994b62ff8 100644 --- a/python/gigl/module/models.py +++ b/python/gigl/module/models.py @@ -147,6 +147,7 @@ def unwrap_from_ddp(self) -> "LinkPredictionGNN": # TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. # TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific # TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) +# TODO(swong3): Look into why we have to call wait() on the Awaitable returned by the embedding bag collection. class LightGCN(nn.Module): """ LightGCN model with TorchRec integration for distributed ID embeddings. From 39e1d6d4dd2a55ac99ae2a9e329627359050a221 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 20 Oct 2025 20:30:20 +0000 Subject: [PATCH 03/16] Added larger world size test --- python/gigl/module/models.py | 414 ---------------------------- python/tests/unit/nn/models_test.py | 1 + 2 files changed, 1 insertion(+), 414 deletions(-) diff --git a/python/gigl/module/models.py b/python/gigl/module/models.py index 994b62ff8..afc55387a 100644 --- a/python/gigl/module/models.py +++ b/python/gigl/module/models.py @@ -1,425 +1,11 @@ from gigl.common.logger import Logger from gigl.nn.models import LightGCN, LinkPredictionGNN -<<<<<<< HEAD __all__ = ["LinkPredictionGNN", "LightGCN"] -======= -import torch -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel -from torch_geometric.data import Data, HeteroData -from torch_geometric.nn.conv import LGConv -from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.modules.embedding_modules import EmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.distributed.types import Awaitable -from typing_extensions import Self ->>>>>>> 0b19adb (Added DMP tests (single process) and simple fix to deal with awaitable object) logger = Logger() -<<<<<<< HEAD logger.warning( "gigl.module.models is deprecated and will be removed in a future release. " "Please use the `gigl.nn.models` module instead." ) -======= - -class LinkPredictionGNN(nn.Module): - """ - Link Prediction GNN model for both homogeneous and heterogeneous use cases - Args: - encoder (nn.Module): Either BasicGNN or Heterogeneous GNN for generating embeddings - decoder (nn.Module): Decoder for transforming embeddings into scores. - Recommended to use `gigl.src.common.models.pyg.link_prediction.LinkPredictionDecoder` - """ - - def __init__( - self, - encoder: nn.Module, - decoder: nn.Module, - ) -> None: - super().__init__() - self._encoder = encoder - self._decoder = decoder - - def forward( - self, - data: Union[Data, HeteroData], - device: torch.device, - output_node_types: Optional[list[NodeType]] = None, - ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: - if isinstance(data, HeteroData): - if output_node_types is None: - raise ValueError( - "Output node types must be specified in forward() pass for heterogeneous model" - ) - return self._encoder( - data=data, output_node_types=output_node_types, device=device - ) - else: - return self._encoder(data=data, device=device) - - def decode( - self, - query_embeddings: torch.Tensor, - candidate_embeddings: torch.Tensor, - ) -> torch.Tensor: - return self._decoder( - query_embeddings=query_embeddings, - candidate_embeddings=candidate_embeddings, - ) - - @property - def encoder(self) -> nn.Module: - return self._encoder - - @property - def decoder(self) -> nn.Module: - return self._decoder - - def to_ddp( - self, - device: Optional[torch.device], - find_unused_encoder_parameters: bool = False, - ) -> Self: - """ - Converts the model to DistributedDataParallel (DDP) mode. - - We do this because DDP does *not* expect the forward method of the modules it wraps to be called directly. - See how DistributedDataParallel.forward calls _pre_forward: - https://github.com/pytorch/pytorch/blob/26807dcf277feb2d99ab88d7b6da526488baea93/torch/nn/parallel/distributed.py#L1657 - If we do not do this, then calling forward() on the individual modules may not work correctly. - - Calling this function makes it safe to do: `LinkPredictionGNN.decoder(data, device)` - - Args: - device (Optional[torch.device]): The device to which the model should be moved. - If None, will default to CPU. - find_unused_encoder_parameters (bool): Whether to find unused parameters in the model. - This should be set to True if the model has parameters that are not used in the forward pass. - Returns: - LinkPredictionGNN: A new instance of LinkPredictionGNN for use with DDP. - """ - - if device is None: - device = torch.device("cpu") - ddp_encoder = DistributedDataParallel( - self._encoder.to(device), - device_ids=[device] if device.type != "cpu" else None, - find_unused_parameters=find_unused_encoder_parameters, - ) - # Do this "backwards" so the we can define "ddp_decoder" as a nn.Module first... - if not any(p.requires_grad for p in self._decoder.parameters()): - # If the decoder has no trainable parameters, we can just use it as is - ddp_decoder = self._decoder.to(device) - else: - # Only wrap the decoder in DDP if it has parameters that require gradients - # Otherwise DDP will complain about no parameters to train. - ddp_decoder = DistributedDataParallel( - self._decoder.to(device), - device_ids=[device] if device.type != "cpu" else None, - ) - self._encoder = ddp_encoder - self._decoder = ddp_decoder - return self - - def unwrap_from_ddp(self) -> "LinkPredictionGNN": - """ - Unwraps the model from DistributedDataParallel if it is wrapped. - - Returns: - LinkPredictionGNN: A new instance of LinkPredictionGNN with the original encoder and decoder. - """ - if isinstance(self._encoder, DistributedDataParallel): - encoder = self._encoder.module - else: - encoder = self._encoder - - if isinstance(self._decoder, DistributedDataParallel): - decoder = self._decoder.module - else: - decoder = self._decoder - - return LinkPredictionGNN(encoder=encoder, decoder=decoder) - - -# TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. -# TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific -# TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) -# TODO(swong3): Look into why we have to call wait() on the Awaitable returned by the embedding bag collection. -class LightGCN(nn.Module): - """ - LightGCN model with TorchRec integration for distributed ID embeddings. - - Reference: https://arxiv.org/pdf/2002.02126 - - This class extends the basic LightGCN implementation to use TorchRec's - distributed embedding tables for handling large-scale ID embeddings. - - Args: - node_type_to_num_nodes (Union[int, Dict[NodeType, int]]): Map from node types - to node counts. Can also pass a single int for homogeneous graphs. - embedding_dim (int): Dimension of node embeddings D. Default: 64. - num_layers (int): Number of LightGCN propagation layers K. Default: 2. - device (torch.device): Device to run the computation on. Default: CPU. - layer_weights (Optional[List[float]]): Weights for [e^(0), e^(1), ..., e^(K)]. - Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. - """ - - def __init__( - self, - node_type_to_num_nodes: Union[int, dict[NodeType, int]], - embedding_dim: int = 64, - num_layers: int = 2, - device: torch.device = torch.device("cpu"), - layer_weights: Optional[list[float]] = None, - ): - super().__init__() - - self._node_type_to_num_nodes = to_heterogeneous_node(node_type_to_num_nodes) - self._embedding_dim = embedding_dim - self._num_layers = num_layers - self._device = device - - # Construct LightGCN α weights: include e^(0) + K propagated layers ==> K+1 weights - if layer_weights is None: - layer_weights = [1.0 / (num_layers + 1)] * (num_layers + 1) - else: - if len(layer_weights) != (num_layers + 1): - raise ValueError( - f"layer_weights must have length K+1={num_layers+1}, got {len(layer_weights)}" - ) - - # Register layer weights as a buffer so it moves with the model to different devices - self.register_buffer( - "_layer_weights", - torch.tensor(layer_weights, dtype=torch.float32), - ) - - # Build TorchRec EBC (one table per node type) - # feature key naming convention: f"{node_type}_id" - self._feature_keys: list[str] = [ - f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys() - ] - tables: list[EmbeddingBagConfig] = [] - for node_type, num_nodes in self._node_type_to_num_nodes.items(): - tables.append( - EmbeddingBagConfig( - name=f"node_embedding_{node_type}", - embedding_dim=embedding_dim, - num_embeddings=num_nodes, - feature_names=[f"{node_type}_id"], - ) - ) - - self._embedding_bag_collection = EmbeddingBagCollection( - tables=tables, device=self._device - ) - - # Construct LightGCN propagation layers (LGConv = Ā X) - self._convs = nn.ModuleList( - [LGConv() for _ in range(self._num_layers)] - ) # K layers - - def forward( - self, - data: Union[Data, HeteroData], - device: torch.device, - output_node_types: Optional[list[NodeType]] = None, - anchor_node_ids: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: - """ - Forward pass of the LightGCN model. - - Args: - data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). - device (torch.device): Device to run the computation on. - output_node_types (Optional[List[NodeType]]): List of node types to return - embeddings for. Required for heterogeneous graphs. Default: None. - anchor_node_ids (Optional[torch.Tensor]): Local node indices to return - embeddings for. If None, returns embeddings for all nodes. Default: None. - - Returns: - Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. - For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. - For heterogeneous graphs, returns dict mapping node types to embeddings. - """ - if isinstance(data, HeteroData): - raise NotImplementedError("HeteroData is not yet supported for LightGCN") - output_node_types = output_node_types or list(data.node_types) - return self._forward_heterogeneous( - data, device, output_node_types, anchor_node_ids - ) - else: - return self._forward_homogeneous(data, device, anchor_node_ids) - - def _forward_homogeneous( - self, - data: Data, - device: torch.device, - anchor_node_ids: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass for homogeneous graphs using LightGCN propagation. - - Notation follows the LightGCN paper (https://arxiv.org/pdf/2002.02126): - - e^(0): Initial embeddings (no propagation) - - e^(k): Embeddings after k layers of graph convolution - - z: Final embedding = weighted sum of [e^(0), e^(1), ..., e^(K)] - - Variable naming: - - embeddings_0: Initial embeddings e^(0) for subgraph nodes - - embeddings_k: Current layer embeddings during propagation - - all_layer_embeddings: List containing [e^(0), e^(1), ..., e^(K)] - - final_embeddings: Final node embeddings (weighted sum) - - Args: - data (Data): PyG Data object containing edge_index and node IDs. - device (torch.device): Device to run computation on. - anchor_node_ids (Optional[torch.Tensor]): Local node indices to return - embeddings for. If None, returns embeddings for all nodes. Default: None. - - Returns: - torch.Tensor: Tensor of shape [num_nodes, embedding_dim] containing - final LightGCN embeddings. - """ - # Check if model is setup to be homogeneous - assert len(self._feature_keys) == 1, ( - f"Homogeneous path expects exactly one node type; got " - f"{len(self._feature_keys)} types: {self._feature_keys}" - ) - key = self._feature_keys[0] - edge_index = data.edge_index.to( - device - ) # shape [2, E], where E is the number of edges - - assert hasattr( - data, "node" - ), "Subgraph must include .node to map local→global IDs." - global_ids = data.node.to( - device - ).long() # shape [N_sub], maps local 0..N_sub-1 → global ids - - embeddings_0 = self._lookup_embeddings_for_single_node_type( - key, global_ids - ) # shape [N_sub, D], where N_sub is number of nodes in subgraph and D is embedding_dim - - # When using DMP, EmbeddingBagCollection returns Awaitable that needs to be resolved - if isinstance(embeddings_0, Awaitable): - embeddings_0 = embeddings_0.wait() - - all_layer_embeddings: list[torch.Tensor] = [embeddings_0] - embeddings_k = embeddings_0 - - for conv in self._convs: - embeddings_k = conv( - embeddings_k, edge_index - ) # shape [N_sub, D], normalized neighbor averaging over *subgraph* edges - all_layer_embeddings.append(embeddings_k) - - final_embeddings = self._weighted_layer_sum( - all_layer_embeddings - ) # shape [N_sub, D], weighted sum of all layer embeddings - - # If anchor node ids are provided, return the embeddings for the anchor nodes only - if anchor_node_ids is not None: - anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] - return final_embeddings[ - anchors_local - ] # shape [num_anchors, D], embeddings for anchor nodes only - - # Otherwise, return the embeddings for all nodes in the subgraph - return ( - final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph - ) - - def _lookup_embeddings_for_single_node_type( - self, node_type: str, ids: torch.Tensor - ) -> torch.Tensor: - """ - Fetch per-ID embeddings for a single node type using EmbeddingBagCollection. - - This method constructs a KeyedJaggedTensor (KJT) that includes all EBC feature - keys to ensure consistent forward pass behavior. For the requested node type, - we create B bags of length 1 (one per ID). For all other node types, we create - B bags of length 0. With SUM pooling, non-requested node types contribute zeros - and the requested node type acts as identity lookup. - - Args: - node_type (str): Feature key for the node type (e.g., "user_id", "item_id"). - ids (torch.Tensor): Node IDs to look up, shape [batch_size]. - - Returns: - torch.Tensor: Embeddings for the requested node type, shape [batch_size, embedding_dim]. - """ - if node_type not in self._feature_keys: - raise KeyError( - f"Unknown feature key '{node_type}'. Valid keys: {self._feature_keys}" - ) - - # Number of examples (one ID per "bag") - batch_size = int(ids.numel()) # B is the number of node IDs to lookup - device = ids.device - - # Build lengths in key-major order: for each key, we give B lengths. - # - requested key: ones (each example has 1 id) - # - other keys: zeros (each example has 0 ids) - lengths_per_key: list[torch.Tensor] = [] - for nt in self._feature_keys: - if nt == node_type: - lengths_per_key.append( - torch.ones(batch_size, dtype=torch.long, device=device) - ) # shape [B], all ones for requested key - else: - lengths_per_key.append( - torch.zeros(batch_size, dtype=torch.long, device=device) - ) # shape [B], all zeros for other keys - - lengths = torch.cat( - lengths_per_key, dim=0 - ) # shape [batch_size * num_keys], concatenated lengths for all keys - - # Values only contain the requested key's ids (sum of other lengths is 0) - kjt = KeyedJaggedTensor( - keys=self._feature_keys, # include ALL keys known by EBC - values=ids.long(), # shape [batch_size], only batch_size values for the requested key - lengths=lengths, # shape [batch_size * num_keys], batch_size lengths per key, concatenated key-major - ) - - out = self._embedding_bag_collection( - kjt - ) # KeyedTensor (dict-like): out[key] -> [batch_size, D] - return out[node_type] # shape [batch_size, D], embeddings for the requested key - - def _weighted_layer_sum( - self, all_layer_embeddings: list[torch.Tensor] - ) -> torch.Tensor: - """ - Computes weighted sum: w_0 * e^(0) + w_1 * e^(1) + ... + w_K * e^(K). - - This implements the final aggregation step in LightGCN where embeddings from - all layers (including the initial e^(0)) are combined using learned weights. - - Args: - all_layer_embeddings (List[torch.Tensor]): List [e^(0), e^(1), ..., e^(K)] - where each tensor has shape [N, D]. - - Returns: - torch.Tensor: Weighted sum of all layer embeddings, shape [N, D]. - """ - if len(all_layer_embeddings) != len(self._layer_weights): - raise ValueError( - f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." - ) - - # Stack all layer embeddings and compute weighted sum - # _layer_weights is already a tensor buffer registered in __init__ - stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D] - w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device - out = (stacked * w.view(-1, 1, 1)).sum( - dim=0 - ) # shape [N, D], w_0*X_0 + w_1*X_1 + ... - - return out ->>>>>>> 0b19adb (Added DMP tests (single process) and simple fix to deal with awaitable object) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 202413754..64d633022 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -6,6 +6,7 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.distributed as dist +import torch.multiprocessing as mp from torch_geometric.data import Data, HeteroData from torch_geometric.nn.models import LightGCN as PyGLightGCN from torchrec.distributed.model_parallel import ( From 50e1f475450c43faafe7ae103dc04026b8dab98b Mon Sep 17 00:00:00 2001 From: swong3 Date: Fri, 31 Oct 2025 07:22:38 +0000 Subject: [PATCH 04/16] Bipartite support for lightgcn --- python/tests/unit/nn/models_test.py | 105 ++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 64d633022..7c1b71f22 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -328,6 +328,111 @@ def test_dmp_multiprocess(self): nprocs=world_size, ) + def test_compare_bipartite_with_math(self): + """Test that bipartite implementation matches the mathematical formulation of LightGCN. + + This test converts the homogeneous 4-node graph into a bipartite graph to verify + that the bipartite implementation produces identical results. The same initial + embeddings and edge structure are used, just split by node type. + + Graph structure: + - Homogeneous: nodes [0, 1, 2, 3] with edges 0->2, 0->3, 1->3, 2->0, 3->0, 3->1 + - Bipartite: users [0, 1] and items [0, 1] with equivalent cross-type edges + + Expected behavior: Bipartite embeddings should match the homogeneous embeddings + for the corresponding nodes (user 0 = node 0, user 1 = node 1, etc.) + """ + # Create bipartite graph + num_users = 2 + num_items = 2 + + node_type_to_num_nodes = { + NodeType("user"): num_users, + NodeType("item"): num_items, + } + + model = self._create_lightgcn_model(node_type_to_num_nodes) + + # Use same embeddings as homogeneous test, split by node type + user_embeddings = torch.tensor( + [ + [0.2, 0.5, 0.1, 0.4], # User 0 (was Node 0) + [0.6, 0.1, 0.2, 0.5], # User 1 (was Node 1) + ], + dtype=torch.float32, + ) + + item_embeddings = torch.tensor( + [ + [0.9, 0.4, 0.1, 0.4], # Item 0 (was Node 2) + [0.3, 0.8, 0.3, 0.6], # Item 1 (was Node 3) + ], + dtype=torch.float32, + ) + + with torch.no_grad(): + user_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_user" + ] + user_table.weight[:] = user_embeddings + + item_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_item" + ] + item_table.weight[:] = item_embeddings + + data = HeteroData() + + # User nodes (local IDs 0, 1 map to global IDs 0, 1) + data["user"].node = torch.tensor([0, 1], dtype=torch.long) + data["user"].num_nodes = num_users + + # Item nodes (local IDs 0, 1 map to global IDs 0, 1) + data["item"].node = torch.tensor([0, 1], dtype=torch.long) + data["item"].num_nodes = num_items + + # User to item edges (converting from original homogeneous edges) + # Original: 0->2, 0->3, 1->3 becomes user 0->item 0, user 0->item 1, user 1->item 1 + data["user", "to", "item"].edge_index = torch.tensor( + [[0, 0, 1], [0, 1, 1]], dtype=torch.long + ) + + # Item to user edges (reverse direction) + # Original: 2->0, 3->0, 3->1 becomes item 0->user 0, item 1->user 0, item 1->user 1 + data["item", "to", "user"].edge_index = torch.tensor( + [[0, 1, 1], [0, 0, 1]], dtype=torch.long + ) + + # Forward pass + output = model( + data, + self.device, + output_node_types=[NodeType("user"), NodeType("item")], + ) + + expected_user_embeddings = torch.tensor( + [ + [0.4495, 0.5311, 0.1555, 0.4865], # User 0 + [0.3943, 0.2975, 0.1825, 0.4386], # User 1 + ], + dtype=torch.float32, + ) + + expected_item_embeddings = torch.tensor( + [ + [0.5325, 0.4121, 0.1089, 0.3650], # Item 0 + [0.4558, 0.6207, 0.2506, 0.5817], # Item 1 + ], + dtype=torch.float32, + ) + + # Check that bipartite output matches expected + self.assertTrue( + torch.allclose(output[NodeType("user")], expected_user_embeddings, atol=1e-4, rtol=1e-4) + ) + self.assertTrue( + torch.allclose(output[NodeType("item")], expected_item_embeddings, atol=1e-4, rtol=1e-4) + ) def _run_dmp_multiprocess_test( rank: int, From 1cdc76c86bc68049651a66f85a473ca45199a9db Mon Sep 17 00:00:00 2001 From: swong3 Date: Sat, 8 Nov 2025 07:01:22 +0000 Subject: [PATCH 05/16] rebase --- python/gigl/nn/models.py | 169 +++++++++++++++++++++++++++++++++++---- 1 file changed, 152 insertions(+), 17 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 9fa29f62e..29f1028ed 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -216,30 +216,40 @@ def forward( data: Union[Data, HeteroData], device: torch.device, output_node_types: Optional[list[NodeType]] = None, - anchor_node_ids: Optional[torch.Tensor] = None, + anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ Forward pass of the LightGCN model. Args: - data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). + data (Union[Data, HeteroData]): Graph data. + - For homogeneous: Data object with edge_index and node field + - For bipartite: HeteroData with 2 node types and edge_index_dict device (torch.device): Device to run the computation on. - output_node_types (Optional[List[NodeType]]): List of node types to return - embeddings for. Required for heterogeneous graphs. Default: None. - anchor_node_ids (Optional[torch.Tensor]): Local node indices to return - embeddings for. If None, returns embeddings for all nodes. Default: None. + output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. + Required for bipartite graphs. If None, returns embeddings for all node types. Default: None. + anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): + Local node indices to return embeddings for. + - For homogeneous: torch.Tensor of shape [num_anchors] + - For bipartite: dict mapping node types to anchor tensors + If None, returns embeddings for all nodes. Default: None. Returns: Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. - For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. - For heterogeneous graphs, returns dict mapping node types to embeddings. + - For homogeneous: tensor of shape [num_nodes, embedding_dim] + - For bipartite: dict mapping node types to embeddings """ - if isinstance(data, HeteroData): - raise NotImplementedError("HeteroData is not yet supported for LightGCN") - output_node_types = output_node_types or list(data.node_types) - return self._forward_heterogeneous( - data, device, output_node_types, anchor_node_ids - ) + is_bipartite = isinstance(data, HeteroData) + + # Validate model configuration + num_node_types = len(self._feature_keys) + assert num_node_types in [1, 2], ( + f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " + f"got {num_node_types} node types: {self._feature_keys}" + ) + + if is_bipartite: + return self._forward_bipartite(data, device, output_node_types, anchor_node_ids) else: return self._forward_homogeneous(data, device, anchor_node_ids) @@ -247,7 +257,7 @@ def _forward_homogeneous( self, data: Data, device: torch.device, - anchor_node_ids: Optional[torch.Tensor] = None, + anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, ) -> torch.Tensor: """ Forward pass for homogeneous graphs using LightGCN propagation. @@ -266,7 +276,7 @@ def _forward_homogeneous( Args: data (Data): PyG Data object containing edge_index and node IDs. device (torch.device): Device to run computation on. - anchor_node_ids (Optional[torch.Tensor]): Local node indices to return + anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Local node indices to return embeddings for. If None, returns embeddings for all nodes. Default: None. Returns: @@ -313,7 +323,10 @@ def _forward_homogeneous( # If anchor node ids are provided, return the embeddings for the anchor nodes only if anchor_node_ids is not None: - anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] + if isinstance(anchor_node_ids, torch.Tensor): + anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] + else: + anchors_local = anchor_node_ids[NodeType(key)].to(device).long() # shape [num_anchors] return final_embeddings[ anchors_local ] # shape [num_anchors, D], embeddings for anchor nodes only @@ -323,6 +336,128 @@ def _forward_homogeneous( final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph ) + def _forward_bipartite( + self, + data: HeteroData, + device: torch.device, + output_node_types: Optional[list[NodeType]] = None, + anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, + ) -> dict[NodeType, torch.Tensor]: + """ + Forward pass for bipartite graphs using LightGCN propagation. + + For bipartite graphs (e.g., user-item), we have two node types and edges between them. + LightGCN propagates embeddings across both node types through the bipartite structure. + + Args: + data (HeteroData): PyG HeteroData object with 2 node types. + device (torch.device): Device to run computation on. + output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. + If None, returns all node types. Default: None. + anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Dict mapping node types + to local anchor indices. If None, returns all nodes. Default: None. + + Returns: + Dict[NodeType, torch.Tensor]: Dict mapping node types to their embeddings, + each of shape [num_nodes_of_type, embedding_dim]. + """ + # Determine which node types to process + if output_node_types is None: + output_node_types = [NodeType(str(nt)) for nt in data.node_types] + + # Lookup initial embeddings e^(0) for each node type + node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} + + for node_type in output_node_types: + node_type_str = str(node_type) + key = f"{node_type_str}_id" + + assert hasattr(data[node_type_str], "node"), ( + f"Subgraph must include .node field for node type {node_type_str}" + ) + + global_ids = data[node_type_str].node.to(device).long() # shape [N_type] + + embeddings = self._lookup_embeddings_for_single_node_type( + key, global_ids + ) # shape [N_type, D] + + # Handle DMP Awaitable + if isinstance(embeddings, Awaitable): + embeddings = embeddings.wait() + + node_type_to_embeddings_0[node_type] = embeddings + + # LightGCN propagation across node types + all_node_types = list(node_type_to_embeddings_0.keys()) + + # For bipartite, we need to create a unified edge representation + # Collect all edges and map node indices to a combined space + # Node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) + node_type_to_offset = {} + offset = 0 + for node_type in all_node_types: + node_type_to_offset[node_type] = offset + node_type_str = str(node_type) + offset += data[node_type_str].num_nodes + + # Combine all embeddings into a single tensor + combined_embeddings_0 = torch.cat( + [node_type_to_embeddings_0[nt] for nt in all_node_types], dim=0 + ) # shape [total_nodes, D] + + # Combine all edges into a single edge_index + combined_edge_list = [] + for edge_type_tuple in data.edge_types: + src_nt_str, _, dst_nt_str = edge_type_tuple + src_node_type = NodeType(src_nt_str) + dst_node_type = NodeType(dst_nt_str) + + edge_index = data[edge_type_tuple].edge_index.to(device) # shape [2, E] + + # Offset the indices to the combined node space + src_offset = node_type_to_offset[src_node_type] + dst_offset = node_type_to_offset[dst_node_type] + + offset_edge_index = edge_index.clone() + offset_edge_index[0] += src_offset + offset_edge_index[1] += dst_offset + + combined_edge_list.append(offset_edge_index) + + combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges] + + # Track all layer embeddings + all_layer_embeddings = [combined_embeddings_0] + current_embeddings = combined_embeddings_0 + + # Perform K layers of propagation + for conv in self._convs: + current_embeddings = conv(current_embeddings, combined_edge_index) # shape [total_nodes, D] + all_layer_embeddings.append(current_embeddings) + + # Weighted sum across layers + combined_final_embeddings = self._weighted_layer_sum(all_layer_embeddings) # shape [total_nodes, D] + + # Split back into per-node-type embeddings + final_embeddings: dict[NodeType, torch.Tensor] = {} + for node_type in all_node_types: + start_idx = node_type_to_offset[node_type] + node_type_str = str(node_type) + num_nodes = data[node_type_str].num_nodes + end_idx = start_idx + num_nodes + + final_embeddings[node_type] = combined_final_embeddings[start_idx:end_idx] # shape [num_nodes, D] + + # Extract anchor nodes if specified + if anchor_node_ids is not None: + for node_type in all_node_types: + if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids: + anchors = anchor_node_ids[node_type].to(device).long() + final_embeddings[node_type] = final_embeddings[node_type][anchors] + + return final_embeddings + def _lookup_embeddings_for_single_node_type( self, node_type: str, ids: torch.Tensor ) -> torch.Tensor: From 117925a4ac5b2e578cb261ac979401d743f89963 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 05:59:56 +0000 Subject: [PATCH 06/16] Removed extraneous imports --- python/tests/unit/nn/models_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 7c1b71f22..408ac592c 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -5,8 +5,6 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -import torch.distributed as dist -import torch.multiprocessing as mp from torch_geometric.data import Data, HeteroData from torch_geometric.nn.models import LightGCN as PyGLightGCN from torchrec.distributed.model_parallel import ( From 2928a9b96f707713c47cf7918c91ca64aee665d3 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 08:31:00 +0000 Subject: [PATCH 07/16] PR changes --- python/gigl/nn/models.py | 48 +++++++++++++++++++++-------- python/tests/unit/nn/models_test.py | 35 +++++++++++++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 29f1028ed..6f446a552 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -156,6 +156,20 @@ class LightGCN(nn.Module): Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. """ + @staticmethod + def _get_feature_key(node_type: Union[str, NodeType]) -> str: + """ + Get the feature key for a node type's embedding table. + + Args: + node_type: Node type as string or NodeType object. + + Returns: + str: Feature key in format "{node_type}_id" + """ + print("IM HERE") + return f"{node_type}_id" + def __init__( self, node_type_to_num_nodes: Union[int, dict[NodeType, int]], @@ -187,9 +201,8 @@ def __init__( ) # Build TorchRec EBC (one table per node type) - # feature key naming convention: f"{node_type}_id" self._feature_keys: list[str] = [ - f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys() + self._get_feature_key(node_type) for node_type in self._node_type_to_num_nodes.keys() ] tables: list[EmbeddingBagConfig] = [] for node_type, num_nodes in self._node_type_to_num_nodes.items(): @@ -198,7 +211,7 @@ def __init__( name=f"node_embedding_{node_type}", embedding_dim=embedding_dim, num_embeddings=num_nodes, - feature_names=[f"{node_type}_id"], + feature_names=[self._get_feature_key(node_type)], ) ) @@ -249,15 +262,27 @@ def forward( ) if is_bipartite: + # For bipartite graphs, anchor_node_ids must be a dict, not a Tensor + if anchor_node_ids is not None and not isinstance(anchor_node_ids, dict): + raise TypeError( + f"For bipartite graphs, anchor_node_ids must be a dict or None, " + f"got {type(anchor_node_ids)}" + ) return self._forward_bipartite(data, device, output_node_types, anchor_node_ids) else: + # For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict + if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor): + raise TypeError( + f"For homogeneous graphs, anchor_node_ids must be a Tensor or None, " + f"got {type(anchor_node_ids)}" + ) return self._forward_homogeneous(data, device, anchor_node_ids) def _forward_homogeneous( self, data: Data, device: torch.device, - anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, + anchor_node_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass for homogeneous graphs using LightGCN propagation. @@ -276,7 +301,7 @@ def _forward_homogeneous( Args: data (Data): PyG Data object containing edge_index and node IDs. device (torch.device): Device to run computation on. - anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Local node indices to return + anchor_node_ids (Optional[torch.Tensor]): Local node indices to return embeddings for. If None, returns embeddings for all nodes. Default: None. Returns: @@ -323,10 +348,7 @@ def _forward_homogeneous( # If anchor node ids are provided, return the embeddings for the anchor nodes only if anchor_node_ids is not None: - if isinstance(anchor_node_ids, torch.Tensor): - anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] - else: - anchors_local = anchor_node_ids[NodeType(key)].to(device).long() # shape [num_anchors] + anchors_local = anchor_node_ids.to(device).long() # shape [num_anchors] return final_embeddings[ anchors_local ] # shape [num_anchors, D], embeddings for anchor nodes only @@ -341,7 +363,7 @@ def _forward_bipartite( data: HeteroData, device: torch.device, output_node_types: Optional[list[NodeType]] = None, - anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, + anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None, ) -> dict[NodeType, torch.Tensor]: """ Forward pass for bipartite graphs using LightGCN propagation. @@ -354,7 +376,7 @@ def _forward_bipartite( device (torch.device): Device to run computation on. output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. If None, returns all node types. Default: None. - anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Dict mapping node types + anchor_node_ids (Optional[Dict[NodeType, torch.Tensor]]): Dict mapping node types to local anchor indices. If None, returns all nodes. Default: None. Returns: @@ -370,7 +392,7 @@ def _forward_bipartite( for node_type in output_node_types: node_type_str = str(node_type) - key = f"{node_type_str}_id" + key = self._get_feature_key(node_type_str) assert hasattr(data[node_type_str], "node"), ( f"Subgraph must include .node field for node type {node_type_str}" @@ -452,7 +474,7 @@ def _forward_bipartite( # Extract anchor nodes if specified if anchor_node_ids is not None: for node_type in all_node_types: - if isinstance(anchor_node_ids, dict) and node_type in anchor_node_ids: + if node_type in anchor_node_ids: anchors = anchor_node_ids[node_type].to(device).long() final_embeddings[node_type] = final_embeddings[node_type][anchors] diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 408ac592c..1ec10cb1c 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -432,6 +432,41 @@ def test_compare_bipartite_with_math(self): torch.allclose(output[NodeType("item")], expected_item_embeddings, atol=1e-4, rtol=1e-4) ) + # Test with anchor nodes - select specific nodes from each type + anchor_node_ids = { + NodeType("user"): torch.tensor([0], dtype=torch.long), # Select user 0 + NodeType("item"): torch.tensor([1], dtype=torch.long), # Select item 1 + } + + output_with_anchors = model( + data, + self.device, + output_node_types=[NodeType("user"), NodeType("item")], + anchor_node_ids=anchor_node_ids, + ) + + # Check shapes - should only return embeddings for anchor nodes + self.assertEqual(output_with_anchors[NodeType("user")].shape, (1, self.embedding_dim)) + self.assertEqual(output_with_anchors[NodeType("item")].shape, (1, self.embedding_dim)) + + # Check values - should match the corresponding rows from full output + self.assertTrue( + torch.allclose( + output_with_anchors[NodeType("user")], + expected_user_embeddings[0:1], # User 0 + atol=1e-4, + rtol=1e-4, + ) + ) + self.assertTrue( + torch.allclose( + output_with_anchors[NodeType("item")], + expected_item_embeddings[1:2], # Item 1 + atol=1e-4, + rtol=1e-4, + ) + ) + def _run_dmp_multiprocess_test( rank: int, world_size: int, From f880b85074b0d75a08ad6c0b79460f30e693dded Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 08:34:19 +0000 Subject: [PATCH 08/16] PR changes --- python/gigl/nn/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 6f446a552..b003792a7 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -167,7 +167,6 @@ def _get_feature_key(node_type: Union[str, NodeType]) -> str: Returns: str: Feature key in format "{node_type}_id" """ - print("IM HERE") return f"{node_type}_id" def __init__( From e6b96effa489c572d7f8d78c33901e4cb5980ce3 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 18:59:18 +0000 Subject: [PATCH 09/16] Add types --- python/gigl/nn/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index b003792a7..1a5ed3be6 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -415,7 +415,7 @@ def _forward_bipartite( # For bipartite, we need to create a unified edge representation # Collect all edges and map node indices to a combined space # Node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) - node_type_to_offset = {} + node_type_to_offset: dict[NodeType, int] = {} offset = 0 for node_type in all_node_types: node_type_to_offset[node_type] = offset @@ -428,7 +428,7 @@ def _forward_bipartite( ) # shape [total_nodes, D] # Combine all edges into a single edge_index - combined_edge_list = [] + combined_edge_list: list[torch.Tensor] = [] for edge_type_tuple in data.edge_types: src_nt_str, _, dst_nt_str = edge_type_tuple src_node_type = NodeType(src_nt_str) @@ -449,7 +449,7 @@ def _forward_bipartite( combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges] # Track all layer embeddings - all_layer_embeddings = [combined_embeddings_0] + all_layer_embeddings: list[torch.Tensor] = [combined_embeddings_0] current_embeddings = combined_embeddings_0 # Perform K layers of propagation From 7cc6e93d19952c7618922dfe366091ac86505e39 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 21:54:49 +0000 Subject: [PATCH 10/16] Chanegd from bipartite to heterogeneous --- python/gigl/nn/models.py | 82 +++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 1a5ed3be6..b771ec4aa 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -134,6 +134,19 @@ def unwrap_from_ddp(self) -> "LinkPredictionGNN": return LinkPredictionGNN(encoder=encoder, decoder=decoder) +def _get_feature_key(node_type: Union[str, NodeType]) -> str: + """ + Get the feature key for a node type's embedding table. + + Args: + node_type: Node type as string or NodeType object. + + Returns: + str: Feature key in format "{node_type}_id" + """ + return f"{node_type}_id" + + # TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. # TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific # TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) @@ -156,19 +169,6 @@ class LightGCN(nn.Module): Must have length K+1. If None, uses uniform weights 1/(K+1). Default: None. """ - @staticmethod - def _get_feature_key(node_type: Union[str, NodeType]) -> str: - """ - Get the feature key for a node type's embedding table. - - Args: - node_type: Node type as string or NodeType object. - - Returns: - str: Feature key in format "{node_type}_id" - """ - return f"{node_type}_id" - def __init__( self, node_type_to_num_nodes: Union[int, dict[NodeType, int]], @@ -201,8 +201,17 @@ def __init__( # Build TorchRec EBC (one table per node type) self._feature_keys: list[str] = [ - self._get_feature_key(node_type) for node_type in self._node_type_to_num_nodes.keys() + _get_feature_key(node_type) for node_type in self._node_type_to_num_nodes.keys() ] + + # Validate model configuration: restrict to homogeneous or bipartite graphs + num_node_types = len(self._feature_keys) + if num_node_types not in [1, 2]: + raise ValueError( + f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " + f"got {num_node_types} node types: {self._feature_keys}" + ) + tables: list[EmbeddingBagConfig] = [] for node_type, num_nodes in self._node_type_to_num_nodes.items(): tables.append( @@ -210,7 +219,7 @@ def __init__( name=f"node_embedding_{node_type}", embedding_dim=embedding_dim, num_embeddings=num_nodes, - feature_names=[self._get_feature_key(node_type)], + feature_names=[_get_feature_key(node_type)], ) ) @@ -236,38 +245,31 @@ def forward( Args: data (Union[Data, HeteroData]): Graph data. - For homogeneous: Data object with edge_index and node field - - For bipartite: HeteroData with 2 node types and edge_index_dict + - For heterogeneous: HeteroData with node types and edge_index_dict device (torch.device): Device to run the computation on. output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. - Required for bipartite graphs. If None, returns embeddings for all node types. Default: None. + Required for heterogeneous graphs. If None, returns embeddings for all node types. Default: None. anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Local node indices to return embeddings for. - For homogeneous: torch.Tensor of shape [num_anchors] - - For bipartite: dict mapping node types to anchor tensors + - For heterogeneous: dict mapping node types to anchor tensors If None, returns embeddings for all nodes. Default: None. Returns: Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. - For homogeneous: tensor of shape [num_nodes, embedding_dim] - - For bipartite: dict mapping node types to embeddings + - For heterogeneous: dict mapping node types to embeddings """ - is_bipartite = isinstance(data, HeteroData) - - # Validate model configuration - num_node_types = len(self._feature_keys) - assert num_node_types in [1, 2], ( - f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " - f"got {num_node_types} node types: {self._feature_keys}" - ) + is_heterogeneous = isinstance(data, HeteroData) - if is_bipartite: - # For bipartite graphs, anchor_node_ids must be a dict, not a Tensor + if is_heterogeneous: + # For heterogeneous graphs, anchor_node_ids must be a dict, not a Tensor if anchor_node_ids is not None and not isinstance(anchor_node_ids, dict): raise TypeError( - f"For bipartite graphs, anchor_node_ids must be a dict or None, " + f"For heterogeneous graphs, anchor_node_ids must be a dict or None, " f"got {type(anchor_node_ids)}" ) - return self._forward_bipartite(data, device, output_node_types, anchor_node_ids) + return self._forward_heterogeneous(data, device, output_node_types, anchor_node_ids) else: # For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor): @@ -357,7 +359,7 @@ def _forward_homogeneous( final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph ) - def _forward_bipartite( + def _forward_heterogeneous( self, data: HeteroData, device: torch.device, @@ -365,13 +367,15 @@ def _forward_bipartite( anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None, ) -> dict[NodeType, torch.Tensor]: """ - Forward pass for bipartite graphs using LightGCN propagation. + Forward pass for heterogeneous graphs using LightGCN propagation. - For bipartite graphs (e.g., user-item), we have two node types and edges between them. - LightGCN propagates embeddings across both node types through the bipartite structure. + For heterogeneous graphs (e.g., user-item), we have + multiple node types. Note that we restrict to one edge type. LightGCN propagates embeddings across + all node types by creating a unified node space, running propagation, then splitting + back into per-type embeddings. Args: - data (HeteroData): PyG HeteroData object with 2 node types. + data (HeteroData): PyG HeteroData object with node types. device (torch.device): Device to run computation on. output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. If None, returns all node types. Default: None. @@ -391,7 +395,7 @@ def _forward_bipartite( for node_type in output_node_types: node_type_str = str(node_type) - key = self._get_feature_key(node_type_str) + key = _get_feature_key(node_type_str) assert hasattr(data[node_type_str], "node"), ( f"Subgraph must include .node field for node type {node_type_str}" @@ -412,9 +416,9 @@ def _forward_bipartite( # LightGCN propagation across node types all_node_types = list(node_type_to_embeddings_0.keys()) - # For bipartite, we need to create a unified edge representation + # For heterogeneous graphs, we need to create a unified edge representation # Collect all edges and map node indices to a combined space - # Node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) + # E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) node_type_to_offset: dict[NodeType, int] = {} offset = 0 for node_type in all_node_types: From 62fc9ed94b51aa04d09b85e0d18f48cb1ee84dc3 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 22:38:23 +0000 Subject: [PATCH 11/16] Separate anchor node tests --- python/tests/unit/nn/models_test.py | 68 ++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 1ec10cb1c..0e6229a03 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -432,6 +432,70 @@ def test_compare_bipartite_with_math(self): torch.allclose(output[NodeType("item")], expected_item_embeddings, atol=1e-4, rtol=1e-4) ) + def test_bipartite_with_anchor_nodes(self): + """Test anchor node selection in bipartite/heterogeneous graphs.""" + # Create bipartite graph + num_users = 2 + num_items = 2 + + node_type_to_num_nodes = { + NodeType("user"): num_users, + NodeType("item"): num_items, + } + + model = self._create_lightgcn_model(node_type_to_num_nodes) + + # Set embeddings + user_embeddings = torch.tensor( + [ + [0.2, 0.5, 0.1, 0.4], # User 0 + [0.6, 0.1, 0.2, 0.5], # User 1 + ], + dtype=torch.float32, + ) + + item_embeddings = torch.tensor( + [ + [0.9, 0.4, 0.1, 0.4], # Item 0 + [0.3, 0.8, 0.3, 0.6], # Item 1 + ], + dtype=torch.float32, + ) + + with torch.no_grad(): + user_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_user" + ] + user_table.weight[:] = user_embeddings + + item_table = model._embedding_bag_collection.embedding_bags[ + "node_embedding_item" + ] + item_table.weight[:] = item_embeddings + + data = HeteroData() + + # Set up nodes + data["user"].node = torch.tensor([0, 1], dtype=torch.long) + data["user"].num_nodes = num_users + data["item"].node = torch.tensor([0, 1], dtype=torch.long) + data["item"].num_nodes = num_items + + # Set up edges + data["user", "to", "item"].edge_index = torch.tensor( + [[0, 0, 1], [0, 1, 1]], dtype=torch.long + ) + data["item", "to", "user"].edge_index = torch.tensor( + [[0, 1, 1], [0, 0, 1]], dtype=torch.long + ) + + # First get full output to compare against + full_output = model( + data, + self.device, + output_node_types=[NodeType("user"), NodeType("item")], + ) + # Test with anchor nodes - select specific nodes from each type anchor_node_ids = { NodeType("user"): torch.tensor([0], dtype=torch.long), # Select user 0 @@ -453,7 +517,7 @@ def test_compare_bipartite_with_math(self): self.assertTrue( torch.allclose( output_with_anchors[NodeType("user")], - expected_user_embeddings[0:1], # User 0 + full_output[NodeType("user")][0:1], # User 0 atol=1e-4, rtol=1e-4, ) @@ -461,7 +525,7 @@ def test_compare_bipartite_with_math(self): self.assertTrue( torch.allclose( output_with_anchors[NodeType("item")], - expected_item_embeddings[1:2], # Item 1 + full_output[NodeType("item")][1:2], # Item 1 atol=1e-4, rtol=1e-4, ) From b0697157241de87fe68c723663289e0f00bd8475 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 10 Nov 2025 22:45:11 +0000 Subject: [PATCH 12/16] Sorted output node types --- python/gigl/nn/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index b771ec4aa..e73b06a78 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -388,7 +388,8 @@ def _forward_heterogeneous( """ # Determine which node types to process if output_node_types is None: - output_node_types = [NodeType(str(nt)) for nt in data.node_types] + # Sort node types for deterministic ordering across machines + output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str) # Lookup initial embeddings e^(0) for each node type node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} @@ -414,7 +415,8 @@ def _forward_heterogeneous( node_type_to_embeddings_0[node_type] = embeddings # LightGCN propagation across node types - all_node_types = list(node_type_to_embeddings_0.keys()) + # Sort node types for deterministic ordering across machines + all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str) # For heterogeneous graphs, we need to create a unified edge representation # Collect all edges and map node indices to a combined space From cf84fcd1eb8aadd9b35a4db21daa6f1bab55a661 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 17 Nov 2025 18:31:29 +0000 Subject: [PATCH 13/16] Sorting changes --- python/gigl/nn/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index e73b06a78..4f59af03b 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -200,20 +200,23 @@ def __init__( ) # Build TorchRec EBC (one table per node type) + # Sort node types for deterministic ordering across machines self._feature_keys: list[str] = [ - _get_feature_key(node_type) for node_type in self._node_type_to_num_nodes.keys() + _get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys(), key=str) ] # Validate model configuration: restrict to homogeneous or bipartite graphs num_node_types = len(self._feature_keys) if num_node_types not in [1, 2]: + # TODO(kmonte, swong3): We should loosen this restriction and allow fully heterogenous graphs in the future. raise ValueError( f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " f"got {num_node_types} node types: {self._feature_keys}" ) tables: list[EmbeddingBagConfig] = [] - for node_type, num_nodes in self._node_type_to_num_nodes.items(): + # Sort node types for deterministic ordering across machines + for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items(), key=lambda x: str(x[0])): tables.append( EmbeddingBagConfig( name=f"node_embedding_{node_type}", @@ -248,7 +251,7 @@ def forward( - For heterogeneous: HeteroData with node types and edge_index_dict device (torch.device): Device to run the computation on. output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. - Required for heterogeneous graphs. If None, returns embeddings for all node types. Default: None. + If None, returns embeddings for all node types. Default: None. anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Local node indices to return embeddings for. - For homogeneous: torch.Tensor of shape [num_anchors] @@ -370,7 +373,7 @@ def _forward_heterogeneous( Forward pass for heterogeneous graphs using LightGCN propagation. For heterogeneous graphs (e.g., user-item), we have - multiple node types. Note that we restrict to one edge type. LightGCN propagates embeddings across + multiple node types. LightGCN propagates embeddings across all node types by creating a unified node space, running propagation, then splitting back into per-type embeddings. From 9ca696daa5412484f3ceed0a5a48365992396e46 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 17 Nov 2025 19:08:57 +0000 Subject: [PATCH 14/16] Changed unexpepcted behavior for anchor nodes (bipartite case) --- python/gigl/nn/models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 4f59af03b..5728d73ce 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -202,7 +202,7 @@ def __init__( # Build TorchRec EBC (one table per node type) # Sort node types for deterministic ordering across machines self._feature_keys: list[str] = [ - _get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys(), key=str) + _get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys()) ] # Validate model configuration: restrict to homogeneous or bipartite graphs @@ -216,7 +216,7 @@ def __init__( tables: list[EmbeddingBagConfig] = [] # Sort node types for deterministic ordering across machines - for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items(), key=lambda x: str(x[0])): + for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items()): tables.append( EmbeddingBagConfig( name=f"node_embedding_{node_type}", @@ -392,7 +392,7 @@ def _forward_heterogeneous( # Determine which node types to process if output_node_types is None: # Sort node types for deterministic ordering across machines - output_node_types = sorted([NodeType(str(nt)) for nt in data.node_types], key=str) + output_node_types = [NodeType(nt) for nt in sorted(data.node_types)] # Lookup initial embeddings e^(0) for each node type node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} @@ -419,7 +419,7 @@ def _forward_heterogeneous( # LightGCN propagation across node types # Sort node types for deterministic ordering across machines - all_node_types = sorted(node_type_to_embeddings_0.keys(), key=str) + all_node_types = sorted(node_type_to_embeddings_0.keys()) # For heterogeneous graphs, we need to create a unified edge representation # Collect all edges and map node indices to a combined space @@ -437,8 +437,9 @@ def _forward_heterogeneous( ) # shape [total_nodes, D] # Combine all edges into a single edge_index + # Sort edge types for deterministic ordering across machines combined_edge_list: list[torch.Tensor] = [] - for edge_type_tuple in data.edge_types: + for edge_type_tuple in sorted(data.edge_types): src_nt_str, _, dst_nt_str = edge_type_tuple src_node_type = NodeType(src_nt_str) dst_node_type = NodeType(dst_nt_str) @@ -481,10 +482,13 @@ def _forward_heterogeneous( # Extract anchor nodes if specified if anchor_node_ids is not None: + # Only return embeddings for node types specified in anchor_node_ids + filtered_embeddings: dict[NodeType, torch.Tensor] = {} for node_type in all_node_types: if node_type in anchor_node_ids: anchors = anchor_node_ids[node_type].to(device).long() - final_embeddings[node_type] = final_embeddings[node_type][anchors] + filtered_embeddings[node_type] = final_embeddings[node_type][anchors] + return filtered_embeddings return final_embeddings From 15f9a079412968c3d3ef208124a4cbd78c1f8763 Mon Sep 17 00:00:00 2001 From: swong3 Date: Fri, 21 Nov 2025 21:18:24 +0000 Subject: [PATCH 15/16] Remove output_node_types --- python/gigl/nn/models.py | 27 +++++++++++++-------------- python/tests/unit/nn/models_test.py | 20 ++++---------------- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 5728d73ce..24cfa6e1a 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -239,7 +239,6 @@ def forward( self, data: Union[Data, HeteroData], device: torch.device, - output_node_types: Optional[list[NodeType]] = None, anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: """ @@ -250,8 +249,6 @@ def forward( - For homogeneous: Data object with edge_index and node field - For heterogeneous: HeteroData with node types and edge_index_dict device (torch.device): Device to run the computation on. - output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. - If None, returns embeddings for all node types. Default: None. anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): Local node indices to return embeddings for. - For homogeneous: torch.Tensor of shape [num_anchors] @@ -272,7 +269,7 @@ def forward( f"For heterogeneous graphs, anchor_node_ids must be a dict or None, " f"got {type(anchor_node_ids)}" ) - return self._forward_heterogeneous(data, device, output_node_types, anchor_node_ids) + return self._forward_heterogeneous(data, device, anchor_node_ids) else: # For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor): @@ -366,7 +363,6 @@ def _forward_heterogeneous( self, data: HeteroData, device: torch.device, - output_node_types: Optional[list[NodeType]] = None, anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None, ) -> dict[NodeType, torch.Tensor]: """ @@ -377,27 +373,30 @@ def _forward_heterogeneous( all node types by creating a unified node space, running propagation, then splitting back into per-type embeddings. + Note: All node types in the graph are processed during message passing, as this is + required for correct GNN computation. Use anchor_node_ids to filter which node types + and specific nodes are returned in the output. + Args: data (HeteroData): PyG HeteroData object with node types. device (torch.device): Device to run computation on. - output_node_types (Optional[List[NodeType]]): Node types to return embeddings for. - If None, returns all node types. Default: None. anchor_node_ids (Optional[Dict[NodeType, torch.Tensor]]): Dict mapping node types - to local anchor indices. If None, returns all nodes. Default: None. + to local anchor indices. If None, returns all nodes for all types. + If provided, only returns embeddings for the specified node types and indices. Returns: Dict[NodeType, torch.Tensor]: Dict mapping node types to their embeddings, - each of shape [num_nodes_of_type, embedding_dim]. + each of shape [num_nodes_of_type, embedding_dim] (or [num_anchors, embedding_dim] + if anchor_node_ids is provided for that type). """ - # Determine which node types to process - if output_node_types is None: - # Sort node types for deterministic ordering across machines - output_node_types = [NodeType(nt) for nt in sorted(data.node_types)] + # Process all node types - this is required for correct message passing in GNNs + # Sort node types for deterministic ordering across machines + all_node_types_in_data = [NodeType(nt) for nt in sorted(data.node_types)] # Lookup initial embeddings e^(0) for each node type node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} - for node_type in output_node_types: + for node_type in all_node_types_in_data: node_type_str = str(node_type) key = _get_feature_key(node_type_str) diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 0e6229a03..74576a948 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -401,11 +401,10 @@ def test_compare_bipartite_with_math(self): [[0, 1, 1], [0, 0, 1]], dtype=torch.long ) - # Forward pass + # Forward pass - will return both user and item embeddings output = model( data, self.device, - output_node_types=[NodeType("user"), NodeType("item")], ) expected_user_embeddings = torch.tensor( @@ -489,29 +488,26 @@ def test_bipartite_with_anchor_nodes(self): [[0, 1, 1], [0, 0, 1]], dtype=torch.long ) - # First get full output to compare against + # First get full output to compare against (will return all node types) full_output = model( data, self.device, - output_node_types=[NodeType("user"), NodeType("item")], ) - # Test with anchor nodes - select specific nodes from each type + # Test with anchor nodes - select specific nodes from specific types + # By only including "user" in anchor_node_ids, we'll only get user embeddings back anchor_node_ids = { NodeType("user"): torch.tensor([0], dtype=torch.long), # Select user 0 - NodeType("item"): torch.tensor([1], dtype=torch.long), # Select item 1 } output_with_anchors = model( data, self.device, - output_node_types=[NodeType("user"), NodeType("item")], anchor_node_ids=anchor_node_ids, ) # Check shapes - should only return embeddings for anchor nodes self.assertEqual(output_with_anchors[NodeType("user")].shape, (1, self.embedding_dim)) - self.assertEqual(output_with_anchors[NodeType("item")].shape, (1, self.embedding_dim)) # Check values - should match the corresponding rows from full output self.assertTrue( @@ -522,14 +518,6 @@ def test_bipartite_with_anchor_nodes(self): rtol=1e-4, ) ) - self.assertTrue( - torch.allclose( - output_with_anchors[NodeType("item")], - full_output[NodeType("item")][1:2], # Item 1 - atol=1e-4, - rtol=1e-4, - ) - ) def _run_dmp_multiprocess_test( rank: int, From 00b06be5cf0ea7f5dfae01e10e5521800d1c0446 Mon Sep 17 00:00:00 2001 From: swong3 Date: Mon, 24 Nov 2025 19:43:25 +0000 Subject: [PATCH 16/16] fix --- python/gigl/nn/models.py | 12 ++++-------- python/tests/unit/nn/models_test.py | 4 ++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/gigl/nn/models.py b/python/gigl/nn/models.py index 24cfa6e1a..26876636a 100644 --- a/python/gigl/nn/models.py +++ b/python/gigl/nn/models.py @@ -416,23 +416,19 @@ def _forward_heterogeneous( node_type_to_embeddings_0[node_type] = embeddings - # LightGCN propagation across node types - # Sort node types for deterministic ordering across machines - all_node_types = sorted(node_type_to_embeddings_0.keys()) - # For heterogeneous graphs, we need to create a unified edge representation # Collect all edges and map node indices to a combined space # E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) node_type_to_offset: dict[NodeType, int] = {} offset = 0 - for node_type in all_node_types: + for node_type in all_node_types_in_data: node_type_to_offset[node_type] = offset node_type_str = str(node_type) offset += data[node_type_str].num_nodes # Combine all embeddings into a single tensor combined_embeddings_0 = torch.cat( - [node_type_to_embeddings_0[nt] for nt in all_node_types], dim=0 + [node_type_to_embeddings_0[nt] for nt in all_node_types_in_data], dim=0 ) # shape [total_nodes, D] # Combine all edges into a single edge_index @@ -471,7 +467,7 @@ def _forward_heterogeneous( # Split back into per-node-type embeddings final_embeddings: dict[NodeType, torch.Tensor] = {} - for node_type in all_node_types: + for node_type in all_node_types_in_data: start_idx = node_type_to_offset[node_type] node_type_str = str(node_type) num_nodes = data[node_type_str].num_nodes @@ -483,7 +479,7 @@ def _forward_heterogeneous( if anchor_node_ids is not None: # Only return embeddings for node types specified in anchor_node_ids filtered_embeddings: dict[NodeType, torch.Tensor] = {} - for node_type in all_node_types: + for node_type in all_node_types_in_data: if node_type in anchor_node_ids: anchors = anchor_node_ids[node_type].to(device).long() filtered_embeddings[node_type] = final_embeddings[node_type][anchors] diff --git a/python/tests/unit/nn/models_test.py b/python/tests/unit/nn/models_test.py index 74576a948..82a75119d 100644 --- a/python/tests/unit/nn/models_test.py +++ b/python/tests/unit/nn/models_test.py @@ -506,8 +506,8 @@ def test_bipartite_with_anchor_nodes(self): anchor_node_ids=anchor_node_ids, ) - # Check shapes - should only return embeddings for anchor nodes - self.assertEqual(output_with_anchors[NodeType("user")].shape, (1, self.embedding_dim)) + # Check that only user embeddings are returned + self.assertEqual(output_with_anchors.keys(), set([NodeType("user")])) # Check values - should match the corresponding rows from full output self.assertTrue(