Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions tests/test_fix_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,20 @@ def test_cubic_forces_vanish(self) -> None:
constraint.adjust_forces(state, forces)
assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10)

def test_large_deformation_raises(self) -> None:
"""Deformation gradient > 0.25 raises RuntimeError."""
def test_large_deformation_clamped(self) -> None:
"""Per-step deformation > 0.25 is clamped rather than rejected."""
state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
new_cell = state.cell.clone()
new_cell[0] *= 1.5
with pytest.raises(RuntimeError, match="deformation gradient"):
constraint.adjust_cell(state, new_cell)
orig_cell = state.cell.clone()
new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25
constraint.adjust_cell(state, new_cell)
# Cell should have changed (not rejected) but less than requested
assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6)
# The change should be bounded
identity = torch.eye(3, dtype=DTYPE)
ref_cell = constraint.reference_cells[0]
strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity
assert torch.abs(strain).max().item() <= 0.5 + 1e-6

def test_init_mismatched_lengths_raises(self) -> None:
"""Mismatched rotations/symm_maps lengths raises ValueError."""
Expand Down Expand Up @@ -640,3 +646,45 @@ def test_noisy_model_preserves_symmetry_with_constraint(
)
assert result["initial_spacegroups"][0] == 229
assert result["final_spacegroups"][0] == 229

def test_cumulative_strain_clamp_direct(self) -> None:
"""adjust_cell clamps deformation when cumulative strain exceeds limit.

Directly tests the clamping mechanism by repeatedly applying small
cell deformations that individually pass the per-step check (< 0.25)
but cumulatively exceed max_cumulative_strain. Verifies:
1. The cell doesn't drift beyond the strain envelope
2. Symmetry is preserved after many small steps
"""
state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
constraint.max_cumulative_strain = 0.15
assert constraint.reference_cells is not None
ref_cell = constraint.reference_cells[0].clone()

# Apply 20 small deformations (each ~5% along one axis)
# Total would be ~100% without clamping, well over the 0.15 limit
identity = torch.eye(3, dtype=DTYPE)
for _ in range(20):
# Anisotropic stretch: elongate c-axis by 5% each step
stretch = identity.clone()
stretch[2, 2] = 1.05
new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0)
constraint.adjust_cell(state, new_cell)
state.cell = new_cell

# Cumulative strain must be clamped to the limit
final_cell = state.row_vector_cell[0]
cumulative = torch.linalg.solve(ref_cell, final_cell) - identity
max_strain = torch.abs(cumulative).max().item()
assert max_strain <= constraint.max_cumulative_strain + 1e-6, (
f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}"
)

# Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15
# Verify it's actually being clamped (not just small steps)
assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low"

# Symmetry should still be detectable
datasets = get_symmetry_datasets(state, symprec=SYMPREC)
assert datasets[0].number == SPACEGROUPS["fcc"]
71 changes: 63 additions & 8 deletions torch_sim/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,10 @@ class FixSymmetry(SystemConstraint):

rotations: list[torch.Tensor]
symm_maps: list[torch.Tensor]
reference_cells: list[torch.Tensor] | None
do_adjust_positions: bool
do_adjust_cell: bool
max_cumulative_strain: float

def __init__(
self,
Expand All @@ -708,6 +710,8 @@ def __init__(
*,
adjust_positions: bool = True,
adjust_cell: bool = True,
reference_cells: list[torch.Tensor] | None = None,
max_cumulative_strain: float = 0.5,
) -> None:
"""Initialize FixSymmetry constraint.

Expand All @@ -717,6 +721,11 @@ def __init__(
system_idx: System indices (defaults to 0..n_systems-1).
adjust_positions: Whether to symmetrize position displacements.
adjust_cell: Whether to symmetrize cell/stress adjustments.
reference_cells: Initial refined cells (row vectors) per system for
cumulative strain tracking. If None, cumulative check is skipped.
max_cumulative_strain: Maximum allowed cumulative strain from the
reference cell. If exceeded, the cell update is clamped to
keep the structure within this strain envelope.
"""
n_systems = len(rotations)
if len(symm_maps) != n_systems:
Expand All @@ -735,8 +744,10 @@ def __init__(
super().__init__(system_idx=system_idx)
self.rotations = rotations
self.symm_maps = symm_maps
self.reference_cells = reference_cells
self.do_adjust_positions = adjust_positions
self.do_adjust_cell = adjust_cell
self.max_cumulative_strain = max_cumulative_strain

@classmethod
def from_state(
Expand Down Expand Up @@ -770,7 +781,7 @@ def from_state(

from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry

rotations, symm_maps = [], []
rotations, symm_maps, reference_cells = [], [], []
cumsum = _cumsum_with_zero(state.n_atoms_per_system)

for sys_idx in range(state.n_systems):
Expand All @@ -793,13 +804,16 @@ def from_state(

rotations.append(rots)
symm_maps.append(smap)
# Store the refined cell as the reference for cumulative strain tracking
reference_cells.append(state.row_vector_cell[sys_idx].clone())

return cls(
rotations,
symm_maps,
system_idx=torch.arange(state.n_systems, device=state.device),
adjust_positions=adjust_positions,
adjust_cell=adjust_cell,
reference_cells=reference_cells,
)

# === Symmetrization hooks ===
Expand Down Expand Up @@ -834,12 +848,17 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None:
Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a
rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``.

Also checks cumulative strain from the initial reference cell. If the
total deformation exceeds ``max_cumulative_strain``, the update is
clamped to prevent phase transitions that would break the symmetry
constraint (e.g. hexagonal → tetragonal cell collapse).

Args:
state: Current simulation state.
new_cell: Cell tensor (n_systems, 3, 3) in column vector convention.

Raises:
RuntimeError: If deformation gradient > 0.25.
RuntimeError: If per-step deformation gradient > 0.25.
"""
if not self.do_adjust_cell:
return
Expand All @@ -850,16 +869,34 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None:
for ci, si in enumerate(self.system_idx):
cur_cell = state.row_vector_cell[si]
new_row = new_cell[si].mT # column → row convention

# Per-step deformation: clamp large steps to avoid ill-conditioned
# symmetrization while still making progress. The cumulative strain
# guard below is the real safety net against phase transitions.
deform_delta = torch.linalg.solve(cur_cell, new_row) - identity
max_delta = torch.abs(deform_delta).max().item()
if not (max_delta <= 0.25): # catches NaN via negated comparison
raise RuntimeError(
f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 "
f"too large. Use smaller optimization steps."
)
if not (max_delta <= 0.25): # clamp large steps; negated form catches NaN
deform_delta = deform_delta * (0.25 / max_delta)

# Symmetrize the per-step deformation
rots = self.rotations[ci].to(dtype=state.dtype)
sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots)
new_cell[si] = (cur_cell @ (sym_delta + identity)).mT
proposed_cell = cur_cell @ (sym_delta + identity)

# Cumulative strain check against reference cell
if self.reference_cells is not None:
ref_cell = self.reference_cells[ci].to(
device=state.device, dtype=state.dtype
)
cumulative_strain = torch.linalg.solve(ref_cell, proposed_cell) - identity
max_cumulative = torch.abs(cumulative_strain).max().item()
if max_cumulative > self.max_cumulative_strain:
# Clamp: scale the cumulative strain to stay within the envelope
scale = self.max_cumulative_strain / max_cumulative
clamped_strain = cumulative_strain * scale
proposed_cell = ref_cell @ (clamped_strain + identity)

new_cell[si] = proposed_cell.mT # back to column convention

def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None:
"""Symmetrize a rank-1 tensor in-place for each constrained system."""
Expand Down Expand Up @@ -890,6 +927,8 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002
self.system_idx + system_offset,
adjust_positions=self.do_adjust_positions,
adjust_cell=self.do_adjust_cell,
reference_cells=list(self.reference_cells) if self.reference_cells else None,
max_cumulative_strain=self.max_cumulative_strain,
)

@classmethod
Expand All @@ -909,12 +948,18 @@ def merge(cls, constraints: list[Self]) -> Self:
rotations = [r for c in constraints for r in c.rotations]
symm_maps = [s for c in constraints for s in c.symm_maps]
system_idx = torch.cat([c.system_idx for c in constraints])
# Merge reference cells if all constraints have them
ref_cells = None
if all(c.reference_cells is not None for c in constraints):
ref_cells = [rc for c in constraints for rc in c.reference_cells]
return cls(
rotations,
symm_maps,
system_idx=system_idx,
adjust_positions=constraints[0].do_adjust_positions,
adjust_cell=constraints[0].do_adjust_cell,
reference_cells=ref_cells,
max_cumulative_strain=constraints[0].max_cumulative_strain,
)

def select_constraint(
Expand All @@ -928,12 +973,19 @@ def select_constraint(
if not mask.any():
return None
local_idx = mask.nonzero(as_tuple=False).flatten().tolist()
ref_cells = (
[self.reference_cells[idx] for idx in local_idx]
if self.reference_cells
else None
)
return type(self)(
[self.rotations[idx] for idx in local_idx],
[self.symm_maps[idx] for idx in local_idx],
_mask_constraint_indices(self.system_idx[mask], system_mask),
adjust_positions=self.do_adjust_positions,
adjust_cell=self.do_adjust_cell,
reference_cells=ref_cells,
max_cumulative_strain=self.max_cumulative_strain,
)

def select_sub_constraint(
Expand All @@ -945,12 +997,15 @@ def select_sub_constraint(
if sys_idx not in self.system_idx:
return None
local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item()
ref_cells = [self.reference_cells[local]] if self.reference_cells else None
return type(self)(
[self.rotations[local]],
[self.symm_maps[local]],
torch.tensor([0], device=self.system_idx.device),
adjust_positions=self.do_adjust_positions,
adjust_cell=self.do_adjust_cell,
reference_cells=ref_cells,
max_cumulative_strain=self.max_cumulative_strain,
)

def __repr__(self) -> str:
Expand Down
Loading