diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes_framework.py similarity index 68% rename from tests/models/test_graphpes.py rename to tests/models/test_graphpes_framework.py index 276b0fc54..44ddadfe2 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes_framework.py @@ -3,6 +3,7 @@ import pytest import torch from ase.build import bulk, molecule +from ase.calculators.calculator import Calculator import torch_sim as ts from tests.conftest import DEVICE @@ -16,8 +17,7 @@ try: from graph_pes.atomic_graph import AtomicGraph, to_batch - from graph_pes.interfaces import mace_mp - from graph_pes.models import LennardJones, SchNet, TensorNet, ZEmbeddingNequIP + from graph_pes.models import LennardJones, SchNet, TensorNet except ImportError: pytest.skip( f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True @@ -122,90 +122,34 @@ def test_graphpes_dtype(dtype: torch.dtype): assert ts_output["forces"].dtype == dtype -_nequip_model = ZEmbeddingNequIP() - - @pytest.fixture -def ts_nequip_model(): - return GraphPESWrapper( - _nequip_model, device=DEVICE, dtype=DTYPE, compute_stress=False - ) +def graphpes_lj_model() -> LennardJones: + return LennardJones(sigma=0.5) @pytest.fixture -def ase_nequip_calculator(): - return _nequip_model.to(DEVICE, DTYPE).ase_calculator(skin=0.0) - - -test_graphpes_nequip_consistency = make_model_calculator_consistency_test( - test_name="graphpes-nequip", - model_fixture_name="ts_nequip_model", - calculator_fixture_name="ase_nequip_calculator", - sim_state_names=CONSISTENCY_SIMSTATES, - device=DEVICE, - dtype=DTYPE, - energy_rtol=1e-3, - energy_atol=1e-3, - force_rtol=1e-3, - force_atol=1e-3, - stress_rtol=1e-3, - stress_atol=1e-3, -) - -test_graphpes_nequip_model_outputs = make_validate_model_outputs_test( - model_fixture_name="ts_nequip_model", device=DEVICE, dtype=DTYPE -) - - -@pytest.fixture -def ts_mace_model(): +def ts_lj_model(graphpes_lj_model: LennardJones) -> GraphPESWrapper: return GraphPESWrapper( - mace_mp("medium-mpa-0"), - device=DEVICE, - dtype=DTYPE, - compute_stress=False, + graphpes_lj_model, device=DEVICE, dtype=DTYPE, compute_stress=False ) @pytest.fixture -def ase_mace_calculator(): - return mace_mp("medium-mpa-0").to(DEVICE, DTYPE).ase_calculator(skin=0.0) +def ase_lj_calculator(graphpes_lj_model: LennardJones) -> Calculator: + return graphpes_lj_model.to(DEVICE, DTYPE).ase_calculator(skin=0.0) -test_graphpes_mace_consistency = make_model_calculator_consistency_test( - test_name="graphpes-mace", - model_fixture_name="ts_mace_model", - calculator_fixture_name="ase_mace_calculator", +test_graphpes_consistency = make_model_calculator_consistency_test( + test_name="graphpes-lj", + model_fixture_name="ts_lj_model", + calculator_fixture_name="ase_lj_calculator", sim_state_names=CONSISTENCY_SIMSTATES, device=DEVICE, dtype=DTYPE, ) -test_graphpes_mace_model_outputs = make_validate_model_outputs_test( - model_fixture_name="ts_mace_model", - device=DEVICE, - dtype=DTYPE, -) - - -_lj_model = LennardJones(sigma=0.5) - - -@pytest.fixture -def ts_lj_model(): - return GraphPESWrapper(_lj_model, device=DEVICE, dtype=DTYPE, compute_stress=False) - - -@pytest.fixture -def ase_lj_calculator(): - return _lj_model.to(DEVICE, DTYPE).ase_calculator(skin=0.0) - - -test_graphpes_lj_consistency = make_model_calculator_consistency_test( - test_name="graphpes-lj", - model_fixture_name="ts_lj_model", - calculator_fixture_name="ase_lj_calculator", - sim_state_names=CONSISTENCY_SIMSTATES, +test_graphpes_model_outputs = make_validate_model_outputs_test( + model_fixture_name="graphpes_lj_model", device=DEVICE, dtype=DTYPE, ) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index e380cab44..de62d1fa5 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -1,187 +1,22 @@ -"""An interface for using arbitrary GraphPESModels in ts. +"""Deprecated module for importing GraphPESWrapper, AtomicGraph, and GraphPESModel. -This module provides a TorchSim wrapper of the GraphPES models for computing -energies, forces, and stresses of atomistic systems. It serves as a wrapper around -the graph_pes library, integrating it with the torch-sim framework to enable seamless -simulation of atomistic systems with machine learning potentials. - -The GraphPESWrapper class adapts GraphPESModels to the ModelInterface protocol, -allowing them to be used within the broader torch-sim simulation framework. - -Notes: - This implementation requires graph_pes to be installed and accessible. - It supports various model configurations through model instances or model paths. +This module is deprecated. Please use the ts.models.graphpes_framework module instead. """ -import traceback import warnings -from pathlib import Path -from typing import Any - -import torch - -import torch_sim as ts -from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torchsim_nl -from torch_sim.typing import StateDict - - -try: - from graph_pes import AtomicGraph, GraphPESModel - from graph_pes.atomic_graph import PropertyKey, to_batch - from graph_pes.models import load_model - -except ImportError as exc: - warnings.warn(f"GraphPES import failed: {traceback.format_exc()}", stacklevel=2) - PropertyKey = str - - class GraphPESWrapper(ModelInterface): # type: ignore[reportRedeclaration] - """GraphPESModel wrapper for torch-sim. - - This class is a placeholder for the GraphPESWrapper class. - It raises an ImportError if graph_pes is not installed. - """ - - def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: - """Dummy init for type checking.""" - raise err - - class AtomicGraph: # type: ignore[reportRedeclaration] # noqa: D101 - def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107,ARG002 - raise ImportError("graph_pes must be installed to use this model.") - - class GraphPESModel(torch.nn.Module): # type: ignore[reportRedeclaration] # noqa: D101 - pass - - -def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGraph: - """Convert a SimState object into an AtomicGraph object. - - Args: - state: SimState object containing atomic positions, cell, and atomic numbers - cutoff: Cutoff radius for the neighbor list - - Returns: - AtomicGraph object representing the batched structures - """ - graphs = [] - - for sys_idx in range(state.n_systems): - system_mask = state.system_idx == sys_idx - R = state.positions[system_mask] - Z = state.atomic_numbers[system_mask] - cell = state.row_vector_cell[sys_idx] - # graph-pes models internally trim the neighbor list to the - # model's cutoff value. To ensure no strange edge effects whereby - # edges that are exactly `cutoff` long are included/excluded, - # we bump cutoff + 1e-5 up slightly - - # Create system_idx for this single system (all atoms belong to system 0) - system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device) - nl, _system_mapping, shifts = torchsim_nl( - R, cell, state.pbc, cutoff + 1e-5, system_idx_single - ) - - atomic_graph = AtomicGraph( - Z=Z.long(), - R=R, - cell=cell, - neighbour_list=nl.long(), - neighbour_cell_offsets=shifts, - properties={}, - cutoff=cutoff.item(), - other={ - "total_charge": torch.tensor(0.0).to(state.device), - "total_spin": torch.tensor(0.0).to(state.device), - }, - ) - graphs.append(atomic_graph) - - return to_batch(graphs) - - -class GraphPESWrapper(ModelInterface): - """Wrapper for GraphPESModel in TorchSim. - - This class provides a TorchSim wrapper around GraphPESModel instances, - allowing them to be used within the broader torch-sim simulation framework. - - The graph-pes package allows for the training of existing model architectures, - including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. - You can use any of these, as well as your own custom architectures, with this wrapper. - See the the graph-pes repo for more details: https://github.com/jla-gardner/graph-pes - - Args: - model: GraphPESModel instance, or a path to a model file - device: Device to run the model on - dtype: Data type for the model - compute_forces: Whether to compute forces - compute_stress: Whether to compute stress - - Example: - >>> from torch_sim.models.graphpes import GraphPESWrapper - >>> from graph_pes.models import load_model - >>> model = load_model("path/to/model.pt") - >>> wrapper = GraphPESWrapper(model) - >>> state = ts.SimState( - ... positions=torch.randn(10, 3), - ... cell=torch.eye(3), - ... atomic_numbers=torch.randint(1, 104, (10,)), - ... ) - >>> wrapper(state) - """ - - def __init__( - self, - model: GraphPESModel | str | Path, - device: torch.device | None = None, - dtype: torch.dtype = torch.float64, - *, - compute_forces: bool = True, - compute_stress: bool = True, - ) -> None: - """Initialize the GraphPESWrapper. - - Args: - model: GraphPESModel instance, or a path to a model file - device: Device to run the model on - dtype: Data type for the model - compute_forces: Whether to compute forces - compute_stress: Whether to compute stress - """ - super().__init__() - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - self._dtype = dtype - - _model = model if isinstance(model, GraphPESModel) else load_model(model) - self._gp_model = _model.to(device=self.device, dtype=self.dtype) - - self._compute_forces = compute_forces - self._compute_stress = compute_stress - - self._properties: list[PropertyKey] = ["energy"] - if self.compute_forces: - self._properties.append("forces") - if self.compute_stress: - self._properties.append("stress") - - if self._gp_model.cutoff.item() < 0.5: - self._memory_scales_with = "n_atoms" - def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: - """Forward pass for the GraphPESWrapper. +from .graphpes_framework import AtomicGraph, GraphPESModel, GraphPESWrapper - Args: - state: SimState object containing atomic positions, cell, and atomic numbers - Returns: - Dictionary containing the computed energies, forces, and stresses - (where applicable) - """ - if not isinstance(state, ts.SimState): - state = ts.SimState(**state) # type: ignore[arg-type] +warnings.warn( + "Importing from the ts.models.graphpes module is deprecated. " + "Please use the ts.models.graphpes_framework module instead.", + DeprecationWarning, + stacklevel=2, +) - atomic_graph = state_to_atomic_graph(state, self._gp_model.cutoff) - return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] +__all__ = [ + "AtomicGraph", + "GraphPESModel", + "GraphPESWrapper", +] diff --git a/torch_sim/models/graphpes_framework.py b/torch_sim/models/graphpes_framework.py new file mode 100644 index 000000000..e380cab44 --- /dev/null +++ b/torch_sim/models/graphpes_framework.py @@ -0,0 +1,187 @@ +"""An interface for using arbitrary GraphPESModels in ts. + +This module provides a TorchSim wrapper of the GraphPES models for computing +energies, forces, and stresses of atomistic systems. It serves as a wrapper around +the graph_pes library, integrating it with the torch-sim framework to enable seamless +simulation of atomistic systems with machine learning potentials. + +The GraphPESWrapper class adapts GraphPESModels to the ModelInterface protocol, +allowing them to be used within the broader torch-sim simulation framework. + +Notes: + This implementation requires graph_pes to be installed and accessible. + It supports various model configurations through model instances or model paths. +""" + +import traceback +import warnings +from pathlib import Path +from typing import Any + +import torch + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import torchsim_nl +from torch_sim.typing import StateDict + + +try: + from graph_pes import AtomicGraph, GraphPESModel + from graph_pes.atomic_graph import PropertyKey, to_batch + from graph_pes.models import load_model + +except ImportError as exc: + warnings.warn(f"GraphPES import failed: {traceback.format_exc()}", stacklevel=2) + PropertyKey = str + + class GraphPESWrapper(ModelInterface): # type: ignore[reportRedeclaration] + """GraphPESModel wrapper for torch-sim. + + This class is a placeholder for the GraphPESWrapper class. + It raises an ImportError if graph_pes is not installed. + """ + + def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: + """Dummy init for type checking.""" + raise err + + class AtomicGraph: # type: ignore[reportRedeclaration] # noqa: D101 + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D107,ARG002 + raise ImportError("graph_pes must be installed to use this model.") + + class GraphPESModel(torch.nn.Module): # type: ignore[reportRedeclaration] # noqa: D101 + pass + + +def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGraph: + """Convert a SimState object into an AtomicGraph object. + + Args: + state: SimState object containing atomic positions, cell, and atomic numbers + cutoff: Cutoff radius for the neighbor list + + Returns: + AtomicGraph object representing the batched structures + """ + graphs = [] + + for sys_idx in range(state.n_systems): + system_mask = state.system_idx == sys_idx + R = state.positions[system_mask] + Z = state.atomic_numbers[system_mask] + cell = state.row_vector_cell[sys_idx] + # graph-pes models internally trim the neighbor list to the + # model's cutoff value. To ensure no strange edge effects whereby + # edges that are exactly `cutoff` long are included/excluded, + # we bump cutoff + 1e-5 up slightly + + # Create system_idx for this single system (all atoms belong to system 0) + system_idx_single = torch.zeros(R.shape[0], dtype=torch.long, device=R.device) + nl, _system_mapping, shifts = torchsim_nl( + R, cell, state.pbc, cutoff + 1e-5, system_idx_single + ) + + atomic_graph = AtomicGraph( + Z=Z.long(), + R=R, + cell=cell, + neighbour_list=nl.long(), + neighbour_cell_offsets=shifts, + properties={}, + cutoff=cutoff.item(), + other={ + "total_charge": torch.tensor(0.0).to(state.device), + "total_spin": torch.tensor(0.0).to(state.device), + }, + ) + graphs.append(atomic_graph) + + return to_batch(graphs) + + +class GraphPESWrapper(ModelInterface): + """Wrapper for GraphPESModel in TorchSim. + + This class provides a TorchSim wrapper around GraphPESModel instances, + allowing them to be used within the broader torch-sim simulation framework. + + The graph-pes package allows for the training of existing model architectures, + including SchNet, PaiNN, MACE, NequIP, TensorNet, EDDP and more. + You can use any of these, as well as your own custom architectures, with this wrapper. + See the the graph-pes repo for more details: https://github.com/jla-gardner/graph-pes + + Args: + model: GraphPESModel instance, or a path to a model file + device: Device to run the model on + dtype: Data type for the model + compute_forces: Whether to compute forces + compute_stress: Whether to compute stress + + Example: + >>> from torch_sim.models.graphpes import GraphPESWrapper + >>> from graph_pes.models import load_model + >>> model = load_model("path/to/model.pt") + >>> wrapper = GraphPESWrapper(model) + >>> state = ts.SimState( + ... positions=torch.randn(10, 3), + ... cell=torch.eye(3), + ... atomic_numbers=torch.randint(1, 104, (10,)), + ... ) + >>> wrapper(state) + """ + + def __init__( + self, + model: GraphPESModel | str | Path, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, + compute_stress: bool = True, + ) -> None: + """Initialize the GraphPESWrapper. + + Args: + model: GraphPESModel instance, or a path to a model file + device: Device to run the model on + dtype: Data type for the model + compute_forces: Whether to compute forces + compute_stress: Whether to compute stress + """ + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + + _model = model if isinstance(model, GraphPESModel) else load_model(model) + self._gp_model = _model.to(device=self.device, dtype=self.dtype) + + self._compute_forces = compute_forces + self._compute_stress = compute_stress + + self._properties: list[PropertyKey] = ["energy"] + if self.compute_forces: + self._properties.append("forces") + if self.compute_stress: + self._properties.append("stress") + + if self._gp_model.cutoff.item() < 0.5: + self._memory_scales_with = "n_atoms" + + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + """Forward pass for the GraphPESWrapper. + + Args: + state: SimState object containing atomic positions, cell, and atomic numbers + + Returns: + Dictionary containing the computed energies, forces, and stresses + (where applicable) + """ + if not isinstance(state, ts.SimState): + state = ts.SimState(**state) # type: ignore[arg-type] + + atomic_graph = state_to_atomic_graph(state, self._gp_model.cutoff) + return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value]