Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pytorch_tabnet/tab_network/embedding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
if cat_dims == [] and cat_idxs == []:
self.skip_embedding = True
self.post_embed_dim = input_dim
self.embedding_group_matrix = group_matrix.to(group_matrix.device)
# Register as buffer to ensure it moves with the model when .to(device) is called
self.register_buffer("embedding_group_matrix", group_matrix.clone())
return
else:
self.skip_embedding = False
Expand All @@ -60,23 +61,25 @@ def __init__(

# update group matrix
n_groups = group_matrix.shape[0]
self.embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device)
embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device)
for group_idx in range(n_groups):
post_emb_idx = 0
cat_feat_counter = 0
for init_feat_idx in range(input_dim):
if self.continuous_idx[init_feat_idx] == 1:
# this means that no embedding is applied to this column
self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa
embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa
post_emb_idx += 1
else:
# this is a categorical feature which creates multiple embeddings
n_embeddings = cat_emb_dims[cat_feat_counter]
self.embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = (
embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = (
group_matrix[group_idx, init_feat_idx] / n_embeddings
) # noqa
post_emb_idx += n_embeddings
cat_feat_counter += 1
# Register as buffer to ensure it moves with the model when .to(device) is called
self.register_buffer("embedding_group_matrix", embedding_group_matrix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply embeddings to inputs
Expand Down
3 changes: 2 additions & 1 deletion pytorch_tabnet/tab_network/random_obfuscator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self, pretraining_ratio: float, group_matrix: torch.Tensor):
super(RandomObfuscator, self).__init__()
self.pretraining_ratio = pretraining_ratio
# group matrix is set to boolean here to pass all posssible information
self.group_matrix = (group_matrix > 0) + 0.0
# Register as buffer to ensure it moves with the model when .to(device) is called
self.register_buffer("group_matrix", (group_matrix > 0) + 0.0)
self.num_groups = group_matrix.shape[0]

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_tabnet/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
def define_device(device_name: str) -> str:
"""Define the device to use during training and inference.

If auto it will detect automatically whether to use cuda or cpu.
If auto it will detect automatically whether to use mps, cuda or cpu.

Parameters
----------
device_name : str
Either "auto", "cpu" or "cuda"
Either "auto", "cpu", "cuda", or "mps"

Returns
-------
str
Either "cpu" or "cuda"
Either "cpu", "cuda", or "mps"

"""
if device_name == "auto":
Expand Down