diff --git a/pytorch_tabnet/tab_network/embedding_generator.py b/pytorch_tabnet/tab_network/embedding_generator.py index 27c13c4..74a1f14 100644 --- a/pytorch_tabnet/tab_network/embedding_generator.py +++ b/pytorch_tabnet/tab_network/embedding_generator.py @@ -39,7 +39,8 @@ def __init__( if cat_dims == [] and cat_idxs == []: self.skip_embedding = True self.post_embed_dim = input_dim - self.embedding_group_matrix = group_matrix.to(group_matrix.device) + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("embedding_group_matrix", group_matrix.clone()) return else: self.skip_embedding = False @@ -60,23 +61,25 @@ def __init__( # update group matrix n_groups = group_matrix.shape[0] - self.embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device) + embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device) for group_idx in range(n_groups): post_emb_idx = 0 cat_feat_counter = 0 for init_feat_idx in range(input_dim): if self.continuous_idx[init_feat_idx] == 1: # this means that no embedding is applied to this column - self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa + embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa post_emb_idx += 1 else: # this is a categorical feature which creates multiple embeddings n_embeddings = cat_emb_dims[cat_feat_counter] - self.embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = ( + embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = ( group_matrix[group_idx, init_feat_idx] / n_embeddings ) # noqa post_emb_idx += n_embeddings cat_feat_counter += 1 + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("embedding_group_matrix", embedding_group_matrix) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply embeddings to inputs diff --git a/pytorch_tabnet/tab_network/random_obfuscator.py b/pytorch_tabnet/tab_network/random_obfuscator.py index 86454b6..f497467 100644 --- a/pytorch_tabnet/tab_network/random_obfuscator.py +++ b/pytorch_tabnet/tab_network/random_obfuscator.py @@ -20,7 +20,8 @@ def __init__(self, pretraining_ratio: float, group_matrix: torch.Tensor): super(RandomObfuscator, self).__init__() self.pretraining_ratio = pretraining_ratio # group matrix is set to boolean here to pass all posssible information - self.group_matrix = (group_matrix > 0) + 0.0 + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("group_matrix", (group_matrix > 0) + 0.0) self.num_groups = group_matrix.shape[0] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/pytorch_tabnet/utils/device.py b/pytorch_tabnet/utils/device.py index c9ac92b..4cf9418 100644 --- a/pytorch_tabnet/utils/device.py +++ b/pytorch_tabnet/utils/device.py @@ -6,17 +6,17 @@ def define_device(device_name: str) -> str: """Define the device to use during training and inference. - If auto it will detect automatically whether to use cuda or cpu. + If auto it will detect automatically whether to use mps, cuda or cpu. Parameters ---------- device_name : str - Either "auto", "cpu" or "cuda" + Either "auto", "cpu", "cuda", or "mps" Returns ------- str - Either "cpu" or "cuda" + Either "cpu", "cuda", or "mps" """ if device_name == "auto":