From 240ed7448cc9fb1af7554284d48c0466aad73d7c Mon Sep 17 00:00:00 2001 From: praagnya Date: Fri, 21 Nov 2025 22:01:12 -0700 Subject: [PATCH 1/2] Implement SimState.wrap_positions using pbc_wrap_batched --- torch_sim/state.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index c59c2efd3..e5d5c8c53 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -167,8 +167,11 @@ def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ - # TODO: implement a wrapping method - return self.positions + if not self.pbc: + return self.positions + return ts.transforms.pbc_wrap_batched( + self.positions, self.cell, self.system_idx, self.pbc + ) @property def device(self) -> torch.device: From 4de1c87f9857e12ac86a64dde1c082433605c492 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 5 Feb 2026 09:18:03 -0500 Subject: [PATCH 2/2] test: fix addition and add tests to wrapped positions property --- tests/test_state.py | 53 +++++++++++++++++++++++++++++++++++++++++++++ torch_sim/state.py | 2 +- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/test_state.py b/tests/test_state.py index d489af390..ee568406e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -681,3 +681,56 @@ def test_state_set_cell(ti_sim_state: SimState) -> None: assert torch.allclose( ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) ) + + +def test_wrap_positions_no_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions returns unwrapped positions when pbc=False.""" + state = si_sim_state.clone() + state.pbc = torch.tensor([False, False, False]) + # Move some atoms outside the cell + state.positions = state.positions + 100.0 + # With no pbc, wrap_positions should return positions unchanged + assert torch.allclose(state.wrap_positions, state.positions) + + +def test_wrap_positions_with_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions wraps positions when pbc=True.""" + state = si_sim_state.clone() + state.pbc = torch.tensor([True, True, True]) + original_positions = state.positions.clone() + # Add one lattice vector to move atoms outside + lattice_shift = state.row_vector_cell[0, 0] # first lattice vector + state.positions = state.positions + lattice_shift + # Wrapped positions should be back to original (within tolerance) + wrapped = state.wrap_positions + assert torch.allclose(wrapped, original_positions, atol=1e-5) + + +def test_wrap_positions_mixed_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions with mixed pbc (True in some dimensions, False in others).""" + state = si_sim_state.clone() + state.pbc = torch.tensor([True, False, True]) # periodic in x and z, not y + original_positions = state.positions.clone() + # Shift by lattice vectors in all directions + shift_x = state.row_vector_cell[0, 0] # first lattice vector (x) + shift_y = state.row_vector_cell[0, 1] # second lattice vector (y) + shift_z = state.row_vector_cell[0, 2] # third lattice vector (z) + state.positions = state.positions + shift_x + shift_y + shift_z + wrapped = state.wrap_positions + # x and z should be wrapped back, y should remain shifted + expected = original_positions + shift_y + assert torch.allclose(wrapped, expected, atol=1e-5) + + +def test_wrap_positions_batched(si_double_sim_state: SimState) -> None: + """Test wrap_positions works with batched systems.""" + state = si_double_sim_state.clone() + state.pbc = torch.tensor([True, True, True]) + original_positions = state.positions.clone() + # Shift all positions by one lattice vector (using first system's cell) + for sys_idx in range(state.n_systems): + mask = state.system_idx == sys_idx + lattice_shift = state.row_vector_cell[sys_idx, 0] + state.positions[mask] = state.positions[mask] + lattice_shift + wrapped = state.wrap_positions + assert torch.allclose(wrapped, original_positions, atol=1e-5) diff --git a/torch_sim/state.py b/torch_sim/state.py index 83e3fb504..b1406cb73 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -197,7 +197,7 @@ def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ - if not self.pbc: + if not self.pbc.any(): return self.positions return ts.transforms.pbc_wrap_batched( self.positions, self.cell, self.system_idx, self.pbc