From bad1079f6f1ab8e7be18d89ec1fd83ef6d257095 Mon Sep 17 00:00:00 2001 From: eh-mls Date: Thu, 19 Feb 2026 00:54:21 +0200 Subject: [PATCH] Refactor device handling and tensor operations for improved clarity and performance --- pytorch_tabnet/data_handlers/tb_dataloader.py | 2 +- .../abstract_models/base_supervised_model.py | 2 +- pytorch_tabnet/tab_models/multitask.py | 4 +- pytorch_tabnet/tab_models/pretraining.py | 39 ++++++------------- pytorch_tabnet/tab_models/tab_class.py | 2 +- pytorch_tabnet/utils/device.py | 4 ++ tests/utils/test_device.py | 12 ++++++ 7 files changed, 32 insertions(+), 33 deletions(-) diff --git a/pytorch_tabnet/data_handlers/tb_dataloader.py b/pytorch_tabnet/data_handlers/tb_dataloader.py index 050f963..982d1f9 100644 --- a/pytorch_tabnet/data_handlers/tb_dataloader.py +++ b/pytorch_tabnet/data_handlers/tb_dataloader.py @@ -31,7 +31,7 @@ def __iter__(self) -> Iterable[Tuple[torch.Tensor, tn_type, tn_type]]: ds_len = len(self.dataset) perm = None if not self.predict: - perm = torch.randperm(ds_len, pin_memory=self.pin_memory) + perm = torch.randperm(ds_len, device="cpu") batched_starts = [i for i in range(0, ds_len, self.batch_size)] batched_starts += [0] if len(batched_starts) == 0 else [] for start in batched_starts[: len(self)]: diff --git a/pytorch_tabnet/tab_models/abstract_models/base_supervised_model.py b/pytorch_tabnet/tab_models/abstract_models/base_supervised_model.py index 7fae904..9a77d73 100644 --- a/pytorch_tabnet/tab_models/abstract_models/base_supervised_model.py +++ b/pytorch_tabnet/tab_models/abstract_models/base_supervised_model.py @@ -85,7 +85,7 @@ def _train_epoch(self, train_loader: TBDataLoader) -> None: X = X.to(self.device) # type: ignore y = y.to(self.device) # type: ignore if w is not None: # type: ignore - w = w.to(self.device) # type: ignore + w = w.to(device=self.device, dtype=torch.float32) # type: ignore batch_logs = self._train_batch(X, y, w) diff --git a/pytorch_tabnet/tab_models/multitask.py b/pytorch_tabnet/tab_models/multitask.py index 68bf4cb..a06c214 100644 --- a/pytorch_tabnet/tab_models/multitask.py +++ b/pytorch_tabnet/tab_models/multitask.py @@ -180,7 +180,7 @@ def predict(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]: results: dict = {} with torch.no_grad(): for data, _, _ in dataloader: # type: ignore - data = data.to(self.device).float() + data = data.float().to(self.device) output, _ = self.network(data) predictions = [ torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1).cpu().detach().numpy().reshape(-1) for task_output in output @@ -219,7 +219,7 @@ def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]: results: dict = {} for data, _, _ in dataloader: # type: ignore - data = data.to(self.device).float() + data = data.float().to(self.device) output, _ = self.network(data) predictions = [torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() for task_output in output] for task_idx in range(len(self.output_dim)): diff --git a/pytorch_tabnet/tab_models/pretraining.py b/pytorch_tabnet/tab_models/pretraining.py index 1db7eea..121a217 100644 --- a/pytorch_tabnet/tab_models/pretraining.py +++ b/pytorch_tabnet/tab_models/pretraining.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from .. import tab_network -from ..data_handlers import UnifiedDataset, create_dataloaders_pt +from ..data_handlers import create_dataloaders_pt from ..error_handlers import filter_weights, validate_eval_set from ..metrics import ( UnsupervisedLoss, @@ -396,7 +396,7 @@ def _predict_batch(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, t Model outputs, embedded inputs, and obfuscated variables. """ - X = X.to(self.device).float() + X = X.float().to(self.device) return self.network(X) def stack_batches( # type: ignore[override] @@ -427,38 +427,21 @@ def stack_batches( # type: ignore[override] obf_vars = torch.vstack(list_obfuscation) return output, embedded_x, obf_vars - def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Predict outputs and embeddings for a batch. + def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Predict reconstructed values for inputs. Parameters ---------- - X : np.ndarray or scipy.sparse.csr_matrix - Input data. + X : np.ndarray + Input matrix. Returns ------- - tuple - Predictions and embedded inputs. + Tuple[np.ndarray, np.ndarray] + Reconstructed values and masks. """ self.network.eval() - - dataloader = DataLoader( - UnifiedDataset(X), - batch_size=self.batch_size, - shuffle=False, - ) - - results = [] - embedded_res = [] - with torch.no_grad(): - for _batch_nb, data in enumerate(dataloader): - data = data.to(self.device).float() - output, embeded_x, _ = self.network(data) - predictions = output - results.append(predictions) - embedded_res.append(embeded_x) - res_output = torch.vstack(results).cpu().detach().numpy() - - embedded_inputs = torch.vstack(embedded_res).cpu().detach().numpy() - return res_output, embedded_inputs + X = torch.from_numpy(X).float().to(self.device) + output, _, obf_vars = self.network(X) + return output.cpu().detach().numpy(), obf_vars.cpu().detach().numpy() diff --git a/pytorch_tabnet/tab_models/tab_class.py b/pytorch_tabnet/tab_models/tab_class.py index 9dbf3a2..e1cf0b3 100644 --- a/pytorch_tabnet/tab_models/tab_class.py +++ b/pytorch_tabnet/tab_models/tab_class.py @@ -202,7 +202,7 @@ def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> np.ndarray: results: List[np.ndarray] = [] with torch.no_grad(): for _batch_nb, (data, _, _) in enumerate(dataloader): # type: ignore - data = data.to(self.device).float() # type: ignore + data = data.float().to(self.device) # type: ignore output: torch.Tensor _M_loss: torch.Tensor diff --git a/pytorch_tabnet/utils/device.py b/pytorch_tabnet/utils/device.py index 4cf9418..ed7c17b 100644 --- a/pytorch_tabnet/utils/device.py +++ b/pytorch_tabnet/utils/device.py @@ -22,9 +22,13 @@ def define_device(device_name: str) -> str: if device_name == "auto": if torch.cuda.is_available(): return "cuda" + elif torch.backends.mps.is_available(): + return "mps" else: return "cpu" elif device_name == "cuda" and not torch.cuda.is_available(): return "cpu" + elif device_name == "mps" and not torch.backends.mps.is_available(): + return "cpu" else: return device_name diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index 9e61bb0..4b2d7fe 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -10,6 +10,7 @@ def test_define_device_auto_with_cuda(self, monkeypatch): """Test define_device with 'auto' when CUDA is available.""" # Mock torch.cuda.is_available to return True monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) assert define_device("auto") == "cuda" @@ -17,13 +18,23 @@ def test_define_device_auto_without_cuda(self, monkeypatch): """Test define_device with 'auto' when CUDA is not available.""" # Mock torch.cuda.is_available to return False monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) assert define_device("auto") == "cpu" + def test_define_device_auto_with_mps(self, monkeypatch): + """Test define_device with 'auto' when MPS is available.""" + # Mock torch.cuda.is_available to return False + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True) + + assert define_device("auto") == "mps" + def test_define_device_cuda_not_available(self, monkeypatch): """Test define_device with 'cuda' when CUDA is not available.""" # Mock torch.cuda.is_available to return False monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) assert define_device("cuda") == "cpu" @@ -31,6 +42,7 @@ def test_define_device_cuda_available(self, monkeypatch): """Test define_device with 'cuda' when CUDA is available.""" # Mock torch.cuda.is_available to return True monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False) assert define_device("cuda") == "cuda"