Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/tab_network/test_attentive_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,29 @@ def test_invalid_mask_type():
momentum=0.02,
mask_type="invalid_mask_type",
)


def test_attentive_transformer_device_movement():
"""Test that AttentiveTransformer moves to the correct device with the model."""
input_dim = 8
group_dim = 5
group_matrix = torch.randint(0, 2, size=(group_dim, input_dim)).float()

transformer = AttentiveTransformer(
input_dim,
group_dim,
group_matrix,
virtual_batch_size=2,
momentum=0.02,
mask_type="sparsemax",
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
transformer = transformer.to(device)

bs = 2
priors = torch.rand((bs, group_dim)).to(device)
processed_feat = torch.rand((bs, input_dim)).to(device)

output = transformer.forward(priors, processed_feat)
assert output.shape == (bs, group_dim)
26 changes: 26 additions & 0 deletions tests/tab_network/test_embedding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,29 @@ def traced_forward(self, x):
finally:
# Restore the original method
EmbeddingGenerator.forward = original_forward


def test_embedding_generator_device_movement():
"""Test that EmbeddingGenerator moves to the correct device with the model."""
input_dim = 10
cat_idxs = [0, 2]
cat_dims = [3, 4]
cat_emb_dims = [2, 2]
group_matrix = torch.rand(2, input_dim)

generator = EmbeddingGenerator(
input_dim,
cat_dims,
cat_idxs,
cat_emb_dims,
group_matrix,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
generator = generator.to(device)

batch_size = 2
x = torch.randint(0, 3, (batch_size, input_dim)).to(device)

output = generator(x)
assert output.shape[0] == batch_size
25 changes: 25 additions & 0 deletions tests/tab_network/test_feat_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,28 @@ def test_feat_transformer_no_independent_layers():
input_data = torch.rand((bs, input_dim))
output = transformer.forward(input_data)
assert output.shape == (bs, output_dim)


def test_feat_transformer_device_movement():
"""Test that FeatTransformer moves to the correct device with the model."""
input_dim = 10
output_dim = 8
shared_layers = torch.nn.ModuleList([torch.nn.Linear(10, 16)])

transformer = FeatTransformer(
input_dim,
output_dim,
shared_layers,
n_glu_independent=2,
virtual_batch_size=2,
momentum=0.02,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
transformer = transformer.to(device)

bs = 2
input_data = torch.rand((bs, input_dim)).to(device)

output = transformer.forward(input_data)
assert output.shape == (bs, output_dim)
16 changes: 16 additions & 0 deletions tests/tab_network/test_gbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,19 @@ def test_gbn():
output_full = full_batch_gbn(y)

assert output_full.shape == y.shape


def test_gbn_device_movement():
"""Test that GBN moves to the correct device with the model."""
feature_dim = 16
batch_size = 2
virtual_batch_size = 2

gbn = GBN(feature_dim, momentum=0.1, virtual_batch_size=virtual_batch_size)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
gbn = gbn.to(device)

x = torch.rand((batch_size, feature_dim)).to(device)
output = gbn(x)
assert output.shape == x.shape
40 changes: 40 additions & 0 deletions tests/tab_network/test_glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,43 @@ def test_glu_block():
# This should work now since we're using first=True
output = glu_block(x)
assert output.shape == (batch_size, output_dim)


def test_glu_layer_device_movement():
"""Test that GLU_Layer moves to the correct device with the model."""
input_dim = 16
output_dim = 8

glu_layer = GLU_Layer(input_dim, output_dim, virtual_batch_size=2, momentum=0.02)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
glu_layer = glu_layer.to(device)

batch_size = 2
x = torch.rand((batch_size, input_dim)).to(device)

output = glu_layer(x)
assert output.shape == (batch_size, output_dim)


def test_glu_block_device_movement():
"""Test that GLU_Block moves to the correct device with the model."""
input_dim = 8
output_dim = 8

glu_block = GLU_Block(
input_dim,
output_dim,
first=True,
virtual_batch_size=2,
momentum=0.02,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
glu_block = glu_block.to(device)

batch_size = 2
x = torch.rand((batch_size, input_dim)).to(device)

output = glu_block(x)
assert output.shape == (batch_size, output_dim)
18 changes: 18 additions & 0 deletions tests/tab_network/test_random_obfuscator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,21 @@ def test_random_obfuscator():
assert masked_x.shape == x.shape
assert obfuscated_groups.shape == (bs, group_matrix.shape[0])
assert obfuscated_vars.shape == x.shape


def test_random_obfuscator_device_movement():
"""Test that RandomObfuscator moves to the correct device with the model."""
bs = 2
input_dim = 16
pretraining_ratio = 0.2
group_matrix = torch.randint(0, 2, size=(5, input_dim)).float()

obfuscator = RandomObfuscator(pretraining_ratio, group_matrix)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
obfuscator = obfuscator.to(device)

x = torch.rand((bs, input_dim)).to(device)

masked_x, obfuscated_groups, obfuscated_vars = obfuscator.forward(x)
assert masked_x.shape == x.shape
31 changes: 31 additions & 0 deletions tests/tab_network/test_tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,34 @@ def test_tabnet_validation_errors():
n_shared=0,
group_attention_matrix=group_matrix,
)


def test_tabnet_device_movement():
"""Test that TabNet moves to the correct device with the model."""
input_dim = 16
output_dim = 8
group_matrix = torch.rand((2, input_dim))

tabnet = TabNet(
input_dim=input_dim,
output_dim=output_dim,
n_d=8,
n_a=8,
n_steps=3,
gamma=1.3,
n_independent=2,
n_shared=2,
virtual_batch_size=2,
momentum=0.02,
mask_type="sparsemax",
group_attention_matrix=group_matrix,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
tabnet = tabnet.to(device)

batch_size = 2
x = torch.rand((batch_size, input_dim)).to(device)

out, M_loss = tabnet.forward(x)
assert out.shape == (batch_size, output_dim)
26 changes: 26 additions & 0 deletions tests/tab_network/test_tabnet_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,29 @@ def test_tabnet_decoder_no_shared():
out = decoder.forward(steps_output)

assert out.shape == (batch_size, input_dim)


def test_tabnet_decoder_device_movement():
"""Test that TabNetDecoder moves to the correct device with the model."""
input_dim = 16
n_d = 8
n_steps = 3

decoder = TabNetDecoder(
input_dim=input_dim,
n_d=n_d,
n_steps=n_steps,
n_independent=2,
n_shared=2,
virtual_batch_size=2,
momentum=0.02,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
decoder = decoder.to(device)

batch_size = 2
steps_output = [torch.rand((batch_size, n_d)).to(device) for _ in range(n_steps)]

out = decoder.forward(steps_output)
assert out.shape == (batch_size, input_dim)
29 changes: 29 additions & 0 deletions tests/tab_network/test_tabnet_noembeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,32 @@ def test_tabnet_no_embeddings_multi_output():
assert len(outputs) == len(output_dim)
for i, dim in enumerate(output_dim):
assert outputs[i].shape == (batch_size, dim)


def test_tabnet_no_embeddings_device_movement():
"""Test that TabNetNoEmbeddings moves to the correct device with the model."""
input_dim = 16
output_dim = 8

tabnet_no_emb = TabNetNoEmbeddings(
input_dim=input_dim,
output_dim=output_dim,
n_d=8,
n_a=8,
n_steps=3,
gamma=1.3,
n_independent=2,
n_shared=2,
virtual_batch_size=2,
momentum=0.02,
mask_type="sparsemax",
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
tabnet_no_emb = tabnet_no_emb.to(device)

batch_size = 2
x = torch.rand((batch_size, input_dim)).to(device)

out, M_loss = tabnet_no_emb.forward(x)
assert out.shape == (batch_size, output_dim)
31 changes: 31 additions & 0 deletions tests/tab_network/test_tabnet_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,34 @@ def test_tabnet_pretraining_forward_masks():
assert explanation_mask.shape == (batch_size, input_dim)
assert isinstance(masks_dict, dict)
assert len(masks_dict) == tabnet.n_steps


def test_tabnet_pretraining_device_movement():
"""Test that TabNetPretraining moves to the correct device with the model."""
input_dim = 16
pretraining_ratio = 0.2
group_matrix = torch.rand((2, input_dim))

tabnet_pretraining = TabNetPretraining(
input_dim=input_dim,
pretraining_ratio=pretraining_ratio,
n_d=8,
n_a=8,
n_steps=3,
gamma=1.3,
n_independent=2,
n_shared=2,
virtual_batch_size=2,
momentum=0.02,
mask_type="sparsemax",
group_attention_matrix=group_matrix,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
tabnet_pretraining = tabnet_pretraining.to(device)

batch_size = 2
x = torch.rand((batch_size, input_dim)).to(device)

reconstructed, embedded_x, obf_vars = tabnet_pretraining.forward(x)
assert reconstructed.shape[0] == batch_size
Loading