Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,56 @@ def test_state_set_cell(ti_sim_state: SimState) -> None:
assert torch.allclose(
ti_sim_state.positions.cpu(), torch.from_numpy(ase_atoms.positions)
)


def test_wrap_positions_no_pbc(si_sim_state: SimState) -> None:
"""Test wrap_positions returns unwrapped positions when pbc=False."""
state = si_sim_state.clone()
state.pbc = torch.tensor([False, False, False])
# Move some atoms outside the cell
state.positions = state.positions + 100.0
# With no pbc, wrap_positions should return positions unchanged
assert torch.allclose(state.wrap_positions, state.positions)


def test_wrap_positions_with_pbc(si_sim_state: SimState) -> None:
"""Test wrap_positions wraps positions when pbc=True."""
state = si_sim_state.clone()
state.pbc = torch.tensor([True, True, True])
original_positions = state.positions.clone()
# Add one lattice vector to move atoms outside
lattice_shift = state.row_vector_cell[0, 0] # first lattice vector
state.positions = state.positions + lattice_shift
# Wrapped positions should be back to original (within tolerance)
wrapped = state.wrap_positions
assert torch.allclose(wrapped, original_positions, atol=1e-5)


def test_wrap_positions_mixed_pbc(si_sim_state: SimState) -> None:
"""Test wrap_positions with mixed pbc (True in some dimensions, False in others)."""
state = si_sim_state.clone()
state.pbc = torch.tensor([True, False, True]) # periodic in x and z, not y
original_positions = state.positions.clone()
# Shift by lattice vectors in all directions
shift_x = state.row_vector_cell[0, 0] # first lattice vector (x)
shift_y = state.row_vector_cell[0, 1] # second lattice vector (y)
shift_z = state.row_vector_cell[0, 2] # third lattice vector (z)
state.positions = state.positions + shift_x + shift_y + shift_z
wrapped = state.wrap_positions
# x and z should be wrapped back, y should remain shifted
expected = original_positions + shift_y
assert torch.allclose(wrapped, expected, atol=1e-5)


def test_wrap_positions_batched(si_double_sim_state: SimState) -> None:
"""Test wrap_positions works with batched systems."""
state = si_double_sim_state.clone()
state.pbc = torch.tensor([True, True, True])
original_positions = state.positions.clone()
# Shift all positions by one lattice vector (using first system's cell)
for sys_idx in range(state.n_systems):
mask = state.system_idx == sys_idx
lattice_shift = state.row_vector_cell[sys_idx, 0]
state.positions[mask] = state.positions[mask] + lattice_shift
wrapped = state.wrap_positions
assert torch.allclose(wrapped, original_positions, atol=1e-5)
7 changes: 5 additions & 2 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ def wrap_positions(self) -> torch.Tensor:
"""Atomic positions wrapped according to periodic boundary conditions if pbc=True,
otherwise returns unwrapped positions with shape (n_atoms, 3).
"""
# TODO: implement a wrapping method
return self.positions
if not self.pbc.any():
return self.positions
return ts.transforms.pbc_wrap_batched(
self.positions, self.cell, self.system_idx, self.pbc
)

@property
def device(self) -> torch.device:
Expand Down