From 5f78e2d60a315cd55dbab2b4c8574e49dc3d4bdc Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 6 Feb 2026 21:06:42 +0000 Subject: [PATCH] Add cumulative strain guard to FixSymmetry to prevent phase transitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-step deformation check (max_delta <= 0.25) is necessary but not sufficient: many small steps can accumulate into a Bravais lattice change (e.g. hexagonal → tetragonal) that breaks the symmetry constraint. This was observed in production on a Sc-Nb-B structure with c/a ≈ 70 (c = 209.6 Å). During NequIP OAM-L relaxation, the c-axis collapsed to 6.2 Å through individually small steps, losing all hexagonal rotations. The fix stores the initial refined cell as a reference in from_state() and checks the total deformation from that reference in adjust_cell(). When cumulative strain exceeds max_cumulative_strain (default 0.5), the cell update is scaled down proportionally to stay within the envelope. The per-step deformation check is also softened from a hard RuntimeError to a clamp (scaling the step to max 0.25), since the cumulative guard is now the real safety net. This fixes test failures where rotated structures with noisy forces occasionally produced single steps just over the 0.25 threshold (e.g. 0.2581). - Add reference_cells and max_cumulative_strain to FixSymmetry.__init__ - Store refined cells in from_state() for cumulative tracking - Clamp per-step deformation > 0.25 instead of raising RuntimeError - Clamp cumulative strain in adjust_cell() when limit exceeded - Propagate reference_cells through reindex/merge/select - Add regression test with direct adjust_cell loop verifying clamping - Update test_large_deformation to verify clamp behavior (not raise) Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 60 ++++++++++++++++++++++++++++---- torch_sim/constraints.py | 71 +++++++++++++++++++++++++++++++++----- 2 files changed, 117 insertions(+), 14 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 7ddb8dc5..276d2f03 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -268,14 +268,20 @@ def test_cubic_forces_vanish(self) -> None: constraint.adjust_forces(state, forces) assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) - def test_large_deformation_raises(self) -> None: - """Deformation gradient > 0.25 raises RuntimeError.""" + def test_large_deformation_clamped(self) -> None: + """Per-step deformation > 0.25 is clamped rather than rejected.""" state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - new_cell = state.cell.clone() - new_cell[0] *= 1.5 - with pytest.raises(RuntimeError, match="deformation gradient"): - constraint.adjust_cell(state, new_cell) + orig_cell = state.cell.clone() + new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25 + constraint.adjust_cell(state, new_cell) + # Cell should have changed (not rejected) but less than requested + assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6) + # The change should be bounded + identity = torch.eye(3, dtype=DTYPE) + ref_cell = constraint.reference_cells[0] + strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity + assert torch.abs(strain).max().item() <= 0.5 + 1e-6 def test_init_mismatched_lengths_raises(self) -> None: """Mismatched rotations/symm_maps lengths raises ValueError.""" @@ -640,3 +646,45 @@ def test_noisy_model_preserves_symmetry_with_constraint( ) assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] == 229 + + def test_cumulative_strain_clamp_direct(self) -> None: + """adjust_cell clamps deformation when cumulative strain exceeds limit. + + Directly tests the clamping mechanism by repeatedly applying small + cell deformations that individually pass the per-step check (< 0.25) + but cumulatively exceed max_cumulative_strain. Verifies: + 1. The cell doesn't drift beyond the strain envelope + 2. Symmetry is preserved after many small steps + """ + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + constraint.max_cumulative_strain = 0.15 + assert constraint.reference_cells is not None + ref_cell = constraint.reference_cells[0].clone() + + # Apply 20 small deformations (each ~5% along one axis) + # Total would be ~100% without clamping, well over the 0.15 limit + identity = torch.eye(3, dtype=DTYPE) + for _ in range(20): + # Anisotropic stretch: elongate c-axis by 5% each step + stretch = identity.clone() + stretch[2, 2] = 1.05 + new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0) + constraint.adjust_cell(state, new_cell) + state.cell = new_cell + + # Cumulative strain must be clamped to the limit + final_cell = state.row_vector_cell[0] + cumulative = torch.linalg.solve(ref_cell, final_cell) - identity + max_strain = torch.abs(cumulative).max().item() + assert max_strain <= constraint.max_cumulative_strain + 1e-6, ( + f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}" + ) + + # Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15 + # Verify it's actually being clamped (not just small steps) + assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low" + + # Symmetry should still be detectable + datasets = get_symmetry_datasets(state, symprec=SYMPREC) + assert datasets[0].number == SPACEGROUPS["fcc"] diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c7c8d5e5..0e2eff31 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -697,8 +697,10 @@ class FixSymmetry(SystemConstraint): rotations: list[torch.Tensor] symm_maps: list[torch.Tensor] + reference_cells: list[torch.Tensor] | None do_adjust_positions: bool do_adjust_cell: bool + max_cumulative_strain: float def __init__( self, @@ -708,6 +710,8 @@ def __init__( *, adjust_positions: bool = True, adjust_cell: bool = True, + reference_cells: list[torch.Tensor] | None = None, + max_cumulative_strain: float = 0.5, ) -> None: """Initialize FixSymmetry constraint. @@ -717,6 +721,11 @@ def __init__( system_idx: System indices (defaults to 0..n_systems-1). adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. + reference_cells: Initial refined cells (row vectors) per system for + cumulative strain tracking. If None, cumulative check is skipped. + max_cumulative_strain: Maximum allowed cumulative strain from the + reference cell. If exceeded, the cell update is clamped to + keep the structure within this strain envelope. """ n_systems = len(rotations) if len(symm_maps) != n_systems: @@ -735,8 +744,10 @@ def __init__( super().__init__(system_idx=system_idx) self.rotations = rotations self.symm_maps = symm_maps + self.reference_cells = reference_cells self.do_adjust_positions = adjust_positions self.do_adjust_cell = adjust_cell + self.max_cumulative_strain = max_cumulative_strain @classmethod def from_state( @@ -770,7 +781,7 @@ def from_state( from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry - rotations, symm_maps = [], [] + rotations, symm_maps, reference_cells = [], [], [] cumsum = _cumsum_with_zero(state.n_atoms_per_system) for sys_idx in range(state.n_systems): @@ -793,6 +804,8 @@ def from_state( rotations.append(rots) symm_maps.append(smap) + # Store the refined cell as the reference for cumulative strain tracking + reference_cells.append(state.row_vector_cell[sys_idx].clone()) return cls( rotations, @@ -800,6 +813,7 @@ def from_state( system_idx=torch.arange(state.n_systems, device=state.device), adjust_positions=adjust_positions, adjust_cell=adjust_cell, + reference_cells=reference_cells, ) # === Symmetrization hooks === @@ -834,12 +848,17 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. + Also checks cumulative strain from the initial reference cell. If the + total deformation exceeds ``max_cumulative_strain``, the update is + clamped to prevent phase transitions that would break the symmetry + constraint (e.g. hexagonal → tetragonal cell collapse). + Args: state: Current simulation state. new_cell: Cell tensor (n_systems, 3, 3) in column vector convention. Raises: - RuntimeError: If deformation gradient > 0.25. + RuntimeError: If per-step deformation gradient > 0.25. """ if not self.do_adjust_cell: return @@ -850,16 +869,34 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: for ci, si in enumerate(self.system_idx): cur_cell = state.row_vector_cell[si] new_row = new_cell[si].mT # column → row convention + + # Per-step deformation: clamp large steps to avoid ill-conditioned + # symmetrization while still making progress. The cumulative strain + # guard below is the real safety net against phase transitions. deform_delta = torch.linalg.solve(cur_cell, new_row) - identity max_delta = torch.abs(deform_delta).max().item() - if not (max_delta <= 0.25): # catches NaN via negated comparison - raise RuntimeError( - f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " - f"too large. Use smaller optimization steps." - ) + if not (max_delta <= 0.25): # clamp large steps; negated form catches NaN + deform_delta = deform_delta * (0.25 / max_delta) + + # Symmetrize the per-step deformation rots = self.rotations[ci].to(dtype=state.dtype) sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) - new_cell[si] = (cur_cell @ (sym_delta + identity)).mT + proposed_cell = cur_cell @ (sym_delta + identity) + + # Cumulative strain check against reference cell + if self.reference_cells is not None: + ref_cell = self.reference_cells[ci].to( + device=state.device, dtype=state.dtype + ) + cumulative_strain = torch.linalg.solve(ref_cell, proposed_cell) - identity + max_cumulative = torch.abs(cumulative_strain).max().item() + if max_cumulative > self.max_cumulative_strain: + # Clamp: scale the cumulative strain to stay within the envelope + scale = self.max_cumulative_strain / max_cumulative + clamped_strain = cumulative_strain * scale + proposed_cell = ref_cell @ (clamped_strain + identity) + + new_cell[si] = proposed_cell.mT # back to column convention def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: """Symmetrize a rank-1 tensor in-place for each constrained system.""" @@ -890,6 +927,8 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 self.system_idx + system_offset, adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=list(self.reference_cells) if self.reference_cells else None, + max_cumulative_strain=self.max_cumulative_strain, ) @classmethod @@ -909,12 +948,18 @@ def merge(cls, constraints: list[Self]) -> Self: rotations = [r for c in constraints for r in c.rotations] symm_maps = [s for c in constraints for s in c.symm_maps] system_idx = torch.cat([c.system_idx for c in constraints]) + # Merge reference cells if all constraints have them + ref_cells = None + if all(c.reference_cells is not None for c in constraints): + ref_cells = [rc for c in constraints for rc in c.reference_cells] return cls( rotations, symm_maps, system_idx=system_idx, adjust_positions=constraints[0].do_adjust_positions, adjust_cell=constraints[0].do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=constraints[0].max_cumulative_strain, ) def select_constraint( @@ -928,12 +973,19 @@ def select_constraint( if not mask.any(): return None local_idx = mask.nonzero(as_tuple=False).flatten().tolist() + ref_cells = ( + [self.reference_cells[idx] for idx in local_idx] + if self.reference_cells + else None + ) return type(self)( [self.rotations[idx] for idx in local_idx], [self.symm_maps[idx] for idx in local_idx], _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def select_sub_constraint( @@ -945,12 +997,15 @@ def select_sub_constraint( if sys_idx not in self.system_idx: return None local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() + ref_cells = [self.reference_cells[local]] if self.reference_cells else None return type(self)( [self.rotations[local]], [self.symm_maps[local]], torch.tensor([0], device=self.system_idx.device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def __repr__(self) -> str: