Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions examples/scripts/8_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""Scaling benchmarks for static, relax, NVE, and NVT."""

# %%
# /// script
# dependencies = [
# "torch_sim_atomistic[mace,test]"
# ]
# ///

import time
import typing

import torch
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp
from pymatgen.io.ase import AseAtomsAdaptor

import torch_sim as ts
from torch_sim.models.mace import MaceModel, MaceUrls


# Shared constants
N_STRUCTURES_STATIC = [1, 1, 1, 1, 10, 100, 500, 1000, 2500, 5000]
N_STRUCTURES_RELAX = [1, 10, 100, 500]
N_STRUCTURES_NVE = [1, 10, 100, 500]
N_STRUCTURES_NVT = [1, 10, 100, 500]
RELAX_STEPS = 10
MD_STEPS = 10
MAX_MEMORY_SCALER = 400_000
MEMORY_SCALES_WITH = "n_atoms_x_density"


def load_mace_model(device: torch.device) -> MaceModel:
"""Load MACE model for benchmarking."""
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
return_raw_model=True,
default_dtype="float64",
device=str(device),
)
return MaceModel(
model=typing.cast("torch.nn.Module", loaded_model),
device=device,
compute_forces=True,
compute_stress=True,
dtype=torch.float64,
enable_cueq=False,
)


def run_torchsim_static(
n_structures_list: list[int],
base_structure: typing.Any,
model: MaceModel,
device: torch.device,
) -> list[float]:
"""Run static calculations for each n using batched path, return timings."""
autobatcher = ts.BinningAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
)
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.static(structures, model, autobatcher=autobatcher)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} static_time={elapsed:.6f}s")
return times


def run_torchsim_relax(
n_structures_list: list[int],
base_structure: typing.Any,
model: MaceModel,
device: torch.device,
) -> list[float]:
"""Run relaxation with ts.optimize for each n; return timings."""
autobatcher = ts.InFlightAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
)
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.optimize(
system=structures,
model=model,
optimizer=ts.optimizers.Optimizer.fire,
init_kwargs={
"cell_filter": ts.optimizers.cell_filters.CellFilter.frechet,
"constant_volume": False,
"hydrostatic_strain": True,
},
max_steps=RELAX_STEPS,
convergence_fn=ts.runners.generate_force_convergence_fn(
force_tol=1e-3,
include_cell_forces=True,
),
autobatcher=autobatcher,
)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} relax_{RELAX_STEPS}_time={elapsed:.6f}s")
return times


def run_torchsim_nve(
n_structures_list: list[int],
base_structure: typing.Any,
model: MaceModel,
device: torch.device,
) -> list[float]:
"""Run NVE MD for MD_STEPS per n; return times."""
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.integrate(
system=structures,
model=model,
integrator=ts.Integrator.nve,
n_steps=MD_STEPS,
temperature=300.0,
timestep=0.002,
autobatcher=ts.BinningAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
),
)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} nve_time={elapsed:.6f}s")
return times


def run_torchsim_nvt(
n_structures_list: list[int],
base_structure: typing.Any,
model: MaceModel,
device: torch.device,
) -> list[float]:
"""Run NVT (Nose-Hoover) MD for MD_STEPS per n; return times."""
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.integrate(
system=structures,
model=model,
integrator=ts.Integrator.nvt_nose_hoover,
n_steps=MD_STEPS,
temperature=300.0,
timestep=0.002,
autobatcher=ts.BinningAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
),
)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} nvt_time={elapsed:.6f}s")
return times


if __name__ == "__main__":
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True)
base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase)

# Load model once
model = load_mace_model(device)

# Run all benchmarks
print("=== Static benchmark ===")
static_times = run_torchsim_static(N_STRUCTURES_STATIC, base_structure, model, device)

print("\n=== Relax benchmark ===")
relax_times = run_torchsim_relax(N_STRUCTURES_RELAX, base_structure, model, device)

print("\n=== NVE benchmark ===")
nve_times = run_torchsim_nve(N_STRUCTURES_NVE, base_structure, model, device)

print("\n=== NVT benchmark ===")
nvt_times = run_torchsim_nvt(N_STRUCTURES_NVT, base_structure, model, device)
20 changes: 10 additions & 10 deletions examples/tutorials/autobatching_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# %%
import torch
from torch_sim.autobatching import calculate_memory_scaler
from torch_sim.autobatching import calculate_memory_scalers
from ase.build import bulk


Expand All @@ -63,14 +63,14 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64)

# Calculate memory scaling factor based on atom count
atom_metric = calculate_memory_scaler(state, memory_scales_with="n_atoms")
# Calculate memory scaling factor based on atom count (returns list, one per system)
atom_metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms")

# Calculate memory scaling based on atom count and density
density_metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density")
density_metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density")

print(f"Atom-based memory metric: {atom_metric}")
print(f"Density-based memory metric: {density_metric:.2f}")
print(f"Atom-based memory metrics: {atom_metrics}")
print(f"Density-based memory metrics: {[f'{m:.2f}' for m in density_metrics]}")


# %% [markdown]
Expand All @@ -95,11 +95,11 @@
mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(model=mace, device=device)

state_list = state.split()
memory_metric_values = [
calculate_memory_scaler(s, memory_scales_with="n_atoms") for s in state_list
]
# calculate_memory_scalers returns a list with one value per system in the state
memory_metric_values = calculate_memory_scalers(state, memory_scales_with="n_atoms")

# estimate_max_memory_scaler needs a list of individual states
state_list = state.split()
max_memory_metric = estimate_max_memory_scaler(
state_list, mace_model, metric_values=memory_metric_values
)
Expand Down
32 changes: 20 additions & 12 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_sim.autobatching import (
BinningAutoBatcher,
InFlightAutoBatcher,
calculate_memory_scaler,
calculate_memory_scalers,
determine_max_batch_size,
to_constant_volume_bins,
)
Expand Down Expand Up @@ -93,31 +93,39 @@ def test_bounds_and_tuples():
def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None:
"""Test calculation of scaling metrics for a state."""
# Test n_atoms metric
n_atoms_metric = calculate_memory_scaler(si_sim_state, "n_atoms")
assert n_atoms_metric == si_sim_state.n_atoms
n_atoms_metric = calculate_memory_scalers(si_sim_state, "n_atoms")
assert n_atoms_metric == [si_sim_state.n_atoms]

# Test n_atoms_x_density metric
density_metric = calculate_memory_scaler(si_sim_state, "n_atoms_x_density")
density_metric = calculate_memory_scalers(si_sim_state, "n_atoms_x_density")
volume = torch.abs(torch.linalg.det(si_sim_state.cell[0])) / 1000
expected = si_sim_state.n_atoms * (si_sim_state.n_atoms / volume.item())
assert pytest.approx(density_metric, rel=1e-5) == expected
assert pytest.approx(density_metric[0], rel=1e-5) == expected

# Test invalid metric
with pytest.raises(ValueError, match="Invalid metric"):
calculate_memory_scaler(si_sim_state, "invalid_metric")
calculate_memory_scalers(si_sim_state, "invalid_metric")


def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None:
"""Test calculation of scaling metrics for a non-periodic state."""
# Test that calculate passes
n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms")
assert n_atoms_metric == benzene_sim_state.n_atoms
n_atoms_metric = calculate_memory_scalers(benzene_sim_state, "n_atoms")
assert n_atoms_metric == [benzene_sim_state.n_atoms]

# Test n_atoms_x_density metric works for non-periodic systems
n_atoms_x_density_metric = calculate_memory_scaler(
n_atoms_x_density_metric = calculate_memory_scalers(
benzene_sim_state, "n_atoms_x_density"
)
assert n_atoms_x_density_metric > 0
assert n_atoms_x_density_metric[0] > 0
bbox = (
benzene_sim_state.positions.max(dim=0).values
- benzene_sim_state.positions.min(dim=0).values
).clone()
for i, p in enumerate(benzene_sim_state.pbc):
if not p:
bbox[i] += 2.0
assert pytest.approx(n_atoms_x_density_metric[0], rel=1e-5) == (
benzene_sim_state.n_atoms**2 / (bbox.prod().item() / 1000)
)


def test_split_state(si_double_sim_state: ts.SimState) -> None:
Expand Down
Loading
Loading