From 210d377940e6bd1b130268edd8a35e1c5b4d9c8c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 13:22:03 +0000 Subject: [PATCH 1/2] Initial plan From efe0be55c91a25ca49153e4d7169e479deddf94a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 2 Nov 2025 13:29:33 +0000 Subject: [PATCH 2/2] Add device movement tests for all tab_network modules Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com> --- .../tab_network/test_attentive_transformer.py | 26 ++++++++++++ tests/tab_network/test_embedding_generator.py | 26 ++++++++++++ tests/tab_network/test_feat_transformer.py | 25 ++++++++++++ tests/tab_network/test_gbn.py | 16 ++++++++ tests/tab_network/test_glu.py | 40 +++++++++++++++++++ tests/tab_network/test_random_obfuscator.py | 18 +++++++++ tests/tab_network/test_tabnet.py | 31 ++++++++++++++ tests/tab_network/test_tabnet_decoder.py | 26 ++++++++++++ tests/tab_network/test_tabnet_noembeddings.py | 29 ++++++++++++++ tests/tab_network/test_tabnet_pretraining.py | 31 ++++++++++++++ 10 files changed, 268 insertions(+) diff --git a/tests/tab_network/test_attentive_transformer.py b/tests/tab_network/test_attentive_transformer.py index c498489..e91e7e8 100644 --- a/tests/tab_network/test_attentive_transformer.py +++ b/tests/tab_network/test_attentive_transformer.py @@ -41,3 +41,29 @@ def test_invalid_mask_type(): momentum=0.02, mask_type="invalid_mask_type", ) + + +def test_attentive_transformer_device_movement(): + """Test that AttentiveTransformer moves to the correct device with the model.""" + input_dim = 8 + group_dim = 5 + group_matrix = torch.randint(0, 2, size=(group_dim, input_dim)).float() + + transformer = AttentiveTransformer( + input_dim, + group_dim, + group_matrix, + virtual_batch_size=2, + momentum=0.02, + mask_type="sparsemax", + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + transformer = transformer.to(device) + + bs = 2 + priors = torch.rand((bs, group_dim)).to(device) + processed_feat = torch.rand((bs, input_dim)).to(device) + + output = transformer.forward(priors, processed_feat) + assert output.shape == (bs, group_dim) diff --git a/tests/tab_network/test_embedding_generator.py b/tests/tab_network/test_embedding_generator.py index f10cd66..a1a793e 100644 --- a/tests/tab_network/test_embedding_generator.py +++ b/tests/tab_network/test_embedding_generator.py @@ -117,3 +117,29 @@ def traced_forward(self, x): finally: # Restore the original method EmbeddingGenerator.forward = original_forward + + +def test_embedding_generator_device_movement(): + """Test that EmbeddingGenerator moves to the correct device with the model.""" + input_dim = 10 + cat_idxs = [0, 2] + cat_dims = [3, 4] + cat_emb_dims = [2, 2] + group_matrix = torch.rand(2, input_dim) + + generator = EmbeddingGenerator( + input_dim, + cat_dims, + cat_idxs, + cat_emb_dims, + group_matrix, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + generator = generator.to(device) + + batch_size = 2 + x = torch.randint(0, 3, (batch_size, input_dim)).to(device) + + output = generator(x) + assert output.shape[0] == batch_size diff --git a/tests/tab_network/test_feat_transformer.py b/tests/tab_network/test_feat_transformer.py index 9369f85..6521d96 100644 --- a/tests/tab_network/test_feat_transformer.py +++ b/tests/tab_network/test_feat_transformer.py @@ -65,3 +65,28 @@ def test_feat_transformer_no_independent_layers(): input_data = torch.rand((bs, input_dim)) output = transformer.forward(input_data) assert output.shape == (bs, output_dim) + + +def test_feat_transformer_device_movement(): + """Test that FeatTransformer moves to the correct device with the model.""" + input_dim = 10 + output_dim = 8 + shared_layers = torch.nn.ModuleList([torch.nn.Linear(10, 16)]) + + transformer = FeatTransformer( + input_dim, + output_dim, + shared_layers, + n_glu_independent=2, + virtual_batch_size=2, + momentum=0.02, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + transformer = transformer.to(device) + + bs = 2 + input_data = torch.rand((bs, input_dim)).to(device) + + output = transformer.forward(input_data) + assert output.shape == (bs, output_dim) diff --git a/tests/tab_network/test_gbn.py b/tests/tab_network/test_gbn.py index 6a2f78c..7a77c1b 100644 --- a/tests/tab_network/test_gbn.py +++ b/tests/tab_network/test_gbn.py @@ -27,3 +27,19 @@ def test_gbn(): output_full = full_batch_gbn(y) assert output_full.shape == y.shape + + +def test_gbn_device_movement(): + """Test that GBN moves to the correct device with the model.""" + feature_dim = 16 + batch_size = 2 + virtual_batch_size = 2 + + gbn = GBN(feature_dim, momentum=0.1, virtual_batch_size=virtual_batch_size) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + gbn = gbn.to(device) + + x = torch.rand((batch_size, feature_dim)).to(device) + output = gbn(x) + assert output.shape == x.shape diff --git a/tests/tab_network/test_glu.py b/tests/tab_network/test_glu.py index 4732b6e..4d6810e 100644 --- a/tests/tab_network/test_glu.py +++ b/tests/tab_network/test_glu.py @@ -38,3 +38,43 @@ def test_glu_block(): # This should work now since we're using first=True output = glu_block(x) assert output.shape == (batch_size, output_dim) + + +def test_glu_layer_device_movement(): + """Test that GLU_Layer moves to the correct device with the model.""" + input_dim = 16 + output_dim = 8 + + glu_layer = GLU_Layer(input_dim, output_dim, virtual_batch_size=2, momentum=0.02) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + glu_layer = glu_layer.to(device) + + batch_size = 2 + x = torch.rand((batch_size, input_dim)).to(device) + + output = glu_layer(x) + assert output.shape == (batch_size, output_dim) + + +def test_glu_block_device_movement(): + """Test that GLU_Block moves to the correct device with the model.""" + input_dim = 8 + output_dim = 8 + + glu_block = GLU_Block( + input_dim, + output_dim, + first=True, + virtual_batch_size=2, + momentum=0.02, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + glu_block = glu_block.to(device) + + batch_size = 2 + x = torch.rand((batch_size, input_dim)).to(device) + + output = glu_block(x) + assert output.shape == (batch_size, output_dim) diff --git a/tests/tab_network/test_random_obfuscator.py b/tests/tab_network/test_random_obfuscator.py index 6ec2513..db24755 100644 --- a/tests/tab_network/test_random_obfuscator.py +++ b/tests/tab_network/test_random_obfuscator.py @@ -15,3 +15,21 @@ def test_random_obfuscator(): assert masked_x.shape == x.shape assert obfuscated_groups.shape == (bs, group_matrix.shape[0]) assert obfuscated_vars.shape == x.shape + + +def test_random_obfuscator_device_movement(): + """Test that RandomObfuscator moves to the correct device with the model.""" + bs = 2 + input_dim = 16 + pretraining_ratio = 0.2 + group_matrix = torch.randint(0, 2, size=(5, input_dim)).float() + + obfuscator = RandomObfuscator(pretraining_ratio, group_matrix) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + obfuscator = obfuscator.to(device) + + x = torch.rand((bs, input_dim)).to(device) + + masked_x, obfuscated_groups, obfuscated_vars = obfuscator.forward(x) + assert masked_x.shape == x.shape diff --git a/tests/tab_network/test_tabnet.py b/tests/tab_network/test_tabnet.py index c90922e..86127cc 100644 --- a/tests/tab_network/test_tabnet.py +++ b/tests/tab_network/test_tabnet.py @@ -134,3 +134,34 @@ def test_tabnet_validation_errors(): n_shared=0, group_attention_matrix=group_matrix, ) + + +def test_tabnet_device_movement(): + """Test that TabNet moves to the correct device with the model.""" + input_dim = 16 + output_dim = 8 + group_matrix = torch.rand((2, input_dim)) + + tabnet = TabNet( + input_dim=input_dim, + output_dim=output_dim, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + virtual_batch_size=2, + momentum=0.02, + mask_type="sparsemax", + group_attention_matrix=group_matrix, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + tabnet = tabnet.to(device) + + batch_size = 2 + x = torch.rand((batch_size, input_dim)).to(device) + + out, M_loss = tabnet.forward(x) + assert out.shape == (batch_size, output_dim) diff --git a/tests/tab_network/test_tabnet_decoder.py b/tests/tab_network/test_tabnet_decoder.py index eb93de5..f3af902 100644 --- a/tests/tab_network/test_tabnet_decoder.py +++ b/tests/tab_network/test_tabnet_decoder.py @@ -41,3 +41,29 @@ def test_tabnet_decoder_no_shared(): out = decoder.forward(steps_output) assert out.shape == (batch_size, input_dim) + + +def test_tabnet_decoder_device_movement(): + """Test that TabNetDecoder moves to the correct device with the model.""" + input_dim = 16 + n_d = 8 + n_steps = 3 + + decoder = TabNetDecoder( + input_dim=input_dim, + n_d=n_d, + n_steps=n_steps, + n_independent=2, + n_shared=2, + virtual_batch_size=2, + momentum=0.02, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + decoder = decoder.to(device) + + batch_size = 2 + steps_output = [torch.rand((batch_size, n_d)).to(device) for _ in range(n_steps)] + + out = decoder.forward(steps_output) + assert out.shape == (batch_size, input_dim) diff --git a/tests/tab_network/test_tabnet_noembeddings.py b/tests/tab_network/test_tabnet_noembeddings.py index 955e837..7a5ffa0 100644 --- a/tests/tab_network/test_tabnet_noembeddings.py +++ b/tests/tab_network/test_tabnet_noembeddings.py @@ -115,3 +115,32 @@ def test_tabnet_no_embeddings_multi_output(): assert len(outputs) == len(output_dim) for i, dim in enumerate(output_dim): assert outputs[i].shape == (batch_size, dim) + + +def test_tabnet_no_embeddings_device_movement(): + """Test that TabNetNoEmbeddings moves to the correct device with the model.""" + input_dim = 16 + output_dim = 8 + + tabnet_no_emb = TabNetNoEmbeddings( + input_dim=input_dim, + output_dim=output_dim, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + virtual_batch_size=2, + momentum=0.02, + mask_type="sparsemax", + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + tabnet_no_emb = tabnet_no_emb.to(device) + + batch_size = 2 + x = torch.rand((batch_size, input_dim)).to(device) + + out, M_loss = tabnet_no_emb.forward(x) + assert out.shape == (batch_size, output_dim) diff --git a/tests/tab_network/test_tabnet_pretraining.py b/tests/tab_network/test_tabnet_pretraining.py index b71a0d7..e1a0a41 100644 --- a/tests/tab_network/test_tabnet_pretraining.py +++ b/tests/tab_network/test_tabnet_pretraining.py @@ -238,3 +238,34 @@ def test_tabnet_pretraining_forward_masks(): assert explanation_mask.shape == (batch_size, input_dim) assert isinstance(masks_dict, dict) assert len(masks_dict) == tabnet.n_steps + + +def test_tabnet_pretraining_device_movement(): + """Test that TabNetPretraining moves to the correct device with the model.""" + input_dim = 16 + pretraining_ratio = 0.2 + group_matrix = torch.rand((2, input_dim)) + + tabnet_pretraining = TabNetPretraining( + input_dim=input_dim, + pretraining_ratio=pretraining_ratio, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + virtual_batch_size=2, + momentum=0.02, + mask_type="sparsemax", + group_attention_matrix=group_matrix, + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + tabnet_pretraining = tabnet_pretraining.to(device) + + batch_size = 2 + x = torch.rand((batch_size, input_dim)).to(device) + + reconstructed, embedded_x, obf_vars = tabnet_pretraining.forward(x) + assert reconstructed.shape[0] == batch_size