diff --git a/tests/test_state.py b/tests/test_state.py index d489af390..ee568406e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -681,3 +681,56 @@ def test_state_set_cell(ti_sim_state: SimState) -> None: assert torch.allclose( ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions) ) + + +def test_wrap_positions_no_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions returns unwrapped positions when pbc=False.""" + state = si_sim_state.clone() + state.pbc = torch.tensor([False, False, False]) + # Move some atoms outside the cell + state.positions = state.positions + 100.0 + # With no pbc, wrap_positions should return positions unchanged + assert torch.allclose(state.wrap_positions, state.positions) + + +def test_wrap_positions_with_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions wraps positions when pbc=True.""" + state = si_sim_state.clone() + state.pbc = torch.tensor([True, True, True]) + original_positions = state.positions.clone() + # Add one lattice vector to move atoms outside + lattice_shift = state.row_vector_cell[0, 0] # first lattice vector + state.positions = state.positions + lattice_shift + # Wrapped positions should be back to original (within tolerance) + wrapped = state.wrap_positions + assert torch.allclose(wrapped, original_positions, atol=1e-5) + + +def test_wrap_positions_mixed_pbc(si_sim_state: SimState) -> None: + """Test wrap_positions with mixed pbc (True in some dimensions, False in others).""" + state = si_sim_state.clone() + state.pbc = torch.tensor([True, False, True]) # periodic in x and z, not y + original_positions = state.positions.clone() + # Shift by lattice vectors in all directions + shift_x = state.row_vector_cell[0, 0] # first lattice vector (x) + shift_y = state.row_vector_cell[0, 1] # second lattice vector (y) + shift_z = state.row_vector_cell[0, 2] # third lattice vector (z) + state.positions = state.positions + shift_x + shift_y + shift_z + wrapped = state.wrap_positions + # x and z should be wrapped back, y should remain shifted + expected = original_positions + shift_y + assert torch.allclose(wrapped, expected, atol=1e-5) + + +def test_wrap_positions_batched(si_double_sim_state: SimState) -> None: + """Test wrap_positions works with batched systems.""" + state = si_double_sim_state.clone() + state.pbc = torch.tensor([True, True, True]) + original_positions = state.positions.clone() + # Shift all positions by one lattice vector (using first system's cell) + for sys_idx in range(state.n_systems): + mask = state.system_idx == sys_idx + lattice_shift = state.row_vector_cell[sys_idx, 0] + state.positions[mask] = state.positions[mask] + lattice_shift + wrapped = state.wrap_positions + assert torch.allclose(wrapped, original_positions, atol=1e-5) diff --git a/torch_sim/state.py b/torch_sim/state.py index f465938c4..b1406cb73 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -197,8 +197,11 @@ def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ - # TODO: implement a wrapping method - return self.positions + if not self.pbc.any(): + return self.positions + return ts.transforms.pbc_wrap_batched( + self.positions, self.cell, self.system_idx, self.pbc + ) @property def device(self) -> torch.device: