From d0ab8e486c1abeb5494f2b161bd231b852b13bf0 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 3 Feb 2026 17:07:10 +0000 Subject: [PATCH 01/16] initial FixSymmetry --- pyproject.toml | 2 + tests/test_fix_symmetry.py | 718 +++++++++++++++++++++++++++ torch_sim/constraints.py | 659 +++++++++++++++++++++++- torch_sim/optimizers/cell_filters.py | 34 +- torch_sim/optimizers/fire.py | 34 +- torch_sim/state.py | 17 + torch_sim/symmetrize.py | 456 +++++++++++++++++ 7 files changed, 1872 insertions(+), 48 deletions(-) create mode 100644 tests/test_fix_symmetry.py create mode 100644 torch_sim/symmetrize.py diff --git a/pyproject.toml b/pyproject.toml index 8093d764a..abd719cf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,10 @@ test = [ "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", + "spglib>=2.5", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] +symmetry = ["spglib>=2.5"] mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py new file mode 100644 index 000000000..7b9c4458f --- /dev/null +++ b/tests/test_fix_symmetry.py @@ -0,0 +1,718 @@ +"""Tests for the FixSymmetry constraint.""" + +from typing import Literal, TypedDict + +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.build import bulk +from ase.constraints import FixSymmetry as ASEFixSymmetry +from ase.spacegroup.symmetrize import refine_symmetry as ase_refine_symmetry +from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress +from pymatgen.core import Lattice, Structure +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.constraints import FixSymmetry +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.symmetrize import get_symmetry_datasets + +# Skip all tests if spglib is not available +spglib = pytest.importorskip("spglib") +from spglib import SpglibDataset + + +class OptimizationResult(TypedDict): + """Return type for run_optimization_check_symmetry.""" + + initial_spacegroups: list[int | None] + final_spacegroups: list[int | None] + initial_datasets: list[SpglibDataset] + final_datasets: list[SpglibDataset] + final_state: ts.SimState + final_atoms_list: list[Atoms] + + +# ============================================================================= +# Structure Definitions (Single Source of Truth) +# ============================================================================= + +# Expected space groups for each structure type +SPACEGROUPS = { + "fcc": 225, # Fm-3m + "hcp": 194, # P6_3/mmc + "diamond": 227, # Fd-3m + "bcc": 229, # Im-3m + "p6bar": 174, # P-6 (low symmetry) +} + +# Default maximum optimization steps for tests +MAX_STEPS = 30 + +# Default dtype for tests (torch.float64 recommended for numerical precision) +DTYPE = torch.float64 + +# Default symmetry precision for spglib +SYMPREC = 0.01 + + +def _make_p6bar() -> Atoms: + """Create low-symmetry P-6 (space group 174) structure using pymatgen.""" + lattice = Lattice.hexagonal(a=3.0, c=5.0) + structure = Structure.from_spacegroup( + sg=174, lattice=lattice, species=["Si"], coords=[[0.3, 0.1, 0.25]] + ) + return AseAtomsAdaptor.get_atoms(structure) + + +def make_structure(name: str) -> Atoms: + """Create a standard test structure by name. + + This is the single source of truth for test structures. + Use this instead of inline bulk() calls to avoid duplication. + + Args: + name: One of "fcc", "hcp", "diamond", "bcc", "p6bar" with optional + "_supercell" and/or "_rotated" suffix + + Returns: + ASE Atoms object + """ + base_name = name.replace("_supercell", "").replace("_rotated", "") + structures = { + "fcc": lambda: bulk("Cu", "fcc", a=3.6), + "hcp": lambda: bulk("Ti", "hcp", a=2.95, c=4.68), + "diamond": lambda: bulk("Si", "diamond", a=5.43), + "bcc": lambda: bulk("Al", "bcc", a=2 / np.sqrt(3), cubic=True), + "p6bar": _make_p6bar, + } + atoms = structures[base_name]() + if "_supercell" in name: + atoms = atoms * (2, 2, 2) + if "_rotated" in name: + # Apply 3 rotation matrices (matching ASE's test setup) + F = np.eye(3) + for k in range(3): + L = list(range(3)) + L.remove(k) + (i, j) = L + R = np.eye(3) + theta = 0.1 * (k + 1) + R[i, i] = np.cos(theta) + R[j, j] = np.cos(theta) + R[i, j] = np.sin(theta) + R[j, i] = -np.sin(theta) + F = np.dot(F, R) + atoms.set_cell(atoms.cell @ F, scale_atoms=True) + return atoms + + +# ============================================================================= +# Shared Fixtures +# ============================================================================= + + +@pytest.fixture(params=["fcc", "hcp", "hcp_supercell", "diamond", "p6bar"]) +def structure_with_spacegroup(request: pytest.FixtureRequest) -> tuple[Atoms, int]: + """Parameterized fixture returning (atoms, expected_spacegroup).""" + name = request.param + atoms = make_structure(name) + base_name = name.replace("_supercell", "") + return atoms, SPACEGROUPS[base_name] + + +@pytest.fixture +def model() -> LennardJonesModel: + """Create a LennardJonesModel for testing.""" + return LennardJonesModel( + sigma=1.0, + epsilon=0.05, + cutoff=6.0, + use_neighbor_list=False, + compute_stress=True, + dtype=DTYPE, + ) + +@pytest.fixture +def noisy_lj_model(model: LennardJonesModel): + """Create a LJ model that adds noise to forces/stress (like ASE's NoisyLennardJones).""" + + class NoisyModelWrapper: + """Wrapper that adds noise to forces and stress from an underlying model.""" + + def __init__(self, model, rng_seed: int = 1, noise_scale: float = 1e-4): + self.model = model + self.rng = np.random.RandomState(rng_seed) + self.noise_scale = noise_scale + + @property + def device(self): + return self.model.device + @property + def dtype(self): + return self.model.dtype + + def __call__(self, state): + results = self.model(state) + # Add noise to forces + if "forces" in results: + noise = self.rng.normal(size=results["forces"].shape) + results["forces"] = results["forces"] + self.noise_scale * torch.tensor( + noise, dtype=results["forces"].dtype, device=results["forces"].device + ) + # Add noise to stress + if "stress" in results: + noise = self.rng.normal(size=results["stress"].shape) + results["stress"] = results["stress"] + self.noise_scale * torch.tensor( + noise, dtype=results["stress"].dtype, device=results["stress"].device + ) + return results + return NoisyModelWrapper(model) + + +# ============================================================================= +# Shared Helper Functions +# ============================================================================= + + +def get_symmetry_dataset_from_atoms(atoms: Atoms, symprec: float = SYMPREC) -> SpglibDataset: + """Get full symmetry dataset for an ASE Atoms object using spglib directly.""" + return spglib.get_symmetry_dataset( + (atoms.cell[:], atoms.get_scaled_positions(), atoms.numbers), + symprec=symprec, + ) + + +def run_optimization_check_symmetry( + state: ts.SimState, + model: LennardJonesModel, + constraint: FixSymmetry | None = None, + adjust_cell: bool = True, + symprec: float = SYMPREC, + max_steps: int = MAX_STEPS, + force_tol: float = 0.001, +) -> OptimizationResult: + """Run optimization and return initial/final symmetry info. + + This is the core helper for testing symmetry preservation during optimization. + + Args: + state: torch-sim SimState (can be batched) + model: torch-sim model for optimization + constraint: Optional FixSymmetry constraint to apply. If None, no constraint. + adjust_cell: Whether to enable cell optimization (with Frechet filter) + symprec: Symmetry precision for spglib checks + max_steps: Maximum optimization steps + force_tol: Force convergence tolerance + + Returns: + Dict with keys: + - 'initial_spacegroups': List of initial space group numbers + - 'final_spacegroups': List of final space group numbers + - 'initial_datasets': List of full spglib datasets for initial structures + - 'final_datasets': List of full spglib datasets for final structures + - 'final_state': Final SimState + - 'final_atoms_list': List of final ASE Atoms objects + """ + # Get initial symmetry for all systems using torch_sim.symmetrize + initial_datasets = get_symmetry_datasets(state, symprec) + + if constraint is not None: + state.constraints = [constraint] + + # Run optimization + init_kwargs = {"cell_filter": ts.CellFilter.frechet} if adjust_cell else None + # When doing cell optimization, include cell_forces in convergence check + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=adjust_cell + ) + final_state = ts.optimize( + system=state, + model=model, + optimizer=ts.Optimizer.fire, + convergence_fn=convergence_fn, + init_kwargs=init_kwargs, + max_steps=max_steps, + steps_between_swaps=1, + ) + + # Get final symmetry for all systems + final_datasets = get_symmetry_datasets(final_state, symprec) + final_atoms_list = final_state.to_atoms() + + return { + "initial_spacegroups": [d.number if d else None for d in initial_datasets], + "final_spacegroups": [d.number if d else None for d in final_datasets], + "initial_datasets": initial_datasets, + "final_datasets": final_datasets, + "final_state": final_state, + "final_atoms_list": final_atoms_list, + } + + +# ============================================================================= +# Test Classes +# ============================================================================= + + +class TestFixSymmetryCreation: + """Tests for FixSymmetry constraint creation.""" + + def test_from_state_batched(self): + """Test creating FixSymmetry from batched SimState with different structures.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + torch.device("cpu"), + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + assert len(constraint.rotations) == 2 + assert len(constraint.symm_maps) == 2 + assert constraint.system_idx.shape == (2,) + # Both have cubic symmetry (48 ops) but different number of atoms + assert constraint.rotations[0].shape[0] == 48 + assert constraint.rotations[1].shape[0] == 48 + # Cu FCC has 1 atom, Si diamond has 2 + assert constraint.symm_maps[0].shape == (48, 1) + assert constraint.symm_maps[1].shape == (48, 2) + + def test_p1_identity_only(self): + """Test P1 (no symmetry) has only identity and doesn't change forces/stress.""" + atoms = Atoms( + "SiGe", positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], + cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], pbc=True + ) + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + assert constraint.rotations[0].shape[0] == 1, "P1 should have 1 operation" + + # Forces should be unchanged + forces = torch.randn(2, 3, dtype=DTYPE) + original_forces = forces.clone() + constraint.adjust_forces(state, forces) + assert torch.allclose(forces, original_forces, atol=1e-10) + + # Stress should be unchanged (identity symmetrization) + stress = torch.randn(1, 3, 3, dtype=DTYPE) + # Make it symmetric (stress tensors are symmetric) + stress = (stress + stress.transpose(-1, -2)) / 2 + original_stress = stress.clone() + constraint.adjust_stress(state, stress) + assert torch.allclose(stress, original_stress, atol=1e-10) + + def test_symmetry_datasets_match_spglib(self): + """Test get_symmetry_datasets matches spglib for single and batched states.""" + atoms_list = [make_structure(name) for name in ["fcc", "diamond", "hcp"]] + + # Test batched state + batched_state = ts.io.atoms_to_state(atoms_list, torch.device("cpu"), DTYPE) + ts_datasets = get_symmetry_datasets(batched_state, SYMPREC) + assert len(ts_datasets) == 3 + + # Compare each with direct spglib call (covers both single and batched) + for i, atoms in enumerate(atoms_list): + spglib_dataset = get_symmetry_dataset_from_atoms(atoms, SYMPREC) + + # Compare key fields + assert ts_datasets[i].number == spglib_dataset.number, ( + f"Space group mismatch for {atoms_list[i].get_chemical_formula()}: " + f"{ts_datasets[i].number} vs {spglib_dataset.number}" + ) + assert ts_datasets[i].international == spglib_dataset.international + assert ts_datasets[i].hall == spglib_dataset.hall + assert len(ts_datasets[i].rotations) == len(spglib_dataset.rotations) + assert np.allclose(ts_datasets[i].rotations, spglib_dataset.rotations) + assert np.allclose(ts_datasets[i].translations, spglib_dataset.translations, atol=1e-10) + + +class TestFixSymmetryComparisonWithASE: + """Compare TorchSim FixSymmetry with ASE's implementation.""" + + def test_symmetrize_forces_batched(self): + """Test force symmetrization for batched systems with different structures.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + torch.device("cpu"), + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Create asymmetric forces (1 atom for Cu FCC, 2 atoms for Si diamond) + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], dtype=DTYPE + ) + + constraint.adjust_forces(state, forces) + + # First atom (Cu FCC) should have zero force due to cubic symmetry + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + + def test_force_symmetrization_matches_ase(self): + """Compare force symmetrization with ASE using a multi-atom structure.""" + atoms = make_structure("p6bar") + + # Create TorchSim state and constraint + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Set up ASE constraint + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + + # Create random test forces + np.random.seed(42) + forces_np = np.random.randn(len(atoms), 3) + forces_ts = torch.tensor(forces_np.copy(), dtype=DTYPE) + + # Symmetrize with both + ts_constraint.adjust_forces(state, forces_ts) + ase_constraint.adjust_forces(ase_atoms, forces_np) + + # Compare results + assert np.allclose(forces_ts.numpy(), forces_np, atol=1e-10) + + def test_stress_symmetrization_matches_ase(self): + """Compare stress symmetrization with ASE's implementation.""" + atoms = make_structure("p6bar") + + # Create TorchSim state and constraint + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Set up ASE constraint + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + + # Create asymmetric but symmetric (as a matrix) stress tensor + stress_3x3 = np.array([ + [10.0, 1.0, 0.5], + [1.0, 8.0, 0.3], + [0.5, 0.3, 6.0] + ]) + + # ASE uses Voigt notation + stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3) + stress_voigt_copy = stress_voigt.copy() + + # TorchSim uses 3x3 tensor with batch dimension + stress_ts = torch.tensor([stress_3x3.copy()], dtype=DTYPE) + + # Symmetrize with both + ts_constraint.adjust_stress(state, stress_ts) + ase_constraint.adjust_stress(ase_atoms, stress_voigt_copy) + + # Convert ASE result back to 3x3 + ase_result_3x3 = voigt_6_to_full_3x3_stress(stress_voigt_copy) + + # Compare results + assert np.allclose(stress_ts[0].numpy(), ase_result_3x3, atol=1e-10), ( + f"Stress mismatch:\nTorchSim:\n{stress_ts[0].numpy()}\nASE:\n{ase_result_3x3}" + ) + + def test_cell_deformation_symmetrization_matches_ase(self): + """Compare cell deformation symmetrization with ASE.""" + atoms = make_structure("p6bar") + + # Create TorchSim state and constraint + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Set up ASE constraint + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + + # Create a small asymmetric deformation of the cell + original_cell = ase_atoms.get_cell().copy() + deformed_cell = original_cell.copy() + deformed_cell[0, 1] += 0.05 # Small off-diagonal perturbation + + # TorchSim - need column vector convention for adjust_cell + new_cell_ts = torch.tensor( + [deformed_cell.copy().T], dtype=DTYPE # Transpose for column vectors + ) + ts_constraint.adjust_cell(state, new_cell_ts) + ts_result = new_cell_ts[0].mT.numpy() # Back to row vectors + + # ASE + ase_cell = deformed_cell.copy() + ase_constraint.adjust_cell(ase_atoms, ase_cell) + + # Compare results + assert np.allclose(ts_result, ase_cell, atol=1e-10), ( + f"Cell mismatch:\nTorchSim:\n{ts_result}\nASE:\n{ase_cell}" + ) + + +class TestFixSymmetryMergeAndSelect: + """Tests for FixSymmetry.merge, select_constraint, and select_sub_constraint methods.""" + + def test_merge_two_constraints(self): + """Test merging two FixSymmetry constraints.""" + state1 = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + state2 = ts.io.atoms_to_state( + make_structure("diamond"), torch.device("cpu"), DTYPE + ) + c1 = FixSymmetry.from_state(state1, symprec=SYMPREC) + c2 = FixSymmetry.from_state(state2, symprec=SYMPREC) + + merged = FixSymmetry.merge([c1, c2], state_indices=[0, 1], atom_offsets=None) + + assert len(merged.rotations) == 2 + assert len(merged.symm_maps) == 2 + assert merged.system_idx.tolist() == [0, 1] + + @pytest.mark.parametrize("mismatch_field", ["adjust_positions", "adjust_cell"]) + def test_merge_mismatched_settings_raises(self, mismatch_field: Literal['adjust_positions'] | Literal['adjust_cell']): + """Test that merging constraints with different settings raises ValueError.""" + state1 = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + state2 = ts.io.atoms_to_state( + make_structure("diamond"), torch.device("cpu"), DTYPE + ) + + kwargs1 = {mismatch_field: True} + kwargs2 = {mismatch_field: False} + c1 = FixSymmetry.from_state(state1, symprec=SYMPREC, **kwargs1) + c2 = FixSymmetry.from_state(state2, symprec=SYMPREC, **kwargs2) + + with pytest.raises(ValueError, match=f"different {mismatch_field} settings"): + FixSymmetry.merge([c1, c2], state_indices=[0, 1], atom_offsets=None) + + def test_select_constraint_single_system(self): + """Test selecting a single system from batched constraint.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + torch.device("cpu"), + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Create masks to select only first system + atom_mask = torch.tensor([True, False, False], dtype=torch.bool) # 1 Cu + 2 Si atoms + system_mask = torch.tensor([True, False], dtype=torch.bool) + + selected = constraint.select_constraint(atom_mask, system_mask) + + assert selected is not None + assert len(selected.rotations) == 1 + assert len(selected.symm_maps) == 1 + assert selected.system_idx.shape == (1,) + # Should have Cu's 48 symmetry operations + assert selected.rotations[0].shape[0] == 48 + + def test_select_sub_constraint(self): + """Test selecting a specific system by index.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + torch.device("cpu"), + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Select second system (Si diamond) + # Note: atom_idx is ignored for FixSymmetry + selected = constraint.select_sub_constraint( + atom_idx=torch.tensor([1, 2]), sys_idx=1 + ) + + assert selected is not None + assert len(selected.rotations) == 1 + # Si diamond has 2 atoms + assert selected.symm_maps[0].shape[1] == 2 + # New system_idx should be 0 (renumbered) + assert selected.system_idx.item() == 0 + + +class TestFixSymmetryWithOptimization: + """Test FixSymmetry with actual optimization routines. + + Uses the shared run_optimization_check_symmetry helper for most tests. + """ + + @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond", "p6bar"]) + @pytest.mark.parametrize( + "adjust_positions,adjust_cell", + [(True, True), (True, False), (False, True), (False, False)], + ) + def test_distorted_structure_preserves_symmetry( + self, noisy_lj_model, structure_name: str, adjust_positions: bool, adjust_cell: bool + ): + """Test that a distorted structure relaxes while preserving symmetry. + + All combinations of adjust_positions and adjust_cell should preserve symmetry + because forces are always symmetrized (matching ASE's behavior). + + """ + atoms = make_structure(structure_name) + expected_spacegroup = SPACEGROUPS[structure_name] + + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + + # Create constraint BEFORE distorting - captures ideal symmetry + constraint = FixSymmetry.from_state( + state, + symprec=SYMPREC, + adjust_positions=adjust_positions, + adjust_cell=adjust_cell, + ) + + # Now distort the cell (uniform scaling preserves symmetry but creates forces) + # Scale by 0.9 to compress - this creates repulsive forces + scale_factor = 0.9 + state.cell = state.cell * scale_factor + state.positions = state.positions * scale_factor + + result = run_optimization_check_symmetry( + state, noisy_lj_model, constraint=constraint, adjust_cell=adjust_cell, + max_steps=MAX_STEPS, force_tol=0.01 # Looser tolerance to ensure movement + ) + + assert result["final_spacegroups"][0] == expected_spacegroup, ( + f"Space group changed from {expected_spacegroup} to " + f"{result['final_spacegroups'][0]} with adjust_positions={adjust_positions}, " + f"adjust_cell={adjust_cell}" + ) + + @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) + def test_cell_filter_preserves_symmetry(self, model: LennardJonesModel, cell_filter: ts.CellFilter | ts.CellFilter): + """Test that cell filters with FixSymmetry preserve symmetry.""" + state = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + state.constraints = [constraint] + + initial_datasets = get_symmetry_datasets(state, symprec=SYMPREC) + + final_state = ts.optimize( + system=state, + model=model, + optimizer=ts.Optimizer.gradient_descent, + convergence_fn=ts.generate_force_convergence_fn(force_tol=0.01), + init_kwargs={"cell_filter": cell_filter}, + max_steps=MAX_STEPS, + ) + + final_datasets = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert initial_datasets[0].number == final_datasets[0].number + + @pytest.mark.parametrize("rotated", [False, True]) + def test_noisy_model_loses_symmetry_without_constraint(self, noisy_lj_model, rotated: bool): + """Test that WITHOUT FixSymmetry, optimization with noisy forces loses symmetry. + + This is a negative control - verifies that noisy forces will break symmetry + if no constraint is applied. Mirrors ASE's test_no_symmetrization. + """ + name = "bcc_rotated" if rotated else "bcc" + bcc_atoms = make_structure(name) + state = ts.io.atoms_to_state(bcc_atoms, torch.device("cpu"), DTYPE) + result = run_optimization_check_symmetry( + state, noisy_lj_model, constraint=None, max_steps=MAX_STEPS, symprec=SYMPREC + ) + + # Initial should be BCC (space group 229) + assert result["initial_spacegroups"][0] == 229 + # Final should have lost symmetry (different space group) + assert result["final_spacegroups"][0] != 229, ( + f"Symmetry should be lost without constraint, but final space group " + f"is still {result['final_spacegroups'][0]}" + ) + + @pytest.mark.parametrize("rotated", [False, True]) + def test_noisy_model_preserves_symmetry_with_constraint(self, noisy_lj_model, rotated: bool): + """Test that WITH FixSymmetry, optimization with noisy forces preserves symmetry. + + Mirrors ASE's test_sym_adj_cell. + """ + bcc_atoms = make_structure("bcc_rotated" if rotated else "bcc") + state = ts.io.atoms_to_state(bcc_atoms, torch.device("cpu"), DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + result = run_optimization_check_symmetry( + state, noisy_lj_model, constraint=constraint, max_steps=MAX_STEPS, + ) + + assert result["initial_spacegroups"][0] == 229 + assert result["final_spacegroups"][0] == 229, ( + f"Symmetry should be preserved with constraint, but final spacegroup " + f"changed to {result['final_spacegroups'][0]}" + ) + + + +class TestFixSymmetryEdgeCases: + """Tests for edge cases and error handling.""" + + def test_get_removed_dof_raises(self): + """Test that get_removed_dof raises NotImplementedError.""" + state = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + with pytest.raises(NotImplementedError, match="get_removed_dof"): + constraint.get_removed_dof(state) + + def test_large_deformation_gradient_raises(self): + """Test that large deformation gradient raises RuntimeError.""" + state = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Create a very large deformation (> 0.25) + # FCC cell has zeros on diagonal, so modify all elements by a large factor + new_cell_col = state.cell.clone() # Column vector convention + new_cell_col[0] *= 1.5 # 50% stretch of entire cell + + with pytest.raises(RuntimeError, match="large deformation gradient"): + constraint.adjust_cell(state, new_cell_col) + + def test_medium_deformation_gradient_warns(self): + """Test that medium deformation gradient emits warning.""" + state = ts.io.atoms_to_state( + make_structure("fcc"), torch.device("cpu"), DTYPE + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + + # Create a medium deformation (> 0.15 but < 0.25) + new_cell_col = state.cell.clone() # Column vector convention + new_cell_col[0] *= 1.2 # 20% stretch of entire cell + + with pytest.warns(UserWarning, match="may be ill-behaved"): + constraint.adjust_cell(state, new_cell_col) + + @pytest.mark.parametrize("refine_symmetry_state", [True, False]) + def test_from_state_refine_symmetry(self, refine_symmetry_state: bool): + """Test from_state with different refine_symmetry_state settings.""" + atoms = make_structure("fcc") + # Add small perturbation + perturbed = atoms.copy() + perturbed.positions += np.random.randn(*perturbed.positions.shape) * 0.001 + + state = ts.io.atoms_to_state( + perturbed, torch.device("cpu"), DTYPE + ) + original_positions = state.positions.clone() + original_cell = state.cell.clone() + + _ = FixSymmetry.from_state( + state, symprec=SYMPREC, refine_symmetry_state=refine_symmetry_state + ) + + if not refine_symmetry_state: + # State should not be modified + assert torch.allclose(state.positions, original_positions) + assert torch.allclose(state.cell, original_cell) + else: + # State may be modified (positions refined to ideal) + # We just check the function runs without error + assert state.positions.shape == original_positions.shape \ No newline at end of file diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index f0ed85991..7c29f2dee 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -16,6 +16,13 @@ import torch +from torch_sim.symmetrize import ( + _prep_symmetry, + refine_symmetry, + symmetrize_rank1, + symmetrize_rank2, +) + if TYPE_CHECKING: from torch_sim.state import SimState @@ -77,6 +84,28 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ + @abstractmethod + def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: + """Adjust stress tensor to satisfy the constraint. + + This method should modify stress in-place. + + Args: + state: Current simulation state + stress: Stress tensor to be adjusted in-place + """ + + @abstractmethod + def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: + """Adjust cell to satisfy the constraint. + + This method should modify cell in-place. + + Args: + state: Current simulation state + cell: Cell tensor to be adjusted in-place (column vector convention) + """ + @abstractmethod def select_constraint( self, atom_mask: torch.Tensor, system_mask: torch.Tensor @@ -100,6 +129,34 @@ def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Constraint for the given atom and system index """ + @classmethod + def merge( + cls, + constraints: list[Self], + state_indices: list[int], + atom_offsets: torch.Tensor, + ) -> Self: + """Merge multiple constraints of the same type into one. + + This method is called during state concatenation to combine constraints + from multiple states. Subclasses can override this for custom merge logic. + + Args: + constraints: List of constraints to merge (all of the same type) + state_indices: Index of the source state for each constraint + atom_offsets: Cumulative atom counts for offset calculation + + Returns: + A single merged constraint + + Raises: + NotImplementedError: If the constraint type doesn't support merging + """ + raise NotImplementedError( + f"Constraint type {cls.__name__} does not implement merge. " + "Override this method to support state concatenation." + ) + def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: cumsum_atom_mask = torch.cumsum(~mask, dim=0) @@ -198,6 +255,29 @@ def select_sub_constraint( return None return type(self)(new_atom_idx) + @classmethod + def merge( + cls, + constraints: list[Self], + state_indices: list[int], + atom_offsets: torch.Tensor, + ) -> Self: + """Merge multiple AtomConstraints by concatenating indices with offsets. + + Args: + constraints: List of constraints to merge + state_indices: Index of the source state for each constraint + atom_offsets: Cumulative atom counts for offset calculation + + Returns: + A single merged constraint with adjusted atom indices + """ + all_indices = [] + for constraint, state_idx in zip(constraints, state_indices, strict=False): + offset = atom_offsets[state_idx] + all_indices.append(constraint.atom_idx + offset) + return cls(torch.cat(all_indices)) + class SystemConstraint(Constraint): """Base class for constraints that act on specific system indices. @@ -280,6 +360,29 @@ def select_sub_constraint( """ return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None + @classmethod + def merge( + cls, + constraints: list[Self], + state_indices: list[int], + atom_offsets: torch.Tensor, # noqa: ARG003 + ) -> Self: + """Merge multiple SystemConstraints by concatenating indices with offsets. + + Args: + constraints: List of constraints to merge + state_indices: Index of the source state for each constraint + atom_offsets: Cumulative atom counts (unused for SystemConstraint) + + Returns: + A single merged constraint with adjusted system indices + """ + all_indices = [] + for constraint, state_idx in zip(constraints, state_indices, strict=False): + # For SystemConstraint, the offset is the state index itself + all_indices.append(constraint.system_idx + state_idx) + return cls(torch.cat(all_indices)) + def merge_constraints( constraint_lists: list[list[AtomConstraint | SystemConstraint]], @@ -296,35 +399,32 @@ def merge_constraints( """ from collections import defaultdict - # Calculate offsets: for state i, offset = sum of atoms in states 0 to i-1 + # Calculate atom offsets: for state i, offset = sum of atoms in states 0 to i-1 device, dtype = num_atoms_per_state.device, num_atoms_per_state.dtype - cumsum_atoms = torch.cat( + atom_offsets = torch.cat( [ torch.zeros(1, device=device, dtype=dtype), torch.cumsum(num_atoms_per_state[:-1], dim=0), ] ) - # aggregate updated constraint indices by constraint type - constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) - for i, constraint_list in enumerate(constraint_lists): + # Group constraints by type, tracking their source state index + constraints_by_type: dict[type[Constraint], tuple[list, list[int]]] = defaultdict( + lambda: ([], []) + ) + for state_idx, constraint_list in enumerate(constraint_lists): for constraint in constraint_list: - if isinstance(constraint, AtomConstraint): - idxs = constraint.atom_idx - offset = cumsum_atoms[i] - elif isinstance(constraint, SystemConstraint): - idxs = constraint.system_idx - offset = i - else: - raise NotImplementedError( - f"Constraint type {type(constraint)} is not implemented" - ) - constraint_indices[type(constraint)].append(idxs + offset) + constraints, indices = constraints_by_type[type(constraint)] + constraints.append(constraint) + indices.append(state_idx) + + # Merge each group using the constraint's merge method + result = [] + for constraint_type, (constraints, state_indices) in constraints_by_type.items(): + merged = constraint_type.merge(constraints, state_indices, atom_offsets) + result.append(merged) - return [ - constraint_type(torch.cat(idxs)) - for constraint_type, idxs in constraint_indices.items() - ] + return result class FixAtoms(AtomConstraint): @@ -393,6 +493,20 @@ def adjust_forces( """ forces[self.atom_idx] = 0.0 + def adjust_stress( + self, + state: SimState, + stress: torch.Tensor, + ) -> None: + """No stress adjustment needed for FixAtoms.""" + + def adjust_cell( + self, + state: SimState, + cell: torch.Tensor, + ) -> None: + """No cell adjustment needed for FixAtoms.""" + def __repr__(self) -> str: """String representation of the constraint.""" if len(self.atom_idx) <= 10: @@ -510,6 +624,20 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces_change[self.system_idx] = lmd[self.system_idx] forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) + def adjust_stress( + self, + state: SimState, + stress: torch.Tensor, + ) -> None: + """No stress adjustment needed for FixCom.""" + + def adjust_cell( + self, + state: SimState, + cell: torch.Tensor, + ) -> None: + """No cell adjustment needed for FixCom.""" + def __repr__(self) -> str: """String representation of the constraint.""" return f"FixCom(system_idx={self.system_idx})" @@ -615,3 +743,494 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None UserWarning, stacklevel=3, ) + + +class FixSymmetry(SystemConstraint): + """Constraint to preserve spacegroup symmetry during optimization. + + This constraint symmetrizes forces, positions, and cell/stress + according to the crystal symmetry operations. Each system in a batch can + have different symmetry operations. + + Requires the spglib package to be available for automatic symmetry detection. + + The constraint works by: + - Symmetrizing forces/momenta as rank-1 tensors using all symmetry operations + - Symmetrizing position displacements similarly for position adjustments + - Symmetrizing stress/cell deformation as rank-2 tensors + + Attributes: + rotations: List of rotation matrices for each system, + shape (n_ops, 3, 3) per system. + symm_maps: List of symmetry atom mappings for each system, + shape (n_ops, n_atoms) per system. + do_adjust_positions: Whether to symmetrize position adjustments. + do_adjust_cell: Whether to symmetrize cell/stress adjustments. + + Examples: + Create from SimState: + >>> constraint = FixSymmetry.from_state(state, symprec=0.01) + """ + + # Type annotations + rotations: list[torch.Tensor] + symm_maps: list[torch.Tensor] + do_adjust_positions: bool + do_adjust_cell: bool + + def __init__( + self, + rotations: list[torch.Tensor], + symm_maps: list[torch.Tensor], + system_idx: torch.Tensor | None = None, + *, + adjust_positions: bool = True, + adjust_cell: bool = True, + ) -> None: + """Initialize FixSymmetry constraint. + + Args: + rotations: List of rotation tensors, one per system. + Each tensor has shape (n_ops, 3, 3). + symm_maps: List of symmetry mapping tensors, one per system. + Each tensor has shape (n_ops, n_atoms_in_system). + system_idx: Indices of systems this constraint applies to. + If None, defaults to [0, 1, ..., n_systems-1]. + adjust_positions: Whether to symmetrize position adjustments. + adjust_cell: Whether to symmetrize cell/stress adjustments. + + Raises: + ValueError: If lists have mismatched lengths or system_idx is wrong length. + """ + n_systems = len(rotations) + + # Validate list lengths + if len(symm_maps) != n_systems: + raise ValueError( + "rotations and symm_maps must have the same length. " + f"Got {len(rotations)}, {len(symm_maps)}." + ) + + if system_idx is None: + # Infer device from rotations tensors + device = rotations[0].device if rotations else torch.device("cpu") + system_idx = torch.arange(n_systems, device=device) + + if len(system_idx) != n_systems: + raise ValueError( + f"system_idx length ({len(system_idx)}) must match " + f"number of systems ({n_systems})" + ) + + super().__init__(system_idx=system_idx) + + self.rotations = rotations + self.symm_maps = symm_maps + self.do_adjust_positions = adjust_positions + self.do_adjust_cell = adjust_cell + + @classmethod + def from_state( + cls, + state: SimState, + symprec: float = 0.01, + *, + adjust_positions: bool = True, + adjust_cell: bool = True, + refine_symmetry_state: bool = True, + ) -> Self: + """Create FixSymmetry constraint from a SimState. + + Directly uses tensor data from the state to determine symmetry. + + Warning: + By default, this method **mutates the input state** in-place to refine + the atomic positions and cell vectors to ideal symmetric values. + Set ``refine_symmetry_state=False`` to skip this refinement if you + want to preserve the original state (though this may lead to + symmetry detection issues if the structure is not already ideal). + + Args: + state: SimState containing one or more systems. + symprec: Symmetry precision for spglib. + adjust_positions: Whether to symmetrize position adjustments. + adjust_cell: Whether to symmetrize cell/stress adjustments. + refine_symmetry_state: Whether to refine the state's positions and cell + to ideal symmetric values. When True (default), the input state + is modified in-place. When False, the state is not modified but + the constraint may not work correctly if the structure deviates + from ideal symmetry. + + Returns: + FixSymmetry constraint configured for the state's structures. + """ + try: + import spglib # noqa: F401 + except ImportError: + raise ImportError("spglib is required for FixSymmetry.from_state") from None + + rotations = [] + symm_maps = [] + + # Get atom counts per system for slicing + atoms_per_system = state.n_atoms_per_system + cumsum = torch.cat( + [ + torch.zeros(1, device=state.device, dtype=torch.long), + torch.cumsum(atoms_per_system, dim=0), + ] + ) + + for sys_idx in range(state.n_systems): + start = cumsum[sys_idx].item() + end = cumsum[sys_idx + 1].item() + + # Extract data for this system + cell = state.row_vector_cell[sys_idx] + positions = state.positions[start:end] + atomic_numbers = state.atomic_numbers[start:end] + + if refine_symmetry_state: + # Refine symmetry of the structure first + refined_cell, refined_positions = refine_symmetry( + cell, positions, atomic_numbers, symprec=symprec + ) + + # Apply refined cell and positions back to state + state.cell[sys_idx] = refined_cell.mT # row→column vector convention + state.positions[start:end] = refined_positions + + # Get symmetry operations using refined structure + rots, symm_map = _prep_symmetry( + refined_cell, refined_positions, atomic_numbers, symprec=symprec + ) + else: + # Use structure as-is without refinement + rots, symm_map = _prep_symmetry( + cell, positions, atomic_numbers, symprec=symprec + ) + + rotations.append(rots) + symm_maps.append(symm_map) + + system_idx = torch.arange(state.n_systems, device=state.device) + + return cls( + rotations=rotations, + symm_maps=symm_maps, + system_idx=system_idx, + adjust_positions=adjust_positions, + adjust_cell=adjust_cell, + ) + + @classmethod + def merge( + cls, + constraints: list[Self], + state_indices: list[int], + atom_offsets: torch.Tensor, # noqa: ARG003 + ) -> Self: + """Merge multiple FixSymmetry constraints into one. + + Args: + constraints: List of FixSymmetry constraints to merge. + state_indices: Index of the source state for each constraint. + atom_offsets: Cumulative atom counts (unused for FixSymmetry). + + Returns: + Merged FixSymmetry constraint. + + Raises: + ValueError: If constraints list is empty or if constraints have + mismatched adjust_positions or adjust_cell settings. + """ + if not constraints: + raise ValueError("Cannot merge empty list of constraints") + + # Validate that all constraints have matching settings + first_adjust_positions = constraints[0].do_adjust_positions + first_adjust_cell = constraints[0].do_adjust_cell + + for i, constraint in enumerate(constraints[1:], start=1): + if constraint.do_adjust_positions != first_adjust_positions: + raise ValueError( + f"Cannot merge FixSymmetry constraints with different " + f"adjust_positions settings: constraint 0 has " + f"adjust_positions={first_adjust_positions}, but constraint " + f"{i} has adjust_positions={constraint.do_adjust_positions}" + ) + if constraint.do_adjust_cell != first_adjust_cell: + raise ValueError( + f"Cannot merge FixSymmetry constraints with different " + f"adjust_cell settings: constraint 0 has " + f"adjust_cell={first_adjust_cell}, but constraint " + f"{i} has adjust_cell={constraint.do_adjust_cell}" + ) + + rotations = [] + symm_maps = [] + system_indices = [] + + for constraint, offset in zip(constraints, state_indices, strict=False): + for i in range(len(constraint.rotations)): + rotations.append(constraint.rotations[i]) + symm_maps.append(constraint.symm_maps[i]) + system_indices.append(offset + i) + + device = rotations[0].device + + return cls( + rotations=rotations, + symm_maps=symm_maps, + system_idx=torch.tensor(system_indices, device=device), + adjust_positions=first_adjust_positions, + adjust_cell=first_adjust_cell, + ) + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + FixSymmetry doesn't explicitly remove DOF in the same way as FixAtoms. + This matches ASE's FixSymmetry behavior which also raises NotImplementedError. + + Args: + state: Simulation state + + Raises: + NotImplementedError: FixSymmetry does not support DOF counting. + """ + raise NotImplementedError("FixSymmetry does not implement get_removed_dof.") + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Symmetrize position displacements. + + Args: + state: Current simulation state + new_positions: Proposed new positions to be adjusted in-place + """ + if not self.do_adjust_positions: + return + + # Compute displacement from current positions + displacement = new_positions - state.positions + + # Symmetrize the displacement + self._symmetrize_rank1(state, displacement) + + # Apply symmetrized displacement + new_positions[:] = state.positions + displacement + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Symmetrize forces according to crystal symmetry. + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + self._symmetrize_rank1(state, forces) + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Symmetrize momenta according to crystal symmetry. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted in-place + """ + self._symmetrize_rank1(state, momenta) + + def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: + """Symmetrize cell deformation in-place. + + Computes the deformation gradient as ``(cell_inv @ new_cell).T - I`` + and symmetrizes it as a rank-2 tensor. + + Args: + state: Current simulation state + new_cell: Proposed new cell tensor of shape (n_systems, 3, 3) + in column vector convention, modified in-place. + + Raises: + RuntimeError: If the deformation gradient step is too large (> 0.25), + which can cause incorrect symmetrization. + + Warns: + UserWarning: If the deformation gradient step is large (> 0.15), + symmetrization may be ill-behaved. + """ + if not self.do_adjust_cell: + return + + device = state.device + dtype = state.dtype + identity = torch.eye(3, device=device, dtype=dtype) + + for sys_idx_local, sys_idx_global in enumerate(self.system_idx): + # Get current and new cells in row vector convention + cur_cell = state.row_vector_cell[sys_idx_global] + new_cell_row = new_cell[sys_idx_global].mT + + # Calculate deformation gradient + cur_cell_inv = torch.linalg.inv(cur_cell) + delta_deform_grad = (cur_cell_inv @ new_cell_row).mT - identity + + # Check for large deformation gradient (following ASE) + max_delta = torch.abs(delta_deform_grad).max().item() + if max_delta > 0.25: + raise RuntimeError( + f"FixSymmetry adjust_cell does not work properly with large " + f"deformation gradient step {max_delta:.4f} > 0.25. " + f"Consider using smaller optimization steps." + ) + if max_delta > 0.15: + warnings.warn( + f"FixSymmetry adjust_cell may be ill-behaved with large " + f"deformation gradient step {max_delta:.4f} > 0.15", + UserWarning, + stacklevel=2, + ) + + # Symmetrize deformation gradient directly + symmetrized_delta = symmetrize_rank2( + cur_cell, delta_deform_grad, self.rotations[sys_idx_local].to(dtype=dtype) + ) + + # Reconstruct cell and update in-place + new_cell_row_sym = cur_cell @ (symmetrized_delta + identity).mT + new_cell[sys_idx_global] = new_cell_row_sym.mT # Back to column convention + + def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: + """Symmetrize stress tensor in-place. + + Args: + state: Current simulation state + stress: Stress tensor of shape (n_systems, 3, 3), modified in-place. + """ + dtype = stress.dtype + + for sys_idx_local, sys_idx_global in enumerate(self.system_idx): + # Get current cell and symmetrize stress directly + cur_cell = state.row_vector_cell[sys_idx_global] + sys_stress = stress[sys_idx_global] + symmetrized = symmetrize_rank2( + cur_cell, sys_stress, self.rotations[sys_idx_local].to(dtype=dtype) + ) + stress[sys_idx_global] = symmetrized + + def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: + """Symmetrize rank-1 tensors (forces, momenta, displacements) in-place. + + Uses fractional-coordinate rotations from spglib together with the current + cell to transform vectors. The cell is fetched at runtime to ensure + correctness during variable-cell relaxation. + + Args: + state: Current simulation state (used for cell and atom indexing) + vectors: Tensor of shape (n_atoms, 3) to be symmetrized in-place + """ + # Get atom counts per system + atoms_per_system = state.n_atoms_per_system + cumsum = torch.cat( + [ + torch.zeros(1, device=state.device, dtype=torch.long), + torch.cumsum(atoms_per_system, dim=0), + ] + ) + + dtype = vectors.dtype + for sys_idx_local, sys_idx_global in enumerate(self.system_idx): + start = cumsum[sys_idx_global].item() + end = cumsum[sys_idx_global + 1].item() + + # Extract vectors for this system + sys_vectors = vectors[start:end] + + # Get current cell for this system + cell = state.row_vector_cell[sys_idx_global] + + # Symmetrize directly + symmetrized = symmetrize_rank1( + cell, + sys_vectors, + self.rotations[sys_idx_local].to(dtype=dtype), + self.symm_maps[sys_idx_local], + ) + + # Update in place + vectors[start:end] = symmetrized + + def select_constraint( + self, + atom_mask: torch.Tensor, # noqa: ARG002 + system_mask: torch.Tensor, + ) -> Self | None: + """Select constraint for systems matching the mask. + + Args: + atom_mask: Boolean mask for atoms (not used for SystemConstraint) + system_mask: Boolean mask for systems to keep + + Returns: + New FixSymmetry for selected systems, or None if no systems match. + """ + # Get indices of systems that are in both system_mask and self.system_idx + keep_global_indices = torch.where(system_mask)[0] + mask = torch.isin(self.system_idx, keep_global_indices) + + if not mask.any(): + return None + + new_rotations = [self.rotations[i] for i in range(len(mask)) if mask[i]] + new_symm_maps = [self.symm_maps[i] for i in range(len(mask)) if mask[i]] + + # Remap system indices + new_system_idx = _mask_constraint_indices(self.system_idx[mask], system_mask) + + return type(self)( + rotations=new_rotations, + symm_maps=new_symm_maps, + system_idx=new_system_idx, + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, # noqa: ARG002 + sys_idx: int, + ) -> Self | None: + """Select constraint for a single system. + + Args: + atom_idx: Atom indices (not used, kept for interface compatibility) + sys_idx: System index to select + + Returns: + New FixSymmetry for the selected system, or None if not found. + """ + if sys_idx not in self.system_idx: + return None + + local_idx = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() + + return type(self)( + rotations=[self.rotations[local_idx]], + symm_maps=[self.symm_maps[local_idx]], + system_idx=torch.tensor([0], device=self.system_idx.device), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + + def __repr__(self) -> str: + """String representation of the constraint.""" + n_systems = len(self.rotations) + n_ops_list = [r.shape[0] for r in self.rotations] + if len(n_ops_list) <= 3: + ops_str = str(n_ops_list) + else: + ops_str = f"[{n_ops_list[0]}, ..., {n_ops_list[-1]}]" + return ( + f"FixSymmetry(n_systems={n_systems}, " + f"n_ops={ops_str}, " + f"adjust_positions={self.do_adjust_positions}, " + f"adjust_cell={self.do_adjust_cell})" + ) diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 3ff0cf2de..abb41c0d8 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -110,7 +110,13 @@ def unit_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces - stress = model_output["stress"] + stress = model_output["stress"].clone() + + # Apply stress constraints if any exist + if state.constraints: + for constraint in state.constraints: + constraint.adjust_stress(state, stress) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -162,7 +168,12 @@ def frechet_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces using Frechet approach - stress = model_output["stress"] + stress = model_output["stress"].clone() + + # Apply stress constraints if present + for constraint in state.constraints: + constraint.adjust_stress(state, stress) + volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -222,9 +233,10 @@ def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> # Update cell from new positions cell_update = cell_positions_new / cell_factor_expanded - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, cell_update.transpose(-2, -1) - ) + new_cell = torch.bmm(state.reference_cell.mT, cell_update.transpose(-2, -1)) + + # Apply cell constraints (in-place, column vector convention) + state.set_constrained_cell(new_cell.mT.clone()) state.cell_positions = cell_positions_new @@ -249,7 +261,9 @@ def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) new_row_vector_cell = torch.bmm( state.reference_cell.mT, deform_grad_new.transpose(-2, -1) ) - state.row_vector_cell = new_row_vector_cell + + # Apply cell constraints (in-place, column vector convention) + state.set_constrained_cell(new_row_vector_cell.mT.clone()) state.cell_positions = cell_positions_new @@ -257,7 +271,13 @@ def compute_cell_forces[T: AnyCellState]( model_output: dict[str, torch.Tensor], state: T ) -> None: """Compute cell forces for both unit and frechet methods.""" - stress = model_output["stress"] + stress = model_output["stress"].clone() + + # Apply stress constraints if any exist + if state.constraints: + for constraint in state.constraints: + constraint.adjust_stress(state, stress) + volumes = torch.linalg.det(state.cell).view(state.n_systems, 1, 1) virial = -volumes * (stress + state.pressure) virial = _apply_constraints( diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 5583cc0c6..9b3ce8e3a 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -415,17 +415,19 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 if isinstance(state, CellFireState): # For cell optimization, handle both atomic and cell position updates # This follows the ASE FIRE implementation pattern - # Transform atomic positions to fractional coordinates cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.set_constrained_positions( - torch.linalg.solve( - cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) - ).squeeze(-1) - + dr_atom - ) + frac_positions = torch.linalg.solve( + cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) + ).squeeze(-1) + new_frac_positions = frac_positions + dr_atom + # Back to Cartesian coordinates + new_positions = torch.bmm( + new_frac_positions.unsqueeze(1), cur_deform_grad[state.system_idx] + ).squeeze(1) + state.set_constrained_positions(new_positions) # Update cell positions directly based on stored cell filter type if hasattr(state, "cell_filter") and state.cell_filter is not None: @@ -446,21 +448,11 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1) deform_grad_new = cell_positions_new / cell_factor_expanded - # Update cell from deformation gradient - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, deform_grad_new.transpose(-2, -1) - ) + # Compute new cell from deformation gradient + new_col_vector_cell = torch.bmm(state.reference_cell, deform_grad_new) - # Transform positions back to Cartesian - new_deform_grad = cell_filters.deform_grad( - state.reference_cell.mT, state.row_vector_cell - ) - state.set_constrained_positions( - torch.bmm( - state.positions.unsqueeze(1), - new_deform_grad[state.system_idx].transpose(-2, -1), - ).squeeze(1) - ) + # Apply cell constraints and scale positions to preserve fractional coords + state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) else: state.set_constrained_positions(state.positions + dr_atom) diff --git a/torch_sim/state.py b/torch_sim/state.py index e97101d73..be8dff1f9 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -271,6 +271,23 @@ def set_constrained_positions(self, new_positions: torch.Tensor) -> None: constraint.adjust_positions(self, new_positions) self.positions = new_positions + def set_constrained_cell( + self, + new_cell: torch.Tensor, + scale_atoms: bool = False, # noqa: FBT001, FBT002 + ) -> None: + """Set the cell, apply constraints, and optionally scale atomic positions. + + Args: + new_cell: New cell tensor with shape (n_systems, 3, 3) + in column vector convention + scale_atoms: Whether to scale atomic positions to preserve + fractional coordinates. Defaults to False. + """ + for constraint in self.constraints: + constraint.adjust_cell(self, new_cell) + self.set_cell(new_cell, scale_atoms=scale_atoms) + @property def constraints(self) -> list[Constraint]: """Get the constraints for the SimState. diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py new file mode 100644 index 000000000..7b82a8c60 --- /dev/null +++ b/torch_sim/symmetrize.py @@ -0,0 +1,456 @@ +"""Symmetry refinement utilities for crystal structures. + +This module provides functions for refining and symmetrizing atomic structures +using spglib. It is adapted from ASE's spacegroup.symmetrize module but +reimplemented to work with torch tensors directly. + +The main entry point is `refine_symmetry` which symmetrizes both the cell +and atomic positions according to the detected space group symmetry. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from spglib import SpglibDataset + + from torch_sim.state import SimState + + +__all__ = [ + "refine_symmetry", + "build_symmetry_map", + "symmetrize_rank1", + "symmetrize_rank2", + "get_symmetry_datasets", +] + + +def _get_symmetry_dataset( + cell: torch.Tensor, + scaled_positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 1.0e-6, +) -> SpglibDataset | None: + """Get symmetry dataset from spglib. + + Args: + cell: Unit cell as row vectors, shape (3, 3) + scaled_positions: Fractional coordinates, shape (n_atoms, 3) + atomic_numbers: Atomic numbers, shape (n_atoms,) + symprec: Symmetry precision + + Returns: + Symmetry dataset with attribute access + """ + import spglib + + # Convert tensors to numpy for spglib + cell_np = cell.detach().cpu().numpy() + positions_np = scaled_positions.detach().cpu().numpy() + numbers_np = atomic_numbers.detach().cpu().numpy() + + cell_tuple = (cell_np, positions_np, numbers_np) + return spglib.get_symmetry_dataset(cell_tuple, symprec=symprec) + + +def get_symmetry_datasets( + state: SimState, + symprec: float = 1.0e-6, +) -> list[SpglibDataset | None]: + """Get symmetry datasets for all systems in a SimState. + + Args: + state: SimState containing one or more systems + symprec: Symmetry precision for spglib + + Returns: + List of spglib symmetry datasets, one per system in the state. + Returns None for systems where symmetry detection fails. + """ + datasets = [] + + for single_state in state.split(): + cell = single_state.row_vector_cell[0] + positions = single_state.positions + + # Compute scaled (fractional) positions for this system + scaled_positions = _get_scaled_positions(positions, cell) + + dataset = _get_symmetry_dataset( + cell=cell, + scaled_positions=scaled_positions, + atomic_numbers=single_state.atomic_numbers, + symprec=symprec, + ) + datasets.append(dataset) + + return datasets + + +def _get_scaled_positions( + positions: torch.Tensor, + cell: torch.Tensor, +) -> torch.Tensor: + """Convert Cartesian positions to fractional coordinates. + + Args: + positions: Cartesian positions, shape (n_atoms, 3) + cell: Unit cell as row vectors, shape (3, 3) + + Returns: + Fractional coordinates, shape (n_atoms, 3) + """ + inv_cell = torch.linalg.inv(cell) + return positions @ inv_cell + + +def _symmetrize_cell( + cell: torch.Tensor, + dataset: SpglibDataset, +) -> torch.Tensor: + """Symmetrize the cell based on the symmetry dataset. + + Args: + cell: Unit cell as row vectors, shape (3, 3) + dataset: spglib symmetry dataset + + Returns: + Symmetrized cell as row vectors, shape (3, 3) + """ + device = cell.device + dtype = cell.dtype + + # Get standardized cell and apply transformations + std_cell = torch.as_tensor(dataset.std_lattice, dtype=dtype, device=device) + trans_matrix = torch.as_tensor( + dataset.transformation_matrix, dtype=dtype, device=device + ) + rot_matrix = torch.as_tensor( + dataset.std_rotation_matrix, dtype=dtype, device=device + ) + + trans_std_cell = trans_matrix.T @ std_cell + rot_trans_std_cell = trans_std_cell @ rot_matrix + + return rot_trans_std_cell + + +def _symmetrize_positions( + positions: torch.Tensor, + dataset: SpglibDataset, + primitive_cell: tuple, +) -> torch.Tensor: + """Symmetrize atomic positions. + + Args: + positions: Cartesian positions, shape (n_atoms, 3) + dataset: spglib symmetry dataset + primitive_cell: Result from spglib.find_primitive (cell, positions, numbers) + + Returns: + Symmetrized Cartesian positions, shape (n_atoms, 3) + """ + device = positions.device + dtype = positions.dtype + + prim_cell_np, _prim_scaled_pos, _prim_types = primitive_cell + prim_cell = torch.as_tensor(prim_cell_np, dtype=dtype, device=device) + + # Calculate offset between standard cell and actual cell + std_cell = torch.as_tensor(dataset.std_lattice, dtype=dtype, device=device) + rot_matrix = torch.as_tensor( + dataset.std_rotation_matrix, dtype=dtype, device=device + ) + std_positions = torch.as_tensor(dataset.std_positions, dtype=dtype, device=device) + + rot_std_cell = std_cell @ rot_matrix + rot_std_pos = std_positions @ rot_std_cell + + # Get mapping indices + mapping_to_primitive = list(dataset.mapping_to_primitive) + std_mapping_to_primitive = list(dataset.std_mapping_to_primitive) + + dp0 = positions[mapping_to_primitive.index(0)] - rot_std_pos[ + std_mapping_to_primitive.index(0) + ] + + # Create aligned set of standard cell positions + rot_prim_cell = prim_cell @ rot_matrix + inv_rot_prim_cell = torch.linalg.inv(rot_prim_cell) + aligned_std_pos = rot_std_pos + dp0 + + # Find ideal positions + new_positions = positions.clone() + n_atoms = positions.shape[0] + + for i_at in range(n_atoms): + std_i_at = std_mapping_to_primitive.index(mapping_to_primitive[i_at]) + dp = aligned_std_pos[std_i_at] - positions[i_at] + dp_s = dp @ inv_rot_prim_cell + new_positions[i_at] = aligned_std_pos[std_i_at] - torch.round(dp_s) @ rot_prim_cell + + return new_positions + + +def refine_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, + verbose: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Refine symmetry of a structure. + + This function symmetrizes both the cell and atomic positions according + to the detected space group symmetry. + + The refinement process: + 1. Detect symmetry of the input structure + 2. Symmetrize the cell vectors to match the ideal lattice + 3. Symmetrize atomic positions to ideal Wyckoff positions + + Args: + cell: Unit cell as row vectors, shape (3, 3) + positions: Cartesian positions, shape (n_atoms, 3) + atomic_numbers: Atomic numbers, shape (n_atoms,) + symprec: Symmetry precision for spglib + verbose: If True, print symmetry information before and after + + Returns: + Tuple of (symmetrized_cell, symmetrized_positions): + - symmetrized_cell: Symmetrized cell as row vectors, shape (3, 3) + - symmetrized_positions: Symmetrized Cartesian positions, shape (n_atoms, 3) + """ + import spglib + + # Step 1: Check and symmetrize cell + scaled_positions = _get_scaled_positions(positions, cell) + dataset = _get_symmetry_dataset(cell, scaled_positions, atomic_numbers, symprec) + + if dataset is None: + raise RuntimeError("spglib could not determine symmetry for structure") + + if verbose: + print( + f"symmetrize: prec {symprec} got symmetry group number {dataset.number}, " + f"international (Hermann-Mauguin) {dataset.international}, " + f"Hall {dataset.hall}" + ) + + new_cell = _symmetrize_cell(cell, dataset) + + # Scale positions to new cell + new_positions = scaled_positions @ new_cell + + # Step 2: Check and symmetrize positions with the new cell + new_scaled_positions = _get_scaled_positions(new_positions, new_cell) + dataset = _get_symmetry_dataset( + new_cell, new_scaled_positions, atomic_numbers, symprec + ) + + if dataset is None: + raise RuntimeError("spglib could not determine symmetry after cell refinement") + + # Find primitive cell + cell_np = new_cell.detach().cpu().numpy() + positions_np = new_scaled_positions.detach().cpu().numpy() + numbers_np = atomic_numbers.detach().cpu().numpy() + + primitive_result = spglib.find_primitive( + (cell_np, positions_np, numbers_np), symprec=symprec + ) + if primitive_result is None: + raise RuntimeError("spglib could not find primitive cell") + + new_positions = _symmetrize_positions( + new_positions, dataset, primitive_result + ) + + # Final check + if verbose: + final_scaled = _get_scaled_positions(new_positions, new_cell) + final_dataset = _get_symmetry_dataset( + new_cell, final_scaled, atomic_numbers, 1e-4 + ) + if final_dataset is not None: + print( + f"symmetrize: prec 1e-4 got symmetry group number " + f"{final_dataset.number}, " + f"international (Hermann-Mauguin) {final_dataset.international}, " + f"Hall {final_dataset.hall}" + ) + + return new_cell, new_positions + + +def _prep_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 1.0e-6, + verbose: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare structure for symmetry-preserving minimization. + + This function determines the symmetry operations and atom mappings + needed for symmetry-constrained optimization. + + Args: + cell: Unit cell as row vectors, shape (3, 3) + positions: Cartesian positions, shape (n_atoms, 3) + atomic_numbers: Atomic numbers, shape (n_atoms,) + symprec: Symmetry precision for spglib + verbose: If True, print symmetry information + + Returns: + Tuple of (rotations, symm_map): + - rotations: Rotation matrices, shape (n_ops, 3, 3) + - symm_map: Atom mapping tensor, shape (n_ops, n_atoms) + """ + device = cell.device + dtype = cell.dtype + + scaled_positions = _get_scaled_positions(positions, cell) + dataset = _get_symmetry_dataset(cell, scaled_positions, atomic_numbers, symprec) + + if dataset is None: + raise RuntimeError("spglib could not determine symmetry for structure") + + if verbose: + print( + f"symmetrize: prec {symprec} got symmetry group number {dataset.number}, " + f"international (Hermann-Mauguin) {dataset.international}, " + f"Hall {dataset.hall}" + ) + + rotations = torch.as_tensor(dataset.rotations.copy(), dtype=dtype, device=device) + translations = torch.as_tensor( + dataset.translations.copy(), dtype=dtype, device=device + ) + + # Build symmetry mapping + symm_map = build_symmetry_map(rotations, translations, scaled_positions) + + return rotations, symm_map + + +def build_symmetry_map( + rotations: torch.Tensor, + translations: torch.Tensor, + scaled_positions: torch.Tensor, +) -> torch.Tensor: + """Build symmetry atom mapping for each symmetry operation. + + For each symmetry operation, determines which atom each atom maps to. + + Args: + rotations: Rotation matrices, shape (n_ops, 3, 3) + translations: Translation vectors, shape (n_ops, 3) + scaled_positions: Fractional coordinates, shape (n_atoms, 3) + + Returns: + Symmetry mapping tensor, shape (n_ops, n_atoms) + """ + # Transform all atoms by all symmetry operations at once + # new_pos: (n_ops, n_atoms, 3) + new_pos = torch.einsum("oij,nj->oni", rotations, scaled_positions) + translations[:, None, :] + + # Compute wrapped deltas to account for periodicity + # delta: (n_ops, n_atoms, n_atoms, 3) + delta = scaled_positions[None, None, :, :] - new_pos[:, :, None, :] + delta -= delta.round() # wrap into [-0.5, 0.5] + + # Distances to all candidate atoms, then choose nearest + distances = torch.linalg.norm(delta, dim=-1) # (n_ops, n_atoms, n_atoms) + symm_map = torch.argmin(distances, dim=-1).to(dtype=torch.long) # (n_ops, n_atoms) + + return symm_map + + +def symmetrize_rank1( + lattice: torch.Tensor, + forces: torch.Tensor, + rotations: torch.Tensor, + symm_map: torch.Tensor, +) -> torch.Tensor: + """Symmetrize rank-1 tensor (forces, velocities, etc). + + Args: + lattice: Cell vectors as row vectors, shape (3, 3) + forces: Forces array, shape (n_atoms, 3) + rotations: Rotation matrices, shape (n_ops, 3, 3) + symm_map: Atom mapping for each symmetry operation, shape (n_ops, n_atoms) + + Returns: + Symmetrized forces, shape (n_atoms, 3) + """ + n_ops = rotations.shape[0] + n_atoms = forces.shape[0] + + # Transform to scaled (fractional) coordinates: (n_atoms, 3) + scaled_forces = forces @ lattice.inverse() + + # Apply all rotations at once: (n_ops, n_atoms, 3) + # rotations: (n_ops, 3, 3), scaled_forces: (n_atoms, 3) + # For each op: scaled_forces @ rot.T (rotate the vectors) + # Note: we use rotations.mT to get the transpose of each rotation matrix + transformed_forces = torch.einsum("ij,nkj->nik", scaled_forces, rotations) + + # Flatten for scatter: (n_ops * n_atoms, 3) + transformed_flat = transformed_forces.reshape(-1, 3) + + # Flatten symm_map to get target indices: (n_ops * n_atoms,) + target_indices = symm_map.reshape(-1) + + # Expand target indices to match 3D coordinates: (n_ops * n_atoms, 3) + target_indices_expanded = target_indices.unsqueeze(-1).expand(-1, 3) + + # Scatter add to accumulate forces at target atoms + # Result shape: (n_atoms, 3) + accumulated = torch.zeros(n_atoms, 3, dtype=forces.dtype, device=forces.device) + accumulated.scatter_add_(0, target_indices_expanded, transformed_flat) + + # Average over symmetry operations + symmetrized_scaled = accumulated / n_ops + + # Transform back to Cartesian + symmetrized_forces = symmetrized_scaled @ lattice + + return symmetrized_forces + + +def symmetrize_rank2( + lattice: torch.Tensor, + stress: torch.Tensor, + rotations: torch.Tensor, +) -> torch.Tensor: + """Symmetrize rank-2 tensor (stress, strain, etc). + + Args: + lattice: Cell vectors as row vectors, shape (3, 3) + stress: Stress tensor, shape (3, 3) + rotations: Rotation matrices, shape (n_ops, 3, 3) + + Returns: + Symmetrized stress tensor, shape (3, 3) + """ + n_ops = rotations.shape[0] + inv_lattice = lattice.inverse() + + # Scale stress: lattice @ stress @ lattice.T + scaled_stress = lattice @ stress @ lattice.T + + # Symmetrize in scaled coordinates using vectorized operations + # r.T @ scaled_stress @ r for all rotations at once + # For r.T @ A @ r: result[i,l] = sum_j,k r[j,i] * A[j,k] * r[k,l] + # With batched rotations: einsum "nji,jk,nkl->il" + symmetrized_scaled_stress = torch.einsum( + "nji,jk,nkl->il", rotations, scaled_stress, rotations + ) / n_ops + + # Transform back: inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T + return inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T From 3d4880ee422b485cec1ded6cd748bcddc309e787 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Tue, 3 Feb 2026 17:07:21 +0000 Subject: [PATCH 02/16] style --- tests/test_fix_symmetry.py | 113 +++++++++++++++++++++---------------- torch_sim/symmetrize.py | 39 ++++++------- 2 files changed, 83 insertions(+), 69 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 7b9c4458f..994018b44 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -18,6 +18,7 @@ from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets + # Skip all tests if spglib is not available spglib = pytest.importorskip("spglib") from spglib import SpglibDataset @@ -40,11 +41,11 @@ class OptimizationResult(TypedDict): # Expected space groups for each structure type SPACEGROUPS = { - "fcc": 225, # Fm-3m - "hcp": 194, # P6_3/mmc + "fcc": 225, # Fm-3m + "hcp": 194, # P6_3/mmc "diamond": 227, # Fd-3m - "bcc": 229, # Im-3m - "p6bar": 174, # P-6 (low symmetry) + "bcc": 229, # Im-3m + "p6bar": 174, # P-6 (low symmetry) } # Default maximum optimization steps for tests @@ -134,6 +135,7 @@ def model() -> LennardJonesModel: dtype=DTYPE, ) + @pytest.fixture def noisy_lj_model(model: LennardJonesModel): """Create a LJ model that adds noise to forces/stress (like ASE's NoisyLennardJones).""" @@ -145,10 +147,11 @@ def __init__(self, model, rng_seed: int = 1, noise_scale: float = 1e-4): self.model = model self.rng = np.random.RandomState(rng_seed) self.noise_scale = noise_scale - + @property def device(self): return self.model.device + @property def dtype(self): return self.model.dtype @@ -168,6 +171,7 @@ def __call__(self, state): noise, dtype=results["stress"].dtype, device=results["stress"].device ) return results + return NoisyModelWrapper(model) @@ -176,7 +180,9 @@ def __call__(self, state): # ============================================================================= -def get_symmetry_dataset_from_atoms(atoms: Atoms, symprec: float = SYMPREC) -> SpglibDataset: +def get_symmetry_dataset_from_atoms( + atoms: Atoms, symprec: float = SYMPREC +) -> SpglibDataset: """Get full symmetry dataset for an ASE Atoms object using spglib directly.""" return spglib.get_symmetry_dataset( (atoms.cell[:], atoms.get_scaled_positions(), atoms.numbers), @@ -281,8 +287,10 @@ def test_from_state_batched(self): def test_p1_identity_only(self): """Test P1 (no symmetry) has only identity and doesn't change forces/stress.""" atoms = Atoms( - "SiGe", positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], - cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], pbc=True + "SiGe", + positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], + cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], + pbc=True, ) state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) @@ -315,17 +323,19 @@ def test_symmetry_datasets_match_spglib(self): # Compare each with direct spglib call (covers both single and batched) for i, atoms in enumerate(atoms_list): spglib_dataset = get_symmetry_dataset_from_atoms(atoms, SYMPREC) - + # Compare key fields assert ts_datasets[i].number == spglib_dataset.number, ( - f"Space group mismatch for {atoms_list[i].get_chemical_formula()}: " + f"Space group mismatch for {atoms.get_chemical_formula()}: " f"{ts_datasets[i].number} vs {spglib_dataset.number}" ) assert ts_datasets[i].international == spglib_dataset.international assert ts_datasets[i].hall == spglib_dataset.hall assert len(ts_datasets[i].rotations) == len(spglib_dataset.rotations) assert np.allclose(ts_datasets[i].rotations, spglib_dataset.rotations) - assert np.allclose(ts_datasets[i].translations, spglib_dataset.translations, atol=1e-10) + assert np.allclose( + ts_datasets[i].translations, spglib_dataset.translations, atol=1e-10 + ) class TestFixSymmetryComparisonWithASE: @@ -389,11 +399,7 @@ def test_stress_symmetrization_matches_ase(self): ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) # Create asymmetric but symmetric (as a matrix) stress tensor - stress_3x3 = np.array([ - [10.0, 1.0, 0.5], - [1.0, 8.0, 0.3], - [0.5, 0.3, 6.0] - ]) + stress_3x3 = np.array([[10.0, 1.0, 0.5], [1.0, 8.0, 0.3], [0.5, 0.3, 6.0]]) # ASE uses Voigt notation stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3) @@ -434,7 +440,8 @@ def test_cell_deformation_symmetrization_matches_ase(self): # TorchSim - need column vector convention for adjust_cell new_cell_ts = torch.tensor( - [deformed_cell.copy().T], dtype=DTYPE # Transpose for column vectors + [deformed_cell.copy().T], + dtype=DTYPE, # Transpose for column vectors ) ts_constraint.adjust_cell(state, new_cell_ts) ts_result = new_cell_ts[0].mT.numpy() # Back to row vectors @@ -454,9 +461,7 @@ class TestFixSymmetryMergeAndSelect: def test_merge_two_constraints(self): """Test merging two FixSymmetry constraints.""" - state1 = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) state2 = ts.io.atoms_to_state( make_structure("diamond"), torch.device("cpu"), DTYPE ) @@ -470,15 +475,15 @@ def test_merge_two_constraints(self): assert merged.system_idx.tolist() == [0, 1] @pytest.mark.parametrize("mismatch_field", ["adjust_positions", "adjust_cell"]) - def test_merge_mismatched_settings_raises(self, mismatch_field: Literal['adjust_positions'] | Literal['adjust_cell']): + def test_merge_mismatched_settings_raises( + self, mismatch_field: Literal["adjust_positions", "adjust_cell"] + ): """Test that merging constraints with different settings raises ValueError.""" - state1 = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) state2 = ts.io.atoms_to_state( make_structure("diamond"), torch.device("cpu"), DTYPE ) - + kwargs1 = {mismatch_field: True} kwargs2 = {mismatch_field: False} c1 = FixSymmetry.from_state(state1, symprec=SYMPREC, **kwargs1) @@ -497,7 +502,9 @@ def test_select_constraint_single_system(self): constraint = FixSymmetry.from_state(state, symprec=SYMPREC) # Create masks to select only first system - atom_mask = torch.tensor([True, False, False], dtype=torch.bool) # 1 Cu + 2 Si atoms + atom_mask = torch.tensor( + [True, False, False], dtype=torch.bool + ) # 1 Cu + 2 Si atoms system_mask = torch.tensor([True, False], dtype=torch.bool) selected = constraint.select_constraint(atom_mask, system_mask) @@ -544,7 +551,11 @@ class TestFixSymmetryWithOptimization: [(True, True), (True, False), (False, True), (False, False)], ) def test_distorted_structure_preserves_symmetry( - self, noisy_lj_model, structure_name: str, adjust_positions: bool, adjust_cell: bool + self, + noisy_lj_model, + structure_name: str, + adjust_positions: bool, + adjust_cell: bool, ): """Test that a distorted structure relaxes while preserving symmetry. @@ -572,8 +583,12 @@ def test_distorted_structure_preserves_symmetry( state.positions = state.positions * scale_factor result = run_optimization_check_symmetry( - state, noisy_lj_model, constraint=constraint, adjust_cell=adjust_cell, - max_steps=MAX_STEPS, force_tol=0.01 # Looser tolerance to ensure movement + state, + noisy_lj_model, + constraint=constraint, + adjust_cell=adjust_cell, + max_steps=MAX_STEPS, + force_tol=0.01, # Looser tolerance to ensure movement ) assert result["final_spacegroups"][0] == expected_spacegroup, ( @@ -583,11 +598,11 @@ def test_distorted_structure_preserves_symmetry( ) @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) - def test_cell_filter_preserves_symmetry(self, model: LennardJonesModel, cell_filter: ts.CellFilter | ts.CellFilter): + def test_cell_filter_preserves_symmetry( + self, model: LennardJonesModel, cell_filter: ts.CellFilter + ): """Test that cell filters with FixSymmetry preserve symmetry.""" - state = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] @@ -606,7 +621,9 @@ def test_cell_filter_preserves_symmetry(self, model: LennardJonesModel, cell_fil assert initial_datasets[0].number == final_datasets[0].number @pytest.mark.parametrize("rotated", [False, True]) - def test_noisy_model_loses_symmetry_without_constraint(self, noisy_lj_model, rotated: bool): + def test_noisy_model_loses_symmetry_without_constraint( + self, noisy_lj_model, rotated: bool + ): """Test that WITHOUT FixSymmetry, optimization with noisy forces loses symmetry. This is a negative control - verifies that noisy forces will break symmetry @@ -628,7 +645,9 @@ def test_noisy_model_loses_symmetry_without_constraint(self, noisy_lj_model, rot ) @pytest.mark.parametrize("rotated", [False, True]) - def test_noisy_model_preserves_symmetry_with_constraint(self, noisy_lj_model, rotated: bool): + def test_noisy_model_preserves_symmetry_with_constraint( + self, noisy_lj_model, rotated: bool + ): """Test that WITH FixSymmetry, optimization with noisy forces preserves symmetry. Mirrors ASE's test_sym_adj_cell. @@ -637,7 +656,10 @@ def test_noisy_model_preserves_symmetry_with_constraint(self, noisy_lj_model, ro state = ts.io.atoms_to_state(bcc_atoms, torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) result = run_optimization_check_symmetry( - state, noisy_lj_model, constraint=constraint, max_steps=MAX_STEPS, + state, + noisy_lj_model, + constraint=constraint, + max_steps=MAX_STEPS, ) assert result["initial_spacegroups"][0] == 229 @@ -647,15 +669,12 @@ def test_noisy_model_preserves_symmetry_with_constraint(self, noisy_lj_model, ro ) - class TestFixSymmetryEdgeCases: """Tests for edge cases and error handling.""" def test_get_removed_dof_raises(self): """Test that get_removed_dof raises NotImplementedError.""" - state = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) with pytest.raises(NotImplementedError, match="get_removed_dof"): @@ -663,9 +682,7 @@ def test_get_removed_dof_raises(self): def test_large_deformation_gradient_raises(self): """Test that large deformation gradient raises RuntimeError.""" - state = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) # Create a very large deformation (> 0.25) @@ -678,9 +695,7 @@ def test_large_deformation_gradient_raises(self): def test_medium_deformation_gradient_warns(self): """Test that medium deformation gradient emits warning.""" - state = ts.io.atoms_to_state( - make_structure("fcc"), torch.device("cpu"), DTYPE - ) + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) # Create a medium deformation (> 0.15 but < 0.25) @@ -698,9 +713,7 @@ def test_from_state_refine_symmetry(self, refine_symmetry_state: bool): perturbed = atoms.copy() perturbed.positions += np.random.randn(*perturbed.positions.shape) * 0.001 - state = ts.io.atoms_to_state( - perturbed, torch.device("cpu"), DTYPE - ) + state = ts.io.atoms_to_state(perturbed, torch.device("cpu"), DTYPE) original_positions = state.positions.clone() original_cell = state.cell.clone() @@ -715,4 +728,4 @@ def test_from_state_refine_symmetry(self, refine_symmetry_state: bool): else: # State may be modified (positions refined to ideal) # We just check the function runs without error - assert state.positions.shape == original_positions.shape \ No newline at end of file + assert state.positions.shape == original_positions.shape diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 7b82a8c60..eed506b94 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -14,6 +14,7 @@ import torch + if TYPE_CHECKING: from spglib import SpglibDataset @@ -21,11 +22,11 @@ __all__ = [ - "refine_symmetry", "build_symmetry_map", + "get_symmetry_datasets", + "refine_symmetry", "symmetrize_rank1", "symmetrize_rank2", - "get_symmetry_datasets", ] @@ -129,9 +130,7 @@ def _symmetrize_cell( trans_matrix = torch.as_tensor( dataset.transformation_matrix, dtype=dtype, device=device ) - rot_matrix = torch.as_tensor( - dataset.std_rotation_matrix, dtype=dtype, device=device - ) + rot_matrix = torch.as_tensor(dataset.std_rotation_matrix, dtype=dtype, device=device) trans_std_cell = trans_matrix.T @ std_cell rot_trans_std_cell = trans_std_cell @ rot_matrix @@ -162,9 +161,7 @@ def _symmetrize_positions( # Calculate offset between standard cell and actual cell std_cell = torch.as_tensor(dataset.std_lattice, dtype=dtype, device=device) - rot_matrix = torch.as_tensor( - dataset.std_rotation_matrix, dtype=dtype, device=device - ) + rot_matrix = torch.as_tensor(dataset.std_rotation_matrix, dtype=dtype, device=device) std_positions = torch.as_tensor(dataset.std_positions, dtype=dtype, device=device) rot_std_cell = std_cell @ rot_matrix @@ -174,9 +171,10 @@ def _symmetrize_positions( mapping_to_primitive = list(dataset.mapping_to_primitive) std_mapping_to_primitive = list(dataset.std_mapping_to_primitive) - dp0 = positions[mapping_to_primitive.index(0)] - rot_std_pos[ - std_mapping_to_primitive.index(0) - ] + dp0 = ( + positions[mapping_to_primitive.index(0)] + - rot_std_pos[std_mapping_to_primitive.index(0)] + ) # Create aligned set of standard cell positions rot_prim_cell = prim_cell @ rot_matrix @@ -191,7 +189,9 @@ def _symmetrize_positions( std_i_at = std_mapping_to_primitive.index(mapping_to_primitive[i_at]) dp = aligned_std_pos[std_i_at] - positions[i_at] dp_s = dp @ inv_rot_prim_cell - new_positions[i_at] = aligned_std_pos[std_i_at] - torch.round(dp_s) @ rot_prim_cell + new_positions[i_at] = ( + aligned_std_pos[std_i_at] - torch.round(dp_s) @ rot_prim_cell + ) return new_positions @@ -266,9 +266,7 @@ def refine_symmetry( if primitive_result is None: raise RuntimeError("spglib could not find primitive cell") - new_positions = _symmetrize_positions( - new_positions, dataset, primitive_result - ) + new_positions = _symmetrize_positions(new_positions, dataset, primitive_result) # Final check if verbose: @@ -357,7 +355,10 @@ def build_symmetry_map( """ # Transform all atoms by all symmetry operations at once # new_pos: (n_ops, n_atoms, 3) - new_pos = torch.einsum("oij,nj->oni", rotations, scaled_positions) + translations[:, None, :] + new_pos = ( + torch.einsum("oij,nj->oni", rotations, scaled_positions) + + translations[:, None, :] + ) # Compute wrapped deltas to account for periodicity # delta: (n_ops, n_atoms, n_atoms, 3) @@ -448,9 +449,9 @@ def symmetrize_rank2( # r.T @ scaled_stress @ r for all rotations at once # For r.T @ A @ r: result[i,l] = sum_j,k r[j,i] * A[j,k] * r[k,l] # With batched rotations: einsum "nji,jk,nkl->il" - symmetrized_scaled_stress = torch.einsum( - "nji,jk,nkl->il", rotations, scaled_stress, rotations - ) / n_ops + symmetrized_scaled_stress = ( + torch.einsum("nji,jk,nkl->il", rotations, scaled_stress, rotations) / n_ops + ) # Transform back: inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T return inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T From 0d08bbadaf2708a4c6272f2ddad314c0c1806df6 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 4 Feb 2026 13:21:28 +0000 Subject: [PATCH 03/16] style --- tests/test_fix_symmetry.py | 103 +++++++++++++++++++++---------------- torch_sim/symmetrize.py | 51 ++++++++++-------- 2 files changed, 88 insertions(+), 66 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 994018b44..d74105cf3 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -21,7 +21,7 @@ # Skip all tests if spglib is not available spglib = pytest.importorskip("spglib") -from spglib import SpglibDataset +from spglib import SpglibDataset # noqa: E402 class OptimizationResult(TypedDict): @@ -136,42 +136,54 @@ def model() -> LennardJonesModel: ) +class NoisyModelWrapper: + """Wrapper that adds noise to forces and stress from an underlying model.""" + + def __init__( + self, + model: LennardJonesModel, + rng_seed: int = 1, + noise_scale: float = 1e-4, + ) -> None: + self.model = model + self.rng = np.random.default_rng(rng_seed) + self.noise_scale = noise_scale + + @property + def device(self) -> torch.device: + return self.model.device + + @property + def dtype(self) -> torch.dtype: + return self.model.dtype + + def __call__(self, state: ts.SimState) -> dict[str, torch.Tensor]: + results = self.model(state) + # Add noise to forces + if "forces" in results: + noise = self.rng.normal(size=results["forces"].shape) + results["forces"] = results["forces"] + self.noise_scale * torch.tensor( + noise, + dtype=results["forces"].dtype, + device=results["forces"].device, + ) + # Add noise to stress + if "stress" in results: + noise = self.rng.normal(size=results["stress"].shape) + results["stress"] = results["stress"] + self.noise_scale * torch.tensor( + noise, + dtype=results["stress"].dtype, + device=results["stress"].device, + ) + return results + + @pytest.fixture -def noisy_lj_model(model: LennardJonesModel): - """Create a LJ model that adds noise to forces/stress (like ASE's NoisyLennardJones).""" - - class NoisyModelWrapper: - """Wrapper that adds noise to forces and stress from an underlying model.""" - - def __init__(self, model, rng_seed: int = 1, noise_scale: float = 1e-4): - self.model = model - self.rng = np.random.RandomState(rng_seed) - self.noise_scale = noise_scale - - @property - def device(self): - return self.model.device - - @property - def dtype(self): - return self.model.dtype - - def __call__(self, state): - results = self.model(state) - # Add noise to forces - if "forces" in results: - noise = self.rng.normal(size=results["forces"].shape) - results["forces"] = results["forces"] + self.noise_scale * torch.tensor( - noise, dtype=results["forces"].dtype, device=results["forces"].device - ) - # Add noise to stress - if "stress" in results: - noise = self.rng.normal(size=results["stress"].shape) - results["stress"] = results["stress"] + self.noise_scale * torch.tensor( - noise, dtype=results["stress"].dtype, device=results["stress"].device - ) - return results +def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: + """Create a LJ model that adds noise to forces/stress. + Similar to ASE's NoisyLennardJones. + """ return NoisyModelWrapper(model) @@ -194,6 +206,7 @@ def run_optimization_check_symmetry( state: ts.SimState, model: LennardJonesModel, constraint: FixSymmetry | None = None, + *, adjust_cell: bool = True, symprec: float = SYMPREC, max_steps: int = MAX_STEPS, @@ -374,8 +387,8 @@ def test_force_symmetrization_matches_ase(self): ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) # Create random test forces - np.random.seed(42) - forces_np = np.random.randn(len(atoms), 3) + rng = np.random.default_rng(42) + forces_np = rng.standard_normal((len(atoms), 3)) forces_ts = torch.tensor(forces_np.copy(), dtype=DTYPE) # Symmetrize with both @@ -457,7 +470,7 @@ def test_cell_deformation_symmetrization_matches_ase(self): class TestFixSymmetryMergeAndSelect: - """Tests for FixSymmetry.merge, select_constraint, and select_sub_constraint methods.""" + """Tests for FixSymmetry.merge, select_constraint, select_sub_constraint.""" def test_merge_two_constraints(self): """Test merging two FixSymmetry constraints.""" @@ -547,13 +560,14 @@ class TestFixSymmetryWithOptimization: @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond", "p6bar"]) @pytest.mark.parametrize( - "adjust_positions,adjust_cell", + ("adjust_positions", "adjust_cell"), [(True, True), (True, False), (False, True), (False, False)], ) def test_distorted_structure_preserves_symmetry( self, - noisy_lj_model, + noisy_lj_model: NoisyModelWrapper, structure_name: str, + *, adjust_positions: bool, adjust_cell: bool, ): @@ -622,7 +636,7 @@ def test_cell_filter_preserves_symmetry( @pytest.mark.parametrize("rotated", [False, True]) def test_noisy_model_loses_symmetry_without_constraint( - self, noisy_lj_model, rotated: bool + self, noisy_lj_model: NoisyModelWrapper, *, rotated: bool ): """Test that WITHOUT FixSymmetry, optimization with noisy forces loses symmetry. @@ -646,7 +660,7 @@ def test_noisy_model_loses_symmetry_without_constraint( @pytest.mark.parametrize("rotated", [False, True]) def test_noisy_model_preserves_symmetry_with_constraint( - self, noisy_lj_model, rotated: bool + self, noisy_lj_model: NoisyModelWrapper, *, rotated: bool ): """Test that WITH FixSymmetry, optimization with noisy forces preserves symmetry. @@ -706,12 +720,13 @@ def test_medium_deformation_gradient_warns(self): constraint.adjust_cell(state, new_cell_col) @pytest.mark.parametrize("refine_symmetry_state", [True, False]) - def test_from_state_refine_symmetry(self, refine_symmetry_state: bool): + def test_from_state_refine_symmetry(self, *, refine_symmetry_state: bool): """Test from_state with different refine_symmetry_state settings.""" atoms = make_structure("fcc") # Add small perturbation perturbed = atoms.copy() - perturbed.positions += np.random.randn(*perturbed.positions.shape) * 0.001 + rng = np.random.default_rng(42) + perturbed.positions += rng.standard_normal(perturbed.positions.shape) * 0.001 state = ts.io.atoms_to_state(perturbed, torch.device("cpu"), DTYPE) original_positions = state.positions.clone() diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index eed506b94..3d5bf6fb0 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -10,11 +10,15 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import torch +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from spglib import SpglibDataset @@ -133,9 +137,7 @@ def _symmetrize_cell( rot_matrix = torch.as_tensor(dataset.std_rotation_matrix, dtype=dtype, device=device) trans_std_cell = trans_matrix.T @ std_cell - rot_trans_std_cell = trans_std_cell @ rot_matrix - - return rot_trans_std_cell + return trans_std_cell @ rot_matrix def _symmetrize_positions( @@ -201,6 +203,7 @@ def refine_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 0.01, + *, verbose: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Refine symmetry of a structure. @@ -235,10 +238,13 @@ def refine_symmetry( raise RuntimeError("spglib could not determine symmetry for structure") if verbose: - print( - f"symmetrize: prec {symprec} got symmetry group number {dataset.number}, " - f"international (Hermann-Mauguin) {dataset.international}, " - f"Hall {dataset.hall}" + logger.info( + "symmetrize: prec %s got symmetry group number %s, " + "international (Hermann-Mauguin) %s, Hall %s", + symprec, + dataset.number, + dataset.international, + dataset.hall, ) new_cell = _symmetrize_cell(cell, dataset) @@ -275,11 +281,12 @@ def refine_symmetry( new_cell, final_scaled, atomic_numbers, 1e-4 ) if final_dataset is not None: - print( - f"symmetrize: prec 1e-4 got symmetry group number " - f"{final_dataset.number}, " - f"international (Hermann-Mauguin) {final_dataset.international}, " - f"Hall {final_dataset.hall}" + logger.info( + "symmetrize: prec 1e-4 got symmetry group number %s, " + "international (Hermann-Mauguin) %s, Hall %s", + final_dataset.number, + final_dataset.international, + final_dataset.hall, ) return new_cell, new_positions @@ -290,6 +297,7 @@ def _prep_symmetry( positions: torch.Tensor, atomic_numbers: torch.Tensor, symprec: float = 1.0e-6, + *, verbose: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare structure for symmetry-preserving minimization. @@ -319,10 +327,13 @@ def _prep_symmetry( raise RuntimeError("spglib could not determine symmetry for structure") if verbose: - print( - f"symmetrize: prec {symprec} got symmetry group number {dataset.number}, " - f"international (Hermann-Mauguin) {dataset.international}, " - f"Hall {dataset.hall}" + logger.info( + "symmetrize: prec %s got symmetry group number %s, " + "international (Hermann-Mauguin) %s, Hall %s", + symprec, + dataset.number, + dataset.international, + dataset.hall, ) rotations = torch.as_tensor(dataset.rotations.copy(), dtype=dtype, device=device) @@ -367,9 +378,7 @@ def build_symmetry_map( # Distances to all candidate atoms, then choose nearest distances = torch.linalg.norm(delta, dim=-1) # (n_ops, n_atoms, n_atoms) - symm_map = torch.argmin(distances, dim=-1).to(dtype=torch.long) # (n_ops, n_atoms) - - return symm_map + return torch.argmin(distances, dim=-1).to(dtype=torch.long) # (n_ops, n_atoms) def symmetrize_rank1( @@ -419,9 +428,7 @@ def symmetrize_rank1( symmetrized_scaled = accumulated / n_ops # Transform back to Cartesian - symmetrized_forces = symmetrized_scaled @ lattice - - return symmetrized_forces + return symmetrized_scaled @ lattice def symmetrize_rank2( From 29472af40bad65016655eb9578800fc315d3cdc4 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 4 Feb 2026 13:21:34 +0000 Subject: [PATCH 04/16] style --- tests/test_fix_symmetry.py | 1 - torch_sim/symmetrize.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index d74105cf3..20a16986a 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -18,7 +18,6 @@ from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets - # Skip all tests if spglib is not available spglib = pytest.importorskip("spglib") from spglib import SpglibDataset # noqa: E402 diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 3d5bf6fb0..31eba9ea4 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -15,7 +15,6 @@ import torch - logger = logging.getLogger(__name__) From 5d7b5b0e788754a16cf916840a2e753e81016504 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 4 Feb 2026 13:44:59 +0000 Subject: [PATCH 05/16] fix --- tests/test_fix_symmetry.py | 1 + torch_sim/optimizers/fire.py | 21 +++++++++++++++------ torch_sim/symmetrize.py | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 20a16986a..d74105cf3 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -18,6 +18,7 @@ from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets + # Skip all tests if spglib is not available spglib = pytest.importorskip("spglib") from spglib import SpglibDataset # noqa: E402 diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 07287cb23..778fbaaea 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -412,12 +412,8 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) + # Store fractional positions (will transform to Cartesian after cell update) new_frac_positions = frac_positions + dr_atom - # Back to Cartesian coordinates - new_positions = torch.bmm( - new_frac_positions.unsqueeze(1), cur_deform_grad[state.system_idx] - ).squeeze(1) - state.set_constrained_positions(new_positions) # Update cell positions directly based on stored cell filter type if hasattr(state, "cell_filter") and state.cell_filter is not None: @@ -441,8 +437,21 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Compute new cell from deformation gradient new_col_vector_cell = torch.bmm(state.reference_cell, deform_grad_new) - # Apply cell constraints and scale positions to preserve fractional coords + # Apply cell constraints and scale positions to new cell coordinates + # (needed for correct displacement calculation in position constraints) state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) + + # Transform fractional positions to Cartesian using NEW deformation gradient + new_deform_grad = cell_filters.deform_grad( + state.reference_cell.mT, state.row_vector_cell + ) + + state.set_constrained_positions( + torch.bmm( + new_frac_positions.unsqueeze(1), + new_deform_grad[state.system_idx].transpose(-2, -1), + ).squeeze(1) + ) else: state.set_constrained_positions(state.positions + dr_atom) diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 31eba9ea4..3d5bf6fb0 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -15,6 +15,7 @@ import torch + logger = logging.getLogger(__name__) From 22f019dcfa199b88df970aa7cdc227ad97f2833d Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 4 Feb 2026 14:42:55 +0000 Subject: [PATCH 06/16] bump spglib --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa25eef23..a3b27d1b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,10 @@ test = [ "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", - "spglib>=2.5", + "spglib>=2.6", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] -symmetry = ["spglib>=2.5"] +symmetry = ["spglib>=2.6"] mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] From fb6e234ccbad033ee71bcb7899703db5fff809f9 Mon Sep 17 00:00:00 2001 From: Daniel Zuegner Date: Wed, 4 Feb 2026 15:02:54 +0000 Subject: [PATCH 07/16] fix --- torch_sim/optimizers/fire.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 778fbaaea..00503a7b9 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -435,7 +435,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 deform_grad_new = cell_positions_new / cell_factor_expanded # Compute new cell from deformation gradient - new_col_vector_cell = torch.bmm(state.reference_cell, deform_grad_new) + new_col_vector_cell = torch.bmm(deform_grad_new, state.reference_cell) # Apply cell constraints and scale positions to new cell coordinates # (needed for correct displacement calculation in position constraints) From c6a8d0f50e4107f5671618cd7c32ae802b2bd5d4 Mon Sep 17 00:00:00 2001 From: janosh_per Date: Fri, 6 Feb 2026 02:02:06 +0000 Subject: [PATCH 08/16] fix FixSymmetry.merge() producing duplicate system indices with autobatcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When InFlightAutoBatcher concatenates states, it calls merge_constraints() which passes state_indices (the position of each source state in the concatenation list) to FixSymmetry.merge(). The old code used these state_indices as offsets: system_indices.append(offset + i). This breaks when a constraint covers multiple systems. E.g. merging constraint_A (5 systems) with constraint_B (3 systems) using state_indices=[0, 1] produces system_indices=[0,1,2,3,4, 1,2,3] — duplicates at 1,2,3 — triggering "Duplicate system indices found in SystemConstraint". Fix: use a cumulative running offset instead of state_indices, so the merged indices are always [0,1,2,...,N-1]. Co-authored-by: Cursor --- torch_sim/constraints.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 7c29f2dee..221746015 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -971,11 +971,18 @@ def merge( symm_maps = [] system_indices = [] - for constraint, offset in zip(constraints, state_indices, strict=False): - for i in range(len(constraint.rotations)): - rotations.append(constraint.rotations[i]) - symm_maps.append(constraint.symm_maps[i]) - system_indices.append(offset + i) + # Use cumulative system count as offset instead of state_indices directly. + # Each constraint can cover multiple systems, so state_indices (which is the + # position of the source state in the concatenation list) does not account + # for multi-system constraints. Using a running offset avoids duplicate + # system indices when merging constraints from states with different system counts. + cumulative_offset = 0 + for constraint in constraints: + for idx in range(len(constraint.rotations)): + rotations.append(constraint.rotations[idx]) + symm_maps.append(constraint.symm_maps[idx]) + system_indices.append(cumulative_offset + idx) + cumulative_offset += len(constraint.rotations) device = rotations[0].device From eac42b653d8c1de0556a15c773bb9d6af102fbe1 Mon Sep 17 00:00:00 2001 From: janosh_per Date: Fri, 6 Feb 2026 02:13:43 +0000 Subject: [PATCH 09/16] use set_constrained_cell in lbfgs_step for FixSymmetry support lbfgs_step was setting state.row_vector_cell directly, bypassing set_constrained_cell() which applies cell constraints like FixSymmetry. This meant FixSymmetry's adjust_cell (which symmetrizes the deformation gradient) was never called during LBFGS cell optimization. Fix: compute new_col_vector_cell and use set_constrained_cell() with scale_atoms=True, matching the pattern already used in fire_step. Tested on 1 and 100 structures with MACE+cueq LBFGS + FixSymmetry. Co-authored-by: Cursor --- torch_sim/optimizers/lbfgs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index f4dd79a34..a9ac8b718 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -504,10 +504,11 @@ def lbfgs_step( # noqa: PLR0915, C901 deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3] # Update cell: new_cell = reference_cell @ deform_grad^T - # reference_cell.mT: [S, 3, 3], deform_grad_new: [S, 3, 3] - state.row_vector_cell = torch.bmm( - state.reference_cell.mT, deform_grad_new.transpose(-2, -1) + # Use set_constrained_cell to apply cell constraints (e.g. FixSymmetry) + new_col_vector_cell = torch.bmm( + deform_grad_new, state.reference_cell ) # [S, 3, 3] + state.set_constrained_cell(new_col_vector_cell, scale_atoms=True) # Apply position step in fractional space, then convert to Cartesian new_frac = frac_positions + step_positions # [N, 3] From 8aed11d69434d85e897171ff72582d9c24315028 Mon Sep 17 00:00:00 2001 From: janosh Date: Thu, 5 Feb 2026 19:05:32 -0800 Subject: [PATCH 10/16] add regression tests for FixSymmetry fixes, clean up constraint code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_merge_multi_system_constraints_no_duplicate_indices: merges a 3-system and 2-system FixSymmetry constraint and asserts sequential system_idx [0,1,2,3,4] — the old state_indices-based offset produced duplicates [0,1,2,1,2] which triggered SystemConstraint validation - test_lbfgs_cell_optimization_preserves_symmetry: runs LBFGS with cell_filter + FixSymmetry on a BCC structure with noisy forces and asserts spacegroup is preserved — the old lbfgs_step bypassed set_constrained_cell so FixSymmetry.adjust_cell was never called - make adjust_stress/adjust_cell default no-ops in Constraint base class instead of @abstractmethod, removing 4 boilerplate overrides in FixAtoms/FixCom that did nothing - remove __all__ from symmetrize.py - remove unused OptimizationResult TypedDict, structure_with_spacegroup fixture, and unnecessary if-guards on constraint loops in cell_filters Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 102 ++++++++++++++++----------- torch_sim/constraints.py | 52 +++----------- torch_sim/models/lennard_jones.py | 2 +- torch_sim/models/morse.py | 2 +- torch_sim/models/particle_life.py | 2 +- torch_sim/models/soft_sphere.py | 4 +- torch_sim/optimizers/cell_filters.py | 16 ++--- torch_sim/symmetrize.py | 9 --- 8 files changed, 84 insertions(+), 105 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index d74105cf3..ed6029657 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -1,6 +1,6 @@ """Tests for the FixSymmetry constraint.""" -from typing import Literal, TypedDict +from typing import Literal import numpy as np import pytest @@ -24,17 +24,6 @@ from spglib import SpglibDataset # noqa: E402 -class OptimizationResult(TypedDict): - """Return type for run_optimization_check_symmetry.""" - - initial_spacegroups: list[int | None] - final_spacegroups: list[int | None] - initial_datasets: list[SpglibDataset] - final_datasets: list[SpglibDataset] - final_state: ts.SimState - final_atoms_list: list[Atoms] - - # ============================================================================= # Structure Definitions (Single Source of Truth) # ============================================================================= @@ -114,15 +103,6 @@ def make_structure(name: str) -> Atoms: # ============================================================================= -@pytest.fixture(params=["fcc", "hcp", "hcp_supercell", "diamond", "p6bar"]) -def structure_with_spacegroup(request: pytest.FixtureRequest) -> tuple[Atoms, int]: - """Parameterized fixture returning (atoms, expected_spacegroup).""" - name = request.param - atoms = make_structure(name) - base_name = name.replace("_supercell", "") - return atoms, SPACEGROUPS[base_name] - - @pytest.fixture def model() -> LennardJonesModel: """Create a LennardJonesModel for testing.""" @@ -211,10 +191,8 @@ def run_optimization_check_symmetry( symprec: float = SYMPREC, max_steps: int = MAX_STEPS, force_tol: float = 0.001, -) -> OptimizationResult: - """Run optimization and return initial/final symmetry info. - - This is the core helper for testing symmetry preservation during optimization. +) -> dict[str, list[int | None]]: + """Run FIRE optimization and return initial/final space group numbers. Args: state: torch-sim SimState (can be batched) @@ -226,23 +204,14 @@ def run_optimization_check_symmetry( force_tol: Force convergence tolerance Returns: - Dict with keys: - - 'initial_spacegroups': List of initial space group numbers - - 'final_spacegroups': List of final space group numbers - - 'initial_datasets': List of full spglib datasets for initial structures - - 'final_datasets': List of full spglib datasets for final structures - - 'final_state': Final SimState - - 'final_atoms_list': List of final ASE Atoms objects + Dict with 'initial_spacegroups' and 'final_spacegroups' lists. """ - # Get initial symmetry for all systems using torch_sim.symmetrize initial_datasets = get_symmetry_datasets(state, symprec) if constraint is not None: state.constraints = [constraint] - # Run optimization init_kwargs = {"cell_filter": ts.CellFilter.frechet} if adjust_cell else None - # When doing cell optimization, include cell_forces in convergence check convergence_fn = ts.generate_force_convergence_fn( force_tol=force_tol, include_cell_forces=adjust_cell ) @@ -256,17 +225,11 @@ def run_optimization_check_symmetry( steps_between_swaps=1, ) - # Get final symmetry for all systems final_datasets = get_symmetry_datasets(final_state, symprec) - final_atoms_list = final_state.to_atoms() return { "initial_spacegroups": [d.number if d else None for d in initial_datasets], "final_spacegroups": [d.number if d else None for d in final_datasets], - "initial_datasets": initial_datasets, - "final_datasets": final_datasets, - "final_state": final_state, - "final_atoms_list": final_atoms_list, } @@ -487,6 +450,28 @@ def test_merge_two_constraints(self): assert len(merged.symm_maps) == 2 assert merged.system_idx.tolist() == [0, 1] + def test_merge_multi_system_constraints_no_duplicate_indices(self): + """Regression: merging multi-system constraints must not produce duplicates.""" + # Create two batched states so each constraint covers multiple systems + atoms_a = [ + make_structure("fcc"), + make_structure("diamond"), + make_structure("hcp"), + ] + atoms_b = [make_structure("bcc"), make_structure("fcc")] + state_a = ts.io.atoms_to_state(atoms_a, torch.device("cpu"), DTYPE) + state_b = ts.io.atoms_to_state(atoms_b, torch.device("cpu"), DTYPE) + c_a = FixSymmetry.from_state(state_a, symprec=SYMPREC) # 3 systems + c_b = FixSymmetry.from_state(state_b, symprec=SYMPREC) # 2 systems + + # Old bug: state_indices=[0, 1] was used as offsets → [0,1,2, 1,2] (duplicates) + # Fix: cumulative offset → [0,1,2, 3,4] + merged = FixSymmetry.merge([c_a, c_b], state_indices=[0, 1], atom_offsets=None) + + assert len(merged.rotations) == 5 + assert len(merged.symm_maps) == 5 + assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] + @pytest.mark.parametrize("mismatch_field", ["adjust_positions", "adjust_cell"]) def test_merge_mismatched_settings_raises( self, mismatch_field: Literal["adjust_positions", "adjust_cell"] @@ -634,6 +619,41 @@ def test_cell_filter_preserves_symmetry( final_datasets = get_symmetry_datasets(final_state, symprec=SYMPREC) assert initial_datasets[0].number == final_datasets[0].number + @pytest.mark.parametrize("cell_filter", [ts.CellFilter.frechet, ts.CellFilter.unit]) + def test_lbfgs_cell_optimization_preserves_symmetry( + self, + noisy_lj_model: NoisyModelWrapper, + cell_filter: ts.CellFilter, + ): + """Regression: LBFGS must use set_constrained_cell for FixSymmetry support.""" + state = ts.io.atoms_to_state(make_structure("bcc"), torch.device("cpu"), DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + state.constraints = [constraint] + + # Compress cell to create forces + state.cell = state.cell * 0.95 + state.positions = state.positions * 0.95 + + initial_datasets = get_symmetry_datasets(state, symprec=SYMPREC) + assert initial_datasets[0].number == SPACEGROUPS["bcc"] + + final_state = ts.optimize( + system=state, + model=noisy_lj_model, + optimizer=ts.Optimizer.lbfgs, + convergence_fn=ts.generate_force_convergence_fn( + force_tol=0.01, include_cell_forces=True + ), + init_kwargs={"cell_filter": cell_filter}, + max_steps=MAX_STEPS, + ) + + final_datasets = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert final_datasets[0].number == SPACEGROUPS["bcc"], ( + f"LBFGS+{cell_filter} lost symmetry: {SPACEGROUPS['bcc']} -> " + f"{final_datasets[0].number}" + ) + @pytest.mark.parametrize("rotated", [False, True]) def test_noisy_model_loses_symmetry_without_constraint( self, noisy_lj_model: NoisyModelWrapper, *, rotated: bool diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 221746015..b1aaa34a2 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -84,22 +84,24 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces: Forces to be adjusted """ - @abstractmethod - def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: + def adjust_stress( # noqa: B027 + self, state: SimState, stress: torch.Tensor + ) -> None: """Adjust stress tensor to satisfy the constraint. - This method should modify stress in-place. + Default is a no-op. Override in subclasses that need stress symmetrization. Args: state: Current simulation state stress: Stress tensor to be adjusted in-place """ - @abstractmethod - def adjust_cell(self, state: SimState, cell: torch.Tensor) -> None: + def adjust_cell( # noqa: B027 + self, state: SimState, cell: torch.Tensor + ) -> None: """Adjust cell to satisfy the constraint. - This method should modify cell in-place. + Default is a no-op. Override in subclasses that need cell symmetrization. Args: state: Current simulation state @@ -493,20 +495,6 @@ def adjust_forces( """ forces[self.atom_idx] = 0.0 - def adjust_stress( - self, - state: SimState, - stress: torch.Tensor, - ) -> None: - """No stress adjustment needed for FixAtoms.""" - - def adjust_cell( - self, - state: SimState, - cell: torch.Tensor, - ) -> None: - """No cell adjustment needed for FixAtoms.""" - def __repr__(self) -> str: """String representation of the constraint.""" if len(self.atom_idx) <= 10: @@ -624,20 +612,6 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: forces_change[self.system_idx] = lmd[self.system_idx] forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) - def adjust_stress( - self, - state: SimState, - stress: torch.Tensor, - ) -> None: - """No stress adjustment needed for FixCom.""" - - def adjust_cell( - self, - state: SimState, - cell: torch.Tensor, - ) -> None: - """No cell adjustment needed for FixCom.""" - def __repr__(self) -> str: """String representation of the constraint.""" return f"FixCom(system_idx={self.system_idx})" @@ -927,14 +901,14 @@ def from_state( def merge( cls, constraints: list[Self], - state_indices: list[int], + state_indices: list[int], # noqa: ARG003 atom_offsets: torch.Tensor, # noqa: ARG003 ) -> Self: """Merge multiple FixSymmetry constraints into one. Args: constraints: List of FixSymmetry constraints to merge. - state_indices: Index of the source state for each constraint. + state_indices: Index of the source state for each constraint (unused). atom_offsets: Cumulative atom counts (unused for FixSymmetry). Returns: @@ -971,11 +945,7 @@ def merge( symm_maps = [] system_indices = [] - # Use cumulative system count as offset instead of state_indices directly. - # Each constraint can cover multiple systems, so state_indices (which is the - # position of the source state in the concatenation list) does not account - # for multi-system constraints. Using a running offset avoids duplicate - # system indices when merging constraints from states with different system counts. + # Use cumulative offset (not state_indices) to handle multi-system constraints cumulative_offset = 0 for constraint in constraints: for idx in range(len(constraint.rotations)): diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index f54e9fe35..788a437f4 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -275,7 +275,7 @@ def unbatched_forward( ) if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( + mapping, _system_mapping, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 78c62f3f3..5444422a6 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -281,7 +281,7 @@ def unbatched_forward( ) if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( + mapping, _system_mapping, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index baa1f8520..1031a073e 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -164,7 +164,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( + mapping, _system_mapping, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 60d647829..d2c537787 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -299,7 +299,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( + mapping, _system_mapping, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, @@ -727,7 +727,7 @@ def unbatched_forward( # noqa: PLR0915 system_idx = torch.zeros( positions.shape[0], dtype=torch.long, device=self.device ) - mapping, _, shifts_idx = torchsim_nl( + mapping, _system_mapping, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=self.pbc, diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 1850234f6..78576059f 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -112,10 +112,9 @@ def unit_cell_filter_init[T: AnyCellState]( # Calculate initial cell forces stress = model_output["stress"].clone() - # Apply stress constraints if any exist - if state.constraints: - for constraint in state.constraints: - constraint.adjust_stress(state, stress) + # Apply stress constraints (e.g. FixSymmetry) + for constraint in state.constraints: + constraint.adjust_stress(state, stress) volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) @@ -170,7 +169,7 @@ def frechet_cell_filter_init[T: AnyCellState]( # Calculate initial cell forces using Frechet approach stress = model_output["stress"].clone() - # Apply stress constraints if present + # Apply stress constraints (e.g. FixSymmetry) for constraint in state.constraints: constraint.adjust_stress(state, stress) @@ -273,10 +272,9 @@ def compute_cell_forces[T: AnyCellState]( """Compute cell forces for both unit and frechet methods.""" stress = model_output["stress"].clone() - # Apply stress constraints if any exist - if state.constraints: - for constraint in state.constraints: - constraint.adjust_stress(state, stress) + # Apply stress constraints (e.g. FixSymmetry) + for constraint in state.constraints: + constraint.adjust_stress(state, stress) volumes = torch.linalg.det(state.cell).view(state.n_systems, 1, 1) virial = -volumes * (stress + state.pressure) diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 3d5bf6fb0..2274cd64c 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -25,15 +25,6 @@ from torch_sim.state import SimState -__all__ = [ - "build_symmetry_map", - "get_symmetry_datasets", - "refine_symmetry", - "symmetrize_rank1", - "symmetrize_rank2", -] - - def _get_symmetry_dataset( cell: torch.Tensor, scaled_positions: torch.Tensor, From 73532ac6388e0ca7792f8e71c662e1f8053df72e Mon Sep 17 00:00:00 2001 From: janosh Date: Thu, 5 Feb 2026 19:33:32 -0800 Subject: [PATCH 11/16] switch from spglib to moyopy, address PR review comments - replace all spglib usage with moyopy (moyo's Python bindings) for symmetry detection, refinement, and dataset queries - rewrite refine_symmetry using metric tensor polar decomposition + periodic-aware position averaging (no longer depends on spglib's standardized cell pipeline) - FixSymmetry.get_removed_dof returns 0 tensor instead of raising NotImplementedError so temperature calculations work in MD - rename sys_idx_local/sys_idx_global to constraint_idx/sys_idx for clarity (thomasloux review) - add cross-reference to transforms.get_fractional_coordinates in _get_scaled_positions docstring - document n_ops meaning and unbatched nature in symmetrize.py module docstring Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 44 +++--- torch_sim/constraints.py | 62 ++++---- torch_sim/symmetrize.py | 303 +++++++++++++++---------------------- 3 files changed, 170 insertions(+), 239 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index ed6029657..acb82aea8 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -19,9 +19,10 @@ from torch_sim.symmetrize import get_symmetry_datasets -# Skip all tests if spglib is not available +# Skip all tests if moyopy is not available +moyopy = pytest.importorskip("moyopy") +# spglib still needed for ASE comparison tests spglib = pytest.importorskip("spglib") -from spglib import SpglibDataset # noqa: E402 # ============================================================================= @@ -172,9 +173,9 @@ def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: # ============================================================================= -def get_symmetry_dataset_from_atoms( +def get_spglib_dataset_from_atoms( atoms: Atoms, symprec: float = SYMPREC -) -> SpglibDataset: +) -> spglib.SpglibDataset: """Get full symmetry dataset for an ASE Atoms object using spglib directly.""" return spglib.get_symmetry_dataset( (atoms.cell[:], atoms.get_scaled_positions(), atoms.numbers), @@ -288,30 +289,25 @@ def test_p1_identity_only(self): assert torch.allclose(stress, original_stress, atol=1e-10) def test_symmetry_datasets_match_spglib(self): - """Test get_symmetry_datasets matches spglib for single and batched states.""" + """Test get_symmetry_datasets space groups match spglib.""" atoms_list = [make_structure(name) for name in ["fcc", "diamond", "hcp"]] # Test batched state batched_state = ts.io.atoms_to_state(atoms_list, torch.device("cpu"), DTYPE) - ts_datasets = get_symmetry_datasets(batched_state, SYMPREC) - assert len(ts_datasets) == 3 + moyo_datasets = get_symmetry_datasets(batched_state, SYMPREC) + assert len(moyo_datasets) == 3 - # Compare each with direct spglib call (covers both single and batched) - for i, atoms in enumerate(atoms_list): - spglib_dataset = get_symmetry_dataset_from_atoms(atoms, SYMPREC) + # Compare space group numbers with spglib + for idx, atoms in enumerate(atoms_list): + spglib_dataset = get_spglib_dataset_from_atoms(atoms, SYMPREC) - # Compare key fields - assert ts_datasets[i].number == spglib_dataset.number, ( + assert moyo_datasets[idx].number == spglib_dataset.number, ( f"Space group mismatch for {atoms.get_chemical_formula()}: " - f"{ts_datasets[i].number} vs {spglib_dataset.number}" - ) - assert ts_datasets[i].international == spglib_dataset.international - assert ts_datasets[i].hall == spglib_dataset.hall - assert len(ts_datasets[i].rotations) == len(spglib_dataset.rotations) - assert np.allclose(ts_datasets[i].rotations, spglib_dataset.rotations) - assert np.allclose( - ts_datasets[i].translations, spglib_dataset.translations, atol=1e-10 + f"moyopy={moyo_datasets[idx].number} vs " + f"spglib={spglib_dataset.number}" ) + # Both should find the same number of symmetry operations + assert len(moyo_datasets[idx].operations) == len(spglib_dataset.rotations) class TestFixSymmetryComparisonWithASE: @@ -706,13 +702,13 @@ def test_noisy_model_preserves_symmetry_with_constraint( class TestFixSymmetryEdgeCases: """Tests for edge cases and error handling.""" - def test_get_removed_dof_raises(self): - """Test that get_removed_dof raises NotImplementedError.""" + def test_get_removed_dof_returns_zero(self): + """Test get_removed_dof returns zero (constrains direction, not DOF count).""" state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - with pytest.raises(NotImplementedError, match="get_removed_dof"): - constraint.get_removed_dof(state) + dof = constraint.get_removed_dof(state) + assert torch.all(dof == 0) def test_large_deformation_gradient_raises(self): """Test that large deformation gradient raises RuntimeError.""" diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index b1aaa34a2..59ac540ee 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -726,7 +726,7 @@ class FixSymmetry(SystemConstraint): according to the crystal symmetry operations. Each system in a batch can have different symmetry operations. - Requires the spglib package to be available for automatic symmetry detection. + Requires the moyopy package to be available for automatic symmetry detection. The constraint works by: - Symmetrizing forces/momenta as rank-1 tensors using all symmetry operations @@ -826,7 +826,7 @@ def from_state( Args: state: SimState containing one or more systems. - symprec: Symmetry precision for spglib. + symprec: Symmetry precision for moyopy. adjust_positions: Whether to symmetrize position adjustments. adjust_cell: Whether to symmetrize cell/stress adjustments. refine_symmetry_state: Whether to refine the state's positions and cell @@ -839,9 +839,12 @@ def from_state( FixSymmetry constraint configured for the state's structures. """ try: - import spglib # noqa: F401 + import moyopy # noqa: F401 except ImportError: - raise ImportError("spglib is required for FixSymmetry.from_state") from None + raise ImportError( + "moyopy is required for FixSymmetry.from_state. " + "Install with: pip install moyopy" + ) from None rotations = [] symm_maps = [] @@ -967,16 +970,16 @@ def merge( def get_removed_dof(self, state: SimState) -> torch.Tensor: """Get number of removed degrees of freedom. - FixSymmetry doesn't explicitly remove DOF in the same way as FixAtoms. - This matches ASE's FixSymmetry behavior which also raises NotImplementedError. + FixSymmetry constrains motion direction rather than removing explicit DOF, + so returns 0 to avoid breaking temperature calculations in MD. Args: state: Simulation state - Raises: - NotImplementedError: FixSymmetry does not support DOF counting. + Returns: + Zero tensor of shape (n_systems,) """ - raise NotImplementedError("FixSymmetry does not implement get_removed_dof.") + return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: """Symmetrize position displacements. @@ -1041,10 +1044,10 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: dtype = state.dtype identity = torch.eye(3, device=device, dtype=dtype) - for sys_idx_local, sys_idx_global in enumerate(self.system_idx): + for constraint_idx, sys_idx in enumerate(self.system_idx): # Get current and new cells in row vector convention - cur_cell = state.row_vector_cell[sys_idx_global] - new_cell_row = new_cell[sys_idx_global].mT + cur_cell = state.row_vector_cell[sys_idx] + new_cell_row = new_cell[sys_idx].mT # Calculate deformation gradient cur_cell_inv = torch.linalg.inv(cur_cell) @@ -1067,13 +1070,12 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: ) # Symmetrize deformation gradient directly - symmetrized_delta = symmetrize_rank2( - cur_cell, delta_deform_grad, self.rotations[sys_idx_local].to(dtype=dtype) - ) + rots = self.rotations[constraint_idx].to(dtype=dtype) + symmetrized_delta = symmetrize_rank2(cur_cell, delta_deform_grad, rots) # Reconstruct cell and update in-place new_cell_row_sym = cur_cell @ (symmetrized_delta + identity).mT - new_cell[sys_idx_global] = new_cell_row_sym.mT # Back to column convention + new_cell[sys_idx] = new_cell_row_sym.mT # Back to column convention def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: """Symmetrize stress tensor in-place. @@ -1084,19 +1086,17 @@ def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: """ dtype = stress.dtype - for sys_idx_local, sys_idx_global in enumerate(self.system_idx): - # Get current cell and symmetrize stress directly - cur_cell = state.row_vector_cell[sys_idx_global] - sys_stress = stress[sys_idx_global] + for constraint_idx, sys_idx in enumerate(self.system_idx): + cur_cell = state.row_vector_cell[sys_idx] symmetrized = symmetrize_rank2( - cur_cell, sys_stress, self.rotations[sys_idx_local].to(dtype=dtype) + cur_cell, stress[sys_idx], self.rotations[constraint_idx].to(dtype=dtype) ) - stress[sys_idx_global] = symmetrized + stress[sys_idx] = symmetrized def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: """Symmetrize rank-1 tensors (forces, momenta, displacements) in-place. - Uses fractional-coordinate rotations from spglib together with the current + Uses fractional-coordinate rotations from moyopy together with the current cell to transform vectors. The cell is fetched at runtime to ensure correctness during variable-cell relaxation. @@ -1114,22 +1114,18 @@ def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: ) dtype = vectors.dtype - for sys_idx_local, sys_idx_global in enumerate(self.system_idx): - start = cumsum[sys_idx_global].item() - end = cumsum[sys_idx_global + 1].item() + for constraint_idx, sys_idx in enumerate(self.system_idx): + start = cumsum[sys_idx].item() + end = cumsum[sys_idx + 1].item() - # Extract vectors for this system sys_vectors = vectors[start:end] + cell = state.row_vector_cell[sys_idx] - # Get current cell for this system - cell = state.row_vector_cell[sys_idx_global] - - # Symmetrize directly symmetrized = symmetrize_rank1( cell, sys_vectors, - self.rotations[sys_idx_local].to(dtype=dtype), - self.symm_maps[sys_idx_local], + self.rotations[constraint_idx].to(dtype=dtype), + self.symm_maps[constraint_idx], ) # Update in place diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 2274cd64c..600494f27 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -1,11 +1,14 @@ """Symmetry refinement utilities for crystal structures. This module provides functions for refining and symmetrizing atomic structures -using spglib. It is adapted from ASE's spacegroup.symmetrize module but -reimplemented to work with torch tensors directly. +using moyopy (Python bindings for the moyo crystal symmetry library). The main entry point is `refine_symmetry` which symmetrizes both the cell and atomic positions according to the detected space group symmetry. + +Note: Functions in this module operate on single (unbatched) systems. +The `n_ops` dimension refers to the number of symmetry operations +(rotations + translations) of the space group. """ from __future__ import annotations @@ -20,52 +23,50 @@ if TYPE_CHECKING: - from spglib import SpglibDataset + from moyopy import MoyoDataset from torch_sim.state import SimState -def _get_symmetry_dataset( +def _get_moyo_dataset( cell: torch.Tensor, scaled_positions: torch.Tensor, atomic_numbers: torch.Tensor, - symprec: float = 1.0e-6, -) -> SpglibDataset | None: - """Get symmetry dataset from spglib. + symprec: float = 1.0e-4, +) -> MoyoDataset: + """Get symmetry dataset from moyopy. Args: cell: Unit cell as row vectors, shape (3, 3) scaled_positions: Fractional coordinates, shape (n_atoms, 3) atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision + symprec: Symmetry precision in units of cell basis vectors Returns: - Symmetry dataset with attribute access + MoyoDataset with symmetry information """ - import spglib + from moyopy import Cell, MoyoDataset - # Convert tensors to numpy for spglib - cell_np = cell.detach().cpu().numpy() - positions_np = scaled_positions.detach().cpu().numpy() - numbers_np = atomic_numbers.detach().cpu().numpy() + cell_list = cell.detach().cpu().tolist() + positions_list = scaled_positions.detach().cpu().tolist() + numbers_list = atomic_numbers.detach().cpu().int().tolist() - cell_tuple = (cell_np, positions_np, numbers_np) - return spglib.get_symmetry_dataset(cell_tuple, symprec=symprec) + moyo_cell = Cell(basis=cell_list, positions=positions_list, numbers=numbers_list) + return MoyoDataset(moyo_cell, symprec=symprec) def get_symmetry_datasets( state: SimState, - symprec: float = 1.0e-6, -) -> list[SpglibDataset | None]: + symprec: float = 1.0e-4, +) -> list[MoyoDataset]: """Get symmetry datasets for all systems in a SimState. Args: state: SimState containing one or more systems - symprec: Symmetry precision for spglib + symprec: Symmetry precision for moyopy Returns: - List of spglib symmetry datasets, one per system in the state. - Returns None for systems where symmetry detection fails. + List of MoyoDataset objects, one per system in the state. """ datasets = [] @@ -73,10 +74,9 @@ def get_symmetry_datasets( cell = single_state.row_vector_cell[0] positions = single_state.positions - # Compute scaled (fractional) positions for this system scaled_positions = _get_scaled_positions(positions, cell) - dataset = _get_symmetry_dataset( + dataset = _get_moyo_dataset( cell=cell, scaled_positions=scaled_positions, atomic_numbers=single_state.atomic_numbers, @@ -91,7 +91,9 @@ def _get_scaled_positions( positions: torch.Tensor, cell: torch.Tensor, ) -> torch.Tensor: - """Convert Cartesian positions to fractional coordinates. + """Convert Cartesian positions to fractional coordinates (unbatched). + + See also ``transforms.get_fractional_coordinates`` for the batched version. Args: positions: Cartesian positions, shape (n_atoms, 3) @@ -100,93 +102,7 @@ def _get_scaled_positions( Returns: Fractional coordinates, shape (n_atoms, 3) """ - inv_cell = torch.linalg.inv(cell) - return positions @ inv_cell - - -def _symmetrize_cell( - cell: torch.Tensor, - dataset: SpglibDataset, -) -> torch.Tensor: - """Symmetrize the cell based on the symmetry dataset. - - Args: - cell: Unit cell as row vectors, shape (3, 3) - dataset: spglib symmetry dataset - - Returns: - Symmetrized cell as row vectors, shape (3, 3) - """ - device = cell.device - dtype = cell.dtype - - # Get standardized cell and apply transformations - std_cell = torch.as_tensor(dataset.std_lattice, dtype=dtype, device=device) - trans_matrix = torch.as_tensor( - dataset.transformation_matrix, dtype=dtype, device=device - ) - rot_matrix = torch.as_tensor(dataset.std_rotation_matrix, dtype=dtype, device=device) - - trans_std_cell = trans_matrix.T @ std_cell - return trans_std_cell @ rot_matrix - - -def _symmetrize_positions( - positions: torch.Tensor, - dataset: SpglibDataset, - primitive_cell: tuple, -) -> torch.Tensor: - """Symmetrize atomic positions. - - Args: - positions: Cartesian positions, shape (n_atoms, 3) - dataset: spglib symmetry dataset - primitive_cell: Result from spglib.find_primitive (cell, positions, numbers) - - Returns: - Symmetrized Cartesian positions, shape (n_atoms, 3) - """ - device = positions.device - dtype = positions.dtype - - prim_cell_np, _prim_scaled_pos, _prim_types = primitive_cell - prim_cell = torch.as_tensor(prim_cell_np, dtype=dtype, device=device) - - # Calculate offset between standard cell and actual cell - std_cell = torch.as_tensor(dataset.std_lattice, dtype=dtype, device=device) - rot_matrix = torch.as_tensor(dataset.std_rotation_matrix, dtype=dtype, device=device) - std_positions = torch.as_tensor(dataset.std_positions, dtype=dtype, device=device) - - rot_std_cell = std_cell @ rot_matrix - rot_std_pos = std_positions @ rot_std_cell - - # Get mapping indices - mapping_to_primitive = list(dataset.mapping_to_primitive) - std_mapping_to_primitive = list(dataset.std_mapping_to_primitive) - - dp0 = ( - positions[mapping_to_primitive.index(0)] - - rot_std_pos[std_mapping_to_primitive.index(0)] - ) - - # Create aligned set of standard cell positions - rot_prim_cell = prim_cell @ rot_matrix - inv_rot_prim_cell = torch.linalg.inv(rot_prim_cell) - aligned_std_pos = rot_std_pos + dp0 - - # Find ideal positions - new_positions = positions.clone() - n_atoms = positions.shape[0] - - for i_at in range(n_atoms): - std_i_at = std_mapping_to_primitive.index(mapping_to_primitive[i_at]) - dp = aligned_std_pos[std_i_at] - positions[i_at] - dp_s = dp @ inv_rot_prim_cell - new_positions[i_at] = ( - aligned_std_pos[std_i_at] - torch.round(dp_s) @ rot_prim_cell - ) - - return new_positions + return positions @ torch.linalg.inv(cell) def refine_symmetry( @@ -199,137 +115,154 @@ def refine_symmetry( ) -> tuple[torch.Tensor, torch.Tensor]: """Refine symmetry of a structure. - This function symmetrizes both the cell and atomic positions according - to the detected space group symmetry. + Symmetrizes both cell vectors and atomic positions by averaging + over the detected symmetry operations using polar decomposition + for the cell metric and scatter-add averaging for positions. The refinement process: - 1. Detect symmetry of the input structure - 2. Symmetrize the cell vectors to match the ideal lattice - 3. Symmetrize atomic positions to ideal Wyckoff positions + 1. Detect symmetry operations of the input structure + 2. Symmetrize the cell metric tensor (preserving cell orientation) + 3. Symmetrize atomic positions by averaging over symmetry orbits Args: cell: Unit cell as row vectors, shape (3, 3) positions: Cartesian positions, shape (n_atoms, 3) atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision for spglib - verbose: If True, print symmetry information before and after + symprec: Symmetry precision for moyopy + verbose: If True, log symmetry information before and after Returns: Tuple of (symmetrized_cell, symmetrized_positions): - symmetrized_cell: Symmetrized cell as row vectors, shape (3, 3) - symmetrized_positions: Symmetrized Cartesian positions, shape (n_atoms, 3) """ - import spglib + device = cell.device + dtype = cell.dtype - # Step 1: Check and symmetrize cell + # Step 1: Detect symmetry scaled_positions = _get_scaled_positions(positions, cell) - dataset = _get_symmetry_dataset(cell, scaled_positions, atomic_numbers, symprec) - - if dataset is None: - raise RuntimeError("spglib could not determine symmetry for structure") + dataset = _get_moyo_dataset(cell, scaled_positions, atomic_numbers, symprec) if verbose: logger.info( - "symmetrize: prec %s got symmetry group number %s, " - "international (Hermann-Mauguin) %s, Hall %s", + "symmetrize: prec %s got space group number %s", symprec, dataset.number, - dataset.international, - dataset.hall, ) - new_cell = _symmetrize_cell(cell, dataset) - - # Scale positions to new cell - new_positions = scaled_positions @ new_cell - - # Step 2: Check and symmetrize positions with the new cell - new_scaled_positions = _get_scaled_positions(new_positions, new_cell) - dataset = _get_symmetry_dataset( - new_cell, new_scaled_positions, atomic_numbers, symprec + rotations = torch.as_tensor( + dataset.operations.rotations, dtype=dtype, device=device + ).round() + translations = torch.as_tensor( + dataset.operations.translations, dtype=dtype, device=device ) + n_ops = rotations.shape[0] - if dataset is None: - raise RuntimeError("spglib could not determine symmetry after cell refinement") - - # Find primitive cell - cell_np = new_cell.detach().cpu().numpy() - positions_np = new_scaled_positions.detach().cpu().numpy() - numbers_np = atomic_numbers.detach().cpu().numpy() + # Step 2: Symmetrize cell via metric tensor + polar decomposition + # Row-vector metric: g[i,j] = a_i · a_j = (cell @ cell.T)[i,j] + # Symmetry invariance: R.T @ g @ R = g for all rotations R + metric = cell @ cell.T + metric_sym = torch.einsum("nji,jk,nkl->il", rotations, metric, rotations) / n_ops + + # Left polar decomposition: cell = P @ V where P = sqrt(metric) + # Keep same orientation V but with symmetrized metric P_sym + sqrt_metric = _matrix_sqrt(metric) + sqrt_metric_sym = _matrix_sqrt(metric_sym) + new_cell = sqrt_metric_sym @ torch.linalg.inv(sqrt_metric) @ cell + + # Step 3: Symmetrize positions by averaging displacements over symmetry orbits + # Recompute fractional coords in the symmetrized cell + new_frac = positions @ torch.linalg.inv(new_cell) + symm_map = build_symmetry_map(rotations, translations, new_frac) + + # For each op, transform fractional positions: R @ frac + t + new_frac_all = ( + torch.einsum("oij,nj->oni", rotations, new_frac) + translations[:, None, :] + ) # (n_ops, n_atoms, 3) + # Compute displacement from target atom's current position, wrapped for periodicity + n_atoms = positions.shape[0] + target_frac = new_frac[symm_map] # (n_ops, n_atoms, 3) + displacement = new_frac_all - target_frac + displacement -= displacement.round() # wrap into [-0.5, 0.5] - primitive_result = spglib.find_primitive( - (cell_np, positions_np, numbers_np), symprec=symprec - ) - if primitive_result is None: - raise RuntimeError("spglib could not find primitive cell") + # Scatter-add wrapped displacements to target atoms and average + target = symm_map.reshape(-1).unsqueeze(-1).expand(-1, 3) + accum = torch.zeros(n_atoms, 3, dtype=dtype, device=device) + accum.scatter_add_(0, target, displacement.reshape(-1, 3)) + sym_frac = new_frac + accum / n_ops - new_positions = _symmetrize_positions(new_positions, dataset, primitive_result) + new_positions = sym_frac @ new_cell - # Final check if verbose: final_scaled = _get_scaled_positions(new_positions, new_cell) - final_dataset = _get_symmetry_dataset( - new_cell, final_scaled, atomic_numbers, 1e-4 + final_dataset = _get_moyo_dataset(new_cell, final_scaled, atomic_numbers, 1e-4) + logger.info( + "symmetrize: prec 1e-4 got space group number %s", + final_dataset.number, ) - if final_dataset is not None: - logger.info( - "symmetrize: prec 1e-4 got symmetry group number %s, " - "international (Hermann-Mauguin) %s, Hall %s", - final_dataset.number, - final_dataset.international, - final_dataset.hall, - ) return new_cell, new_positions +def _matrix_sqrt(mat: torch.Tensor) -> torch.Tensor: + """Compute matrix square root of a symmetric positive-definite matrix. + + Uses eigendecomposition: sqrt(A) = Q @ diag(sqrt(eigenvalues)) @ Q.T + + Args: + mat: Symmetric positive-definite matrix, shape (3, 3) + + Returns: + Matrix square root, shape (3, 3) + """ + eigenvalues, eigenvectors = torch.linalg.eigh(mat) + return eigenvectors @ torch.diag(eigenvalues.sqrt()) @ eigenvectors.T + + def _prep_symmetry( cell: torch.Tensor, positions: torch.Tensor, atomic_numbers: torch.Tensor, - symprec: float = 1.0e-6, + symprec: float = 1.0e-4, *, verbose: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare structure for symmetry-preserving minimization. - This function determines the symmetry operations and atom mappings - needed for symmetry-constrained optimization. + Determines the symmetry operations (rotations in fractional coordinates) + and atom mappings needed for symmetry-constrained optimization. Args: cell: Unit cell as row vectors, shape (3, 3) positions: Cartesian positions, shape (n_atoms, 3) atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision for spglib - verbose: If True, print symmetry information + symprec: Symmetry precision for moyopy + verbose: If True, log symmetry information Returns: Tuple of (rotations, symm_map): - - rotations: Rotation matrices, shape (n_ops, 3, 3) + - rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) - symm_map: Atom mapping tensor, shape (n_ops, n_atoms) """ device = cell.device dtype = cell.dtype scaled_positions = _get_scaled_positions(positions, cell) - dataset = _get_symmetry_dataset(cell, scaled_positions, atomic_numbers, symprec) - - if dataset is None: - raise RuntimeError("spglib could not determine symmetry for structure") + dataset = _get_moyo_dataset(cell, scaled_positions, atomic_numbers, symprec) if verbose: logger.info( - "symmetrize: prec %s got symmetry group number %s, " - "international (Hermann-Mauguin) %s, Hall %s", + "symmetrize: prec %s got space group number %s, n_ops %d", symprec, dataset.number, - dataset.international, - dataset.hall, + len(dataset.operations), ) - rotations = torch.as_tensor(dataset.rotations.copy(), dtype=dtype, device=device) + rotations = torch.as_tensor( + dataset.operations.rotations, dtype=dtype, device=device + ).round() translations = torch.as_tensor( - dataset.translations.copy(), dtype=dtype, device=device + dataset.operations.translations, dtype=dtype, device=device ) # Build symmetry mapping @@ -345,11 +278,12 @@ def build_symmetry_map( ) -> torch.Tensor: """Build symmetry atom mapping for each symmetry operation. - For each symmetry operation, determines which atom each atom maps to. + For each symmetry operation (R, t), determines which atom each atom + maps to: atom i → atom j where R @ frac_i + t ≈ frac_j (mod 1). Args: - rotations: Rotation matrices, shape (n_ops, 3, 3) - translations: Translation vectors, shape (n_ops, 3) + rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) + translations: Translation vectors in fractional coords, shape (n_ops, 3) scaled_positions: Fractional coordinates, shape (n_atoms, 3) Returns: @@ -380,10 +314,13 @@ def symmetrize_rank1( ) -> torch.Tensor: """Symmetrize rank-1 tensor (forces, velocities, etc). + Averages the tensor over all symmetry operations, respecting atom + permutations. Works in fractional coordinates internally. + Args: lattice: Cell vectors as row vectors, shape (3, 3) forces: Forces array, shape (n_atoms, 3) - rotations: Rotation matrices, shape (n_ops, 3, 3) + rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) symm_map: Atom mapping for each symmetry operation, shape (n_ops, n_atoms) Returns: @@ -429,10 +366,12 @@ def symmetrize_rank2( ) -> torch.Tensor: """Symmetrize rank-2 tensor (stress, strain, etc). + Averages the tensor over all symmetry operations in scaled coordinates. + Args: lattice: Cell vectors as row vectors, shape (3, 3) stress: Stress tensor, shape (3, 3) - rotations: Rotation matrices, shape (n_ops, 3, 3) + rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) Returns: Symmetrized stress tensor, shape (3, 3) From aff7b65b5b23d1ea1995f75c43d902abbb2d8a3b Mon Sep 17 00:00:00 2001 From: janosh Date: Fri, 6 Feb 2026 08:07:05 -0800 Subject: [PATCH 12/16] simplify FixSymmetry implementation, improve robustness - Rewrite symmetrize.py: inline trivial helpers, extract shared moyopy logic into _extract_symmetry_ops and _refine_symmetry_impl - Add chunked fallback in build_symmetry_map for large systems (>200 atoms) to avoid O(n_ops * n_atoms^2) OOM - Add refine_and_prep_symmetry() to combine refinement + symmetry detection in a single moyopy call, eliminating redundant C-library invocation - Lazy-import symmetrize functions inside FixSymmetry methods so importing constraints.py for FixAtoms/FixCom doesn't load the symmetry module - Simplify FixSymmetry.adjust_cell by removing confusing .mT transpose dance - Revert unrelated _system_mapping renames in model files - Add missing tests: position symmetrization vs ASE, refine_symmetry correctness, moyopy to pyproject.toml symmetry extra Co-authored-by: Cursor --- pyproject.toml | 3 +- tests/test_fix_symmetry.py | 693 +++++++++------------------ torch_sim/constraints.py | 465 +++++------------- torch_sim/models/lennard_jones.py | 2 +- torch_sim/models/morse.py | 2 +- torch_sim/models/particle_life.py | 2 +- torch_sim/models/soft_sphere.py | 4 +- torch_sim/optimizers/cell_filters.py | 35 +- torch_sim/symmetrize.py | 443 ++++++----------- 9 files changed, 525 insertions(+), 1124 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3b27d1b5..1428e67b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,11 @@ test = [ "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", + "moyopy>=0.3", "spglib>=2.6", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] -symmetry = ["spglib>=2.6"] +symmetry = ["moyopy>=0.3"] mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index acb82aea8..ffd328017 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -1,7 +1,5 @@ """Tests for the FixSymmetry constraint.""" -from typing import Literal - import numpy as np import pytest import torch @@ -19,37 +17,20 @@ from torch_sim.symmetrize import get_symmetry_datasets -# Skip all tests if moyopy is not available -moyopy = pytest.importorskip("moyopy") -# spglib still needed for ASE comparison tests -spglib = pytest.importorskip("spglib") - - -# ============================================================================= -# Structure Definitions (Single Source of Truth) -# ============================================================================= +pytest.importorskip("moyopy") +pytest.importorskip("spglib") # needed by ASE's FixSymmetry -# Expected space groups for each structure type -SPACEGROUPS = { - "fcc": 225, # Fm-3m - "hcp": 194, # P6_3/mmc - "diamond": 227, # Fd-3m - "bcc": 229, # Im-3m - "p6bar": 174, # P-6 (low symmetry) -} - -# Default maximum optimization steps for tests +SPACEGROUPS = {"fcc": 225, "hcp": 194, "diamond": 227, "bcc": 229, "p6bar": 174} MAX_STEPS = 30 - -# Default dtype for tests (torch.float64 recommended for numerical precision) DTYPE = torch.float64 - -# Default symmetry precision for spglib SYMPREC = 0.01 +# === Structure helpers === + + def _make_p6bar() -> Atoms: - """Create low-symmetry P-6 (space group 174) structure using pymatgen.""" + """Create P-6 (space group 174) structure.""" lattice = Lattice.hexagonal(a=3.0, c=5.0) structure = Structure.from_spacegroup( sg=174, lattice=lattice, species=["Si"], coords=[[0.3, 0.1, 0.25]] @@ -58,55 +39,39 @@ def _make_p6bar() -> Atoms: def make_structure(name: str) -> Atoms: - """Create a standard test structure by name. - - This is the single source of truth for test structures. - Use this instead of inline bulk() calls to avoid duplication. - - Args: - name: One of "fcc", "hcp", "diamond", "bcc", "p6bar" with optional - "_supercell" and/or "_rotated" suffix - - Returns: - ASE Atoms object - """ - base_name = name.replace("_supercell", "").replace("_rotated", "") - structures = { + """Create a test structure by name (fcc/hcp/diamond/bcc/p6bar + _rotated suffix).""" + base = name.replace("_rotated", "") + builders = { "fcc": lambda: bulk("Cu", "fcc", a=3.6), "hcp": lambda: bulk("Ti", "hcp", a=2.95, c=4.68), "diamond": lambda: bulk("Si", "diamond", a=5.43), "bcc": lambda: bulk("Al", "bcc", a=2 / np.sqrt(3), cubic=True), "p6bar": _make_p6bar, } - atoms = structures[base_name]() - if "_supercell" in name: - atoms = atoms * (2, 2, 2) + atoms = builders[base]() if "_rotated" in name: - # Apply 3 rotation matrices (matching ASE's test setup) - F = np.eye(3) - for k in range(3): - L = list(range(3)) - L.remove(k) - (i, j) = L - R = np.eye(3) - theta = 0.1 * (k + 1) - R[i, i] = np.cos(theta) - R[j, j] = np.cos(theta) - R[i, j] = np.sin(theta) - R[j, i] = -np.sin(theta) - F = np.dot(F, R) - atoms.set_cell(atoms.cell @ F, scale_atoms=True) + rotation_product = np.eye(3) + for axis_idx in range(3): + axes = list(range(3)) + axes.remove(axis_idx) + row_idx, col_idx = axes + rot_mat = np.eye(3) + theta = 0.1 * (axis_idx + 1) + rot_mat[row_idx, row_idx] = np.cos(theta) + rot_mat[col_idx, col_idx] = np.cos(theta) + rot_mat[row_idx, col_idx] = np.sin(theta) + rot_mat[col_idx, row_idx] = -np.sin(theta) + rotation_product = np.dot(rotation_product, rot_mat) + atoms.set_cell(atoms.cell @ rotation_product, scale_atoms=True) return atoms -# ============================================================================= -# Shared Fixtures -# ============================================================================= +# === Fixtures === @pytest.fixture def model() -> LennardJonesModel: - """Create a LennardJonesModel for testing.""" + """LJ model for testing.""" return LennardJonesModel( sigma=1.0, epsilon=0.05, @@ -118,16 +83,11 @@ def model() -> LennardJonesModel: class NoisyModelWrapper: - """Wrapper that adds noise to forces and stress from an underlying model.""" + """Wrapper that adds noise to forces and stress.""" - def __init__( - self, - model: LennardJonesModel, - rng_seed: int = 1, - noise_scale: float = 1e-4, - ) -> None: + def __init__(self, model: LennardJonesModel, noise_scale: float = 1e-4) -> None: self.model = model - self.rng = np.random.default_rng(rng_seed) + self.rng = np.random.default_rng(seed=1) self.noise_scale = noise_scale @property @@ -139,82 +99,45 @@ def dtype(self) -> torch.dtype: return self.model.dtype def __call__(self, state: ts.SimState) -> dict[str, torch.Tensor]: + """Forward pass with added noise.""" results = self.model(state) - # Add noise to forces - if "forces" in results: - noise = self.rng.normal(size=results["forces"].shape) - results["forces"] = results["forces"] + self.noise_scale * torch.tensor( - noise, - dtype=results["forces"].dtype, - device=results["forces"].device, - ) - # Add noise to stress - if "stress" in results: - noise = self.rng.normal(size=results["stress"].shape) - results["stress"] = results["stress"] + self.noise_scale * torch.tensor( - noise, - dtype=results["stress"].dtype, - device=results["stress"].device, - ) + for key in ("forces", "stress"): + if key in results: + noise = torch.tensor( + self.rng.normal(size=results[key].shape), + dtype=results[key].dtype, + device=results[key].device, + ) + results[key] = results[key] + self.noise_scale * noise return results @pytest.fixture def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: - """Create a LJ model that adds noise to forces/stress. - - Similar to ASE's NoisyLennardJones. - """ + """LJ model with noise added to forces/stress.""" return NoisyModelWrapper(model) -# ============================================================================= -# Shared Helper Functions -# ============================================================================= - - -def get_spglib_dataset_from_atoms( - atoms: Atoms, symprec: float = SYMPREC -) -> spglib.SpglibDataset: - """Get full symmetry dataset for an ASE Atoms object using spglib directly.""" - return spglib.get_symmetry_dataset( - (atoms.cell[:], atoms.get_scaled_positions(), atoms.numbers), - symprec=symprec, - ) +# === Optimization helper === def run_optimization_check_symmetry( state: ts.SimState, - model: LennardJonesModel, + model: LennardJonesModel | NoisyModelWrapper, constraint: FixSymmetry | None = None, *, adjust_cell: bool = True, - symprec: float = SYMPREC, max_steps: int = MAX_STEPS, force_tol: float = 0.001, ) -> dict[str, list[int | None]]: - """Run FIRE optimization and return initial/final space group numbers. - - Args: - state: torch-sim SimState (can be batched) - model: torch-sim model for optimization - constraint: Optional FixSymmetry constraint to apply. If None, no constraint. - adjust_cell: Whether to enable cell optimization (with Frechet filter) - symprec: Symmetry precision for spglib checks - max_steps: Maximum optimization steps - force_tol: Force convergence tolerance - - Returns: - Dict with 'initial_spacegroups' and 'final_spacegroups' lists. - """ - initial_datasets = get_symmetry_datasets(state, symprec) - + """Run FIRE optimization and return initial/final space group numbers.""" + initial = get_symmetry_datasets(state, SYMPREC) if constraint is not None: state.constraints = [constraint] - init_kwargs = {"cell_filter": ts.CellFilter.frechet} if adjust_cell else None convergence_fn = ts.generate_force_convergence_fn( - force_tol=force_tol, include_cell_forces=adjust_cell + force_tol=force_tol, + include_cell_forces=adjust_cell, ) final_state = ts.optimize( system=state, @@ -225,44 +148,34 @@ def run_optimization_check_symmetry( max_steps=max_steps, steps_between_swaps=1, ) - - final_datasets = get_symmetry_datasets(final_state, symprec) - + final = get_symmetry_datasets(final_state, SYMPREC) return { - "initial_spacegroups": [d.number if d else None for d in initial_datasets], - "final_spacegroups": [d.number if d else None for d in final_datasets], + "initial_spacegroups": [d.number if d else None for d in initial], + "final_spacegroups": [d.number if d else None for d in final], } -# ============================================================================= -# Test Classes -# ============================================================================= +# === Tests: Creation & Basics === class TestFixSymmetryCreation: - """Tests for FixSymmetry constraint creation.""" + """Tests for FixSymmetry creation and basic behavior.""" - def test_from_state_batched(self): - """Test creating FixSymmetry from batched SimState with different structures.""" + def test_from_state_batched(self) -> None: + """Batched state with FCC + diamond gets correct ops and atom counts.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], torch.device("cpu"), DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - assert len(constraint.rotations) == 2 - assert len(constraint.symm_maps) == 2 - assert constraint.system_idx.shape == (2,) - # Both have cubic symmetry (48 ops) but different number of atoms - assert constraint.rotations[0].shape[0] == 48 - assert constraint.rotations[1].shape[0] == 48 - # Cu FCC has 1 atom, Si diamond has 2 - assert constraint.symm_maps[0].shape == (48, 1) - assert constraint.symm_maps[1].shape == (48, 2) - - def test_p1_identity_only(self): - """Test P1 (no symmetry) has only identity and doesn't change forces/stress.""" + assert constraint.rotations[0].shape[0] == 48 # cubic + assert constraint.symm_maps[0].shape == (48, 1) # Cu: 1 atom + assert constraint.symm_maps[1].shape == (48, 2) # Si: 2 atoms + + def test_p1_identity_only(self) -> None: + """P1 structure has 1 op and symmetrization is a no-op.""" atoms = Atoms( "SiGe", positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], @@ -271,338 +184,292 @@ def test_p1_identity_only(self): ) state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert constraint.rotations[0].shape[0] == 1 - assert constraint.rotations[0].shape[0] == 1, "P1 should have 1 operation" - - # Forces should be unchanged forces = torch.randn(2, 3, dtype=DTYPE) - original_forces = forces.clone() + original = forces.clone() constraint.adjust_forces(state, forces) - assert torch.allclose(forces, original_forces, atol=1e-10) + assert torch.allclose(forces, original, atol=1e-10) - # Stress should be unchanged (identity symmetrization) stress = torch.randn(1, 3, 3, dtype=DTYPE) - # Make it symmetric (stress tensors are symmetric) - stress = (stress + stress.transpose(-1, -2)) / 2 + stress = (stress + stress.mT) / 2 original_stress = stress.clone() constraint.adjust_stress(state, stress) assert torch.allclose(stress, original_stress, atol=1e-10) - def test_symmetry_datasets_match_spglib(self): - """Test get_symmetry_datasets space groups match spglib.""" - atoms_list = [make_structure(name) for name in ["fcc", "diamond", "hcp"]] - - # Test batched state - batched_state = ts.io.atoms_to_state(atoms_list, torch.device("cpu"), DTYPE) - moyo_datasets = get_symmetry_datasets(batched_state, SYMPREC) - assert len(moyo_datasets) == 3 - - # Compare space group numbers with spglib - for idx, atoms in enumerate(atoms_list): - spglib_dataset = get_spglib_dataset_from_atoms(atoms, SYMPREC) + def test_get_removed_dof_returns_zero(self) -> None: + """FixSymmetry constrains direction, not DOF count.""" + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert torch.all(constraint.get_removed_dof(state) == 0) - assert moyo_datasets[idx].number == spglib_dataset.number, ( - f"Space group mismatch for {atoms.get_chemical_formula()}: " - f"moyopy={moyo_datasets[idx].number} vs " - f"spglib={spglib_dataset.number}" - ) - # Both should find the same number of symmetry operations - assert len(moyo_datasets[idx].operations) == len(spglib_dataset.rotations) + @pytest.mark.parametrize("refine", [True, False]) + def test_from_state_refine_symmetry(self, *, refine: bool) -> None: + """With refine=False state is unmodified; with refine=True it may change.""" + atoms = make_structure("fcc") + rng = np.random.default_rng(42) + atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + orig_pos = state.positions.clone() + _ = FixSymmetry.from_state(state, symprec=SYMPREC, refine_symmetry_state=refine) + if not refine: + assert torch.allclose(state.positions, orig_pos) + @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond", "p6bar"]) + def test_refine_symmetry_produces_correct_spacegroup( + self, + structure_name: str, + ) -> None: + """Perturbed structure recovers correct spacegroup after refinement.""" + from torch_sim.symmetrize import get_symmetry_datasets, refine_symmetry -class TestFixSymmetryComparisonWithASE: - """Compare TorchSim FixSymmetry with ASE's implementation.""" + atoms = make_structure(structure_name) + expected = SPACEGROUPS[structure_name] + rng = np.random.default_rng(42) + atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 - def test_symmetrize_forces_batched(self): - """Test force symmetrization for batched systems with different structures.""" - state = ts.io.atoms_to_state( - [make_structure("fcc"), make_structure("diamond")], - torch.device("cpu"), - DTYPE, + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + cell = state.row_vector_cell[0] + pos = state.positions + nums = state.atomic_numbers + + refined_cell, refined_pos = refine_symmetry(cell, pos, nums, symprec=SYMPREC) + state.cell[0] = refined_cell.mT + state.positions = refined_pos + + # Check at tight precision + datasets = get_symmetry_datasets(state, symprec=1e-4) + assert datasets[0].number == expected, ( + f"{structure_name}: expected SG {expected}, got {datasets[0].number}" ) + + def test_large_deformation_raises(self) -> None: + """Deformation gradient > 0.25 raises RuntimeError.""" + state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + new_cell = state.cell.clone() + new_cell[0] *= 1.5 + with pytest.raises(RuntimeError, match="deformation gradient"): + constraint.adjust_cell(state, new_cell) - # Create asymmetric forces (1 atom for Cu FCC, 2 atoms for Si diamond) - forces = torch.tensor( - [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], dtype=DTYPE - ) - constraint.adjust_forces(state, forces) +# === Tests: Comparison with ASE === - # First atom (Cu FCC) should have zero force due to cubic symmetry - assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) - def test_force_symmetrization_matches_ase(self): - """Compare force symmetrization with ASE using a multi-atom structure.""" - atoms = make_structure("p6bar") +class TestFixSymmetryComparisonWithASE: + """Compare TorchSim FixSymmetry with ASE's implementation.""" - # Create TorchSim state and constraint + def test_force_symmetrization_matches_ase(self) -> None: + """Force symmetrization matches ASE on multi-atom P-6 structure.""" + atoms = make_structure("p6bar") state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - # Set up ASE constraint ase_atoms = atoms.copy() ase_refine_symmetry(ase_atoms, symprec=SYMPREC) ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - # Create random test forces rng = np.random.default_rng(42) forces_np = rng.standard_normal((len(atoms), 3)) forces_ts = torch.tensor(forces_np.copy(), dtype=DTYPE) - # Symmetrize with both ts_constraint.adjust_forces(state, forces_ts) ase_constraint.adjust_forces(ase_atoms, forces_np) - - # Compare results assert np.allclose(forces_ts.numpy(), forces_np, atol=1e-10) - def test_stress_symmetrization_matches_ase(self): - """Compare stress symmetrization with ASE's implementation.""" + def test_stress_symmetrization_matches_ase(self) -> None: + """Stress symmetrization matches ASE on P-6 structure.""" atoms = make_structure("p6bar") - - # Create TorchSim state and constraint state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - # Set up ASE constraint ase_atoms = atoms.copy() ase_refine_symmetry(ase_atoms, symprec=SYMPREC) ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - # Create asymmetric but symmetric (as a matrix) stress tensor stress_3x3 = np.array([[10.0, 1.0, 0.5], [1.0, 8.0, 0.3], [0.5, 0.3, 6.0]]) - - # ASE uses Voigt notation stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3) stress_voigt_copy = stress_voigt.copy() - - # TorchSim uses 3x3 tensor with batch dimension stress_ts = torch.tensor([stress_3x3.copy()], dtype=DTYPE) - # Symmetrize with both ts_constraint.adjust_stress(state, stress_ts) ase_constraint.adjust_stress(ase_atoms, stress_voigt_copy) + ase_result = voigt_6_to_full_3x3_stress(stress_voigt_copy) + assert np.allclose(stress_ts[0].numpy(), ase_result, atol=1e-10) - # Convert ASE result back to 3x3 - ase_result_3x3 = voigt_6_to_full_3x3_stress(stress_voigt_copy) - - # Compare results - assert np.allclose(stress_ts[0].numpy(), ase_result_3x3, atol=1e-10), ( - f"Stress mismatch:\nTorchSim:\n{stress_ts[0].numpy()}\nASE:\n{ase_result_3x3}" - ) - - def test_cell_deformation_symmetrization_matches_ase(self): - """Compare cell deformation symmetrization with ASE.""" + def test_cell_deformation_matches_ase(self) -> None: + """Cell deformation symmetrization matches ASE on P-6 structure.""" atoms = make_structure("p6bar") - - # Create TorchSim state and constraint state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - # Set up ASE constraint ase_atoms = atoms.copy() ase_refine_symmetry(ase_atoms, symprec=SYMPREC) ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - # Create a small asymmetric deformation of the cell original_cell = ase_atoms.get_cell().copy() deformed_cell = original_cell.copy() - deformed_cell[0, 1] += 0.05 # Small off-diagonal perturbation + deformed_cell[0, 1] += 0.05 - # TorchSim - need column vector convention for adjust_cell - new_cell_ts = torch.tensor( - [deformed_cell.copy().T], - dtype=DTYPE, # Transpose for column vectors - ) + new_cell_ts = torch.tensor([deformed_cell.copy().T], dtype=DTYPE) ts_constraint.adjust_cell(state, new_cell_ts) - ts_result = new_cell_ts[0].mT.numpy() # Back to row vectors + ts_result = new_cell_ts[0].mT.numpy() - # ASE ase_cell = deformed_cell.copy() ase_constraint.adjust_cell(ase_atoms, ase_cell) + assert np.allclose(ts_result, ase_cell, atol=1e-10) - # Compare results - assert np.allclose(ts_result, ase_cell, atol=1e-10), ( - f"Cell mismatch:\nTorchSim:\n{ts_result}\nASE:\n{ase_cell}" - ) + def test_position_symmetrization_matches_ase(self) -> None: + """Position displacement symmetrization matches ASE on P-6 structure.""" + atoms = make_structure("p6bar") + state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) -class TestFixSymmetryMergeAndSelect: - """Tests for FixSymmetry.merge, select_constraint, select_sub_constraint.""" + # Create a displacement by proposing new positions + rng = np.random.default_rng(42) + displacement = rng.standard_normal((len(atoms), 3)) * 0.01 + new_pos_ts = state.positions.clone() + torch.tensor(displacement, dtype=DTYPE) + new_pos_ase = ase_atoms.positions.copy() + displacement - def test_merge_two_constraints(self): - """Test merging two FixSymmetry constraints.""" - state1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - state2 = ts.io.atoms_to_state( - make_structure("diamond"), torch.device("cpu"), DTYPE + ts_constraint.adjust_positions(state, new_pos_ts) + ase_constraint.adjust_positions(ase_atoms, new_pos_ase) + assert np.allclose(new_pos_ts.numpy(), new_pos_ase, atol=1e-10) + + def test_cubic_forces_vanish(self) -> None: + """Asymmetric force on single cubic atom symmetrizes to zero.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + torch.device("cpu"), + DTYPE, ) - c1 = FixSymmetry.from_state(state1, symprec=SYMPREC) - c2 = FixSymmetry.from_state(state2, symprec=SYMPREC) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], + dtype=DTYPE, + ) + constraint.adjust_forces(state, forces) + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + - merged = FixSymmetry.merge([c1, c2], state_indices=[0, 1], atom_offsets=None) +# === Tests: Merge & Select === + +class TestFixSymmetryMergeAndSelect: + """Tests for merge, select_constraint, select_sub_constraint.""" + + def test_merge_two_constraints(self) -> None: + """Merge two single-system constraints.""" + s1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), torch.device("cpu"), DTYPE) + merged = FixSymmetry.merge( + [FixSymmetry.from_state(s1), FixSymmetry.from_state(s2)], + state_indices=[0, 1], + atom_offsets=None, + ) assert len(merged.rotations) == 2 - assert len(merged.symm_maps) == 2 assert merged.system_idx.tolist() == [0, 1] - def test_merge_multi_system_constraints_no_duplicate_indices(self): - """Regression: merging multi-system constraints must not produce duplicates.""" - # Create two batched states so each constraint covers multiple systems + def test_merge_multi_system_no_duplicate_indices(self) -> None: + """Regression: multi-system constraints must use cumulative offsets.""" atoms_a = [ make_structure("fcc"), make_structure("diamond"), make_structure("hcp"), ] atoms_b = [make_structure("bcc"), make_structure("fcc")] - state_a = ts.io.atoms_to_state(atoms_a, torch.device("cpu"), DTYPE) - state_b = ts.io.atoms_to_state(atoms_b, torch.device("cpu"), DTYPE) - c_a = FixSymmetry.from_state(state_a, symprec=SYMPREC) # 3 systems - c_b = FixSymmetry.from_state(state_b, symprec=SYMPREC) # 2 systems - - # Old bug: state_indices=[0, 1] was used as offsets → [0,1,2, 1,2] (duplicates) - # Fix: cumulative offset → [0,1,2, 3,4] + c_a = FixSymmetry.from_state( + ts.io.atoms_to_state(atoms_a, torch.device("cpu"), DTYPE), + ) + c_b = FixSymmetry.from_state( + ts.io.atoms_to_state(atoms_b, torch.device("cpu"), DTYPE), + ) merged = FixSymmetry.merge([c_a, c_b], state_indices=[0, 1], atom_offsets=None) - - assert len(merged.rotations) == 5 - assert len(merged.symm_maps) == 5 assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] - @pytest.mark.parametrize("mismatch_field", ["adjust_positions", "adjust_cell"]) - def test_merge_mismatched_settings_raises( - self, mismatch_field: Literal["adjust_positions", "adjust_cell"] - ): - """Test that merging constraints with different settings raises ValueError.""" - state1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - state2 = ts.io.atoms_to_state( - make_structure("diamond"), torch.device("cpu"), DTYPE - ) - - kwargs1 = {mismatch_field: True} - kwargs2 = {mismatch_field: False} - c1 = FixSymmetry.from_state(state1, symprec=SYMPREC, **kwargs1) - c2 = FixSymmetry.from_state(state2, symprec=SYMPREC, **kwargs2) - - with pytest.raises(ValueError, match=f"different {mismatch_field} settings"): - FixSymmetry.merge([c1, c2], state_indices=[0, 1], atom_offsets=None) - - def test_select_constraint_single_system(self): - """Test selecting a single system from batched constraint.""" + def test_select_sub_constraint(self) -> None: + """Select second system from batched constraint.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], torch.device("cpu"), DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - # Create masks to select only first system - atom_mask = torch.tensor( - [True, False, False], dtype=torch.bool - ) # 1 Cu + 2 Si atoms - system_mask = torch.tensor([True, False], dtype=torch.bool) - - selected = constraint.select_constraint(atom_mask, system_mask) - + selected = constraint.select_sub_constraint(torch.tensor([1, 2]), sys_idx=1) assert selected is not None - assert len(selected.rotations) == 1 - assert len(selected.symm_maps) == 1 - assert selected.system_idx.shape == (1,) - # Should have Cu's 48 symmetry operations - assert selected.rotations[0].shape[0] == 48 + assert selected.symm_maps[0].shape[1] == 2 # Si diamond: 2 atoms + assert selected.system_idx.item() == 0 - def test_select_sub_constraint(self): - """Test selecting a specific system by index.""" + def test_select_constraint_by_mask(self) -> None: + """Select first system via system_mask.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], torch.device("cpu"), DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - # Select second system (Si diamond) - # Note: atom_idx is ignored for FixSymmetry - selected = constraint.select_sub_constraint( - atom_idx=torch.tensor([1, 2]), sys_idx=1 - ) - + atom_mask = torch.tensor([True, False, False], dtype=torch.bool) + system_mask = torch.tensor([True, False], dtype=torch.bool) + selected = constraint.select_constraint(atom_mask, system_mask) assert selected is not None assert len(selected.rotations) == 1 - # Si diamond has 2 atoms - assert selected.symm_maps[0].shape[1] == 2 - # New system_idx should be 0 (renumbered) - assert selected.system_idx.item() == 0 + assert selected.rotations[0].shape[0] == 48 -class TestFixSymmetryWithOptimization: - """Test FixSymmetry with actual optimization routines. +# === Tests: Optimization === - Uses the shared run_optimization_check_symmetry helper for most tests. - """ - @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond", "p6bar"]) +class TestFixSymmetryWithOptimization: + """Test FixSymmetry with actual optimization routines.""" + + @pytest.mark.parametrize("structure_name", ["fcc", "hcp", "diamond"]) @pytest.mark.parametrize( ("adjust_positions", "adjust_cell"), - [(True, True), (True, False), (False, True), (False, False)], + [(True, True), (False, False)], ) - def test_distorted_structure_preserves_symmetry( + def test_distorted_preserves_symmetry( self, noisy_lj_model: NoisyModelWrapper, structure_name: str, *, adjust_positions: bool, adjust_cell: bool, - ): - """Test that a distorted structure relaxes while preserving symmetry. - - All combinations of adjust_positions and adjust_cell should preserve symmetry - because forces are always symmetrized (matching ASE's behavior). - - """ + ) -> None: + """Compressed structure relaxes while preserving symmetry.""" atoms = make_structure(structure_name) - expected_spacegroup = SPACEGROUPS[structure_name] - + expected = SPACEGROUPS[structure_name] state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - - # Create constraint BEFORE distorting - captures ideal symmetry constraint = FixSymmetry.from_state( state, symprec=SYMPREC, adjust_positions=adjust_positions, adjust_cell=adjust_cell, ) - - # Now distort the cell (uniform scaling preserves symmetry but creates forces) - # Scale by 0.9 to compress - this creates repulsive forces - scale_factor = 0.9 - state.cell = state.cell * scale_factor - state.positions = state.positions * scale_factor - + state.cell = state.cell * 0.9 + state.positions = state.positions * 0.9 result = run_optimization_check_symmetry( state, noisy_lj_model, constraint=constraint, adjust_cell=adjust_cell, - max_steps=MAX_STEPS, - force_tol=0.01, # Looser tolerance to ensure movement - ) - - assert result["final_spacegroups"][0] == expected_spacegroup, ( - f"Space group changed from {expected_spacegroup} to " - f"{result['final_spacegroups'][0]} with adjust_positions={adjust_positions}, " - f"adjust_cell={adjust_cell}" + force_tol=0.01, ) + assert result["final_spacegroups"][0] == expected @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) def test_cell_filter_preserves_symmetry( - self, model: LennardJonesModel, cell_filter: ts.CellFilter - ): - """Test that cell filters with FixSymmetry preserve symmetry.""" + self, + model: LennardJonesModel, + cell_filter: ts.CellFilter, + ) -> None: + """Cell filters with FixSymmetry preserve symmetry.""" state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] - - initial_datasets = get_symmetry_datasets(state, symprec=SYMPREC) - + initial = get_symmetry_datasets(state, symprec=SYMPREC) final_state = ts.optimize( system=state, model=model, @@ -611,152 +478,64 @@ def test_cell_filter_preserves_symmetry( init_kwargs={"cell_filter": cell_filter}, max_steps=MAX_STEPS, ) - - final_datasets = get_symmetry_datasets(final_state, symprec=SYMPREC) - assert initial_datasets[0].number == final_datasets[0].number + final = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert initial[0].number == final[0].number @pytest.mark.parametrize("cell_filter", [ts.CellFilter.frechet, ts.CellFilter.unit]) - def test_lbfgs_cell_optimization_preserves_symmetry( + def test_lbfgs_preserves_symmetry( self, noisy_lj_model: NoisyModelWrapper, cell_filter: ts.CellFilter, - ): + ) -> None: """Regression: LBFGS must use set_constrained_cell for FixSymmetry support.""" state = ts.io.atoms_to_state(make_structure("bcc"), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] - - # Compress cell to create forces state.cell = state.cell * 0.95 state.positions = state.positions * 0.95 - - initial_datasets = get_symmetry_datasets(state, symprec=SYMPREC) - assert initial_datasets[0].number == SPACEGROUPS["bcc"] - final_state = ts.optimize( system=state, model=noisy_lj_model, optimizer=ts.Optimizer.lbfgs, convergence_fn=ts.generate_force_convergence_fn( - force_tol=0.01, include_cell_forces=True + force_tol=0.01, + include_cell_forces=True, ), init_kwargs={"cell_filter": cell_filter}, max_steps=MAX_STEPS, ) - - final_datasets = get_symmetry_datasets(final_state, symprec=SYMPREC) - assert final_datasets[0].number == SPACEGROUPS["bcc"], ( - f"LBFGS+{cell_filter} lost symmetry: {SPACEGROUPS['bcc']} -> " - f"{final_datasets[0].number}" - ) + final = get_symmetry_datasets(final_state, symprec=SYMPREC) + assert final[0].number == SPACEGROUPS["bcc"] @pytest.mark.parametrize("rotated", [False, True]) def test_noisy_model_loses_symmetry_without_constraint( - self, noisy_lj_model: NoisyModelWrapper, *, rotated: bool - ): - """Test that WITHOUT FixSymmetry, optimization with noisy forces loses symmetry. - - This is a negative control - verifies that noisy forces will break symmetry - if no constraint is applied. Mirrors ASE's test_no_symmetrization. - """ + self, + noisy_lj_model: NoisyModelWrapper, + *, + rotated: bool, + ) -> None: + """Negative control: without FixSymmetry, noisy forces break symmetry.""" name = "bcc_rotated" if rotated else "bcc" - bcc_atoms = make_structure(name) - state = ts.io.atoms_to_state(bcc_atoms, torch.device("cpu"), DTYPE) - result = run_optimization_check_symmetry( - state, noisy_lj_model, constraint=None, max_steps=MAX_STEPS, symprec=SYMPREC - ) - - # Initial should be BCC (space group 229) + state = ts.io.atoms_to_state(make_structure(name), torch.device("cpu"), DTYPE) + result = run_optimization_check_symmetry(state, noisy_lj_model, constraint=None) assert result["initial_spacegroups"][0] == 229 - # Final should have lost symmetry (different space group) - assert result["final_spacegroups"][0] != 229, ( - f"Symmetry should be lost without constraint, but final space group " - f"is still {result['final_spacegroups'][0]}" - ) + assert result["final_spacegroups"][0] != 229 @pytest.mark.parametrize("rotated", [False, True]) def test_noisy_model_preserves_symmetry_with_constraint( - self, noisy_lj_model: NoisyModelWrapper, *, rotated: bool - ): - """Test that WITH FixSymmetry, optimization with noisy forces preserves symmetry. - - Mirrors ASE's test_sym_adj_cell. - """ - bcc_atoms = make_structure("bcc_rotated" if rotated else "bcc") - state = ts.io.atoms_to_state(bcc_atoms, torch.device("cpu"), DTYPE) + self, + noisy_lj_model: NoisyModelWrapper, + *, + rotated: bool, + ) -> None: + """With FixSymmetry, noisy forces still preserve symmetry.""" + name = "bcc_rotated" if rotated else "bcc" + state = ts.io.atoms_to_state(make_structure(name), torch.device("cpu"), DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) result = run_optimization_check_symmetry( state, noisy_lj_model, constraint=constraint, - max_steps=MAX_STEPS, ) - assert result["initial_spacegroups"][0] == 229 - assert result["final_spacegroups"][0] == 229, ( - f"Symmetry should be preserved with constraint, but final spacegroup " - f"changed to {result['final_spacegroups'][0]}" - ) - - -class TestFixSymmetryEdgeCases: - """Tests for edge cases and error handling.""" - - def test_get_removed_dof_returns_zero(self): - """Test get_removed_dof returns zero (constrains direction, not DOF count).""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - dof = constraint.get_removed_dof(state) - assert torch.all(dof == 0) - - def test_large_deformation_gradient_raises(self): - """Test that large deformation gradient raises RuntimeError.""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - # Create a very large deformation (> 0.25) - # FCC cell has zeros on diagonal, so modify all elements by a large factor - new_cell_col = state.cell.clone() # Column vector convention - new_cell_col[0] *= 1.5 # 50% stretch of entire cell - - with pytest.raises(RuntimeError, match="large deformation gradient"): - constraint.adjust_cell(state, new_cell_col) - - def test_medium_deformation_gradient_warns(self): - """Test that medium deformation gradient emits warning.""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - # Create a medium deformation (> 0.15 but < 0.25) - new_cell_col = state.cell.clone() # Column vector convention - new_cell_col[0] *= 1.2 # 20% stretch of entire cell - - with pytest.warns(UserWarning, match="may be ill-behaved"): - constraint.adjust_cell(state, new_cell_col) - - @pytest.mark.parametrize("refine_symmetry_state", [True, False]) - def test_from_state_refine_symmetry(self, *, refine_symmetry_state: bool): - """Test from_state with different refine_symmetry_state settings.""" - atoms = make_structure("fcc") - # Add small perturbation - perturbed = atoms.copy() - rng = np.random.default_rng(42) - perturbed.positions += rng.standard_normal(perturbed.positions.shape) * 0.001 - - state = ts.io.atoms_to_state(perturbed, torch.device("cpu"), DTYPE) - original_positions = state.positions.clone() - original_cell = state.cell.clone() - - _ = FixSymmetry.from_state( - state, symprec=SYMPREC, refine_symmetry_state=refine_symmetry_state - ) - - if not refine_symmetry_state: - # State should not be modified - assert torch.allclose(state.positions, original_positions) - assert torch.allclose(state.cell, original_cell) - else: - # State may be modified (positions refined to ideal) - # We just check the function runs without error - assert state.positions.shape == original_positions.shape + assert result["final_spacegroups"][0] == 229 diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 59ac540ee..33fc78852 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -16,13 +16,6 @@ import torch -from torch_sim.symmetrize import ( - _prep_symmetry, - refine_symmetry, - symmetrize_rank1, - symmetrize_rank2, -) - if TYPE_CHECKING: from torch_sim.state import SimState @@ -720,33 +713,16 @@ def validate_constraints(constraints: list[Constraint], state: SimState) -> None class FixSymmetry(SystemConstraint): - """Constraint to preserve spacegroup symmetry during optimization. - - This constraint symmetrizes forces, positions, and cell/stress - according to the crystal symmetry operations. Each system in a batch can - have different symmetry operations. - - Requires the moyopy package to be available for automatic symmetry detection. + """Preserve spacegroup symmetry during optimization. - The constraint works by: - - Symmetrizing forces/momenta as rank-1 tensors using all symmetry operations - - Symmetrizing position displacements similarly for position adjustments - - Symmetrizing stress/cell deformation as rank-2 tensors + Symmetrizes forces/momenta as rank-1 tensors and stress/cell deformation + as rank-2 tensors using the crystal's symmetry operations. Each system in + a batch can have different symmetry operations. - Attributes: - rotations: List of rotation matrices for each system, - shape (n_ops, 3, 3) per system. - symm_maps: List of symmetry atom mappings for each system, - shape (n_ops, n_atoms) per system. - do_adjust_positions: Whether to symmetrize position adjustments. - do_adjust_cell: Whether to symmetrize cell/stress adjustments. - - Examples: - Create from SimState: - >>> constraint = FixSymmetry.from_state(state, symprec=0.01) + Forces and stress are always symmetrized. Position and cell symmetrization + can be toggled via ``adjust_positions`` and ``adjust_cell``. """ - # Type annotations rotations: list[torch.Tensor] symm_maps: list[torch.Tensor] do_adjust_positions: bool @@ -764,40 +740,27 @@ def __init__( """Initialize FixSymmetry constraint. Args: - rotations: List of rotation tensors, one per system. - Each tensor has shape (n_ops, 3, 3). - symm_maps: List of symmetry mapping tensors, one per system. - Each tensor has shape (n_ops, n_atoms_in_system). - system_idx: Indices of systems this constraint applies to. - If None, defaults to [0, 1, ..., n_systems-1]. - adjust_positions: Whether to symmetrize position adjustments. + rotations: Rotation tensors per system, each (n_ops, 3, 3). + symm_maps: Atom mapping tensors per system, each (n_ops, n_atoms). + system_idx: System indices (defaults to 0..n_systems-1). + adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. - - Raises: - ValueError: If lists have mismatched lengths or system_idx is wrong length. """ n_systems = len(rotations) - - # Validate list lengths if len(symm_maps) != n_systems: raise ValueError( - "rotations and symm_maps must have the same length. " - f"Got {len(rotations)}, {len(symm_maps)}." + f"rotations and symm_maps length mismatch: " + f"{n_systems} vs {len(symm_maps)}" ) - if system_idx is None: - # Infer device from rotations tensors device = rotations[0].device if rotations else torch.device("cpu") system_idx = torch.arange(n_systems, device=device) - if len(system_idx) != n_systems: raise ValueError( - f"system_idx length ({len(system_idx)}) must match " - f"number of systems ({n_systems})" + f"system_idx length ({len(system_idx)}) != n_systems ({n_systems})" ) super().__init__(system_idx=system_idx) - self.rotations = rotations self.symm_maps = symm_maps self.do_adjust_positions = adjust_positions @@ -813,355 +776,192 @@ def from_state( adjust_cell: bool = True, refine_symmetry_state: bool = True, ) -> Self: - """Create FixSymmetry constraint from a SimState. - - Directly uses tensor data from the state to determine symmetry. + """Create from SimState, optionally refining to ideal symmetry first. Warning: - By default, this method **mutates the input state** in-place to refine - the atomic positions and cell vectors to ideal symmetric values. - Set ``refine_symmetry_state=False`` to skip this refinement if you - want to preserve the original state (though this may lead to - symmetry detection issues if the structure is not already ideal). + When ``refine_symmetry_state=True`` (default), the input state is + **mutated in-place** to have ideal symmetric positions and cell. Args: state: SimState containing one or more systems. symprec: Symmetry precision for moyopy. - adjust_positions: Whether to symmetrize position adjustments. + adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. - refine_symmetry_state: Whether to refine the state's positions and cell - to ideal symmetric values. When True (default), the input state - is modified in-place. When False, the state is not modified but - the constraint may not work correctly if the structure deviates - from ideal symmetry. - - Returns: - FixSymmetry constraint configured for the state's structures. + refine_symmetry_state: Whether to refine positions/cell to ideal values. """ try: import moyopy # noqa: F401 except ImportError: raise ImportError( - "moyopy is required for FixSymmetry.from_state. " - "Install with: pip install moyopy" + "moyopy required for FixSymmetry: pip install moyopy" ) from None - rotations = [] - symm_maps = [] + from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry - # Get atom counts per system for slicing - atoms_per_system = state.n_atoms_per_system + rotations, symm_maps = [], [] + n_per = state.n_atoms_per_system cumsum = torch.cat( [ torch.zeros(1, device=state.device, dtype=torch.long), - torch.cumsum(atoms_per_system, dim=0), + torch.cumsum(n_per, dim=0), ] ) for sys_idx in range(state.n_systems): - start = cumsum[sys_idx].item() - end = cumsum[sys_idx + 1].item() - - # Extract data for this system + start, end = cumsum[sys_idx].item(), cumsum[sys_idx + 1].item() cell = state.row_vector_cell[sys_idx] - positions = state.positions[start:end] - atomic_numbers = state.atomic_numbers[start:end] + pos, nums = state.positions[start:end], state.atomic_numbers[start:end] if refine_symmetry_state: - # Refine symmetry of the structure first - refined_cell, refined_positions = refine_symmetry( - cell, positions, atomic_numbers, symprec=symprec - ) - - # Apply refined cell and positions back to state - state.cell[sys_idx] = refined_cell.mT # row→column vector convention - state.positions[start:end] = refined_positions - - # Get symmetry operations using refined structure - rots, symm_map = _prep_symmetry( - refined_cell, refined_positions, atomic_numbers, symprec=symprec + # Single moyopy call: refine + get symmetry ops in one pass + cell, pos, rots, smap = refine_and_prep_symmetry( + cell, + pos, + nums, + symprec=symprec, ) + state.cell[sys_idx] = cell.mT # row→column vector convention + state.positions[start:end] = pos else: - # Use structure as-is without refinement - rots, symm_map = _prep_symmetry( - cell, positions, atomic_numbers, symprec=symprec - ) + rots, smap = prep_symmetry(cell, pos, nums, symprec=symprec) rotations.append(rots) - symm_maps.append(symm_map) - - system_idx = torch.arange(state.n_systems, device=state.device) + symm_maps.append(smap) return cls( - rotations=rotations, - symm_maps=symm_maps, - system_idx=system_idx, + rotations, + symm_maps, + system_idx=torch.arange(state.n_systems, device=state.device), adjust_positions=adjust_positions, adjust_cell=adjust_cell, ) - @classmethod - def merge( - cls, - constraints: list[Self], - state_indices: list[int], # noqa: ARG003 - atom_offsets: torch.Tensor, # noqa: ARG003 - ) -> Self: - """Merge multiple FixSymmetry constraints into one. - - Args: - constraints: List of FixSymmetry constraints to merge. - state_indices: Index of the source state for each constraint (unused). - atom_offsets: Cumulative atom counts (unused for FixSymmetry). - - Returns: - Merged FixSymmetry constraint. - - Raises: - ValueError: If constraints list is empty or if constraints have - mismatched adjust_positions or adjust_cell settings. - """ - if not constraints: - raise ValueError("Cannot merge empty list of constraints") - - # Validate that all constraints have matching settings - first_adjust_positions = constraints[0].do_adjust_positions - first_adjust_cell = constraints[0].do_adjust_cell - - for i, constraint in enumerate(constraints[1:], start=1): - if constraint.do_adjust_positions != first_adjust_positions: - raise ValueError( - f"Cannot merge FixSymmetry constraints with different " - f"adjust_positions settings: constraint 0 has " - f"adjust_positions={first_adjust_positions}, but constraint " - f"{i} has adjust_positions={constraint.do_adjust_positions}" - ) - if constraint.do_adjust_cell != first_adjust_cell: - raise ValueError( - f"Cannot merge FixSymmetry constraints with different " - f"adjust_cell settings: constraint 0 has " - f"adjust_cell={first_adjust_cell}, but constraint " - f"{i} has adjust_cell={constraint.do_adjust_cell}" - ) - - rotations = [] - symm_maps = [] - system_indices = [] - - # Use cumulative offset (not state_indices) to handle multi-system constraints - cumulative_offset = 0 - for constraint in constraints: - for idx in range(len(constraint.rotations)): - rotations.append(constraint.rotations[idx]) - symm_maps.append(constraint.symm_maps[idx]) - system_indices.append(cumulative_offset + idx) - cumulative_offset += len(constraint.rotations) - - device = rotations[0].device - - return cls( - rotations=rotations, - symm_maps=symm_maps, - system_idx=torch.tensor(system_indices, device=device), - adjust_positions=first_adjust_positions, - adjust_cell=first_adjust_cell, - ) - - def get_removed_dof(self, state: SimState) -> torch.Tensor: - """Get number of removed degrees of freedom. - - FixSymmetry constrains motion direction rather than removing explicit DOF, - so returns 0 to avoid breaking temperature calculations in MD. + # === Symmetrization hooks === - Args: - state: Simulation state - - Returns: - Zero tensor of shape (n_systems,) - """ - return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Symmetrize forces according to crystal symmetry.""" + self._symmetrize_rank1(state, forces) def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: - """Symmetrize position displacements. - - Args: - state: Current simulation state - new_positions: Proposed new positions to be adjusted in-place - """ + """Symmetrize position displacements (skipped if do_adjust_positions=False).""" if not self.do_adjust_positions: return - - # Compute displacement from current positions displacement = new_positions - state.positions - - # Symmetrize the displacement self._symmetrize_rank1(state, displacement) - - # Apply symmetrized displacement new_positions[:] = state.positions + displacement - def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: - """Symmetrize forces according to crystal symmetry. + def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: + """Symmetrize stress tensor in-place. - Args: - state: Current simulation state - forces: Forces to be adjusted in-place + Always runs (like adjust_forces), independent of do_adjust_cell. """ - self._symmetrize_rank1(state, forces) - - def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: - """Symmetrize momenta according to crystal symmetry. + from torch_sim.symmetrize import symmetrize_rank2 - Args: - state: Current simulation state - momenta: Momenta to be adjusted in-place - """ - self._symmetrize_rank1(state, momenta) + dtype = stress.dtype + for ci, si in enumerate(self.system_idx): + rots = self.rotations[ci].to(dtype=dtype) + stress[si] = symmetrize_rank2(state.row_vector_cell[si], stress[si], rots) def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: - """Symmetrize cell deformation in-place. + """Symmetrize cell deformation gradient in-place. - Computes the deformation gradient as ``(cell_inv @ new_cell).T - I`` - and symmetrizes it as a rank-2 tensor. + Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a + rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. Args: - state: Current simulation state - new_cell: Proposed new cell tensor of shape (n_systems, 3, 3) - in column vector convention, modified in-place. + state: Current simulation state. + new_cell: Cell tensor (n_systems, 3, 3) in column vector convention. Raises: - RuntimeError: If the deformation gradient step is too large (> 0.25), - which can cause incorrect symmetrization. - - Warns: - UserWarning: If the deformation gradient step is large (> 0.15), - symmetrization may be ill-behaved. + RuntimeError: If deformation gradient > 0.25. """ if not self.do_adjust_cell: return - device = state.device - dtype = state.dtype - identity = torch.eye(3, device=device, dtype=dtype) - - for constraint_idx, sys_idx in enumerate(self.system_idx): - # Get current and new cells in row vector convention - cur_cell = state.row_vector_cell[sys_idx] - new_cell_row = new_cell[sys_idx].mT - - # Calculate deformation gradient - cur_cell_inv = torch.linalg.inv(cur_cell) - delta_deform_grad = (cur_cell_inv @ new_cell_row).mT - identity + from torch_sim.symmetrize import symmetrize_rank2 - # Check for large deformation gradient (following ASE) - max_delta = torch.abs(delta_deform_grad).max().item() + identity = torch.eye(3, device=state.device, dtype=state.dtype) + for ci, si in enumerate(self.system_idx): + cur_cell = state.row_vector_cell[si] + new_row = new_cell[si].mT # column → row convention + deform_delta = torch.linalg.inv(cur_cell) @ new_row - identity + max_delta = torch.abs(deform_delta).max().item() if max_delta > 0.25: raise RuntimeError( - f"FixSymmetry adjust_cell does not work properly with large " - f"deformation gradient step {max_delta:.4f} > 0.25. " - f"Consider using smaller optimization steps." + f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " + f"too large. Use smaller optimization steps." ) - if max_delta > 0.15: - warnings.warn( - f"FixSymmetry adjust_cell may be ill-behaved with large " - f"deformation gradient step {max_delta:.4f} > 0.15", - UserWarning, - stacklevel=2, - ) - - # Symmetrize deformation gradient directly - rots = self.rotations[constraint_idx].to(dtype=dtype) - symmetrized_delta = symmetrize_rank2(cur_cell, delta_deform_grad, rots) - - # Reconstruct cell and update in-place - new_cell_row_sym = cur_cell @ (symmetrized_delta + identity).mT - new_cell[sys_idx] = new_cell_row_sym.mT # Back to column convention - - def adjust_stress(self, state: SimState, stress: torch.Tensor) -> None: - """Symmetrize stress tensor in-place. - - Args: - state: Current simulation state - stress: Stress tensor of shape (n_systems, 3, 3), modified in-place. - """ - dtype = stress.dtype - - for constraint_idx, sys_idx in enumerate(self.system_idx): - cur_cell = state.row_vector_cell[sys_idx] - symmetrized = symmetrize_rank2( - cur_cell, stress[sys_idx], self.rotations[constraint_idx].to(dtype=dtype) - ) - stress[sys_idx] = symmetrized + rots = self.rotations[ci].to(dtype=state.dtype) + sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) + new_cell[si] = (cur_cell @ (sym_delta + identity)).mT def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: - """Symmetrize rank-1 tensors (forces, momenta, displacements) in-place. + """Symmetrize a rank-1 tensor in-place for each constrained system.""" + from torch_sim.symmetrize import symmetrize_rank1 - Uses fractional-coordinate rotations from moyopy together with the current - cell to transform vectors. The cell is fetched at runtime to ensure - correctness during variable-cell relaxation. - - Args: - state: Current simulation state (used for cell and atom indexing) - vectors: Tensor of shape (n_atoms, 3) to be symmetrized in-place - """ - # Get atom counts per system - atoms_per_system = state.n_atoms_per_system cumsum = torch.cat( [ torch.zeros(1, device=state.device, dtype=torch.long), - torch.cumsum(atoms_per_system, dim=0), + torch.cumsum(state.n_atoms_per_system, dim=0), ] ) - dtype = vectors.dtype - for constraint_idx, sys_idx in enumerate(self.system_idx): - start = cumsum[sys_idx].item() - end = cumsum[sys_idx + 1].item() + for ci, si in enumerate(self.system_idx): + start, end = cumsum[si].item(), cumsum[si + 1].item() + vectors[start:end] = symmetrize_rank1( + state.row_vector_cell[si], + vectors[start:end], + self.rotations[ci].to(dtype=dtype), + self.symm_maps[ci], + ) - sys_vectors = vectors[start:end] - cell = state.row_vector_cell[sys_idx] + # === Constraint interface === - symmetrized = symmetrize_rank1( - cell, - sys_vectors, - self.rotations[constraint_idx].to(dtype=dtype), - self.symm_maps[constraint_idx], - ) + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Returns zero - constrains direction, not DOF count.""" + return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) - # Update in place - vectors[start:end] = symmetrized + @classmethod + def merge( + cls, + constraints: list[Self], + state_indices: list[int], # noqa: ARG003 + atom_offsets: torch.Tensor, # noqa: ARG003 + ) -> Self: + """Merge multiple FixSymmetry constraints into one.""" + if not constraints: + raise ValueError("Cannot merge empty constraint list") + rotations, symm_maps, sys_indices = [], [], [] + offset = 0 + for constraint in constraints: + for idx in range(len(constraint.rotations)): + rotations.append(constraint.rotations[idx]) + symm_maps.append(constraint.symm_maps[idx]) + sys_indices.append(offset + idx) + offset += len(constraint.rotations) + return cls( + rotations, + symm_maps, + system_idx=torch.tensor(sys_indices, device=rotations[0].device), + adjust_positions=constraints[0].do_adjust_positions, + adjust_cell=constraints[0].do_adjust_cell, + ) def select_constraint( self, atom_mask: torch.Tensor, # noqa: ARG002 system_mask: torch.Tensor, ) -> Self | None: - """Select constraint for systems matching the mask. - - Args: - atom_mask: Boolean mask for atoms (not used for SystemConstraint) - system_mask: Boolean mask for systems to keep - - Returns: - New FixSymmetry for selected systems, or None if no systems match. - """ - # Get indices of systems that are in both system_mask and self.system_idx - keep_global_indices = torch.where(system_mask)[0] - mask = torch.isin(self.system_idx, keep_global_indices) - + """Select constraint for systems matching the mask.""" + keep = torch.where(system_mask)[0] + mask = torch.isin(self.system_idx, keep) if not mask.any(): return None - - new_rotations = [self.rotations[i] for i in range(len(mask)) if mask[i]] - new_symm_maps = [self.symm_maps[i] for i in range(len(mask)) if mask[i]] - - # Remap system indices - new_system_idx = _mask_constraint_indices(self.system_idx[mask], system_mask) - + indices = [idx for idx in range(len(mask)) if mask[idx]] return type(self)( - rotations=new_rotations, - symm_maps=new_symm_maps, - system_idx=new_system_idx, + [self.rotations[idx] for idx in indices], + [self.symm_maps[idx] for idx in indices], + _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, ) @@ -1171,39 +971,24 @@ def select_sub_constraint( atom_idx: torch.Tensor, # noqa: ARG002 sys_idx: int, ) -> Self | None: - """Select constraint for a single system. - - Args: - atom_idx: Atom indices (not used, kept for interface compatibility) - sys_idx: System index to select - - Returns: - New FixSymmetry for the selected system, or None if not found. - """ + """Select constraint for a single system.""" if sys_idx not in self.system_idx: return None - - local_idx = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() - + local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() return type(self)( - rotations=[self.rotations[local_idx]], - symm_maps=[self.symm_maps[local_idx]], - system_idx=torch.tensor([0], device=self.system_idx.device), + [self.rotations[local]], + [self.symm_maps[local]], + torch.tensor([0], device=self.system_idx.device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, ) def __repr__(self) -> str: - """String representation of the constraint.""" - n_systems = len(self.rotations) - n_ops_list = [r.shape[0] for r in self.rotations] - if len(n_ops_list) <= 3: - ops_str = str(n_ops_list) - else: - ops_str = f"[{n_ops_list[0]}, ..., {n_ops_list[-1]}]" + """String representation.""" + n_ops = [r.shape[0] for r in self.rotations] + ops = str(n_ops) if len(n_ops) <= 3 else f"[{n_ops[0]}, ..., {n_ops[-1]}]" return ( - f"FixSymmetry(n_systems={n_systems}, " - f"n_ops={ops_str}, " + f"FixSymmetry(n_systems={len(self.rotations)}, n_ops={ops}, " f"adjust_positions={self.do_adjust_positions}, " f"adjust_cell={self.do_adjust_cell})" ) diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 788a437f4..f54e9fe35 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -275,7 +275,7 @@ def unbatched_forward( ) if self.use_neighbor_list: - mapping, _system_mapping, shifts_idx = torchsim_nl( + mapping, _, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 5444422a6..78c62f3f3 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -281,7 +281,7 @@ def unbatched_forward( ) if self.use_neighbor_list: - mapping, _system_mapping, shifts_idx = torchsim_nl( + mapping, _, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index 1031a073e..baa1f8520 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -164,7 +164,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) if self.use_neighbor_list: - mapping, _system_mapping, shifts_idx = torchsim_nl( + mapping, _, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index d2c537787..60d647829 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -299,7 +299,7 @@ def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: ) if self.use_neighbor_list: - mapping, _system_mapping, shifts_idx = torchsim_nl( + mapping, _, shifts_idx = torchsim_nl( positions=wrapped_positions, cell=cell, pbc=pbc, @@ -727,7 +727,7 @@ def unbatched_forward( # noqa: PLR0915 system_idx = torch.zeros( positions.shape[0], dtype=torch.long, device=self.device ) - mapping, _system_mapping, shifts_idx = torchsim_nl( + mapping, _, shifts_idx = torchsim_nl( positions=positions, cell=cell, pbc=self.pbc, diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 78576059f..1067296d3 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -61,6 +61,16 @@ def _compute_cell_masses(state: SimState) -> torch.Tensor: return cell_masses.unsqueeze(-1).expand(-1, 3) +def _get_constrained_stress( + model_output: dict[str, torch.Tensor], state: SimState +) -> torch.Tensor: + """Clone stress from model output and apply constraint symmetrization.""" + stress = model_output["stress"].clone() + for constraint in state.constraints: + constraint.adjust_stress(state, stress) + return stress + + def _apply_constraints( virial: torch.Tensor, *, hydrostatic_strain: bool, constant_volume: bool ) -> torch.Tensor: @@ -110,12 +120,7 @@ def unit_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces - stress = model_output["stress"].clone() - - # Apply stress constraints (e.g. FixSymmetry) - for constraint in state.constraints: - constraint.adjust_stress(state, stress) - + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -167,12 +172,7 @@ def frechet_cell_filter_init[T: AnyCellState]( model_output = model(state) # Calculate initial cell forces using Frechet approach - stress = model_output["stress"].clone() - - # Apply stress constraints (e.g. FixSymmetry) - for constraint in state.constraints: - constraint.adjust_stress(state, stress) - + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(n_systems, 1, 1) virial = -volumes * (stress + pressure) virial = _apply_constraints( @@ -235,7 +235,7 @@ def unit_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) -> new_cell = torch.bmm(state.reference_cell.mT, cell_update.transpose(-2, -1)) # Apply cell constraints (in-place, column vector convention) - state.set_constrained_cell(new_cell.mT.clone()) + state.set_constrained_cell(new_cell.mT.contiguous()) state.cell_positions = cell_positions_new @@ -262,7 +262,7 @@ def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor) ) # Apply cell constraints (in-place, column vector convention) - state.set_constrained_cell(new_row_vector_cell.mT.clone()) + state.set_constrained_cell(new_row_vector_cell.mT.contiguous()) state.cell_positions = cell_positions_new @@ -270,12 +270,7 @@ def compute_cell_forces[T: AnyCellState]( model_output: dict[str, torch.Tensor], state: T ) -> None: """Compute cell forces for both unit and frechet methods.""" - stress = model_output["stress"].clone() - - # Apply stress constraints (e.g. FixSymmetry) - for constraint in state.constraints: - constraint.adjust_stress(state, stress) - + stress = _get_constrained_stress(model_output, state) volumes = torch.linalg.det(state.cell).view(state.n_systems, 1, 1) virial = -volumes * (stress + state.pressure) virial = _apply_constraints( diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 600494f27..3a49de311 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -1,394 +1,235 @@ -"""Symmetry refinement utilities for crystal structures. +"""Symmetry utilities for crystal structures using moyopy. -This module provides functions for refining and symmetrizing atomic structures -using moyopy (Python bindings for the moyo crystal symmetry library). - -The main entry point is `refine_symmetry` which symmetrizes both the cell -and atomic positions according to the detected space group symmetry. - -Note: Functions in this module operate on single (unbatched) systems. -The `n_ops` dimension refers to the number of symmetry operations -(rotations + translations) of the space group. +Functions operate on single (unbatched) systems. The ``n_ops`` dimension +refers to the number of symmetry operations of the space group. """ from __future__ import annotations -import logging from typing import TYPE_CHECKING import torch -logger = logging.getLogger(__name__) - - if TYPE_CHECKING: from moyopy import MoyoDataset from torch_sim.state import SimState -def _get_moyo_dataset( +def _moyo_dataset( cell: torch.Tensor, - scaled_positions: torch.Tensor, + frac_pos: torch.Tensor, atomic_numbers: torch.Tensor, - symprec: float = 1.0e-4, + symprec: float = 1e-4, ) -> MoyoDataset: - """Get symmetry dataset from moyopy. - - Args: - cell: Unit cell as row vectors, shape (3, 3) - scaled_positions: Fractional coordinates, shape (n_atoms, 3) - atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision in units of cell basis vectors - - Returns: - MoyoDataset with symmetry information - """ + """Get MoyoDataset from cell, fractional positions, and atomic numbers.""" from moyopy import Cell, MoyoDataset - cell_list = cell.detach().cpu().tolist() - positions_list = scaled_positions.detach().cpu().tolist() - numbers_list = atomic_numbers.detach().cpu().int().tolist() - - moyo_cell = Cell(basis=cell_list, positions=positions_list, numbers=numbers_list) + moyo_cell = Cell( + basis=cell.detach().cpu().tolist(), + positions=frac_pos.detach().cpu().tolist(), + numbers=atomic_numbers.detach().cpu().int().tolist(), + ) return MoyoDataset(moyo_cell, symprec=symprec) -def get_symmetry_datasets( - state: SimState, - symprec: float = 1.0e-4, -) -> list[MoyoDataset]: - """Get symmetry datasets for all systems in a SimState. - - Args: - state: SimState containing one or more systems - symprec: Symmetry precision for moyopy +def _extract_symmetry_ops( + dataset: MoyoDataset, dtype: torch.dtype, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """Extract rotation and translation tensors from a MoyoDataset. Returns: - List of MoyoDataset objects, one per system in the state. + (rotations, translations) with shapes (n_ops, 3, 3) and (n_ops, 3). """ - datasets = [] + rotations = torch.as_tensor( + dataset.operations.rotations, dtype=dtype, device=device + ).round() + translations = torch.as_tensor( + dataset.operations.translations, dtype=dtype, device=device + ) + return rotations, translations - for single_state in state.split(): - cell = single_state.row_vector_cell[0] - positions = single_state.positions - scaled_positions = _get_scaled_positions(positions, cell) +def get_symmetry_datasets(state: SimState, symprec: float = 1e-4) -> list[MoyoDataset]: + """Get MoyoDataset for each system in a SimState.""" + datasets = [] + for single in state.split(): + cell = single.row_vector_cell[0] + frac = single.positions @ torch.linalg.inv(cell) + datasets.append(_moyo_dataset(cell, frac, single.atomic_numbers, symprec)) + return datasets - dataset = _get_moyo_dataset( - cell=cell, - scaled_positions=scaled_positions, - atomic_numbers=single_state.atomic_numbers, - symprec=symprec, - ) - datasets.append(dataset) - return datasets +# Above this threshold, build_symmetry_map falls back to a per-operation loop +# to avoid allocating an O(n_ops * n_atoms^2) tensor that can OOM on supercells. +_SYMM_MAP_CHUNK_THRESHOLD = 200 -def _get_scaled_positions( - positions: torch.Tensor, - cell: torch.Tensor, +def build_symmetry_map( + rotations: torch.Tensor, + translations: torch.Tensor, + frac_pos: torch.Tensor, ) -> torch.Tensor: - """Convert Cartesian positions to fractional coordinates (unbatched). + """Build atom mapping for each symmetry operation. - See also ``transforms.get_fractional_coordinates`` for the batched version. - - Args: - positions: Cartesian positions, shape (n_atoms, 3) - cell: Unit cell as row vectors, shape (3, 3) + For each (R, t), maps atom i to atom j where R @ frac_i + t ≈ frac_j (mod 1). Returns: - Fractional coordinates, shape (n_atoms, 3) + Symmetry mapping tensor, shape (n_ops, n_atoms). """ - return positions @ torch.linalg.inv(cell) - - -def refine_symmetry( + n_ops = rotations.shape[0] + n_atoms = frac_pos.shape[0] + + if n_atoms <= _SYMM_MAP_CHUNK_THRESHOLD: + # Vectorized: allocates (n_ops, n_atoms, n_atoms, 3) — fast for small systems + new_pos = torch.einsum("oij,nj->oni", rotations, frac_pos) + translations[:, None] + delta = frac_pos[None, None] - new_pos[:, :, None] + delta -= delta.round() + return torch.argmin(torch.linalg.norm(delta, dim=-1), dim=-1).long() + + # Per-op loop: allocates only (n_atoms, n_atoms, 3) at a time + result = torch.empty(n_ops, n_atoms, dtype=torch.long, device=frac_pos.device) + for op_idx in range(n_ops): + new_pos_op = frac_pos @ rotations[op_idx].T + translations[op_idx] + delta = frac_pos[None, :, :] - new_pos_op[:, None, :] + delta -= delta.round() + result[op_idx] = torch.argmin(torch.linalg.norm(delta, dim=-1), dim=-1) + return result + + +def prep_symmetry( cell: torch.Tensor, positions: torch.Tensor, atomic_numbers: torch.Tensor, - symprec: float = 0.01, - *, - verbose: bool = False, + symprec: float = 1e-4, ) -> tuple[torch.Tensor, torch.Tensor]: - """Refine symmetry of a structure. - - Symmetrizes both cell vectors and atomic positions by averaging - over the detected symmetry operations using polar decomposition - for the cell metric and scatter-add averaging for positions. - - The refinement process: - 1. Detect symmetry operations of the input structure - 2. Symmetrize the cell metric tensor (preserving cell orientation) - 3. Symmetrize atomic positions by averaging over symmetry orbits - - Args: - cell: Unit cell as row vectors, shape (3, 3) - positions: Cartesian positions, shape (n_atoms, 3) - atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision for moyopy - verbose: If True, log symmetry information before and after + """Get symmetry rotations and atom mappings for a structure. Returns: - Tuple of (symmetrized_cell, symmetrized_positions): - - symmetrized_cell: Symmetrized cell as row vectors, shape (3, 3) - - symmetrized_positions: Symmetrized Cartesian positions, shape (n_atoms, 3) + (rotations, symm_map) with shapes (n_ops, 3, 3) and (n_ops, n_atoms). """ - device = cell.device - dtype = cell.dtype + frac_pos = positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + rotations, translations = _extract_symmetry_ops(dataset, cell.dtype, cell.device) + return rotations, build_symmetry_map(rotations, translations, frac_pos) - # Step 1: Detect symmetry - scaled_positions = _get_scaled_positions(positions, cell) - dataset = _get_moyo_dataset(cell, scaled_positions, atomic_numbers, symprec) - if verbose: - logger.info( - "symmetrize: prec %s got space group number %s", - symprec, - dataset.number, - ) +def _refine_symmetry_impl( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Core refinement returning all intermediate data for reuse. - rotations = torch.as_tensor( - dataset.operations.rotations, dtype=dtype, device=device - ).round() - translations = torch.as_tensor( - dataset.operations.translations, dtype=dtype, device=device - ) - n_ops = rotations.shape[0] + Returns: + (refined_cell, refined_positions, rotations, translations) + """ + dtype, device = cell.dtype, cell.device + frac_pos = positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac_pos, atomic_numbers, symprec) + rotations, translations = _extract_symmetry_ops(dataset, dtype, device) + n_ops, n_atoms = rotations.shape[0], positions.shape[0] - # Step 2: Symmetrize cell via metric tensor + polar decomposition - # Row-vector metric: g[i,j] = a_i · a_j = (cell @ cell.T)[i,j] - # Symmetry invariance: R.T @ g @ R = g for all rotations R + # Symmetrize cell metric: g_sym = avg(R^T @ g @ R), then polar decomposition metric = cell @ cell.T metric_sym = torch.einsum("nji,jk,nkl->il", rotations, metric, rotations) / n_ops - # Left polar decomposition: cell = P @ V where P = sqrt(metric) - # Keep same orientation V but with symmetrized metric P_sym - sqrt_metric = _matrix_sqrt(metric) - sqrt_metric_sym = _matrix_sqrt(metric_sym) - new_cell = sqrt_metric_sym @ torch.linalg.inv(sqrt_metric) @ cell + def _mat_sqrt(mat: torch.Tensor) -> torch.Tensor: + evals, evecs = torch.linalg.eigh(mat) + return evecs @ torch.diag(evals.sqrt()) @ evecs.T + + new_cell = _mat_sqrt(metric_sym) @ torch.linalg.inv(_mat_sqrt(metric)) @ cell - # Step 3: Symmetrize positions by averaging displacements over symmetry orbits - # Recompute fractional coords in the symmetrized cell + # Symmetrize positions via displacement averaging over symmetry orbits new_frac = positions @ torch.linalg.inv(new_cell) symm_map = build_symmetry_map(rotations, translations, new_frac) - # For each op, transform fractional positions: R @ frac + t - new_frac_all = ( - torch.einsum("oij,nj->oni", rotations, new_frac) + translations[:, None, :] - ) # (n_ops, n_atoms, 3) - # Compute displacement from target atom's current position, wrapped for periodicity - n_atoms = positions.shape[0] - target_frac = new_frac[symm_map] # (n_ops, n_atoms, 3) - displacement = new_frac_all - target_frac - displacement -= displacement.round() # wrap into [-0.5, 0.5] - - # Scatter-add wrapped displacements to target atoms and average + transformed = torch.einsum("oij,nj->oni", rotations, new_frac) + translations[:, None] + disp = transformed - new_frac[symm_map] + disp -= disp.round() # wrap into [-0.5, 0.5] + target = symm_map.reshape(-1).unsqueeze(-1).expand(-1, 3) accum = torch.zeros(n_atoms, 3, dtype=dtype, device=device) - accum.scatter_add_(0, target, displacement.reshape(-1, 3)) - sym_frac = new_frac + accum / n_ops - - new_positions = sym_frac @ new_cell - - if verbose: - final_scaled = _get_scaled_positions(new_positions, new_cell) - final_dataset = _get_moyo_dataset(new_cell, final_scaled, atomic_numbers, 1e-4) - logger.info( - "symmetrize: prec 1e-4 got space group number %s", - final_dataset.number, - ) - - return new_cell, new_positions - - -def _matrix_sqrt(mat: torch.Tensor) -> torch.Tensor: - """Compute matrix square root of a symmetric positive-definite matrix. - - Uses eigendecomposition: sqrt(A) = Q @ diag(sqrt(eigenvalues)) @ Q.T + accum.scatter_add_(0, target, disp.reshape(-1, 3)) - Args: - mat: Symmetric positive-definite matrix, shape (3, 3) - - Returns: - Matrix square root, shape (3, 3) - """ - eigenvalues, eigenvectors = torch.linalg.eigh(mat) - return eigenvectors @ torch.diag(eigenvalues.sqrt()) @ eigenvectors.T + new_positions = (new_frac + accum / n_ops) @ new_cell + return new_cell, new_positions, rotations, translations -def _prep_symmetry( +def refine_symmetry( cell: torch.Tensor, positions: torch.Tensor, atomic_numbers: torch.Tensor, - symprec: float = 1.0e-4, - *, - verbose: bool = False, + symprec: float = 0.01, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare structure for symmetry-preserving minimization. + """Symmetrize cell and positions according to the detected space group. - Determines the symmetry operations (rotations in fractional coordinates) - and atom mappings needed for symmetry-constrained optimization. - - Args: - cell: Unit cell as row vectors, shape (3, 3) - positions: Cartesian positions, shape (n_atoms, 3) - atomic_numbers: Atomic numbers, shape (n_atoms,) - symprec: Symmetry precision for moyopy - verbose: If True, log symmetry information + Uses polar decomposition for the cell metric tensor and scatter-add + averaging over symmetry orbits for atomic positions. Returns: - Tuple of (rotations, symm_map): - - rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) - - symm_map: Atom mapping tensor, shape (n_ops, n_atoms) + (symmetrized_cell, symmetrized_positions) as row vectors. """ - device = cell.device - dtype = cell.dtype - - scaled_positions = _get_scaled_positions(positions, cell) - dataset = _get_moyo_dataset(cell, scaled_positions, atomic_numbers, symprec) - - if verbose: - logger.info( - "symmetrize: prec %s got space group number %s, n_ops %d", - symprec, - dataset.number, - len(dataset.operations), - ) - - rotations = torch.as_tensor( - dataset.operations.rotations, dtype=dtype, device=device - ).round() - translations = torch.as_tensor( - dataset.operations.translations, dtype=dtype, device=device + new_cell, new_positions, _rotations, _translations = _refine_symmetry_impl( + cell, positions, atomic_numbers, symprec ) + return new_cell, new_positions - # Build symmetry mapping - symm_map = build_symmetry_map(rotations, translations, scaled_positions) - - return rotations, symm_map - - -def build_symmetry_map( - rotations: torch.Tensor, - translations: torch.Tensor, - scaled_positions: torch.Tensor, -) -> torch.Tensor: - """Build symmetry atom mapping for each symmetry operation. - For each symmetry operation (R, t), determines which atom each atom - maps to: atom i → atom j where R @ frac_i + t ≈ frac_j (mod 1). +def refine_and_prep_symmetry( + cell: torch.Tensor, + positions: torch.Tensor, + atomic_numbers: torch.Tensor, + symprec: float = 0.01, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Refine symmetry and get ops/mappings in a single moyopy call. - Args: - rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) - translations: Translation vectors in fractional coords, shape (n_ops, 3) - scaled_positions: Fractional coordinates, shape (n_atoms, 3) + Combines ``refine_symmetry`` and ``prep_symmetry`` to avoid redundant + symmetry detection. Used by ``FixSymmetry.from_state``. Returns: - Symmetry mapping tensor, shape (n_ops, n_atoms) + (refined_cell, refined_positions, rotations, symm_map) """ - # Transform all atoms by all symmetry operations at once - # new_pos: (n_ops, n_atoms, 3) - new_pos = ( - torch.einsum("oij,nj->oni", rotations, scaled_positions) - + translations[:, None, :] + new_cell, new_positions, rotations, translations = _refine_symmetry_impl( + cell, positions, atomic_numbers, symprec ) - - # Compute wrapped deltas to account for periodicity - # delta: (n_ops, n_atoms, n_atoms, 3) - delta = scaled_positions[None, None, :, :] - new_pos[:, :, None, :] - delta -= delta.round() # wrap into [-0.5, 0.5] - - # Distances to all candidate atoms, then choose nearest - distances = torch.linalg.norm(delta, dim=-1) # (n_ops, n_atoms, n_atoms) - return torch.argmin(distances, dim=-1).to(dtype=torch.long) # (n_ops, n_atoms) + # Build symm_map on the final refined fractional coordinates + refined_frac = new_positions @ torch.linalg.inv(new_cell) + symm_map = build_symmetry_map(rotations, translations, refined_frac) + return new_cell, new_positions, rotations, symm_map def symmetrize_rank1( lattice: torch.Tensor, - forces: torch.Tensor, + vectors: torch.Tensor, rotations: torch.Tensor, symm_map: torch.Tensor, ) -> torch.Tensor: - """Symmetrize rank-1 tensor (forces, velocities, etc). - - Averages the tensor over all symmetry operations, respecting atom - permutations. Works in fractional coordinates internally. - - Args: - lattice: Cell vectors as row vectors, shape (3, 3) - forces: Forces array, shape (n_atoms, 3) - rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) - symm_map: Atom mapping for each symmetry operation, shape (n_ops, n_atoms) + """Symmetrize a rank-1 per-atom tensor (forces, velocities, displacements). - Returns: - Symmetrized forces, shape (n_atoms, 3) + Works in fractional coordinates internally. Returns symmetrized Cartesian tensor. """ - n_ops = rotations.shape[0] - n_atoms = forces.shape[0] - - # Transform to scaled (fractional) coordinates: (n_atoms, 3) - scaled_forces = forces @ lattice.inverse() - - # Apply all rotations at once: (n_ops, n_atoms, 3) - # rotations: (n_ops, 3, 3), scaled_forces: (n_atoms, 3) - # For each op: scaled_forces @ rot.T (rotate the vectors) - # Note: we use rotations.mT to get the transpose of each rotation matrix - transformed_forces = torch.einsum("ij,nkj->nik", scaled_forces, rotations) - - # Flatten for scatter: (n_ops * n_atoms, 3) - transformed_flat = transformed_forces.reshape(-1, 3) - - # Flatten symm_map to get target indices: (n_ops * n_atoms,) - target_indices = symm_map.reshape(-1) - - # Expand target indices to match 3D coordinates: (n_ops * n_atoms, 3) - target_indices_expanded = target_indices.unsqueeze(-1).expand(-1, 3) - - # Scatter add to accumulate forces at target atoms - # Result shape: (n_atoms, 3) - accumulated = torch.zeros(n_atoms, 3, dtype=forces.dtype, device=forces.device) - accumulated.scatter_add_(0, target_indices_expanded, transformed_flat) - - # Average over symmetry operations - symmetrized_scaled = accumulated / n_ops - - # Transform back to Cartesian - return symmetrized_scaled @ lattice + n_ops, n_atoms = rotations.shape[0], vectors.shape[0] + scaled = vectors @ torch.linalg.inv(lattice) + # Rotate each vector by each symmetry op: scaled @ R^T + rotated = torch.einsum("ij,nkj->nik", scaled, rotations).reshape(-1, 3) + # Scatter-add to target atoms and average + target = symm_map.reshape(-1).unsqueeze(-1).expand(-1, 3) + accum = torch.zeros(n_atoms, 3, dtype=vectors.dtype, device=vectors.device) + accum.scatter_add_(0, target, rotated) + return (accum / n_ops) @ lattice def symmetrize_rank2( lattice: torch.Tensor, - stress: torch.Tensor, + tensor: torch.Tensor, rotations: torch.Tensor, ) -> torch.Tensor: - """Symmetrize rank-2 tensor (stress, strain, etc). - - Averages the tensor over all symmetry operations in scaled coordinates. - - Args: - lattice: Cell vectors as row vectors, shape (3, 3) - stress: Stress tensor, shape (3, 3) - rotations: Rotation matrices in fractional coords, shape (n_ops, 3, 3) - - Returns: - Symmetrized stress tensor, shape (3, 3) - """ + """Symmetrize a rank-2 tensor (stress, strain) over all symmetry operations.""" n_ops = rotations.shape[0] - inv_lattice = lattice.inverse() - - # Scale stress: lattice @ stress @ lattice.T - scaled_stress = lattice @ stress @ lattice.T - - # Symmetrize in scaled coordinates using vectorized operations - # r.T @ scaled_stress @ r for all rotations at once - # For r.T @ A @ r: result[i,l] = sum_j,k r[j,i] * A[j,k] * r[k,l] - # With batched rotations: einsum "nji,jk,nkl->il" - symmetrized_scaled_stress = ( - torch.einsum("nji,jk,nkl->il", rotations, scaled_stress, rotations) / n_ops - ) - - # Transform back: inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T - return inv_lattice @ symmetrized_scaled_stress @ inv_lattice.T + inv_lat = torch.linalg.inv(lattice) + scaled = lattice @ tensor @ lattice.T + sym_scaled = torch.einsum("nji,jk,nkl->il", rotations, scaled, rotations) / n_ops + return inv_lat @ sym_scaled @ inv_lat.T From c770bcc93367f81ed4d59d8863673b9b2c5218fe Mon Sep 17 00:00:00 2001 From: janosh Date: Fri, 6 Feb 2026 08:50:48 -0800 Subject: [PATCH 13/16] fix SystemConstraint.merge producing duplicate indices for multi-system states SystemConstraint.merge used the raw state enumeration index as the system offset, which is only correct when each state has exactly 1 system. For multi-system states (e.g. merging two FixCom([0,1]) constraints), this produced duplicate indices [0,1,1,2] instead of [0,1,2,3]. Fix: merge_constraints now computes cumulative system offsets from num_systems_per_state (passed from concatenate_states) and uses those as the state_indices for SystemConstraint.merge, so the offset correctly accounts for multi-system states and gaps from states without the constraint. Co-authored-by: Cursor --- torch_sim/constraints.py | 33 +++++++++++++++++++++++++++------ torch_sim/state.py | 5 ++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 33fc78852..8913d5a78 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -366,28 +366,32 @@ def merge( Args: constraints: List of constraints to merge - state_indices: Index of the source state for each constraint + state_indices: Cumulative system offset for each constraint's source + state (computed by ``merge_constraints``) atom_offsets: Cumulative atom counts (unused for SystemConstraint) Returns: A single merged constraint with adjusted system indices """ all_indices = [] - for constraint, state_idx in zip(constraints, state_indices, strict=False): - # For SystemConstraint, the offset is the state index itself - all_indices.append(constraint.system_idx + state_idx) + for constraint, offset in zip(constraints, state_indices, strict=False): + all_indices.append(constraint.system_idx + offset) return cls(torch.cat(all_indices)) def merge_constraints( constraint_lists: list[list[AtomConstraint | SystemConstraint]], num_atoms_per_state: torch.Tensor, + num_systems_per_state: torch.Tensor | None = None, ) -> list[Constraint]: """Merge constraints from multiple systems into a single list of constraints. Args: constraint_lists: List of lists of constraints - num_atoms_per_state: Number of atoms per system + num_atoms_per_state: Number of atoms per state + num_systems_per_state: Number of systems per state (needed for correct + SystemConstraint offsets in multi-system states). Falls back to 1 + per state if not provided. Returns: List of merged constraints @@ -403,6 +407,19 @@ def merge_constraints( ] ) + # Calculate system offsets for SystemConstraints + if num_systems_per_state is None: + # Default: assume 1 system per state (backward compatible) + num_systems_per_state = torch.ones( + len(constraint_lists), device=device, dtype=dtype + ) + system_offsets = torch.cat( + [ + torch.zeros(1, device=device, dtype=dtype), + torch.cumsum(num_systems_per_state[:-1], dim=0), + ] + ) + # Group constraints by type, tracking their source state index constraints_by_type: dict[type[Constraint], tuple[list, list[int]]] = defaultdict( lambda: ([], []) @@ -411,7 +428,11 @@ def merge_constraints( for constraint in constraint_list: constraints, indices = constraints_by_type[type(constraint)] constraints.append(constraint) - indices.append(state_idx) + # SystemConstraints need cumulative system offsets, not raw state indices + if isinstance(constraint, SystemConstraint): + indices.append(int(system_offsets[state_idx].item())) + else: + indices.append(state_idx) # Merge each group using the constraint's merge method result = [] diff --git a/torch_sim/state.py b/torch_sim/state.py index eb412a821..063e9b585 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1121,8 +1121,11 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Merge constraints constraint_lists = [state.constraints for state in states] + num_systems_per_state = [state.n_systems for state in states] constraints = merge_constraints( - constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) + constraint_lists, + torch.tensor(num_atoms_per_state, device=target_device), + torch.tensor(num_systems_per_state, device=target_device), ) # Create a new instance of the same class From cea861d29ecbc89431887074827682b154bcb4b8 Mon Sep 17 00:00:00 2001 From: janosh Date: Fri, 6 Feb 2026 08:55:20 -0800 Subject: [PATCH 14/16] refactor constraint merging: split into reindex() + merge() The old merge(constraints, state_indices, atom_offsets) conflated two concerns: shifting indices to global coordinates and concatenating constraints. The state_indices parameter meant different things for different constraint types, and FixSymmetry ignored it entirely. New design: - reindex(atom_offset, system_offset): returns a copy with indices shifted to global coordinates. Each subclass knows what to shift. - merge(constraints): just concatenates already-reindexed constraints. No offset logic needed. - merge_constraints() orchestrates: reindex first, then merge. This eliminates the ambiguous state_indices parameter and gives each method a single clear responsibility. Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 12 ++- torch_sim/constraints.py | 159 ++++++++++++++----------------------- 2 files changed, 63 insertions(+), 108 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index ffd328017..adff7b29b 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -365,11 +365,9 @@ def test_merge_two_constraints(self) -> None: """Merge two single-system constraints.""" s1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) s2 = ts.io.atoms_to_state(make_structure("diamond"), torch.device("cpu"), DTYPE) - merged = FixSymmetry.merge( - [FixSymmetry.from_state(s1), FixSymmetry.from_state(s2)], - state_indices=[0, 1], - atom_offsets=None, - ) + c1 = FixSymmetry.from_state(s1) + c2 = FixSymmetry.from_state(s2).reindex(atom_offset=0, system_offset=1) + merged = FixSymmetry.merge([c1, c2]) assert len(merged.rotations) == 2 assert merged.system_idx.tolist() == [0, 1] @@ -386,8 +384,8 @@ def test_merge_multi_system_no_duplicate_indices(self) -> None: ) c_b = FixSymmetry.from_state( ts.io.atoms_to_state(atoms_b, torch.device("cpu"), DTYPE), - ) - merged = FixSymmetry.merge([c_a, c_b], state_indices=[0, 1], atom_offsets=None) + ).reindex(atom_offset=0, system_offset=3) + merged = FixSymmetry.merge([c_a, c_b]) assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] def test_select_sub_constraint(self) -> None: diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 8913d5a78..4979b8286 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -124,28 +124,26 @@ def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Constraint for the given atom and system index """ - @classmethod - def merge( - cls, - constraints: list[Self], - state_indices: list[int], - atom_offsets: torch.Tensor, - ) -> Self: - """Merge multiple constraints of the same type into one. + @abstractmethod + def reindex(self, atom_offset: int, system_offset: int) -> Self: + """Return a copy with indices shifted to global coordinates. - This method is called during state concatenation to combine constraints - from multiple states. Subclasses can override this for custom merge logic. + Called during state concatenation to adjust indices before merging. Args: - constraints: List of constraints to merge (all of the same type) - state_indices: Index of the source state for each constraint - atom_offsets: Cumulative atom counts for offset calculation + atom_offset: Offset to add to atom indices + system_offset: Offset to add to system indices + """ - Returns: - A single merged constraint + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge multiple already-reindexed constraints into one. - Raises: - NotImplementedError: If the constraint type doesn't support merging + Constraints must have global (absolute) indices — call ``reindex`` + first. Subclasses override this to handle type-specific data. + + Args: + constraints: Constraints to merge (all same type, already reindexed) """ raise NotImplementedError( f"Constraint type {cls.__name__} does not implement merge. " @@ -250,28 +248,14 @@ def select_sub_constraint( return None return type(self)(new_atom_idx) - @classmethod - def merge( - cls, - constraints: list[Self], - state_indices: list[int], - atom_offsets: torch.Tensor, - ) -> Self: - """Merge multiple AtomConstraints by concatenating indices with offsets. - - Args: - constraints: List of constraints to merge - state_indices: Index of the source state for each constraint - atom_offsets: Cumulative atom counts for offset calculation + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with atom indices shifted by atom_offset.""" + return type(self)(self.atom_idx + atom_offset) - Returns: - A single merged constraint with adjusted atom indices - """ - all_indices = [] - for constraint, state_idx in zip(constraints, state_indices, strict=False): - offset = atom_offsets[state_idx] - all_indices.append(constraint.atom_idx + offset) - return cls(torch.cat(all_indices)) + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating already-reindexed atom indices.""" + return cls(torch.cat([c.atom_idx for c in constraints])) class SystemConstraint(Constraint): @@ -355,28 +339,14 @@ def select_sub_constraint( """ return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None - @classmethod - def merge( - cls, - constraints: list[Self], - state_indices: list[int], - atom_offsets: torch.Tensor, # noqa: ARG003 - ) -> Self: - """Merge multiple SystemConstraints by concatenating indices with offsets. - - Args: - constraints: List of constraints to merge - state_indices: Cumulative system offset for each constraint's source - state (computed by ``merge_constraints``) - atom_offsets: Cumulative atom counts (unused for SystemConstraint) + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with system indices shifted by system_offset.""" + return type(self)(self.system_idx + system_offset) - Returns: - A single merged constraint with adjusted system indices - """ - all_indices = [] - for constraint, offset in zip(constraints, state_indices, strict=False): - all_indices.append(constraint.system_idx + offset) - return cls(torch.cat(all_indices)) + @classmethod + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating already-reindexed system indices.""" + return cls(torch.cat([c.system_idx for c in constraints])) def merge_constraints( @@ -384,13 +354,15 @@ def merge_constraints( num_atoms_per_state: torch.Tensor, num_systems_per_state: torch.Tensor | None = None, ) -> list[Constraint]: - """Merge constraints from multiple systems into a single list of constraints. + """Merge constraints from multiple states into a single list. + + Each constraint is first reindexed to global coordinates (via ``reindex``), + then constraints of the same type are merged (via ``merge``). Args: - constraint_lists: List of lists of constraints + constraint_lists: List of lists of constraints, one list per state num_atoms_per_state: Number of atoms per state - num_systems_per_state: Number of systems per state (needed for correct - SystemConstraint offsets in multi-system states). Falls back to 1 + num_systems_per_state: Number of systems per state. Falls back to 1 per state if not provided. Returns: @@ -398,7 +370,7 @@ def merge_constraints( """ from collections import defaultdict - # Calculate atom offsets: for state i, offset = sum of atoms in states 0 to i-1 + # Calculate cumulative offsets for atoms and systems device, dtype = num_atoms_per_state.device, num_atoms_per_state.dtype atom_offsets = torch.cat( [ @@ -406,10 +378,7 @@ def merge_constraints( torch.cumsum(num_atoms_per_state[:-1], dim=0), ] ) - - # Calculate system offsets for SystemConstraints if num_systems_per_state is None: - # Default: assume 1 system per state (backward compatible) num_systems_per_state = torch.ones( len(constraint_lists), device=device, dtype=dtype ) @@ -420,27 +389,15 @@ def merge_constraints( ] ) - # Group constraints by type, tracking their source state index - constraints_by_type: dict[type[Constraint], tuple[list, list[int]]] = defaultdict( - lambda: ([], []) - ) + # Reindex each constraint to global coordinates, then group by type + grouped: dict[type[Constraint], list[Constraint]] = defaultdict(list) for state_idx, constraint_list in enumerate(constraint_lists): + a_off = int(atom_offsets[state_idx].item()) + s_off = int(system_offsets[state_idx].item()) for constraint in constraint_list: - constraints, indices = constraints_by_type[type(constraint)] - constraints.append(constraint) - # SystemConstraints need cumulative system offsets, not raw state indices - if isinstance(constraint, SystemConstraint): - indices.append(int(system_offsets[state_idx].item())) - else: - indices.append(state_idx) + grouped[type(constraint)].append(constraint.reindex(a_off, s_off)) - # Merge each group using the constraint's merge method - result = [] - for constraint_type, (constraints, state_indices) in constraints_by_type.items(): - merged = constraint_type.merge(constraints, state_indices, atom_offsets) - result.append(merged) - - return result + return [ctype.merge(cs) for ctype, cs in grouped.items()] class FixAtoms(AtomConstraint): @@ -942,28 +899,28 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: """Returns zero - constrains direction, not DOF count.""" return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 + """Return copy with system indices shifted by system_offset.""" + return type(self)( + self.rotations, + self.symm_maps, + self.system_idx + system_offset, + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + ) + @classmethod - def merge( - cls, - constraints: list[Self], - state_indices: list[int], # noqa: ARG003 - atom_offsets: torch.Tensor, # noqa: ARG003 - ) -> Self: - """Merge multiple FixSymmetry constraints into one.""" + def merge(cls, constraints: list[Self]) -> Self: + """Merge by concatenating rotations, symm_maps, and system indices.""" if not constraints: raise ValueError("Cannot merge empty constraint list") - rotations, symm_maps, sys_indices = [], [], [] - offset = 0 - for constraint in constraints: - for idx in range(len(constraint.rotations)): - rotations.append(constraint.rotations[idx]) - symm_maps.append(constraint.symm_maps[idx]) - sys_indices.append(offset + idx) - offset += len(constraint.rotations) + rotations = [r for c in constraints for r in c.rotations] + symm_maps = [s for c in constraints for s in c.symm_maps] + system_idx = torch.cat([c.system_idx for c in constraints]) return cls( rotations, symm_maps, - system_idx=torch.tensor(sys_indices, device=rotations[0].device), + system_idx=system_idx, adjust_positions=constraints[0].do_adjust_positions, adjust_cell=constraints[0].do_adjust_cell, ) From e6bcacbf4be63a3e1dcb8a170bc71cdd7533a757 Mon Sep 17 00:00:00 2001 From: janosh Date: Fri, 6 Feb 2026 09:09:00 -0800 Subject: [PATCH 15/16] add regression tests, simplify test structure New tests covering key fixes: - reindex() preserves rotations/symm_maps while shifting system_idx - SystemConstraint.merge multi-system duplicate-indices regression - concatenate_states end-to-end with FixSymmetry and FixCom - FixSymmetry.__init__ validation, select returning None - adjust_positions/adjust_cell skip when disabled - build_symmetry_map chunked vs vectorized path equivalence - refine_symmetry recovers correct spacegroup after perturbation - position symmetrization matches ASE Simplifications: - Extract p6bar_both_constraints fixture (dedup 4 ASE comparison tests) - Merge 2 test classes into TestFixSymmetryMergeSelectReindex - Combine similar tests via parametrize and multi-assert - Unwrap single-test class, add CPU constant Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 367 +++++++++++++++++++++++-------------- 1 file changed, 233 insertions(+), 134 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index adff7b29b..8697fcf4b 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -12,7 +12,7 @@ from pymatgen.io.ase import AseAtomsAdaptor import torch_sim as ts -from torch_sim.constraints import FixSymmetry +from torch_sim.constraints import FixCom, FixSymmetry from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets @@ -24,6 +24,7 @@ MAX_STEPS = 30 DTYPE = torch.float64 SYMPREC = 0.01 +CPU = torch.device("cpu") # === Structure helpers === @@ -118,6 +119,18 @@ def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: return NoisyModelWrapper(model) +@pytest.fixture +def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSymmetry]: + """P-6 structure with both TorchSim and ASE constraints (shared setup).""" + atoms = make_structure("p6bar") + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + ase_atoms = atoms.copy() + ase_refine_symmetry(ase_atoms, symprec=SYMPREC) + ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + return state, ts_constraint, ase_atoms, ase_constraint + + # === Optimization helper === @@ -162,10 +175,10 @@ class TestFixSymmetryCreation: """Tests for FixSymmetry creation and basic behavior.""" def test_from_state_batched(self) -> None: - """Batched state with FCC + diamond gets correct ops and atom counts.""" + """Batched state with FCC + diamond gets correct ops, atom counts, and DOF.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - torch.device("cpu"), + CPU, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) @@ -173,35 +186,30 @@ def test_from_state_batched(self) -> None: assert constraint.rotations[0].shape[0] == 48 # cubic assert constraint.symm_maps[0].shape == (48, 1) # Cu: 1 atom assert constraint.symm_maps[1].shape == (48, 2) # Si: 2 atoms + assert torch.all(constraint.get_removed_dof(state) == 0) - def test_p1_identity_only(self) -> None: - """P1 structure has 1 op and symmetrization is a no-op.""" + def test_p1_identity_is_noop(self) -> None: + """P1 structure has 1 op and symmetrization is a no-op for forces and stress.""" atoms = Atoms( "SiGe", positions=[[0.1, 0.2, 0.3], [1.1, 0.9, 1.3]], cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], pbc=True, ) - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) assert constraint.rotations[0].shape[0] == 1 forces = torch.randn(2, 3, dtype=DTYPE) - original = forces.clone() + orig_forces = forces.clone() constraint.adjust_forces(state, forces) - assert torch.allclose(forces, original, atol=1e-10) + assert torch.allclose(forces, orig_forces, atol=1e-10) stress = torch.randn(1, 3, 3, dtype=DTYPE) stress = (stress + stress.mT) / 2 - original_stress = stress.clone() + orig_stress = stress.clone() constraint.adjust_stress(state, stress) - assert torch.allclose(stress, original_stress, atol=1e-10) - - def test_get_removed_dof_returns_zero(self) -> None: - """FixSymmetry constrains direction, not DOF count.""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - assert torch.all(constraint.get_removed_dof(state) == 0) + assert torch.allclose(stress, orig_stress, atol=1e-10) @pytest.mark.parametrize("refine", [True, False]) def test_from_state_refine_symmetry(self, *, refine: bool) -> None: @@ -209,7 +217,7 @@ def test_from_state_refine_symmetry(self, *, refine: bool) -> None: atoms = make_structure("fcc") rng = np.random.default_rng(42) atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) orig_pos = state.positions.clone() _ = FixSymmetry.from_state(state, symprec=SYMPREC, refine_symmetry_state=refine) if not refine: @@ -221,150 +229,160 @@ def test_refine_symmetry_produces_correct_spacegroup( structure_name: str, ) -> None: """Perturbed structure recovers correct spacegroup after refinement.""" - from torch_sim.symmetrize import get_symmetry_datasets, refine_symmetry + from torch_sim.symmetrize import refine_symmetry atoms = make_structure(structure_name) expected = SPACEGROUPS[structure_name] rng = np.random.default_rng(42) atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - cell = state.row_vector_cell[0] - pos = state.positions - nums = state.atomic_numbers - - refined_cell, refined_pos = refine_symmetry(cell, pos, nums, symprec=SYMPREC) + refined_cell, refined_pos = refine_symmetry( + state.row_vector_cell[0], + state.positions, + state.atomic_numbers, + symprec=SYMPREC, + ) state.cell[0] = refined_cell.mT state.positions = refined_pos - # Check at tight precision datasets = get_symmetry_datasets(state, symprec=1e-4) - assert datasets[0].number == expected, ( - f"{structure_name}: expected SG {expected}, got {datasets[0].number}" + assert datasets[0].number == expected + + def test_cubic_forces_vanish(self) -> None: + """Asymmetric force on single cubic atom symmetrizes to zero.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], + dtype=DTYPE, + ) + constraint.adjust_forces(state, forces) + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) def test_large_deformation_raises(self) -> None: """Deformation gradient > 0.25 raises RuntimeError.""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) new_cell = state.cell.clone() new_cell[0] *= 1.5 with pytest.raises(RuntimeError, match="deformation gradient"): constraint.adjust_cell(state, new_cell) + def test_init_mismatched_lengths_raises(self) -> None: + """Mismatched rotations/symm_maps lengths raises ValueError.""" + rots = [torch.eye(3).unsqueeze(0)] + smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)] + with pytest.raises(ValueError, match="length mismatch"): + FixSymmetry(rots, smaps) + + @pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"]) + def test_adjust_skipped_when_disabled(self, method: str) -> None: + """adjust_positions=False / adjust_cell=False leaves data unchanged.""" + flag = method.replace("adjust_", "") # "positions" or "cell" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state( + state, + symprec=SYMPREC, + **{f"adjust_{flag}": False}, + ) + if method == "adjust_positions": + data = state.positions.clone() + 0.1 + else: + data = state.cell.clone() * 1.01 + expected = data.clone() + getattr(constraint, method)(state, data) + assert torch.equal(data, expected) + # === Tests: Comparison with ASE === class TestFixSymmetryComparisonWithASE: - """Compare TorchSim FixSymmetry with ASE's implementation.""" - - def test_force_symmetrization_matches_ase(self) -> None: - """Force symmetrization matches ASE on multi-atom P-6 structure.""" - atoms = make_structure("p6bar") - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - ase_atoms = atoms.copy() - ase_refine_symmetry(ase_atoms, symprec=SYMPREC) - ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) + """Compare TorchSim FixSymmetry with ASE's implementation on P-6 structure.""" + def test_force_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Force symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints rng = np.random.default_rng(42) - forces_np = rng.standard_normal((len(atoms), 3)) + forces_np = rng.standard_normal((len(ase_atoms), 3)) forces_ts = torch.tensor(forces_np.copy(), dtype=DTYPE) - - ts_constraint.adjust_forces(state, forces_ts) - ase_constraint.adjust_forces(ase_atoms, forces_np) + ts_c.adjust_forces(state, forces_ts) + ase_c.adjust_forces(ase_atoms, forces_np) assert np.allclose(forces_ts.numpy(), forces_np, atol=1e-10) - def test_stress_symmetrization_matches_ase(self) -> None: - """Stress symmetrization matches ASE on P-6 structure.""" - atoms = make_structure("p6bar") - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - ase_atoms = atoms.copy() - ase_refine_symmetry(ase_atoms, symprec=SYMPREC) - ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - + def test_stress_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Stress symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints stress_3x3 = np.array([[10.0, 1.0, 0.5], [1.0, 8.0, 0.3], [0.5, 0.3, 6.0]]) - stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3) - stress_voigt_copy = stress_voigt.copy() + stress_voigt = full_3x3_to_voigt_6_stress(stress_3x3).copy() stress_ts = torch.tensor([stress_3x3.copy()], dtype=DTYPE) + ts_c.adjust_stress(state, stress_ts) + ase_c.adjust_stress(ase_atoms, stress_voigt) + assert np.allclose( + stress_ts[0].numpy(), + voigt_6_to_full_3x3_stress(stress_voigt), + atol=1e-10, + ) - ts_constraint.adjust_stress(state, stress_ts) - ase_constraint.adjust_stress(ase_atoms, stress_voigt_copy) - ase_result = voigt_6_to_full_3x3_stress(stress_voigt_copy) - assert np.allclose(stress_ts[0].numpy(), ase_result, atol=1e-10) - - def test_cell_deformation_matches_ase(self) -> None: - """Cell deformation symmetrization matches ASE on P-6 structure.""" - atoms = make_structure("p6bar") - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - ase_atoms = atoms.copy() - ase_refine_symmetry(ase_atoms, symprec=SYMPREC) - ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - - original_cell = ase_atoms.get_cell().copy() - deformed_cell = original_cell.copy() - deformed_cell[0, 1] += 0.05 - - new_cell_ts = torch.tensor([deformed_cell.copy().T], dtype=DTYPE) - ts_constraint.adjust_cell(state, new_cell_ts) - ts_result = new_cell_ts[0].mT.numpy() - - ase_cell = deformed_cell.copy() - ase_constraint.adjust_cell(ase_atoms, ase_cell) - assert np.allclose(ts_result, ase_cell, atol=1e-10) - - def test_position_symmetrization_matches_ase(self) -> None: - """Position displacement symmetrization matches ASE on P-6 structure.""" - atoms = make_structure("p6bar") - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) - ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - - ase_atoms = atoms.copy() - ase_refine_symmetry(ase_atoms, symprec=SYMPREC) - ase_constraint = ASEFixSymmetry(ase_atoms, symprec=SYMPREC) - - # Create a displacement by proposing new positions + def test_cell_deformation_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Cell deformation symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints + deformed = ase_atoms.get_cell().copy() + deformed[0, 1] += 0.05 + new_cell_ts = torch.tensor([deformed.copy().T], dtype=DTYPE) + ts_c.adjust_cell(state, new_cell_ts) + ase_cell = deformed.copy() + ase_c.adjust_cell(ase_atoms, ase_cell) + assert np.allclose(new_cell_ts[0].mT.numpy(), ase_cell, atol=1e-10) + + def test_position_symmetrization_matches_ase( + self, + p6bar_both_constraints: tuple, + ) -> None: + """Position displacement symmetrization matches ASE.""" + state, ts_c, ase_atoms, ase_c = p6bar_both_constraints rng = np.random.default_rng(42) - displacement = rng.standard_normal((len(atoms), 3)) * 0.01 - new_pos_ts = state.positions.clone() + torch.tensor(displacement, dtype=DTYPE) - new_pos_ase = ase_atoms.positions.copy() + displacement - - ts_constraint.adjust_positions(state, new_pos_ts) - ase_constraint.adjust_positions(ase_atoms, new_pos_ase) + disp = rng.standard_normal((len(ase_atoms), 3)) * 0.01 + new_pos_ts = state.positions.clone() + torch.tensor(disp, dtype=DTYPE) + new_pos_ase = ase_atoms.positions.copy() + disp + ts_c.adjust_positions(state, new_pos_ts) + ase_c.adjust_positions(ase_atoms, new_pos_ase) assert np.allclose(new_pos_ts.numpy(), new_pos_ase, atol=1e-10) - def test_cubic_forces_vanish(self) -> None: - """Asymmetric force on single cubic atom symmetrizes to zero.""" - state = ts.io.atoms_to_state( - [make_structure("fcc"), make_structure("diamond")], - torch.device("cpu"), - DTYPE, - ) - constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - forces = torch.tensor( - [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], - dtype=DTYPE, - ) - constraint.adjust_forces(state, forces) - assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) +# === Tests: Merge, Select, Reindex === -# === Tests: Merge & Select === +class TestFixSymmetryMergeSelectReindex: + """Tests for reindex/merge API, select, and concatenation.""" -class TestFixSymmetryMergeAndSelect: - """Tests for merge, select_constraint, select_sub_constraint.""" + def test_reindex_preserves_symmetry_data(self) -> None: + """reindex shifts system_idx but preserves rotations and symm_maps.""" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + orig = FixSymmetry.from_state(state, symprec=SYMPREC) + shifted = orig.reindex(atom_offset=100, system_offset=5) + assert shifted.system_idx.item() == 5 + assert torch.equal(shifted.rotations[0], orig.rotations[0]) + assert torch.equal(shifted.symm_maps[0], orig.symm_maps[0]) def test_merge_two_constraints(self) -> None: - """Merge two single-system constraints.""" - s1 = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) - s2 = ts.io.atoms_to_state(make_structure("diamond"), torch.device("cpu"), DTYPE) + """Merge two single-system constraints via reindex + merge.""" + s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) c1 = FixSymmetry.from_state(s1) c2 = FixSymmetry.from_state(s2).reindex(atom_offset=0, system_offset=1) merged = FixSymmetry.merge([c1, c2]) @@ -379,43 +397,124 @@ def test_merge_multi_system_no_duplicate_indices(self) -> None: make_structure("hcp"), ] atoms_b = [make_structure("bcc"), make_structure("fcc")] - c_a = FixSymmetry.from_state( - ts.io.atoms_to_state(atoms_a, torch.device("cpu"), DTYPE), - ) + c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, CPU, DTYPE)) c_b = FixSymmetry.from_state( - ts.io.atoms_to_state(atoms_b, torch.device("cpu"), DTYPE), + ts.io.atoms_to_state(atoms_b, CPU, DTYPE), ).reindex(atom_offset=0, system_offset=3) merged = FixSymmetry.merge([c_a, c_b]) assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] + def test_system_constraint_merge_multi_system_via_concatenate(self) -> None: + """Regression: merging multi-system FixCom via concatenate_states.""" + s1 = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + s2 = ts.io.atoms_to_state( + [make_structure("bcc"), make_structure("hcp")], + CPU, + DTYPE, + ) + s1.constraints = [FixCom(system_idx=torch.tensor([0, 1]))] + s2.constraints = [FixCom(system_idx=torch.tensor([0, 1]))] + combined = ts.concatenate_states([s1, s2]) + assert combined.constraints[0].system_idx.tolist() == [0, 1, 2, 3] + + def test_concatenate_states_with_fix_symmetry(self) -> None: + """FixSymmetry survives concatenate_states and still symmetrizes correctly.""" + s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) + s1.constraints = [FixSymmetry.from_state(s1, symprec=SYMPREC)] + s2.constraints = [FixSymmetry.from_state(s2, symprec=SYMPREC)] + combined = ts.concatenate_states([s1, s2]) + constraint = combined.constraints[0] + assert isinstance(constraint, FixSymmetry) + assert constraint.system_idx.tolist() == [0, 1] + assert len(constraint.rotations) == 2 + # Forces on single FCC atom should still vanish + forces = torch.tensor( + [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], + dtype=DTYPE, + ) + constraint.adjust_forces(combined, forces) + assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + def test_select_sub_constraint(self) -> None: """Select second system from batched constraint.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - torch.device("cpu"), + CPU, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) selected = constraint.select_sub_constraint(torch.tensor([1, 2]), sys_idx=1) assert selected is not None - assert selected.symm_maps[0].shape[1] == 2 # Si diamond: 2 atoms + assert selected.symm_maps[0].shape[1] == 2 assert selected.system_idx.item() == 0 def test_select_constraint_by_mask(self) -> None: """Select first system via system_mask.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - torch.device("cpu"), + CPU, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - atom_mask = torch.tensor([True, False, False], dtype=torch.bool) - system_mask = torch.tensor([True, False], dtype=torch.bool) - selected = constraint.select_constraint(atom_mask, system_mask) + selected = constraint.select_constraint( + atom_mask=torch.tensor([True, False, False]), + system_mask=torch.tensor([True, False]), + ) assert selected is not None assert len(selected.rotations) == 1 assert selected.rotations[0].shape[0] == 48 + def test_select_returns_none_for_nonexistent(self) -> None: + """select_sub_constraint and select_constraint return None when no match.""" + state = ts.io.atoms_to_state( + [make_structure("fcc"), make_structure("diamond")], + CPU, + DTYPE, + ) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + assert constraint.select_sub_constraint(torch.tensor([0]), sys_idx=99) is None + assert ( + constraint.select_constraint( + atom_mask=torch.zeros(3, dtype=torch.bool), + system_mask=torch.zeros(2, dtype=torch.bool), + ) + is None + ) + + +# === Tests: build_symmetry_map chunked path === + + +def test_build_symmetry_map_chunked_matches_vectorized() -> None: + """Per-op loop gives same result as vectorized path.""" + import torch_sim.symmetrize as sym_mod + from torch_sim.symmetrize import ( + _extract_symmetry_ops, + _moyo_dataset, + build_symmetry_map, + ) + + state = ts.io.atoms_to_state(make_structure("p6bar"), CPU, DTYPE) + cell = state.row_vector_cell[0] + frac = state.positions @ torch.linalg.inv(cell) + dataset = _moyo_dataset(cell, frac, state.atomic_numbers) + rotations, translations = _extract_symmetry_ops(dataset, DTYPE, CPU) + + old_threshold = sym_mod._SYMM_MAP_CHUNK_THRESHOLD # noqa: SLF001 + try: + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = len(state.positions) + 1 # noqa: SLF001 + vectorized = build_symmetry_map(rotations, translations, frac) + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = 0 # noqa: SLF001 + chunked = build_symmetry_map(rotations, translations, frac) + finally: + sym_mod._SYMM_MAP_CHUNK_THRESHOLD = old_threshold # noqa: SLF001 + assert torch.equal(vectorized, chunked) + # === Tests: Optimization === @@ -439,7 +538,7 @@ def test_distorted_preserves_symmetry( """Compressed structure relaxes while preserving symmetry.""" atoms = make_structure(structure_name) expected = SPACEGROUPS[structure_name] - state = ts.io.atoms_to_state(atoms, torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(atoms, CPU, DTYPE) constraint = FixSymmetry.from_state( state, symprec=SYMPREC, @@ -464,7 +563,7 @@ def test_cell_filter_preserves_symmetry( cell_filter: ts.CellFilter, ) -> None: """Cell filters with FixSymmetry preserve symmetry.""" - state = ts.io.atoms_to_state(make_structure("fcc"), torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] initial = get_symmetry_datasets(state, symprec=SYMPREC) @@ -486,7 +585,7 @@ def test_lbfgs_preserves_symmetry( cell_filter: ts.CellFilter, ) -> None: """Regression: LBFGS must use set_constrained_cell for FixSymmetry support.""" - state = ts.io.atoms_to_state(make_structure("bcc"), torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(make_structure("bcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] state.cell = state.cell * 0.95 @@ -514,7 +613,7 @@ def test_noisy_model_loses_symmetry_without_constraint( ) -> None: """Negative control: without FixSymmetry, noisy forces break symmetry.""" name = "bcc_rotated" if rotated else "bcc" - state = ts.io.atoms_to_state(make_structure(name), torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE) result = run_optimization_check_symmetry(state, noisy_lj_model, constraint=None) assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] != 229 @@ -528,7 +627,7 @@ def test_noisy_model_preserves_symmetry_with_constraint( ) -> None: """With FixSymmetry, noisy forces still preserve symmetry.""" name = "bcc_rotated" if rotated else "bcc" - state = ts.io.atoms_to_state(make_structure(name), torch.device("cpu"), DTYPE) + state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) result = run_optimization_check_symmetry( state, From a16d4b504a6ac312120b9170105798867b0bc509 Mon Sep 17 00:00:00 2001 From: janosh Date: Fri, 6 Feb 2026 09:41:13 -0800 Subject: [PATCH 16/16] harden FixSymmetry constraint and clean up code - fix reindex() sharing mutable rotations/symm_maps lists between original and copy (shallow copy with list()) - validate adjust_positions/adjust_cell flag consistency in merge() - catch NaN in deformation guard via negated comparison (NaN > x is False) - clamp eigenvalues in _mat_sqrt to prevent NaN from float noise - use torch.linalg.solve instead of inv @ matmul for numerical stability - extract _cumsum_with_zero utility to deduplicate 4 identical patterns - skip stress clone in _get_constrained_stress when no constraints - make Constraint.merge @abstractmethod (was non-abstract NotImplementedError) - use mask.nonzero() instead of manual list comprehension in select_constraint - add type annotations to NoisyModelWrapper test helper Co-authored-by: Cursor --- tests/test_fix_symmetry.py | 4 ++ torch_sim/constraints.py | 64 ++++++++++++---------------- torch_sim/optimizers/cell_filters.py | 2 + torch_sim/symmetrize.py | 6 ++- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 8697fcf4b..7ddb8dc56 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -86,6 +86,10 @@ def model() -> LennardJonesModel: class NoisyModelWrapper: """Wrapper that adds noise to forces and stress.""" + model: LennardJonesModel + rng: np.random.Generator + noise_scale: float + def __init__(self, model: LennardJonesModel, noise_scale: float = 1e-4) -> None: self.model = model self.rng = np.random.default_rng(seed=1) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 4979b8286..c7c8d5e50 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -136,6 +136,7 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: """ @classmethod + @abstractmethod def merge(cls, constraints: list[Self]) -> Self: """Merge multiple already-reindexed constraints into one. @@ -145,10 +146,13 @@ def merge(cls, constraints: list[Self]) -> Self: Args: constraints: Constraints to merge (all same type, already reindexed) """ - raise NotImplementedError( - f"Constraint type {cls.__name__} does not implement merge. " - "Override this method to support state concatenation." - ) + + +def _cumsum_with_zero(tensor: torch.Tensor) -> torch.Tensor: + """Cumulative sum with a leading zero, e.g. [3, 2, 4] -> [0, 3, 5, 9].""" + return torch.cat( + [torch.zeros(1, device=tensor.device, dtype=tensor.dtype), tensor.cumsum(dim=0)] + ) def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: @@ -372,22 +376,12 @@ def merge_constraints( # Calculate cumulative offsets for atoms and systems device, dtype = num_atoms_per_state.device, num_atoms_per_state.dtype - atom_offsets = torch.cat( - [ - torch.zeros(1, device=device, dtype=dtype), - torch.cumsum(num_atoms_per_state[:-1], dim=0), - ] - ) + atom_offsets = _cumsum_with_zero(num_atoms_per_state[:-1]) if num_systems_per_state is None: num_systems_per_state = torch.ones( len(constraint_lists), device=device, dtype=dtype ) - system_offsets = torch.cat( - [ - torch.zeros(1, device=device, dtype=dtype), - torch.cumsum(num_systems_per_state[:-1], dim=0), - ] - ) + system_offsets = _cumsum_with_zero(num_systems_per_state[:-1]) # Reindex each constraint to global coordinates, then group by type grouped: dict[type[Constraint], list[Constraint]] = defaultdict(list) @@ -777,13 +771,7 @@ def from_state( from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry rotations, symm_maps = [], [] - n_per = state.n_atoms_per_system - cumsum = torch.cat( - [ - torch.zeros(1, device=state.device, dtype=torch.long), - torch.cumsum(n_per, dim=0), - ] - ) + cumsum = _cumsum_with_zero(state.n_atoms_per_system) for sys_idx in range(state.n_systems): start, end = cumsum[sys_idx].item(), cumsum[sys_idx + 1].item() @@ -862,9 +850,9 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: for ci, si in enumerate(self.system_idx): cur_cell = state.row_vector_cell[si] new_row = new_cell[si].mT # column → row convention - deform_delta = torch.linalg.inv(cur_cell) @ new_row - identity + deform_delta = torch.linalg.solve(cur_cell, new_row) - identity max_delta = torch.abs(deform_delta).max().item() - if max_delta > 0.25: + if not (max_delta <= 0.25): # catches NaN via negated comparison raise RuntimeError( f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " f"too large. Use smaller optimization steps." @@ -877,12 +865,7 @@ def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: """Symmetrize a rank-1 tensor in-place for each constrained system.""" from torch_sim.symmetrize import symmetrize_rank1 - cumsum = torch.cat( - [ - torch.zeros(1, device=state.device, dtype=torch.long), - torch.cumsum(state.n_atoms_per_system, dim=0), - ] - ) + cumsum = _cumsum_with_zero(state.n_atoms_per_system) dtype = vectors.dtype for ci, si in enumerate(self.system_idx): start, end = cumsum[si].item(), cumsum[si + 1].item() @@ -902,8 +885,8 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 """Return copy with system indices shifted by system_offset.""" return type(self)( - self.rotations, - self.symm_maps, + list(self.rotations), + list(self.symm_maps), self.system_idx + system_offset, adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, @@ -914,6 +897,15 @@ def merge(cls, constraints: list[Self]) -> Self: """Merge by concatenating rotations, symm_maps, and system indices.""" if not constraints: raise ValueError("Cannot merge empty constraint list") + if any( + c.do_adjust_positions != constraints[0].do_adjust_positions + or c.do_adjust_cell != constraints[0].do_adjust_cell + for c in constraints[1:] + ): + raise ValueError( + "Cannot merge FixSymmetry constraints with different " + "adjust_positions/adjust_cell settings" + ) rotations = [r for c in constraints for r in c.rotations] symm_maps = [s for c in constraints for s in c.symm_maps] system_idx = torch.cat([c.system_idx for c in constraints]) @@ -935,10 +927,10 @@ def select_constraint( mask = torch.isin(self.system_idx, keep) if not mask.any(): return None - indices = [idx for idx in range(len(mask)) if mask[idx]] + local_idx = mask.nonzero(as_tuple=False).flatten().tolist() return type(self)( - [self.rotations[idx] for idx in indices], - [self.symm_maps[idx] for idx in indices], + [self.rotations[idx] for idx in local_idx], + [self.symm_maps[idx] for idx in local_idx], _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 1067296d3..60ac5a1cc 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -65,6 +65,8 @@ def _get_constrained_stress( model_output: dict[str, torch.Tensor], state: SimState ) -> torch.Tensor: """Clone stress from model output and apply constraint symmetrization.""" + if not state.constraints: + return model_output["stress"] stress = model_output["stress"].clone() for constraint in state.constraints: constraint.adjust_stress(state, stress) diff --git a/torch_sim/symmetrize.py b/torch_sim/symmetrize.py index 3a49de311..8677e8f05 100644 --- a/torch_sim/symmetrize.py +++ b/torch_sim/symmetrize.py @@ -83,12 +83,14 @@ def build_symmetry_map( if n_atoms <= _SYMM_MAP_CHUNK_THRESHOLD: # Vectorized: allocates (n_ops, n_atoms, n_atoms, 3) — fast for small systems + # einsum computes R[o] @ frac[n] for all (o, n) pairs at once new_pos = torch.einsum("oij,nj->oni", rotations, frac_pos) + translations[:, None] delta = frac_pos[None, None] - new_pos[:, :, None] delta -= delta.round() return torch.argmin(torch.linalg.norm(delta, dim=-1), dim=-1).long() # Per-op loop: allocates only (n_atoms, n_atoms, 3) at a time + # Equivalent to vectorized path: frac @ R.T == R @ frac per row result = torch.empty(n_ops, n_atoms, dtype=torch.long, device=frac_pos.device) for op_idx in range(n_ops): new_pos_op = frac_pos @ rotations[op_idx].T + translations[op_idx] @@ -138,9 +140,9 @@ def _refine_symmetry_impl( def _mat_sqrt(mat: torch.Tensor) -> torch.Tensor: evals, evecs = torch.linalg.eigh(mat) - return evecs @ torch.diag(evals.sqrt()) @ evecs.T + return evecs @ torch.diag(evals.clamp(min=0).sqrt()) @ evecs.T - new_cell = _mat_sqrt(metric_sym) @ torch.linalg.inv(_mat_sqrt(metric)) @ cell + new_cell = _mat_sqrt(metric_sym) @ torch.linalg.solve(_mat_sqrt(metric), cell) # Symmetrize positions via displacement averaging over symmetry orbits new_frac = positions @ torch.linalg.inv(new_cell)