Skip to content

Fix MPS device movement for constant tensors in TabNet modules#279

Merged
DanielAvdar merged 6 commits intoadd-simple-tests-tab-networkfrom
copilot/fix-device-movement-issue
Nov 2, 2025
Merged

Fix MPS device movement for constant tensors in TabNet modules#279
DanielAvdar merged 6 commits intoadd-simple-tests-tab-networkfrom
copilot/fix-device-movement-issue

Conversation

Copy link
Contributor

Copilot AI commented Nov 2, 2025

Tensor device mismatch errors occur on macOS MPS backend because constant tensors (group_matrix, embedding_group_matrix) are not registered as buffers, preventing automatic device movement during .to(device) calls.

Changes

  • pytorch_tabnet/utils/device.py: Added MPS device detection with priority ordering: MPS → CUDA → CPU
  • pytorch_tabnet/tab_network/random_obfuscator.py: Register group_matrix as buffer via self.register_buffer()
  • pytorch_tabnet/tab_network/embedding_generator.py: Register embedding_group_matrix as buffer in both skip-embedding and with-embedding paths

Example

Before (device mismatch on MPS):

obfuscator = RandomObfuscator(0.2, group_matrix)
obfuscator = obfuscator.to('mps')
x = torch.rand((2, 16), device='mps')
obfuscator(x)  # RuntimeError: mat2 is on CPU, expected MPS

After (correct device movement):

obfuscator = RandomObfuscator(0.2, group_matrix)
obfuscator = obfuscator.to('mps')
x = torch.rand((2, 16), device='mps')
obfuscator(x)  # ✓ All tensors on MPS
Original prompt

This section details on the original issue you should resolve

<issue_title>Device movement issues and tensor mismatch on macOS MPS device</issue_title>
<issue_description>## Summary

Running tests and models on macOS with PyTorch MPS (Apple Silicon GPU) backend results in errors related to tensor device mismatch, specifically during matrix multiplication operations (e.g., torch.mm).

Problem Details

  • Tests such as test_random_obfuscator_device_movement and test_tabnet_pretraining_device_movement fail when running on macOS MPS due to one tensor being on CPU and the other on MPS device.
  • The error message is:

RuntimeError: Tensor for argument #2 'mat2' is on CPU, but expected it to be on MPS (while checking arguments for mm)

Cause

  • Certain constant tensors (such as group matrices or attention matrices) are not always properly moved to the same device as model parameters and input tensors.
  • This commonly happens on MPS because device logic in some places only checks for "cuda" and "cpu".
  • All tensors used in matrix ops must be on the same device; PyTorch's MPS backend enforces this strictly.

Proposed Solution

  • Register constant tensors (e.g., group_attention_matrix) as buffers with self.register_buffer(...) so they move with .to(device).
  • When creating new tensors in forward or test code, specify device (e.g., torch.ones(..., device=x.device)).
  • Update device detection logic (such as in define_device) to handle 'mps' properly:
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
  • Before any matrix multiply, ensure both operands are on the same device:
mat2 = mat2.to(mat1.device)
result = torch.mm(mat1, mat2)

Related code

Label

bug

tl;dr: Device movement issues on macOS MPS due to constant tensors and device logic. Register tensors as buffers, create new tensors on correct device, and improve device detection for 'mps'.</issue_description>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI and others added 3 commits November 2, 2025 16:20
Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>
Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix device movement issues with PyTorch on macOS MPS Fix MPS device movement for constant tensors in TabNet modules Nov 2, 2025
Copilot AI requested a review from DanielAvdar November 2, 2025 16:25
Removed MPS device check from auto device selection.

Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>
Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>
@DanielAvdar DanielAvdar marked this pull request as ready for review November 2, 2025 16:53
@DanielAvdar DanielAvdar merged commit 0f10578 into add-simple-tests-tab-network Nov 2, 2025
24 checks passed
@DanielAvdar DanielAvdar deleted the copilot/fix-device-movement-issue branch November 2, 2025 16:53
DanielAvdar added a commit that referenced this pull request Nov 2, 2025
* Initial plan

* Add device movement tests for all tab_network modules

Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>

* Fix MPS device movement for constant tensors in TabNet modules (#279)

* Initial plan

* Fix MPS device movement issues by registering tensors as buffers

Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>

* Final verification: all tests pass, no security issues

Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>

* Remove spurious nul file

* Simplify device selection logic

Removed MPS device check from auto device selection.

Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>

* Fix device selection logic for 'auto' case

Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>

---------

Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: DanielAvdar <66269169+DanielAvdar@users.noreply.github.com>

---------

Signed-off-by: Daniel Avdar <66269169+DanielAvdar@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments