Accelerated ts.static + added scaling scripts#427
Accelerated ts.static + added scaling scripts#427falletta wants to merge 9 commits intoTorchSim:mainfrom
Conversation
torch_sim/autobatching.py
Outdated
| bbox[i] += 2.0 | ||
| volume = bbox.prod() / 1000 # convert A^3 to nm^3 | ||
| number_density = state.n_atoms / volume.item() | ||
| # Use cell volume (O(1)); SimState always has a cell. Avoids O(N) position scan. |
There was a problem hiding this comment.
non-periodic systems don't have a sensible cell, see #412
There was a problem hiding this comment.
I now minimized the differences compared to the initial code
There was a problem hiding this comment.
In addition, I added explicit tests for the memory scaler values and verified that the changes in this PR do not affect the test’s success
torch_sim/autobatching.py
Outdated
| self.memory_scalers = calculate_batched_memory_scalers( | ||
| states, self.memory_scales_with | ||
| ) | ||
| self.state_slices = states.split() |
There was a problem hiding this comment.
batching makes sense here
torch_sim/autobatching.py
Outdated
| if isinstance(states, SimState): | ||
| self.batched_states = [[states[index_bin]] for index_bin in self.index_bins] |
There was a problem hiding this comment.
state.split() is identical to this and faster
There was a problem hiding this comment.
Reusing self.state_slices instead of calling states.split() again makes the code 5% faster, so I'd keep it
3138aed to
e91fe92
Compare
torch_sim/autobatching.py
Outdated
| ) | ||
| self.state_slices = states.split() | ||
| else: | ||
| self.state_slices = states |
There was a problem hiding this comment.
why not concat and then called the batched logic?
There was a problem hiding this comment.
In the if branch, the input is already a single batched SimState, so we call calculate_batched_memory_scalers and then split() once to get state_slices. No concatenation is needed.
In the else branch, the input is a list of states, so we keep state_slices = states and compute scalers per state. We avoid concatenating and using the batched path, since that would require a concat followed by a split(), resulting in extra passes and higher peak memory for the same outcome.
torch_sim/state.py
Outdated
|
|
||
| def split(self) -> list[Self]: | ||
| """Split the SimState into a list of single-system SimStates. | ||
| def split(self) -> Sequence[Self]: # noqa: C901 |
There was a problem hiding this comment.
this whole code block looks hard to understand/maintain. why did the external _split_state functional programming method need to be removed?
This looks like it just creates an efficient slicing iter, can we break those parts out more cleanly in the same functional pattern as before? i.e. _get_system_slice(sim_state: SimState, i: int) as a function
There was a problem hiding this comment.
Ok I got rid of the intermediate class and made the code more functional
a260cc0 to
2ba34e0
Compare
|
I’ve implemented the revisions and added an additional optimization to the state creation. Please let me know if you need any further edits. |
|
I will give another review in the next few days, thanks @Fallett! |
00da901 to
da23a55
Compare
149af42 to
413b9cd
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
413b9cd to
698044a
Compare
|
Thank you @orionarcher! I made further edits to follow the original logic more closely. Among the changes, I generalized the existing |
Co-authored-by: Cursor <cursoragent@cursor.com>
0302a94 to
9bf6a56
Compare
orionarcher
left a comment
There was a problem hiding this comment.
Thanks @falletta! I am happy to have the batched memory scaler calculations and associated speedup!
A few thoughts:
- The batched memory scaler calculation seems quite reasonable and is a great addition to this PR.
- The changes to state manipulation don't actually seem to save any time, they just shift from eager to lazy evaluation. In the absense of concrete benchmarking data about those changes specifically, I'd favor removing the changes.
- There are several changes that seem arbitrary and unrelated to the main purpose of the PR. I generally support vibe coded contributions but it shouldn't be just the reviewers responsibility to catch random unrelated changes. For example, you introduce a breaking change to the API of
initialize_statesfor no apparent reason. This seems to get past the tests but could easily disrupt downstream behavior.
examples/scaling/scaling_nve.py
Outdated
| @@ -0,0 +1,76 @@ | |||
| """Scaling for TorchSim NVE.""" | |||
There was a problem hiding this comment.
These will all run in CI, so maybe we could put them all in a single script instead of four to speed up testing?
There was a problem hiding this comment.
Actually, these tests should not run in CI given the changes in .github/workflows/test.yml. In their current form, these are not really intended as tests (there are no assert statements), but rather as performance checks. I think keeping them separate makes it easier to isolate the scaling behavior we want to examine for potential optimizations.
That said, we could turn these into actual tests and include them in the tests repository, for example by checking that the timings stay below the current thresholds. Would you prefer that approach?
There was a problem hiding this comment.
I do worry about these decaying over time if they aren't tested. The reason to run all scripts in CI is because it makes sure they don't break over time. We shouldn't include them as tests but they should be runnable scripts like the other scripts we have.
There was a problem hiding this comment.
Right, good point. I merged all the tests into examples/scripts/8.scaling.py and set the maximum number of structures to relatively small values so the test runs quickly.
tests/test_autobatching.py
Outdated
| n_atoms_metric = calculate_memory_scaler(si_sim_state, "n_atoms") | ||
| assert n_atoms_metric == si_sim_state.n_atoms | ||
| assert n_atoms_metric == [si_sim_state.n_atoms] | ||
| assert n_atoms_metric == [8] |
There was a problem hiding this comment.
let's get rid of all the hardcoded values in the tests, it just makes them more brittle
torch_sim/autobatching.py
Outdated
| volume = bbox.prod() / 1000 | ||
| scalers.append(s.n_atoms * (s.n_atoms / volume.item())) | ||
| return scalers | ||
| raise ValueError(f"Invalid metric: {memory_scales_with}") |
There was a problem hiding this comment.
the new error value just strips important context, why change it?
torch_sim/autobatching.py
Outdated
| vol = torch.abs(state.volume) / 1000 # A^3 -> nm^3 | ||
| return torch.where(vol > 0, n * n / vol, n).tolist() |
There was a problem hiding this comment.
why is the variable volume elsewhere but vol here?
would also change n -> n_atoms
torch_sim/autobatching.py
Outdated
| @@ -325,58 +325,39 @@ def determine_max_batch_size( | |||
| def calculate_memory_scaler( | |||
There was a problem hiding this comment.
I suggest changing the name to calculate_memory_scalers since this now returns a list
| n_systems_val = 1 | ||
| else: # assert that system indices are unique consecutive integers | ||
| _, counts = torch.unique_consecutive(initial_system_idx, return_counts=True) | ||
| n_systems_val = len(counts) | ||
| if not torch.all(counts == torch.bincount(initial_system_idx)): | ||
| raise ValueError("System indices must be unique consecutive integers") | ||
|
|
||
| if self.constraints: | ||
| validate_constraints(self.constraints, state=self) | ||
|
|
||
| if self.charge is None: | ||
| self.charge = torch.zeros( | ||
| self.n_systems, device=self.device, dtype=self.dtype | ||
| ) | ||
| elif self.charge.shape[0] != self.n_systems: | ||
| raise ValueError(f"Charge must have shape (n_systems={self.n_systems},)") | ||
| self.charge = torch.zeros(n_systems_val, device=self.device, dtype=self.dtype) | ||
| elif self.charge.shape[0] != n_systems_val: | ||
| raise ValueError(f"Charge must have shape (n_systems={n_systems_val},)") | ||
| if self.spin is None: | ||
| self.spin = torch.zeros(self.n_systems, device=self.device, dtype=self.dtype) |
There was a problem hiding this comment.
why is any of this logic changed? seems unrelated to rest of PR
There was a problem hiding this comment.
The n_systems property calls torch.unique() every time it's accessed. The old code called this property 4-5 times during __post_init__. The new code compute the value once into a local variable (n_systems_val) and reuses it, avoiding redundant unique() calls. This speeds up state initialization, especially for large batches. It’s a minor optimization, but I think it's still relevant.
| def _split_state[T: SimState](state: T) -> Sequence[T]: | ||
| """Return a lazy Sequence view of state split into single-system states. | ||
|
|
||
| Each single-system state is created on first access, so the call is O(1). |
There was a problem hiding this comment.
I guess I don't really follow how this is faster. It just shifts the cost to later down the line by making the evaluation lazy instead of eager, which isn't necessarily better given the code is harder to follow. Could you explain how this is saving time?
There was a problem hiding this comment.
The eager implementation creates all N SimState objects upfront by running torch.split() across every attribute, which is O(N) work regardless of how many states are actually used. The lazy implementation only computes a cumsum (effectively O(1)) and defers creating states until __getitem__ is called.
This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.
Benchmarking with n=10000 systems shows the lazy version finishes in 9.89 s vs 11.37 s for eager (~15% faster). If all states are accessed, the total work is basically the same, but the common autobatching paths don’t require all states to be materialized at once.
There was a problem hiding this comment.
This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.
I don't follow? How can you get the argmin over values that haven't been evaluated? Don't you need to evaluate to know the minimum?
There was a problem hiding this comment.
Could you point me to where the eager and lazy are benchmarked against eachother?
There was a problem hiding this comment.
Please find the benchmark script below. The script defines an OldState(SimState) class using the old _split_state method and produces the results shown. Alternatively, you can run 8.scaling.py with the current version of the code and with the split method reverted, and you’ll get the same results.
=== Comparison ===
n= 1: eager=2.387301s, lazy=0.026688s, speedup=89.45x
n= 1: eager=0.275248s, lazy=0.022556s, speedup=12.20x
n= 1: eager=0.677568s, lazy=0.022692s, speedup=29.86x
n= 1: eager=0.023758s, lazy=0.021832s, speedup=1.09x
n= 10: eager=0.283295s, lazy=0.025066s, speedup=11.30x
n= 100: eager=0.700301s, lazy=0.108458s, speedup=6.46x
n= 250: eager=1.020541s, lazy=0.252591s, speedup=4.04x
n= 500: eager=0.572549s, lazy=0.506986s, speedup=1.13x
n=1000: eager=1.115220s, lazy=0.989014s, speedup=1.13x
n=1500: eager=1.665096s, lazy=1.469927s, speedup=1.13x
n=5000: eager=5.755523s, lazy=5.062751s, speedup=1.14x
n=10000: eager=10.948717s, lazy=9.643398s, speedup=1.14x
Benchmark Script
"""Test comparing OldState (eager split) vs State (lazy split) performance."""
import time
import typing
from typing import Self
from unittest.mock import patch
import torch
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp
import torch_sim as ts
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.state import SimState, get_attrs_for_scope
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float64
N_STRUCTURES = [1, 1, 1, 1, 10, 100, 250, 500, 1000, 1500, 5000, 10000]
MAX_MEMORY_SCALER = 400_000
MEMORY_SCALES_WITH = "n_atoms_x_density"
class OldState(SimState):
"""Old state representation that uses eager splitting.
This class inherits from SimState but overrides split() to use the old eager
approach where all sub-states are created upfront when split() is called.
"""
def split(self) -> list[Self]:
"""Split the OldState into a list of single-system OldStates (EAGER).
This is the OLD approach that creates ALL states upfront, which is O(n)
where n is the number of systems.
Returns:
list[OldState]: A list of OldState objects, one per system
"""
return self._split_state(self)
@staticmethod
def _split_state[T: SimState](state: T) -> list[T]:
"""Split a SimState into a list of states, each containing a single system.
Divides a multi-system state into individual single-system states, preserving
appropriate properties for each system.
Args:
state (SimState): The SimState to split
Returns:
list[SimState]: A list of SimState objects, each containing a single
system
"""
system_sizes = state.n_atoms_per_system.tolist()
split_per_atom = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"):
if attr_name != "system_idx":
split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)
split_per_system = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-system"):
if isinstance(attr_value, torch.Tensor):
split_per_system[attr_name] = torch.split(attr_value, 1, dim=0)
else: # Non-tensor attributes are replicated for each split
split_per_system[attr_name] = [attr_value] * state.n_systems
global_attrs = dict(get_attrs_for_scope(state, "global"))
# Create a state for each system
states: list[T] = []
n_systems = len(system_sizes)
zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64)
cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0)))
for sys_idx in range(n_systems):
# Build per-system attributes (padded attributes stay padded for consistency)
per_system_dict = {
attr_name: split_per_system[attr_name][sys_idx]
for attr_name in split_per_system
}
system_attrs = {
# Create a system tensor with all zeros for this system
"system_idx": torch.zeros(
system_sizes[sys_idx], device=state.device, dtype=torch.int64
),
# Add the split per-atom attributes
**{
attr_name: split_per_atom[attr_name][sys_idx]
for attr_name in split_per_atom
},
# Add the split per-system attributes (with unpadding applied)
**per_system_dict,
# Add the global attributes
**global_attrs,
}
atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1])
new_constraints = [
new_constraint
for constraint in state.constraints
if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx))
]
system_attrs["_constraints"] = new_constraints
states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type]
return states
@classmethod
def from_sim_state(cls, state: SimState) -> "OldState":
"""Create OldState from a SimState."""
return cls(
positions=state.positions.clone(),
masses=state.masses.clone(),
cell=state.cell.clone(),
pbc=state.pbc.clone() if isinstance(state.pbc, torch.Tensor) else state.pbc,
atomic_numbers=state.atomic_numbers.clone(),
charge=state.charge.clone() if state.charge is not None else None,
spin=state.spin.clone() if state.spin is not None else None,
system_idx=state.system_idx.clone() if state.system_idx is not None else None,
_constraints=state.constraints.copy(),
)
def run_torchsim_static(
n_structures_list: list[int],
base_structure,
model,
device: torch.device,
) -> list[float]:
"""Run static calculations for each n using batched path, return timings."""
autobatcher = ts.BinningAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
)
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.static(structures, model, autobatcher=autobatcher)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} static_time={elapsed:.6f}s")
return times
if __name__ == "__main__":
# Setup
mgo_atoms = bulk("MgO", crystalstructure="rocksalt", a=4.21, cubic=True)
print("Loading MACE model...")
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
return_raw_model=True,
default_dtype="float64",
device=str(DEVICE),
)
mace_model = MaceModel(
model=typing.cast("torch.nn.Module", loaded_model),
device=DEVICE,
compute_forces=True,
compute_stress=True,
dtype=DTYPE,
enable_cueq=False,
)
print("\n=== Static Benchmark Comparison: Eager vs Lazy Split ===\n")
# Run with eager split (OldState behavior)
print("Running with eager split (OldState):")
with patch.object(SimState, "split", lambda self: OldState._split_state(self)):
eager_times = run_torchsim_static(
N_STRUCTURES, mgo_atoms, mace_model, DEVICE
)
# Run with lazy split (default State behavior)
print("\nRunning with lazy split (State):")
lazy_times = run_torchsim_static(
N_STRUCTURES, mgo_atoms, mace_model, DEVICE
)
# Print comparison
print("\n=== Comparison ===")
for i, n in enumerate(N_STRUCTURES):
speedup = eager_times[i] / max(lazy_times[i], 1e-9)
print(
f" n={n:4d}: eager={eager_times[i]:.6f}s, "
f"lazy={lazy_times[i]:.6f}s, speedup={speedup:.2f}x"
)
There was a problem hiding this comment.
The argmin runs over metric_values, which is just a list of floats computed from the batched state tensors before we split anything. This gives us integer indices, and only when we call state_list[idx]we do actually create the 2 needed states.
torch_sim/state.py
Outdated
| all relevant properties. | ||
| their requested order (not natural 0,1,2 order). |
There was a problem hiding this comment.
That refers to the fact that calling _slice_state(state, [3, 1, 4]) returns a new state where the systems appear in the exact order requested (3→0, 1→1, 4→2), instead of being sorted in ascending order. I updated the docstring to better reflect this point.
| device: torch.device | None = None, | ||
| dtype: torch.dtype | None = None, | ||
| device: torch.device, | ||
| dtype: torch.dtype, |
There was a problem hiding this comment.
This is a minor fix, but I think making device and dtype required improves type checking and forces callers to be explicit. With None defaults, .to(None, None) is a no-op, so the state silently stays where it was, potentially causing device mismatches with the model later.
| return keep_state, pop_states | ||
|
|
||
|
|
||
| def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor) -> T: |
There was a problem hiding this comment.
the rewrite of this function adds a lot of code and it's unclear it makes anything faster
|
Thanks Orion for reviewing. I applied a few cosmetic changes to address your comments. Regarding your points:
|
0846289 to
3859e00
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
33f760e to
9e2f13e
Compare
Summary
Changes
Results
The figure below shows the speedup achieved for static evaluations, 10-step atomic relaxation, 10-step NVE MD, and 10-step NVT MD. The test is performed for a 8-atom cubic supercell of MgO using the
mace-mpamodel. Prior results are shown in blue, while new results are shown in red. The speedup is calculated asspeedup (%) = (baseline_time / current_time − 1) × 100. We observe that:ts.staticachieves a 52.6% speedup for 100,000 structurests.relaxachieves a 4.8% speedup for 1,500 structurests.integrate(NVE) achieves a 0.9% speedup for 10,000 structurests.integrate(NVT) achieves a 1.4% speedup for 10,000 structuresProfiling
The figure below shows a detailed performance profile. Additional optimization can be achieved by disabling the trajectory reporter when not needed, which will be addressed in a separate PR.
Comments
From the scaling plots, we can see that the timings of
ts.staticandts.integrateare all consistent with each other. Indeed:ts.static→ 85s for 100'000 evaluationsts.integrateNVE → 87s for 10'000 structures (10 MD steps each) → 87s for 100'000 evaluationsts.integrateNVT → 89s for 10'000 structures (10 MD steps each) → 89s for 100'000 evaluationsHowever, when looking at the relaxation:
ts.relax→ 63s for 1'000 structures (10 relax steps each) → 63s for 10'000 evaluations → ~630s for 100'000 evaluationsSo
ts.relaxis about 7x slower thants.staticorts.integrate. The unbatched FrechetCellFilter clearly contributes to that, and will be the focus on a separate PR.