diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 7ddb8dc5..276d2f03 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -268,14 +268,20 @@ def test_cubic_forces_vanish(self) -> None: constraint.adjust_forces(state, forces) assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) - def test_large_deformation_raises(self) -> None: - """Deformation gradient > 0.25 raises RuntimeError.""" + def test_large_deformation_clamped(self) -> None: + """Per-step deformation > 0.25 is clamped rather than rejected.""" state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - new_cell = state.cell.clone() - new_cell[0] *= 1.5 - with pytest.raises(RuntimeError, match="deformation gradient"): - constraint.adjust_cell(state, new_cell) + orig_cell = state.cell.clone() + new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25 + constraint.adjust_cell(state, new_cell) + # Cell should have changed (not rejected) but less than requested + assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6) + # The change should be bounded + identity = torch.eye(3, dtype=DTYPE) + ref_cell = constraint.reference_cells[0] + strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity + assert torch.abs(strain).max().item() <= 0.5 + 1e-6 def test_init_mismatched_lengths_raises(self) -> None: """Mismatched rotations/symm_maps lengths raises ValueError.""" @@ -640,3 +646,45 @@ def test_noisy_model_preserves_symmetry_with_constraint( ) assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] == 229 + + def test_cumulative_strain_clamp_direct(self) -> None: + """adjust_cell clamps deformation when cumulative strain exceeds limit. + + Directly tests the clamping mechanism by repeatedly applying small + cell deformations that individually pass the per-step check (< 0.25) + but cumulatively exceed max_cumulative_strain. Verifies: + 1. The cell doesn't drift beyond the strain envelope + 2. Symmetry is preserved after many small steps + """ + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + constraint.max_cumulative_strain = 0.15 + assert constraint.reference_cells is not None + ref_cell = constraint.reference_cells[0].clone() + + # Apply 20 small deformations (each ~5% along one axis) + # Total would be ~100% without clamping, well over the 0.15 limit + identity = torch.eye(3, dtype=DTYPE) + for _ in range(20): + # Anisotropic stretch: elongate c-axis by 5% each step + stretch = identity.clone() + stretch[2, 2] = 1.05 + new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0) + constraint.adjust_cell(state, new_cell) + state.cell = new_cell + + # Cumulative strain must be clamped to the limit + final_cell = state.row_vector_cell[0] + cumulative = torch.linalg.solve(ref_cell, final_cell) - identity + max_strain = torch.abs(cumulative).max().item() + assert max_strain <= constraint.max_cumulative_strain + 1e-6, ( + f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}" + ) + + # Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15 + # Verify it's actually being clamped (not just small steps) + assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low" + + # Symmetry should still be detectable + datasets = get_symmetry_datasets(state, symprec=SYMPREC) + assert datasets[0].number == SPACEGROUPS["fcc"] diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c7c8d5e5..0e2eff31 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -697,8 +697,10 @@ class FixSymmetry(SystemConstraint): rotations: list[torch.Tensor] symm_maps: list[torch.Tensor] + reference_cells: list[torch.Tensor] | None do_adjust_positions: bool do_adjust_cell: bool + max_cumulative_strain: float def __init__( self, @@ -708,6 +710,8 @@ def __init__( *, adjust_positions: bool = True, adjust_cell: bool = True, + reference_cells: list[torch.Tensor] | None = None, + max_cumulative_strain: float = 0.5, ) -> None: """Initialize FixSymmetry constraint. @@ -717,6 +721,11 @@ def __init__( system_idx: System indices (defaults to 0..n_systems-1). adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. + reference_cells: Initial refined cells (row vectors) per system for + cumulative strain tracking. If None, cumulative check is skipped. + max_cumulative_strain: Maximum allowed cumulative strain from the + reference cell. If exceeded, the cell update is clamped to + keep the structure within this strain envelope. """ n_systems = len(rotations) if len(symm_maps) != n_systems: @@ -735,8 +744,10 @@ def __init__( super().__init__(system_idx=system_idx) self.rotations = rotations self.symm_maps = symm_maps + self.reference_cells = reference_cells self.do_adjust_positions = adjust_positions self.do_adjust_cell = adjust_cell + self.max_cumulative_strain = max_cumulative_strain @classmethod def from_state( @@ -770,7 +781,7 @@ def from_state( from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry - rotations, symm_maps = [], [] + rotations, symm_maps, reference_cells = [], [], [] cumsum = _cumsum_with_zero(state.n_atoms_per_system) for sys_idx in range(state.n_systems): @@ -793,6 +804,8 @@ def from_state( rotations.append(rots) symm_maps.append(smap) + # Store the refined cell as the reference for cumulative strain tracking + reference_cells.append(state.row_vector_cell[sys_idx].clone()) return cls( rotations, @@ -800,6 +813,7 @@ def from_state( system_idx=torch.arange(state.n_systems, device=state.device), adjust_positions=adjust_positions, adjust_cell=adjust_cell, + reference_cells=reference_cells, ) # === Symmetrization hooks === @@ -834,12 +848,17 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. + Also checks cumulative strain from the initial reference cell. If the + total deformation exceeds ``max_cumulative_strain``, the update is + clamped to prevent phase transitions that would break the symmetry + constraint (e.g. hexagonal → tetragonal cell collapse). + Args: state: Current simulation state. new_cell: Cell tensor (n_systems, 3, 3) in column vector convention. Raises: - RuntimeError: If deformation gradient > 0.25. + RuntimeError: If per-step deformation gradient > 0.25. """ if not self.do_adjust_cell: return @@ -850,16 +869,34 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: for ci, si in enumerate(self.system_idx): cur_cell = state.row_vector_cell[si] new_row = new_cell[si].mT # column → row convention + + # Per-step deformation: clamp large steps to avoid ill-conditioned + # symmetrization while still making progress. The cumulative strain + # guard below is the real safety net against phase transitions. deform_delta = torch.linalg.solve(cur_cell, new_row) - identity max_delta = torch.abs(deform_delta).max().item() - if not (max_delta <= 0.25): # catches NaN via negated comparison - raise RuntimeError( - f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " - f"too large. Use smaller optimization steps." - ) + if not (max_delta <= 0.25): # clamp large steps; negated form catches NaN + deform_delta = deform_delta * (0.25 / max_delta) + + # Symmetrize the per-step deformation rots = self.rotations[ci].to(dtype=state.dtype) sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) - new_cell[si] = (cur_cell @ (sym_delta + identity)).mT + proposed_cell = cur_cell @ (sym_delta + identity) + + # Cumulative strain check against reference cell + if self.reference_cells is not None: + ref_cell = self.reference_cells[ci].to( + device=state.device, dtype=state.dtype + ) + cumulative_strain = torch.linalg.solve(ref_cell, proposed_cell) - identity + max_cumulative = torch.abs(cumulative_strain).max().item() + if max_cumulative > self.max_cumulative_strain: + # Clamp: scale the cumulative strain to stay within the envelope + scale = self.max_cumulative_strain / max_cumulative + clamped_strain = cumulative_strain * scale + proposed_cell = ref_cell @ (clamped_strain + identity) + + new_cell[si] = proposed_cell.mT # back to column convention def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: """Symmetrize a rank-1 tensor in-place for each constrained system.""" @@ -890,6 +927,8 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 self.system_idx + system_offset, adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=list(self.reference_cells) if self.reference_cells else None, + max_cumulative_strain=self.max_cumulative_strain, ) @classmethod @@ -909,12 +948,18 @@ def merge(cls, constraints: list[Self]) -> Self: rotations = [r for c in constraints for r in c.rotations] symm_maps = [s for c in constraints for s in c.symm_maps] system_idx = torch.cat([c.system_idx for c in constraints]) + # Merge reference cells if all constraints have them + ref_cells = None + if all(c.reference_cells is not None for c in constraints): + ref_cells = [rc for c in constraints for rc in c.reference_cells] return cls( rotations, symm_maps, system_idx=system_idx, adjust_positions=constraints[0].do_adjust_positions, adjust_cell=constraints[0].do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=constraints[0].max_cumulative_strain, ) def select_constraint( @@ -928,12 +973,19 @@ def select_constraint( if not mask.any(): return None local_idx = mask.nonzero(as_tuple=False).flatten().tolist() + ref_cells = ( + [self.reference_cells[idx] for idx in local_idx] + if self.reference_cells + else None + ) return type(self)( [self.rotations[idx] for idx in local_idx], [self.symm_maps[idx] for idx in local_idx], _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def select_sub_constraint( @@ -945,12 +997,15 @@ def select_sub_constraint( if sys_idx not in self.system_idx: return None local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() + ref_cells = [self.reference_cells[local]] if self.reference_cells else None return type(self)( [self.rotations[local]], [self.symm_maps[local]], torch.tensor([0], device=self.system_idx.device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def __repr__(self) -> str: