Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_tabnet/data_handlers/tb_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __iter__(self) -> Iterable[Tuple[torch.Tensor, tn_type, tn_type]]:
ds_len = len(self.dataset)
perm = None
if not self.predict:
perm = torch.randperm(ds_len, pin_memory=self.pin_memory)
perm = torch.randperm(ds_len, device="cpu")
batched_starts = [i for i in range(0, ds_len, self.batch_size)]
batched_starts += [0] if len(batched_starts) == 0 else []
for start in batched_starts[: len(self)]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _train_epoch(self, train_loader: TBDataLoader) -> None:
X = X.to(self.device) # type: ignore
y = y.to(self.device) # type: ignore
if w is not None: # type: ignore
w = w.to(self.device) # type: ignore
w = w.to(device=self.device, dtype=torch.float32) # type: ignore

batch_logs = self._train_batch(X, y, w)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_tabnet/tab_models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def predict(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]:
results: dict = {}
with torch.no_grad():
for data, _, _ in dataloader: # type: ignore
data = data.to(self.device).float()
data = data.float().to(self.device)
output, _ = self.network(data)
predictions = [
torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1).cpu().detach().numpy().reshape(-1) for task_output in output
Expand Down Expand Up @@ -219,7 +219,7 @@ def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> List[np.ndarray]:

results: dict = {}
for data, _, _ in dataloader: # type: ignore
data = data.to(self.device).float()
data = data.float().to(self.device)
output, _ = self.network(data)
predictions = [torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() for task_output in output]
for task_idx in range(len(self.output_dim)):
Expand Down
39 changes: 11 additions & 28 deletions pytorch_tabnet/tab_models/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader

from .. import tab_network
from ..data_handlers import UnifiedDataset, create_dataloaders_pt
from ..data_handlers import create_dataloaders_pt
from ..error_handlers import filter_weights, validate_eval_set
from ..metrics import (
UnsupervisedLoss,
Expand Down Expand Up @@ -396,7 +396,7 @@ def _predict_batch(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, t
Model outputs, embedded inputs, and obfuscated variables.

"""
X = X.to(self.device).float()
X = X.float().to(self.device)
return self.network(X)

def stack_batches( # type: ignore[override]
Expand Down Expand Up @@ -427,38 +427,21 @@ def stack_batches( # type: ignore[override]
obf_vars = torch.vstack(list_obfuscation)
return output, embedded_x, obf_vars

def predict(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Predict outputs and embeddings for a batch.
def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Predict reconstructed values for inputs.

Parameters
----------
X : np.ndarray or scipy.sparse.csr_matrix
Input data.
X : np.ndarray
Input matrix.

Returns
-------
tuple
Predictions and embedded inputs.
Tuple[np.ndarray, np.ndarray]
Reconstructed values and masks.

"""
self.network.eval()

dataloader = DataLoader(
UnifiedDataset(X),
batch_size=self.batch_size,
shuffle=False,
)

results = []
embedded_res = []
with torch.no_grad():
for _batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
output, embeded_x, _ = self.network(data)
predictions = output
results.append(predictions)
embedded_res.append(embeded_x)
res_output = torch.vstack(results).cpu().detach().numpy()

embedded_inputs = torch.vstack(embedded_res).cpu().detach().numpy()
return res_output, embedded_inputs
X = torch.from_numpy(X).float().to(self.device)
output, _, obf_vars = self.network(X)
return output.cpu().detach().numpy(), obf_vars.cpu().detach().numpy()
2 changes: 1 addition & 1 deletion pytorch_tabnet/tab_models/tab_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def predict_proba(self, X: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
results: List[np.ndarray] = []
with torch.no_grad():
for _batch_nb, (data, _, _) in enumerate(dataloader): # type: ignore
data = data.to(self.device).float() # type: ignore
data = data.float().to(self.device) # type: ignore

output: torch.Tensor
_M_loss: torch.Tensor
Expand Down
4 changes: 4 additions & 0 deletions pytorch_tabnet/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ def define_device(device_name: str) -> str:
if device_name == "auto":
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
elif device_name == "cuda" and not torch.cuda.is_available():
return "cpu"
elif device_name == "mps" and not torch.backends.mps.is_available():
return "cpu"
else:
return device_name
12 changes: 12 additions & 0 deletions tests/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,39 @@ def test_define_device_auto_with_cuda(self, monkeypatch):
"""Test define_device with 'auto' when CUDA is available."""
# Mock torch.cuda.is_available to return True
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False)

assert define_device("auto") == "cuda"

def test_define_device_auto_without_cuda(self, monkeypatch):
"""Test define_device with 'auto' when CUDA is not available."""
# Mock torch.cuda.is_available to return False
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False)

assert define_device("auto") == "cpu"

def test_define_device_auto_with_mps(self, monkeypatch):
"""Test define_device with 'auto' when MPS is available."""
# Mock torch.cuda.is_available to return False
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(torch.backends.mps, "is_available", lambda: True)

assert define_device("auto") == "mps"

def test_define_device_cuda_not_available(self, monkeypatch):
"""Test define_device with 'cuda' when CUDA is not available."""
# Mock torch.cuda.is_available to return False
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False)

assert define_device("cuda") == "cpu"

def test_define_device_cuda_available(self, monkeypatch):
"""Test define_device with 'cuda' when CUDA is available."""
# Mock torch.cuda.is_available to return True
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(torch.backends.mps, "is_available", lambda: False)

assert define_device("cuda") == "cuda"

Expand Down
Loading