From 8a58122bc55c71aa444e1b4015af688bafa562a2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 16:14:03 +0000 Subject: [PATCH 1/6] Initial plan From 270b221d9c590f6daa828669d1fb0dab2cd07a16 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 16:20:24 +0000 Subject: [PATCH 2/6] Fix MPS device movement issues by registering tensors as buffers Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com> --- pytorch_tabnet/tab_network/embedding_generator.py | 11 +++++++---- pytorch_tabnet/tab_network/random_obfuscator.py | 3 ++- pytorch_tabnet/utils/device.py | 12 ++++++++---- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pytorch_tabnet/tab_network/embedding_generator.py b/pytorch_tabnet/tab_network/embedding_generator.py index 27c13c4..74a1f14 100644 --- a/pytorch_tabnet/tab_network/embedding_generator.py +++ b/pytorch_tabnet/tab_network/embedding_generator.py @@ -39,7 +39,8 @@ def __init__( if cat_dims == [] and cat_idxs == []: self.skip_embedding = True self.post_embed_dim = input_dim - self.embedding_group_matrix = group_matrix.to(group_matrix.device) + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("embedding_group_matrix", group_matrix.clone()) return else: self.skip_embedding = False @@ -60,23 +61,25 @@ def __init__( # update group matrix n_groups = group_matrix.shape[0] - self.embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device) + embedding_group_matrix = torch.empty((n_groups, self.post_embed_dim), device=group_matrix.device) for group_idx in range(n_groups): post_emb_idx = 0 cat_feat_counter = 0 for init_feat_idx in range(input_dim): if self.continuous_idx[init_feat_idx] == 1: # this means that no embedding is applied to this column - self.embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa + embedding_group_matrix[group_idx, post_emb_idx] = group_matrix[group_idx, init_feat_idx] # noqa post_emb_idx += 1 else: # this is a categorical feature which creates multiple embeddings n_embeddings = cat_emb_dims[cat_feat_counter] - self.embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = ( + embedding_group_matrix[group_idx, post_emb_idx : post_emb_idx + n_embeddings] = ( group_matrix[group_idx, init_feat_idx] / n_embeddings ) # noqa post_emb_idx += n_embeddings cat_feat_counter += 1 + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("embedding_group_matrix", embedding_group_matrix) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply embeddings to inputs diff --git a/pytorch_tabnet/tab_network/random_obfuscator.py b/pytorch_tabnet/tab_network/random_obfuscator.py index 86454b6..f497467 100644 --- a/pytorch_tabnet/tab_network/random_obfuscator.py +++ b/pytorch_tabnet/tab_network/random_obfuscator.py @@ -20,7 +20,8 @@ def __init__(self, pretraining_ratio: float, group_matrix: torch.Tensor): super(RandomObfuscator, self).__init__() self.pretraining_ratio = pretraining_ratio # group matrix is set to boolean here to pass all posssible information - self.group_matrix = (group_matrix > 0) + 0.0 + # Register as buffer to ensure it moves with the model when .to(device) is called + self.register_buffer("group_matrix", (group_matrix > 0) + 0.0) self.num_groups = group_matrix.shape[0] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/pytorch_tabnet/utils/device.py b/pytorch_tabnet/utils/device.py index c9ac92b..be206b4 100644 --- a/pytorch_tabnet/utils/device.py +++ b/pytorch_tabnet/utils/device.py @@ -6,25 +6,29 @@ def define_device(device_name: str) -> str: """Define the device to use during training and inference. - If auto it will detect automatically whether to use cuda or cpu. + If auto it will detect automatically whether to use mps, cuda or cpu. Parameters ---------- device_name : str - Either "auto", "cpu" or "cuda" + Either "auto", "cpu", "cuda", or "mps" Returns ------- str - Either "cpu" or "cuda" + Either "cpu", "cuda", or "mps" """ if device_name == "auto": - if torch.cuda.is_available(): + if torch.backends.mps.is_available(): + return "mps" + elif torch.cuda.is_available(): return "cuda" else: return "cpu" elif device_name == "cuda" and not torch.cuda.is_available(): return "cpu" + elif device_name == "mps" and not torch.backends.mps.is_available(): + return "cpu" else: return device_name From 3ef07ce1f6b52725a7113160e1c9efde0aaf000a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 16:23:00 +0000 Subject: [PATCH 3/6] Final verification: all tests pass, no security issues Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com> --- nul | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 nul diff --git a/nul b/nul new file mode 100644 index 0000000..e69de29 From 101de47ee521e96d12c175e467d59e0ac1585d82 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 16:23:29 +0000 Subject: [PATCH 4/6] Remove spurious nul file --- nul | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 nul diff --git a/nul b/nul deleted file mode 100644 index e69de29..0000000 From 1a6a515aeb302555c19e71ec4715110b97cfbf96 Mon Sep 17 00:00:00 2001 From: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com> Date: Sun, 2 Nov 2025 18:49:07 +0200 Subject: [PATCH 5/6] Simplify device selection logic Removed MPS device check from auto device selection. Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com> --- pytorch_tabnet/utils/device.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_tabnet/utils/device.py b/pytorch_tabnet/utils/device.py index be206b4..eb87d87 100644 --- a/pytorch_tabnet/utils/device.py +++ b/pytorch_tabnet/utils/device.py @@ -20,15 +20,11 @@ def define_device(device_name: str) -> str: """ if device_name == "auto": - if torch.backends.mps.is_available(): - return "mps" elif torch.cuda.is_available(): return "cuda" else: return "cpu" elif device_name == "cuda" and not torch.cuda.is_available(): return "cpu" - elif device_name == "mps" and not torch.backends.mps.is_available(): - return "cpu" else: return device_name From 24f314525b1970909a06d168286875d7b359b451 Mon Sep 17 00:00:00 2001 From: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com> Date: Sun, 2 Nov 2025 18:49:46 +0200 Subject: [PATCH 6/6] Fix device selection logic for 'auto' case Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com> --- pytorch_tabnet/utils/device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_tabnet/utils/device.py b/pytorch_tabnet/utils/device.py index eb87d87..4cf9418 100644 --- a/pytorch_tabnet/utils/device.py +++ b/pytorch_tabnet/utils/device.py @@ -20,7 +20,7 @@ def define_device(device_name: str) -> str: """ if device_name == "auto": - elif torch.cuda.is_available(): + if torch.cuda.is_available(): return "cuda" else: return "cpu"