diff --git a/pyproject.toml b/pyproject.toml index d8fd246b..1428e67b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,11 @@ test = [ "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", + "moyopy>=0.3", + "spglib>=2.6", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] +symmetry = ["moyopy>=0.3"] mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py new file mode 100644 index 00000000..7ddb8dc5 --- /dev/null +++ b/tests/test_fix_symmetry.py @@ -0,0 +1,642 @@ +"""Tests for the FixSymmetry constraint.""" + +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.build import bulk +from ase.constraints import FixSymmetry as ASEFixSymmetry +from ase.spacegroup.symmetrize import refine_symmetry as ase_refine_symmetry +from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress +from pymatgen.core import Lattice, Structure +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.constraints import FixCom, FixSymmetry +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.symmetrize import get_symmetry_datasets + + +pytest.importorskip("moyopy") +pytest.importorskip("spglib") # needed by ASE's FixSymmetry + +SPACEGROUPS = {"fcc": 225, "hcp": 194, "diamond": 227, "bcc": 229, "p6bar": 174} +MAX_STEPS = 30 +DTYPE = torch.float64 +SYMPREC = 0.01 +CPU = torch.device("cpu") + + +# === Structure helpers === + + +def _make_p6bar() -> Atoms: + """Create P-6 (space group 174) structure.""" + lattice = Lattice.hexagonal(a=3.0, c=5.0) + structure = Structure.from_spacegroup( + sg=174, lattice=lattice, species=["Si"], coords=[[0.3, 0.1, 0.25]] + ) + return AseAtomsAdaptor.get_atoms(structure) + + +def make_structure(name: str) -> Atoms: + """Create a test structure by name (fcc/hcp/diamond/bcc/p6bar + _rotated suffix).""" + base = name.replace("_rotated", "") + builders = { + "fcc": lambda: bulk("Cu", "fcc", a=3.6), + "hcp": lambda: bulk("Ti", "hcp", a=2.95, c=4.68), + "diamond": lambda: bulk("Si", "diamond", a=5.43), + "bcc": lambda: bulk("Al", "bcc", a=2 / np.sqrt(3), cubic=True), + "p6bar": _make_p6bar, + } + atoms = builders[base]() + if "_rotated" in name: + rotation_product = np.eye(3) + for axis_idx in range(3): + axes = list(range(3)) + axes.remove(axis_idx) + row_idx, col_idx = axes + rot_mat = np.eye(3) + theta = 0.1 * (axis_idx + 1) + rot_mat[row_idx, row_idx] = np.cos(theta) + rot_mat[col_idx, col_idx] = np.cos(theta) + rot_mat[row_idx, col_idx] = np.sin(theta) + rot_mat[col_idx, row_idx] = -np.sin(theta) + rotation_product = np.dot(rotation_product, rot_mat) + atoms.set_cell(atoms.cell @ rotation_product, scale_atoms=True) + return atoms + + +# === Fixtures === + + +@pytest.fixture +def model() -> LennardJonesModel: + """LJ model for testing.""" + return LennardJonesModel( + sigma=1.0, + epsilon=0.05, + cutoff=6.0, + use_neighbor_list=False, + compute_stress=True, + dtype=DTYPE, + ) + + +class NoisyModelWrapper: + """Wrapper that adds noise to forces and stress.""" + + model: LennardJonesModel + rng: np.random.Generator + noise_scale: float + + def __init__(self, model: LennardJonesModel, noise_scale: float = 1e-4) -> None: + self.model = model + self.rng = np.random.default_rng(seed=1) + self.noise_scale = noise_scale + + @property + def device(self) -> torch.device: + return self.model.device + + @property + def dtype(self) -> torch.dtype: + return self.model.dtype + + def __call__(self, state: ts.SimState) -> dict[str, torch.Tensor]: + """Forward pass with added noise.""" + results = self.model(state) + for key in ("forces", "stress"): + if key in results: + noise = torch.tensor( + self.rng.normal(size=results[key].shape), + dtype=results[key].dtype, + device=results[key].device, + ) + results[key] = results[key] + self.noise_scale * noise + return results + + +@pytest.fixture +def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: + """LJ model with noise added to forces/stress.""" + return NoisyModelWrapper(model) + + +@pytest.fixture +def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSymmetry]: + """P-6 structure with both TorchSim and ASE constraints (shared setup).""" + atoms = make_structure("p6bar") + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + return state, ts_constraint, ase_atoms, ase_constraint + + +# === Optimization helper === + + +def run_optimization_check_symmetry( + state: ts.SimState, + model: LennardJonesModel | NoisyModelWrapper, + constraint: FixSymmetry | None = None, + *, + adjust_cell: bool = True, + max_steps: int = MAX_STEPS, + force_tol: float = 0.001, +) -> dict[str, list[int | None]]: + """Run FIRE optimization and return initial/final space group numbers.""" + initial = get_symmetry_datasets(state, SYMPREC) + if constraint is not None: + state.constraints = [constraint] + init_kwargs = {"cell_filter": ts.CellFilter.frechet} if adjust_cell else None + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, + include_cell_forces=adjust_cell, + ) + final_state = ts.optimize( + system=state, + model=model, + optimizer=ts.Optimizer.fire, + convergence_fn=convergence_fn, + init_kwargs=init_kwargs, + max_steps=max_steps, + steps_between_swaps=1, + ) + final = get_symmetry_datasets(final_state, SYMPREC) + return { + "initial_spacegroups": [d.number if d else None for d in initial], + "final_spacegroups": [d.number if d else None for d in final], + } + + +# === Tests: Creation & Basics === + + +class TestFixSymmetryCreation: + """Tests for FixSymmetry creation and basic behavior.""" + + def test_from_state_batched(self) -> None: + """Batched state with FCC + diamond gets correct ops, atom counts, and DOF.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert len(constraint.rotations) == 2 + assert constraint.rotations[0].shape[0] == 48 # cubic + assert constraint.symm_maps[0].shape == (48, 1) # Cu: 1 atom + assert constraint.symm_maps[1].shape == (48, 2) # Si: 2 atoms + assert torch.all(constraint.get_removed_dof(state) == 0) + + def test_p1_identity_is_noop(self) -> None: + """P1 structure has 1 op and symmetrization is a no-op for forces and stress.""" + atoms = Atoms( + "SiGe", + positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], + cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], + pbc=True, + ) + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert constraint.rotations[0].shape[0] == 1 + + forces = torch.randn(2, 3, dtype=DTYPE) + orig_forces = forces.clone() + constraint.adjust_forces(state, forces) + assert torch.allclose(forces, orig_forces, atol=1e-10) + + stress = torch.randn(1, 3, 3, dtype=DTYPE) + stress = (stress + stress.mT) / 2 + orig_stress = stress.clone() + constraint.adjust_stress(state, stress) + assert torch.allclose(stress, orig_stress, atol=1e-10) + + @pytest.mark.parametrize("refine", [True, False]) + def test_from_state_refine_symmetry(self, *, refine: bool) -> None: + """With refine=False state is unmodified; with refine=True it may change.""" + atoms = make_structure("fcc") + rng = np.random.default_rng(42) + atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + orig_pos = state.positions.clone() + _ = FixSymmetry.from_state(state, symprec=SYMPREC, refine_symmetry_state=refine) + if not refine: + assert torch.allclose(state.positions, orig_pos) + + @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond", "p6bar"]) + def test_refine_symmetry_produces_correct_spacegroup( + self, + structure_name: str, + ) -> None: + """Perturbed structure recovers correct spacegroup after refinement.""" + from torch_sim.symmetrize import refine_symmetry + + atoms = make_structure(structure_name) + expected = SPACEGROUPS[structure_name] + rng = np.random.default_rng(42) + atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + + refined_cell, refined_pos = refine_symmetry( + state.row_vector_cell[0], + state.positions, + state.atomic_numbers, + symprec=SYMPREC, + ) + state.cell[0] = refined_cell.mT + state.positions = refined_pos + + datasets = get_symmetry_datasets(state, symprec=1e-4) + assert datasets[0].number == expected + + def test_cubic_forces_vanish(self) -> None: + """Asymmetric force on single cubic atom symmetrizes to zero.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], + dtype=DTYPE, + ) + constraint.adjust_forces(state, forces) + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + + def test_large_deformation_raises(self) -> None: + """Deformation gradient > 0.25 raises RuntimeError.""" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + new_cell = state.cell.clone() + new_cell[0] *= 1.5 + with pytest.raises(RuntimeError, match="deformation gradient"): + constraint.adjust_cell(state, new_cell) + + def test_init_mismatched_lengths_raises(self) -> None: + """Mismatched rotations/symm_maps lengths raises ValueError.""" + rots = [torch.eye(3).unsqueeze(0)] + smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)] + with pytest.raises(ValueError, match="length mismatch"): + FixSymmetry(rots, smaps) + + @pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"]) + def test_adjust_skipped_when_disabled(self, method: str) -> None: + """adjust_positions=False / adjust_cell=False leaves data unchanged.""" + flag = method.replace("adjust_", "") # "positions" or "cell" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state( + state, + symprec=SYMPREC, + **{f"adjust_{flag}": False}, + ) + if method == "adjust_positions": + data = state.positions.clone() + 0.1 + else: + data = state.cell.clone() * 1.01 + expected = data.clone() + getattr(constraint, method)(state, data) + assert torch.equal(data, expected) + + +# === Tests: Comparison with ASE === + + +class TestFixSymmetryComparisonWithASE: + """Compare TorchSim FixSymmetry with ASE's implementation on P-6 structure.""" + + def test_force_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Force symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints + rng = np.random.default_rng(42) + forces_np = rng.standard_normal((len(ase_atoms), 3)) + forces_ts = torch.tensor(forces_np.copy(), dtype=DTYPE) + ts_c.adjust_forces(state, forces_ts) + ase_c.adjust_forces(ase_atoms, forces_np) + assert np.allclose(forces_ts.numpy(), forces_np, atol=1e-10) + + def test_stress_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Stress symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints + stress_3x3 = np.array([[10.0, 1.0, 0.5], [1.0, 8.0, 0.3], [0.5, 0.3, 6.0]]) + stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3).copy() + stress_ts = torch.tensor([stress_3x3.copy()], dtype=DTYPE) + ts_c.adjust_stress(state, stress_ts) + ase_c.adjust_stress(ase_atoms, stress_voigt) + assert np.allclose( + stress_ts[0].numpy(), + voigt_6_to_full_3x3_stress(stress_voigt), + atol=1e-10, + ) + + def test_cell_deformation_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Cell deformation symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints + deformed = ase_atoms.get_cell().copy() + deformed[0, 1] += 0.05 + new_cell_ts = torch.tensor([deformed.copy().T], dtype=DTYPE) + ts_c.adjust_cell(state, new_cell_ts) + ase_cell = deformed.copy() + ase_c.adjust_cell(ase_atoms, ase_cell) + assert np.allclose(new_cell_ts[0].mT.numpy(), ase_cell, atol=1e-10) + + def test_position_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Position displacement symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints + rng = np.random.default_rng(42) + disp = rng.standard_normal((len(ase_atoms), 3)) * 0.01 + new_pos_ts = state.positions.clone() + torch.tensor(disp, dtype=DTYPE) + new_pos_ase = ase_atoms.positions.copy() + disp + ts_c.adjust_positions(state, new_pos_ts) + ase_c.adjust_positions(ase_atoms, new_pos_ase) + assert np.allclose(new_pos_ts.numpy(), new_pos_ase, atol=1e-10) + + +# === Tests: Merge, Select, Reindex === + + +class TestFixSymmetryMergeSelectReindex: + """Tests for reindex/merge API, select, and concatenation.""" + + def test_reindex_preserves_symmetry_data(self) -> None: + """reindex shifts system_idx but preserves rotations and symm_maps.""" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + orig = FixSymmetry.from_state(state, symprec=SYMPREC) + shifted = orig.reindex(atom_offset=100, system_offset=5) + assert shifted.system_idx.item() == 5 + assert torch.equal(shifted.rotations[0], orig.rotations[0]) + assert torch.equal(shifted.symm_maps[0], orig.symm_maps[0]) + + def test_merge_two_constraints(self) -> None: + """Merge two single-system constraints via reindex + merge.""" + s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) + c1 = FixSymmetry.from_state(s1) + c2 = FixSymmetry.from_state(s2).reindex(atom_offset=0, system_offset=1) + merged = FixSymmetry.merge([c1, c2]) + assert len(merged.rotations) == 2 + assert merged.system_idx.tolist() == [0, 1] + + def test_merge_multi_system_no_duplicate_indices(self) -> None: + """Regression: multi-system constraints must use cumulative offsets.""" + atoms_a = [ + make_structure("fcc"), + make_structure("diamond"), + make_structure("hcp"), + ] + atoms_b = [make_structure("bcc"), make_structure("fcc")] + c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, CPU, DTYPE)) + c_b = FixSymmetry.from_state( + ts.io.atoms_to_state(atoms_b, CPU, DTYPE), + ).reindex(atom_offset=0, system_offset=3) + merged = FixSymmetry.merge([c_a, c_b]) + assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] + + def test_system_constraint_merge_multi_system_via_concatenate(self) -> None: + """Regression: merging multi-system FixCom via concatenate_states.""" + s1 = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + s2 = ts.io.atoms_to_state( + [make_structure("bcc"), make_structure("hcp")], + CPU, + DTYPE, + ) + s1.constraints = [FixCom(system_idx=torch.tensor([0, 1]))] + s2.constraints = [FixCom(system_idx=torch.tensor([0, 1]))] + combined = ts.concatenate_states([s1, s2]) + assert combined.constraints[0].system_idx.tolist() == [0, 1, 2, 3] + + def test_concatenate_states_with_fix_symmetry(self) -> None: + """FixSymmetry survives concatenate_states and still symmetrizes correctly.""" + s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) + s1.constraints = [FixSymmetry.from_state(s1, symprec=SYMPREC)] + s2.constraints = [FixSymmetry.from_state(s2, symprec=SYMPREC)] + combined = ts.concatenate_states([s1, s2]) + constraint = combined.constraints[0] + assert isinstance(constraint, FixSymmetry) + assert constraint.system_idx.tolist() == [0, 1] + assert len(constraint.rotations) == 2 + # Forces on single FCC atom should still vanish + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], + dtype=DTYPE, + ) + constraint.adjust_forces(combined, forces) + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + + def test_select_sub_constraint(self) -> None: + """Select second system from batched constraint.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + selected = constraint.select_sub_constraint(torch.tensor([1, 2]), sys_idx=1) + assert selected is not None + assert selected.symm_maps[0].shape[1] == 2 + assert selected.system_idx.item() == 0 + + def test_select_constraint_by_mask(self) -> None: + """Select first system via system_mask.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + selected = constraint.select_constraint( + atom_mask=torch.tensor([True, False, False]), + system_mask=torch.tensor([True, False]), + ) + assert selected is not None + assert len(selected.rotations) == 1 + assert selected.rotations[0].shape[0] == 48 + + def test_select_returns_none_for_nonexistent(self) -> None: + """select_sub_constraint and select_constraint return None when no match.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert constraint.select_sub_constraint(torch.tensor([0]), sys_idx=99) is None + assert ( + constraint.select_constraint( + atom_mask=torch.zeros(3, dtype=torch.bool), + system_mask=torch.zeros(2, dtype=torch.bool), + ) + is None + ) + + +# === Tests: build_symmetry_map chunked path === + + +def test_build_symmetry_map_chunked_matches_vectorized() -> None: + """Per-op loop gives same result as vectorized path.""" + import torch_sim.symmetrize as sym_mod + from torch_sim.symmetrize import ( + _extract_symmetry_ops, + _moyo_dataset, + build_symmetry_map, + ) + + state = ts.io.atoms_to_state(make_structure("p6bar"), CPU, DTYPE) + cell = state.row_vector_cell[0] + frac = state.positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac, state.atomic_numbers) + rotations, translations = _extract_symmetry_ops(dataset, DTYPE, CPU) + + old_threshold = sym_mod._SYMM_MAP_CHUNK_THRESHOLD # noqa: SLF001 + try: + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = len(state.positions) + 1 # noqa: SLF001 + vectorized = build_symmetry_map(rotations, translations, frac) + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = 0 # noqa: SLF001 + chunked = build_symmetry_map(rotations, translations, frac) + finally: + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = old_threshold # noqa: SLF001 + assert torch.equal(vectorized, chunked) + + +# === Tests: Optimization === + + +class TestFixSymmetryWithOptimization: + """Test FixSymmetry with actual optimization routines.""" + + @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond"]) + @pytest.mark.parametrize( + ("adjust_positions", "adjust_cell"), + [(True, True), (False, False)], + ) + def test_distorted_preserves_symmetry( + self, + noisy_lj_model: NoisyModelWrapper, + structure_name: str, + *, + adjust_positions: bool, + adjust_cell: bool, + ) -> None: + """Compressed structure relaxes while preserving symmetry.""" + atoms = make_structure(structure_name) + expected = SPACEGROUPS[structure_name] + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + constraint = FixSymmetry.from_state( + state, + symprec=SYMPREC, + adjust_positions=adjust_positions, + adjust_cell=adjust_cell, + ) + state.cell = state.cell * 0.9 + state.positions = state.positions * 0.9 + result = run_optimization_check_symmetry( + state, + noisy_lj_model, + constraint=constraint, + adjust_cell=adjust_cell, + force_tol=0.01, + ) + assert result["final_spacegroups"][0] == expected + + @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) + def test_cell_filter_preserves_symmetry( + self, + model: LennardJonesModel, + cell_filter: ts.CellFilter, + ) -> None: + """Cell filters with FixSymmetry preserve symmetry.""" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + state.constraints = [constraint] + initial = get_symmetry_datasets(state, symprec=SYMPREC) + final_state = ts.optimize( + system=state, + model=model, + optimizer=ts.Optimizer.gradient_descent, + convergence_fn=ts.generate_force_convergence_fn(force_tol=0.01), + init_kwargs={"cell_filter": cell_filter}, + max_steps=MAX_STEPS, + ) + final = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert initial[0].number == final[0].number + + @pytest.mark.parametrize("cell_filter", [ts.CellFilter.frechet, ts.CellFilter.unit]) + def test_lbfgs_preserves_symmetry( + self, + noisy_lj_model: NoisyModelWrapper, + cell_filter: ts.CellFilter, + ) -> None: + """Regression: LBFGS must use set_constrained_cell for FixSymmetry support.""" + state = ts.io.atoms_to_state(make_structure("bcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + state.constraints = [constraint] + state.cell = state.cell * 0.95 + state.positions = state.positions * 0.95 + final_state = ts.optimize( + system=state, + model=noisy_lj_model, + optimizer=ts.Optimizer.lbfgs, + convergence_fn=ts.generate_force_convergence_fn( + force_tol=0.01, + include_cell_forces=True, + ), + init_kwargs={"cell_filter": cell_filter}, + max_steps=MAX_STEPS, + ) + final = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert final[0].number == SPACEGROUPS["bcc"] + + @pytest.mark.parametrize("rotated", [False, True]) + def test_noisy_model_loses_symmetry_without_constraint( + self, + noisy_lj_model: NoisyModelWrapper, + *, + rotated: bool, + ) -> None: + """Negative control: without FixSymmetry, noisy forces break symmetry.""" + name = "bcc_rotated" if rotated else "bcc" + state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE) + result = run_optimization_check_symmetry(state, noisy_lj_model, constraint=None) + assert result["initial_spacegroups"][0] == 229 + assert result["final_spacegroups"][0] != 229 + + @pytest.mark.parametrize("rotated", [False, True]) + def test_noisy_model_preserves_symmetry_with_constraint( + self, + noisy_lj_model: NoisyModelWrapper, + *, + rotated: bool, + ) -> None: + """With FixSymmetry, noisy forces still preserve symmetry.""" + name = "bcc_rotated" if rotated else "bcc" + state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + result = run_optimization_check_symmetry( + state, + noisy_lj_model, + constraint=constraint, + ) + assert result["initial_spacegroups"][0] == 229 + assert result["final_spacegroups"][0] == 229 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index f0ed8599..c7c8d5e5 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -77,6 +77,30 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ + def adjust_stress( # noqa: B027 + self, state: SimState, stress: torch.Tensor + ) -> None: + """Adjust stress tensor to satisfy the constraint. + + Default is a no-op. Override in subclasses that need stress symmetrization. + + Args: + state: Current simulation state + stress: Stress tensor to be adjusted in-place + """ + + def adjust_cell( # noqa: B027 + self, state: SimState, cell: torch.Tensor + ) -> None: + """Adjust cell to satisfy the constraint. + + Default is a no-op. Override in subclasses that need cell symmetrization. + + Args: + state: Current simulation state + cell: Cell tensor to be adjusted in-place (column vector convention) + """ + @abstractmethod def select_constraint( self, atom_mask: torch.Tensor, system_mask: torch.Tensor @@ -100,6 +124,36 @@ def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Constraint for the given atom and system index """ + @abstractmethod + def reindex(self, atom_offset: int, system_offset: int) -> Self: + """Return a copy with indices shifted to global coordinates. + + Called during state concatenation to adjust indices before merging. + + Args: + atom_offset: Offset to add to atom indices + system_offset: Offset to add to system indices + """ + + @classmethod + @abstractmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge multiple already-reindexed constraints into one. + + Constraints must have global (absolute) indices — call ``reindex`` + first. Subclasses override this to handle type-specific data. + + Args: + constraints: Constraints to merge (all same type, already reindexed) + """ + + +def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor: + """Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9].""" + return torch.cat( + [torch.zeros(1, device=tensor.device, dtype=tensor.dtype), tensor.cumsum(dim=0)] + ) + def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: cumsum_atom_mask = torch.cumsum(~mask, dim=0) @@ -198,6 +252,15 @@ def select_sub_constraint( return None return type(self)(new_atom_idx) + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with atom indices shifted by atom_offset.""" + return type(self)(self.atom_idx + atom_offset) + + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating already-reindexed atom indices.""" + return cls(torch.cat([c.atom_idx for c in constraints])) + class SystemConstraint(Constraint): """Base class for constraints that act on specific system indices. @@ -280,51 +343,55 @@ def select_sub_constraint( """ return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with system indices shifted by system_offset.""" + return type(self)(self.system_idx + system_offset) + + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating already-reindexed system indices.""" + return cls(torch.cat([c.system_idx for c in constraints])) + def merge_constraints( constraint_lists: list[list[AtomConstraint | SystemConstraint]], num_atoms_per_state: torch.Tensor, + num_systems_per_state: torch.Tensor | None = None, ) -> list[Constraint]: - """Merge constraints from multiple systems into a single list of constraints. + """Merge constraints from multiple states into a single list. + + Each constraint is first reindexed to global coordinates (via ``reindex``), + then constraints of the same type are merged (via ``merge``). Args: - constraint_lists: List of lists of constraints - num_atoms_per_state: Number of atoms per system + constraint_lists: List of lists of constraints, one list per state + num_atoms_per_state: Number of atoms per state + num_systems_per_state: Number of systems per state. Falls back to 1 + per state if not provided. Returns: List of merged constraints """ from collections import defaultdict - # Calculate offsets: for state i, offset = sum of atoms in states 0 to i-1 + # Calculate cumulative offsets for atoms and systems device, dtype = num_atoms_per_state.device, num_atoms_per_state.dtype - cumsum_atoms = torch.cat( - [ - torch.zeros(1, device=device, dtype=dtype), - torch.cumsum(num_atoms_per_state[:-1], dim=0), - ] - ) + atom_offsets = _cumsum_with_zero(num_atoms_per_state[:-1]) + if num_systems_per_state is None: + num_systems_per_state = torch.ones( + len(constraint_lists), device=device, dtype=dtype + ) + system_offsets = _cumsum_with_zero(num_systems_per_state[:-1]) - # aggregate updated constraint indices by constraint type - constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) - for i, constraint_list in enumerate(constraint_lists): + # Reindex each constraint to global coordinates, then group by type + grouped: dict[type[Constraint], list[Constraint]] = defaultdict(list) + for state_idx, constraint_list in enumerate(constraint_lists): + a_off = int(atom_offsets[state_idx].item()) + s_off = int(system_offsets[state_idx].item()) for constraint in constraint_list: - if isinstance(constraint, AtomConstraint): - idxs = constraint.atom_idx - offset = cumsum_atoms[i] - elif isinstance(constraint, SystemConstraint): - idxs = constraint.system_idx - offset = i - else: - raise NotImplementedError( - f"Constraint type {type(constraint)} is not implemented" - ) - constraint_indices[type(constraint)].append(idxs + offset) + grouped[type(constraint)].append(constraint.reindex(a_off, s_off)) - return [ - constraint_type(torch.cat(idxs)) - for constraint_type, idxs in constraint_indices.items() - ] + return [ctype.merge(cs) for ctype, cs in grouped.items()] class FixAtoms(AtomConstraint): @@ -615,3 +682,283 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None UserWarning, stacklevel=3, ) + + +class FixSymmetry(SystemConstraint): + """Preserve spacegroup symmetry during optimization. + + Symmetrizes forces/momenta as rank-1 tensors and stress/cell deformation + as rank-2 tensors using the crystal's symmetry operations. Each system in + a batch can have different symmetry operations. + + Forces and stress are always symmetrized. Position and cell symmetrization + can be toggled via ``adjust_positions`` and ``adjust_cell``. + """ + + rotations: list[torch.Tensor] + symm_maps: list[torch.Tensor] + do_adjust_positions: bool + do_adjust_cell: bool + + def __init__( + self, + rotations: list[torch.Tensor], + symm_maps: list[torch.Tensor], + system_idx: torch.Tensor | None = None, + *, + adjust_positions: bool = True, + adjust_cell: bool = True, + ) -> None: + """Initialize FixSymmetry constraint. + + Args: + rotations: Rotation tensors per system, each (n_ops, 3, 3). + symm_maps: Atom mapping tensors per system, each (n_ops, n_atoms). + system_idx: System indices (defaults to 0..n_systems-1). + adjust_positions: Whether to symmetrize position displacements. + adjust_cell: Whether to symmetrize cell/stress adjustments. + """ + n_systems = len(rotations) + if len(symm_maps) != n_systems: + raise ValueError( + f"rotations and symm_maps length mismatch: " + f"{n_systems} vs {len(symm_maps)}" + ) + if system_idx is None: + device = rotations[0].device if rotations else torch.device("cpu") + system_idx = torch.arange(n_systems, device=device) + if len(system_idx) != n_systems: + raise ValueError( + f"system_idx length ({len(system_idx)}) != n_systems ({n_systems})" + ) + + super().__init__(system_idx=system_idx) + self.rotations = rotations + self.symm_maps = symm_maps + self.do_adjust_positions = adjust_positions + self.do_adjust_cell = adjust_cell + + @classmethod + def from_state( + cls, + state: SimState, + symprec: float = 0.01, + *, + adjust_positions: bool = True, + adjust_cell: bool = True, + refine_symmetry_state: bool = True, + ) -> Self: + """Create from SimState, optionally refining to ideal symmetry first. + + Warning: + When ``refine_symmetry_state=True`` (default), the input state is + **mutated in-place** to have ideal symmetric positions and cell. + + Args: + state: SimState containing one or more systems. + symprec: Symmetry precision for moyopy. + adjust_positions: Whether to symmetrize position displacements. + adjust_cell: Whether to symmetrize cell/stress adjustments. + refine_symmetry_state: Whether to refine positions/cell to ideal values. + """ + try: + import moyopy # noqa: F401 + except ImportError: + raise ImportError( + "moyopy required for FixSymmetry: pip install moyopy" + ) from None + + from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry + + rotations, symm_maps = [], [] + cumsum = _cumsum_with_zero(state.n_atoms_per_system) + + for sys_idx in range(state.n_systems): + start, end = cumsum[sys_idx].item(), cumsum[sys_idx + 1].item() + cell = state.row_vector_cell[sys_idx] + pos, nums = state.positions[start:end], state.atomic_numbers[start:end] + + if refine_symmetry_state: + # Single moyopy call: refine + get symmetry ops in one pass + cell, pos, rots, smap = refine_and_prep_symmetry( + cell, + pos, + nums, + symprec=symprec, + ) + state.cell[sys_idx] = cell.mT # row→column vector convention + state.positions[start:end] = pos + else: + rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec) + + rotations.append(rots) + symm_maps.append(smap) + + return cls( + rotations, + symm_maps, + system_idx=torch.arange(state.n_systems, device=state.device), + adjust_positions=adjust_positions, + adjust_cell=adjust_cell, + ) + + # === Symmetrization hooks === + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Symmetrize forces according to crystal symmetry.""" + self._symmetrize_rank1(state, forces) + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Symmetrize position displacements (skipped if do_adjust_positions=False).""" + if not self.do_adjust_positions: + return + displacement = new_positions - state.positions + self._symmetrize_rank1(state, displacement) + new_positions[:] = state.positions + displacement + + def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: + """Symmetrize stress tensor in-place. + + Always runs (like adjust_forces), independent of do_adjust_cell. + """ + from torch_sim.symmetrize import symmetrize_rank2 + + dtype = stress.dtype + for ci, si in enumerate(self.system_idx): + rots = self.rotations[ci].to(dtype=dtype) + stress[si] = symmetrize_rank2(state.row_vector_cell[si], stress[si], rots) + + def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: + """Symmetrize cell deformation gradient in-place. + + Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a + rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. + + Args: + state: Current simulation state. + new_cell: Cell tensor (n_systems, 3, 3) in column vector convention. + + Raises: + RuntimeError: If deformation gradient > 0.25. + """ + if not self.do_adjust_cell: + return + + from torch_sim.symmetrize import symmetrize_rank2 + + identity = torch.eye(3, device=state.device, dtype=state.dtype) + for ci, si in enumerate(self.system_idx): + cur_cell = state.row_vector_cell[si] + new_row = new_cell[si].mT # column → row convention + deform_delta = torch.linalg.solve(cur_cell, new_row) - identity + max_delta = torch.abs(deform_delta).max().item() + if not (max_delta <= 0.25): # catches NaN via negated comparison + raise RuntimeError( + f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " + f"too large. Use smaller optimization steps." + ) + rots = self.rotations[ci].to(dtype=state.dtype) + sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) + new_cell[si] = (cur_cell @ (sym_delta + identity)).mT + + def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: + """Symmetrize a rank-1 tensor in-place for each constrained system.""" + from torch_sim.symmetrize import symmetrize_rank1 + + cumsum = _cumsum_with_zero(state.n_atoms_per_system) + dtype = vectors.dtype + for ci, si in enumerate(self.system_idx): + start, end = cumsum[si].item(), cumsum[si + 1].item() + vectors[start:end] = symmetrize_rank1( + state.row_vector_cell[si], + vectors[start:end], + self.rotations[ci].to(dtype=dtype), + self.symm_maps[ci], + ) + + # === Constraint interface === + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Returns zero - constrains direction, not DOF count.""" + return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) + + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with system indices shifted by system_offset.""" + return type(self)( + list(self.rotations), + list(self.symm_maps), + self.system_idx + system_offset, + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating rotations, symm_maps, and system indices.""" + if not constraints: + raise ValueError("Cannot merge empty constraint list") + if any( + c.do_adjust_positions != constraints[0].do_adjust_positions + or c.do_adjust_cell != constraints[0].do_adjust_cell + for c in constraints[1:] + ): + raise ValueError( + "Cannot merge FixSymmetry constraints with different " + "adjust_positions/adjust_cell settings" + ) + rotations = [r for c in constraints for r in c.rotations] + symm_maps = [s for c in constraints for s in c.symm_maps] + system_idx = torch.cat([c.system_idx for c in constraints]) + return cls( + rotations, + symm_maps, + system_idx=system_idx, + adjust_positions=constraints[0].do_adjust_positions, + adjust_cell=constraints[0].do_adjust_cell, + ) + + def select_constraint( + self, + atom_mask: torch.Tensor, # noqa: ARG002 + system_mask: torch.Tensor, + ) -> Self | None: + """Select constraint for systems matching the mask.""" + keep = torch.where(system_mask)[0] + mask = torch.isin(self.system_idx, keep) + if not mask.any(): + return None + local_idx = mask.nonzero(as_tuple=False).flatten().tolist() + return type(self)( + [self.rotations[idx] for idx in local_idx], + [self.symm_maps[idx] for idx in local_idx], + _mask_constraint_indices(self.system_idx[mask], system_mask), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, # noqa: ARG002 + sys_idx: int, + ) -> Self | None: + """Select constraint for a single system.""" + if sys_idx not in self.system_idx: + return None + local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() + return type(self)( + [self.rotations[local]], + [self.symm_maps[local]], + torch.tensor([0], device=self.system_idx.device), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + + def __repr__(self) -> str: + """String representation.""" + n_ops = [r.shape[0] for r in self.rotations] + ops = str(n_ops) if len(n_ops) <= 3 else f"[{n_ops[0]}, ..., {n_ops[-1]}]" + return ( + f"FixSymmetry(n_systems={len(self.rotations)}, n_ops={ops}, " + f"adjust_positions={self.do_adjust_positions}, " + f"adjust_cell={self.do_adjust_cell})" + ) diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 3a061f23..60ac5a1c 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -61,6 +61,18 @@ def _compute_cell_masses(state: SimState) -> torch.Tensor: return cell_masses.unsqueeze(-1).expand(-1, 3) +def _get_constrained_stress( + model_output: dict[str, torch.Tensor], state: SimState +) -> torch.Tensor: + """Clone stress from model output and apply constraint symmetrization.""" + if not state.constraints: + return model_output["stress"] + stress = model_output["stress"].clone() + for constraint in state.constraints: + constraint.adjust_stress(state, stress) + return stress + + def _apply_constraints( virial: torch.Tensor, *, hydrostatic_strain: bool, constant_volume: bool ) -> torch.Tensor: @@ -110,7 +122,7 @@ def unit_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces - stress = model_output["stress"] + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -162,7 +174,7 @@ def frechet_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces using Frechet approach - stress = model_output["stress"] + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -222,9 +234,10 @@ def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> # Update cell from new positions cell_update = cell_positions_new / cell_factor_expanded - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, cell_update.transpose(-2, -1) - ) + new_cell = torch.bmm(state.reference_cell.mT, cell_update.transpose(-2, -1)) + + # Apply cell constraints (in-place, column vector convention) + state.set_constrained_cell(new_cell.mT.contiguous()) state.cell_positions = cell_positions_new @@ -249,7 +262,9 @@ def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) new_row_vector_cell = torch.bmm( state.reference_cell.mT, deform_grad_new.transpose(-2, -1) ) - state.row_vector_cell = new_row_vector_cell + + # Apply cell constraints (in-place, column vector convention) + state.set_constrained_cell(new_row_vector_cell.mT.contiguous()) state.cell_positions = cell_positions_new @@ -257,7 +272,7 @@ def compute_cell_forces[T: AnyCellState]( model_output: dict[str, torch.Tensor], state: T ) -> None: """Compute cell forces for both unit and frechet methods.""" - stress = model_output["stress"] + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(state.n_systems, 1, 1) virial = -volumes * (stress + state.pressure) virial = _apply_constraints( diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 763cb646..00503a7b 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -405,17 +405,15 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 if isinstance(state, CellFireState): # For cell optimization, handle both atomic and cell position updates # This follows the ASE FIRE implementation pattern - # Transform atomic positions to fractional coordinates cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.set_constrained_positions( - torch.linalg.solve( - cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) - ).squeeze(-1) - + dr_atom - ) + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) + ).squeeze(-1) + # Store fractional positions (will transform to Cartesian after cell update) + new_frac_positions = frac_positions + dr_atom # Update cell positions directly based on stored cell filter type if hasattr(state, "cell_filter") and state.cell_filter is not None: @@ -436,18 +434,21 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) deform_grad_new = cell_positions_new / cell_factor_expanded - # Update cell from deformation gradient - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, deform_grad_new.transpose(-2, -1) - ) + # Compute new cell from deformation gradient + new_col_vector_cell = torch.bmm(deform_grad_new, state.reference_cell) + + # Apply cell constraints and scale positions to new cell coordinates + # (needed for correct displacement calculation in position constraints) + state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) - # Transform positions back to Cartesian + # Transform fractional positions to Cartesian using NEW deformation gradient new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) + state.set_constrained_positions( torch.bmm( - state.positions.unsqueeze(1), + new_frac_positions.unsqueeze(1), new_deform_grad[state.system_idx].transpose(-2, -1), ).squeeze(1) ) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index f4dd79a3..a9ac8b71 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -504,10 +504,11 @@ def lbfgs_step( # noqa: PLR0915, C901 deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] # Update cell: new_cell = reference_cell @ deform_grad^T - # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + # Use set_constrained_cell to apply cell constraints (e.g. FixSymmetry) + new_col_vector_cell = torch.bmm( + deform_grad_new, state.reference_cell ) # [S, 3, 3] + state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) # Apply position step in fractional space, then convert to Cartesian new_frac = frac_positions + step_positions # [N, 3] diff --git a/torch_sim/state.py b/torch_sim/state.py index b1406cb7..063e9b58 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -277,6 +277,23 @@ def set_constrained_positions(self, new_positions: torch.Tensor) -> None: constraint.adjust_positions(self, new_positions) self.positions = new_positions + def set_constrained_cell( + self, + new_cell: torch.Tensor, + scale_atoms: bool = False, # noqa: FBT001, FBT002 + ) -> None: + """Set the cell, apply constraints, and optionally scale atomic positions. + + Args: + new_cell: New cell tensor with shape (n_systems, 3, 3) + in column vector convention + scale_atoms: Whether to scale atomic positions to preserve + fractional coordinates. Defaults to False. + """ + for constraint in self.constraints: + constraint.adjust_cell(self, new_cell) + self.set_cell(new_cell, scale_atoms=scale_atoms) + @property def constraints(self) -> list[Constraint]: """Get the constraints for the SimState. @@ -1104,8 +1121,11 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Merge constraints constraint_lists = [state.constraints for state in states] + num_systems_per_state = [state.n_systems for state in states] constraints = merge_constraints( - constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) + constraint_lists, + torch.tensor(num_atoms_per_state, device=target_device), + torch.tensor(num_systems_per_state, device=target_device), ) # Create a new instance of the same class diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py new file mode 100644 index 00000000..8677e8f0 --- /dev/null +++ b/torch_sim/symmetrize.py @@ -0,0 +1,237 @@ +"""Symmetry utilities for crystal structures using moyopy. + +Functions operate on single (unbatched) systems. The ``n_ops`` dimension +refers to the number of symmetry operations of the space group. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + + +if TYPE_CHECKING: + from moyopy import MoyoDataset + + from torch_sim.state import SimState + + +def _moyo_dataset( + cell: torch.Tensor, + frac_pos: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 1e-4, +) -> MoyoDataset: + """Get MoyoDataset from cell, fractional positions, and atomic numbers.""" + from moyopy import Cell, MoyoDataset + + moyo_cell = Cell( + basis=cell.detach().cpu().tolist(), + positions=frac_pos.detach().cpu().tolist(), + numbers=atomic_numbers.detach().cpu().int().tolist(), + ) + return MoyoDataset(moyo_cell, symprec=symprec) + + +def _extract_symmetry_ops( + dataset: MoyoDataset, dtype: torch.dtype, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract rotation and translation tensors from a MoyoDataset. + + Returns: + (rotations, translations) with shapes (n_ops, 3, 3) and (n_ops, 3). + """ + rotations = torch.as_tensor( + dataset.operations.rotations, dtype=dtype, device=device + ).round() + translations = torch.as_tensor( + dataset.operations.translations, dtype=dtype, device=device + ) + return rotations, translations + + +def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]: + """Get MoyoDataset for each system in a SimState.""" + datasets = [] + for single in state.split(): + cell = single.row_vector_cell[0] + frac = single.positions @ torch.linalg.inv(cell) + datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec)) + return datasets + + +# Above this threshold, build_symmetry_map falls back to a per-operation loop +# to avoid allocating an O(n_ops * n_atoms^2) tensor that can OOM on supercells. +_SYMM_MAP_CHUNK_THRESHOLD = 200 + + +def build_symmetry_map( + rotations: torch.Tensor, + translations: torch.Tensor, + frac_pos: torch.Tensor, +) -> torch.Tensor: + """Build atom mapping for each symmetry operation. + + For each (R, t), maps atom i to atom j where R @ frac_i + t ≈ frac_j (mod 1). + + Returns: + Symmetry mapping tensor, shape (n_ops, n_atoms). + """ + n_ops = rotations.shape[0] + n_atoms = frac_pos.shape[0] + + if n_atoms <= _SYMM_MAP_CHUNK_THRESHOLD: + # Vectorized: allocates (n_ops, n_atoms, n_atoms, 3) — fast for small systems + # einsum computes R[o] @ frac[n] for all (o, n) pairs at once + new_pos = torch.einsum("oij,nj->oni", rotations, frac_pos) + translations[:, None] + delta = frac_pos[None, None] - new_pos[:, :, None] + delta -= delta.round() + return torch.argmin(torch.linalg.norm(delta, dim=-1), dim=-1).long() + + # Per-op loop: allocates only (n_atoms, n_atoms, 3) at a time + # Equivalent to vectorized path: frac @ R.T == R @ frac per row + result = torch.empty(n_ops, n_atoms, dtype=torch.long, device=frac_pos.device) + for op_idx in range(n_ops): + new_pos_op = frac_pos @ rotations[op_idx].T + translations[op_idx] + delta = frac_pos[None, :, :] - new_pos_op[:, None, :] + delta -= delta.round() + result[op_idx] = torch.argmin(torch.linalg.norm(delta, dim=-1), dim=-1) + return result + + +def prep_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 1e-4, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get symmetry rotations and atom mappings for a structure. + + Returns: + (rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms). + """ + frac_pos = positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device) + return rotations, build_symmetry_map(rotations, translations, frac_pos) + + +def _refine_symmetry_impl( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Core refinement returning all intermediate data for reuse. + + Returns: + (refined_cell, refined_positions, rotations, translations) + """ + dtype, device = cell.dtype, cell.device + frac_pos = positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + rotations, translations = _extract_symmetry_ops(dataset, dtype, device) + n_ops, n_atoms = rotations.shape[0], positions.shape[0] + + # Symmetrize cell metric: g_sym = avg(R^T @ g @ R), then polar decomposition + metric = cell @ cell.T + metric_sym = torch.einsum("nji,jk,nkl->il", rotations, metric, rotations) / n_ops + + def _mat_sqrt(mat: torch.Tensor) -> torch.Tensor: + evals, evecs = torch.linalg.eigh(mat) + return evecs @ torch.diag(evals.clamp(min=0).sqrt()) @ evecs.T + + new_cell = _mat_sqrt(metric_sym) @ torch.linalg.solve(_mat_sqrt(metric), cell) + + # Symmetrize positions via displacement averaging over symmetry orbits + new_frac = positions @ torch.linalg.inv(new_cell) + symm_map = build_symmetry_map(rotations, translations, new_frac) + + transformed = torch.einsum("oij,nj->oni", rotations, new_frac) + translations[:, None] + disp = transformed - new_frac[symm_map] + disp -= disp.round() # wrap into [-0.5, 0.5] + + target = symm_map.reshape(-1).unsqueeze(-1).expand(-1, 3) + accum = torch.zeros(n_atoms, 3, dtype=dtype, device=device) + accum.scatter_add_(0, target, disp.reshape(-1, 3)) + + new_positions = (new_frac + accum / n_ops) @ new_cell + return new_cell, new_positions, rotations, translations + + +def refine_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor]: + """Symmetrize cell and positions according to the detected space group. + + Uses polar decomposition for the cell metric tensor and scatter-add + averaging over symmetry orbits for atomic positions. + + Returns: + (symmetrized_cell, symmetrized_positions) as row vectors. + """ + new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl( + cell, positions, atomic_numbers, symprec + ) + return new_cell, new_positions + + +def refine_and_prep_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Refine symmetry and get ops/mappings in a single moyopy call. + + Combines ``refine_symmetry`` and ``prep_symmetry`` to avoid redundant + symmetry detection. Used by ``FixSymmetry.from_state``. + + Returns: + (refined_cell, refined_positions, rotations, symm_map) + """ + new_cell, new_positions, rotations, translations = _refine_symmetry_impl( + cell, positions, atomic_numbers, symprec + ) + # Build symm_map on the final refined fractional coordinates + refined_frac = new_positions @ torch.linalg.inv(new_cell) + symm_map = build_symmetry_map(rotations, translations, refined_frac) + return new_cell, new_positions, rotations, symm_map + + +def symmetrize_rank1( + lattice: torch.Tensor, + vectors: torch.Tensor, + rotations: torch.Tensor, + symm_map: torch.Tensor, +) -> torch.Tensor: + """Symmetrize a rank-1 per-atom tensor (forces, velocities, displacements). + + Works in fractional coordinates internally. Returns symmetrized Cartesian tensor. + """ + n_ops, n_atoms = rotations.shape[0], vectors.shape[0] + scaled = vectors @ torch.linalg.inv(lattice) + # Rotate each vector by each symmetry op: scaled @ R^T + rotated = torch.einsum("ij,nkj->nik", scaled, rotations).reshape(-1, 3) + # Scatter-add to target atoms and average + target = symm_map.reshape(-1).unsqueeze(-1).expand(-1, 3) + accum = torch.zeros(n_atoms, 3, dtype=vectors.dtype, device=vectors.device) + accum.scatter_add_(0, target, rotated) + return (accum / n_ops) @ lattice + + +def symmetrize_rank2( + lattice: torch.Tensor, + tensor: torch.Tensor, + rotations: torch.Tensor, +) -> torch.Tensor: + """Symmetrize a rank-2 tensor (stress, strain) over all symmetry operations.""" + n_ops = rotations.shape[0] + inv_lat = torch.linalg.inv(lattice) + scaled = lattice @ tensor @ lattice.T + sym_scaled = torch.einsum("nji,jk,nkl->il", rotations, scaled, rotations) / n_ops + return inv_lat @ sym_scaled @ inv_lat.T