Skip to content

Accelerated ts.static + added scaling scripts#427

Open
falletta wants to merge 9 commits intoTorchSim:mainfrom
falletta:speedup_static
Open

Accelerated ts.static + added scaling scripts#427
falletta wants to merge 9 commits intoTorchSim:mainfrom
falletta:speedup_static

Conversation

@falletta
Copy link
Contributor

@falletta falletta commented Jan 30, 2026

Summary


Changes

  • Optimized the memory scaler and split operations, providing a substantial speedup for ts.static (up to 48% for batches larger than 5000).
  • Added scaling scripts for ts.static, ts.relax, and ts.integrate (NVE, NVT) to analyze scaling performance.
  • Added tests for memory scaler values for non periodic systems.
  • Updated unit tests and tutorials for the batched version of calculate memory scaler.

Results
The figure below shows the speedup achieved for static evaluations, 10-step atomic relaxation, 10-step NVE MD, and 10-step NVT MD. The test is performed for a 8-atom cubic supercell of MgO using the mace-mpa model. Prior results are shown in blue, while new results are shown in red. The speedup is calculated as
speedup (%) = (baseline_time / current_time − 1) × 100. We observe that:

  • ts.static achieves a 52.6% speedup for 100,000 structures
  • ts.relax achieves a 4.8% speedup for 1,500 structures
  • ts.integrate (NVE) achieves a 0.9% speedup for 10,000 structures
  • ts.integrate (NVT) achieves a 1.4% speedup for 10,000 structures
Screenshot 2026-02-03 at 5 04 11 PM

Profiling
The figure below shows a detailed performance profile. Additional optimization can be achieved by disabling the trajectory reporter when not needed, which will be addressed in a separate PR.

Screenshot 2026-02-03 at 5 47 53 PM

Comments

From the scaling plots, we can see that the timings of ts.static and ts.integrate are all consistent with each other. Indeed:

  • ts.static → 85s for 100'000 evaluations
  • ts.integrate NVE → 87s for 10'000 structures (10 MD steps each) → 87s for 100'000 evaluations
  • ts.integrate NVT → 89s for 10'000 structures (10 MD steps each) → 89s for 100'000 evaluations

However, when looking at the relaxation:

  • ts.relax → 63s for 1'000 structures (10 relax steps each) → 63s for 10'000 evaluations → ~630s for 100'000 evaluations

So ts.relax is about 7x slower than ts.static or ts.integrate. The unbatched FrechetCellFilter clearly contributes to that, and will be the focus on a separate PR.

@falletta falletta marked this pull request as draft January 30, 2026 01:22
bbox[i] += 2.0
volume = bbox.prod() / 1000 # convert A^3 to nm^3
number_density = state.n_atoms / volume.item()
# Use cell volume (O(1)); SimState always has a cell. Avoids O(N) position scan.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-periodic systems don't have a sensible cell, see #412

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now minimized the differences compared to the initial code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, I added explicit tests for the memory scaler values and verified that the changes in this PR do not affect the test’s success

Comment on lines 589 to 598
self.memory_scalers = calculate_batched_memory_scalers(
states, self.memory_scales_with
)
self.state_slices = states.split()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batching makes sense here

Comment on lines 628 to 635
if isinstance(states, SimState):
self.batched_states = [[states[index_bin]] for index_bin in self.index_bins]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state.split() is identical to this and faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing self.state_slices instead of calling states.split() again makes the code 5% faster, so I'd keep it

@falletta falletta marked this pull request as ready for review January 30, 2026 15:42
)
self.state_slices = states.split()
else:
self.state_slices = states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not concat and then called the batched logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the if branch, the input is already a single batched SimState, so we call calculate_batched_memory_scalers and then split() once to get state_slices. No concatenation is needed.

In the else branch, the input is a list of states, so we keep state_slices = states and compute scalers per state. We avoid concatenating and using the batched path, since that would require a concat followed by a split(), resulting in extra passes and higher peak memory for the same outcome.


def split(self) -> list[Self]:
"""Split the SimState into a list of single-system SimStates.
def split(self) -> Sequence[Self]: # noqa: C901
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole code block looks hard to understand/maintain. why did the external _split_state functional programming method need to be removed?

This looks like it just creates an efficient slicing iter, can we break those parts out more cleanly in the same functional pattern as before? i.e. _get_system_slice(sim_state: SimState, i: int) as a function

Copy link
Contributor Author

@falletta falletta Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I got rid of the intermediate class and made the code more functional

@falletta falletta force-pushed the speedup_static branch 2 times, most recently from a260cc0 to 2ba34e0 Compare February 2, 2026 16:23
@falletta
Copy link
Contributor Author

falletta commented Feb 2, 2026

I’ve implemented the revisions and added an additional optimization to the state creation. Please let me know if you need any further edits.

@orionarcher
Copy link
Collaborator

I will give another review in the next few days, thanks @Fallett!

Co-authored-by: Cursor <cursoragent@cursor.com>
@falletta
Copy link
Contributor Author

falletta commented Feb 3, 2026

Thank you @orionarcher! I made further edits to follow the original logic more closely. Among the changes, I generalized the existing calculate_memory_scaler function to a batched version that now returns a list of floats. This required adjusting the unit tests and tutorials, but I think it’s a cleaner approach that avoids having two subroutines for the memory scaler and improves type checking. In addition, I tried to mimic as close as possible the logic for the split function, the _split_state subroutine however still requires substantial changes. Let me know your thoughts and I'm happy to keep revising the PR if necessary.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @falletta! I am happy to have the batched memory scaler calculations and associated speedup!

A few thoughts:

  1. The batched memory scaler calculation seems quite reasonable and is a great addition to this PR.
  2. The changes to state manipulation don't actually seem to save any time, they just shift from eager to lazy evaluation. In the absense of concrete benchmarking data about those changes specifically, I'd favor removing the changes.
  3. There are several changes that seem arbitrary and unrelated to the main purpose of the PR. I generally support vibe coded contributions but it shouldn't be just the reviewers responsibility to catch random unrelated changes. For example, you introduce a breaking change to the API of initialize_states for no apparent reason. This seems to get past the tests but could easily disrupt downstream behavior.

@@ -0,0 +1,76 @@
"""Scaling for TorchSim NVE."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These will all run in CI, so maybe we could put them all in a single script instead of four to speed up testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, these tests should not run in CI given the changes in .github/workflows/test.yml. In their current form, these are not really intended as tests (there are no assert statements), but rather as performance checks. I think keeping them separate makes it easier to isolate the scaling behavior we want to examine for potential optimizations.

That said, we could turn these into actual tests and include them in the tests repository, for example by checking that the timings stay below the current thresholds. Would you prefer that approach?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do worry about these decaying over time if they aren't tested. The reason to run all scripts in CI is because it makes sure they don't break over time. We shouldn't include them as tests but they should be runnable scripts like the other scripts we have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, good point. I merged all the tests into examples/scripts/8.scaling.py and set the maximum number of structures to relatively small values so the test runs quickly.

n_atoms_metric = calculate_memory_scaler(si_sim_state, "n_atoms")
assert n_atoms_metric == si_sim_state.n_atoms
assert n_atoms_metric == [si_sim_state.n_atoms]
assert n_atoms_metric == [8]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get rid of all the hardcoded values in the tests, it just makes them more brittle

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

volume = bbox.prod() / 1000
scalers.append(s.n_atoms * (s.n_atoms / volume.item()))
return scalers
raise ValueError(f"Invalid metric: {memory_scales_with}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new error value just strips important context, why change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Comment on lines 345 to 346
vol = torch.abs(state.volume) / 1000 # A^3 -> nm^3
return torch.where(vol > 0, n * n / vol, n).tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the variable volume elsewhere but vol here?

would also change n -> n_atoms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

@@ -325,58 +325,39 @@ def determine_max_batch_size(
def calculate_memory_scaler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest changing the name to calculate_memory_scalers since this now returns a list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Comment on lines +140 to -155
n_systems_val = 1
else: # assert that system indices are unique consecutive integers
_, counts = torch.unique_consecutive(initial_system_idx, return_counts=True)
n_systems_val = len(counts)
if not torch.all(counts == torch.bincount(initial_system_idx)):
raise ValueError("System indices must be unique consecutive integers")

if self.constraints:
validate_constraints(self.constraints, state=self)

if self.charge is None:
self.charge = torch.zeros(
self.n_systems, device=self.device, dtype=self.dtype
)
elif self.charge.shape[0] != self.n_systems:
raise ValueError(f"Charge must have shape (n_systems={self.n_systems},)")
self.charge = torch.zeros(n_systems_val, device=self.device, dtype=self.dtype)
elif self.charge.shape[0] != n_systems_val:
raise ValueError(f"Charge must have shape (n_systems={n_systems_val},)")
if self.spin is None:
self.spin = torch.zeros(self.n_systems, device=self.device, dtype=self.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is any of this logic changed? seems unrelated to rest of PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The n_systems property calls torch.unique() every time it's accessed. The old code called this property 4-5 times during __post_init__. The new code compute the value once into a local variable (n_systems_val) and reuses it, avoiding redundant unique() calls. This speeds up state initialization, especially for large batches. It’s a minor optimization, but I think it's still relevant.

Comment on lines +648 to +651
def _split_state[T: SimState](state: T) -> Sequence[T]:
"""Return a lazy Sequence view of state split into single-system states.

Each single-system state is created on first access, so the call is O(1).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I don't really follow how this is faster. It just shifts the cost to later down the line by making the evaluation lazy instead of eager, which isn't necessarily better given the code is harder to follow. Could you explain how this is saving time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eager implementation creates all N SimState objects upfront by running torch.split() across every attribute, which is O(N) work regardless of how many states are actually used. The lazy implementation only computes a cumsum (effectively O(1)) and defers creating states until __getitem__ is called.

This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.

Benchmarking with n=10000 systems shows the lazy version finishes in 9.89 s vs 11.37 s for eager (~15% faster). If all states are accessed, the total work is basically the same, but the common autobatching paths don’t require all states to be materialized at once.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.

I don't follow? How can you get the argmin over values that haven't been evaluated? Don't you need to evaluate to know the minimum?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to where the eager and lazy are benchmarked against eachother?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please find the benchmark script below. The script defines an OldState(SimState) class using the old _split_state method and produces the results shown. Alternatively, you can run 8.scaling.py with the current version of the code and with the split method reverted, and you’ll get the same results.

=== Comparison ===
  n=   1: eager=2.387301s, lazy=0.026688s, speedup=89.45x
  n=   1: eager=0.275248s, lazy=0.022556s, speedup=12.20x
  n=   1: eager=0.677568s, lazy=0.022692s, speedup=29.86x
  n=   1: eager=0.023758s, lazy=0.021832s, speedup=1.09x
  n=  10: eager=0.283295s, lazy=0.025066s, speedup=11.30x
  n= 100: eager=0.700301s, lazy=0.108458s, speedup=6.46x
  n= 250: eager=1.020541s, lazy=0.252591s, speedup=4.04x
  n= 500: eager=0.572549s, lazy=0.506986s, speedup=1.13x
  n=1000: eager=1.115220s, lazy=0.989014s, speedup=1.13x
  n=1500: eager=1.665096s, lazy=1.469927s, speedup=1.13x
  n=5000: eager=5.755523s, lazy=5.062751s, speedup=1.14x
  n=10000: eager=10.948717s, lazy=9.643398s, speedup=1.14x

Benchmark Script

"""Test comparing OldState (eager split) vs State (lazy split) performance."""

import time
import typing
from typing import Self
from unittest.mock import patch
import torch
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.state import SimState, get_attrs_for_scope


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float64
N_STRUCTURES = [1, 1, 1, 1, 10, 100, 250, 500, 1000, 1500, 5000, 10000]
MAX_MEMORY_SCALER = 400_000
MEMORY_SCALES_WITH = "n_atoms_x_density"



class OldState(SimState):
    """Old state representation that uses eager splitting.

    This class inherits from SimState but overrides split() to use the old eager
    approach where all sub-states are created upfront when split() is called.
    """

    def split(self) -> list[Self]:
        """Split the OldState into a list of single-system OldStates (EAGER).

        This is the OLD approach that creates ALL states upfront, which is O(n)
        where n is the number of systems.

        Returns:
            list[OldState]: A list of OldState objects, one per system
        """
        return self._split_state(self)
            
    @staticmethod
    def _split_state[T: SimState](state: T) -> list[T]:
        """Split a SimState into a list of states, each containing a single system.

        Divides a multi-system state into individual single-system states, preserving
        appropriate properties for each system.

        Args:
            state (SimState): The SimState to split

        Returns:
            list[SimState]: A list of SimState objects, each containing a single
                system
        """
        system_sizes = state.n_atoms_per_system.tolist()

        split_per_atom = {}
        for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"):
            if attr_name != "system_idx":
                split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)

        split_per_system = {}
        for attr_name, attr_value in get_attrs_for_scope(state, "per-system"):
            if isinstance(attr_value, torch.Tensor):
                split_per_system[attr_name] = torch.split(attr_value, 1, dim=0)
            else:  # Non-tensor attributes are replicated for each split
                split_per_system[attr_name] = [attr_value] * state.n_systems

        global_attrs = dict(get_attrs_for_scope(state, "global"))

        # Create a state for each system
        states: list[T] = []
        n_systems = len(system_sizes)
        zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64)
        cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0)))
        for sys_idx in range(n_systems):
            # Build per-system attributes (padded attributes stay padded for consistency)
            per_system_dict = {
                attr_name: split_per_system[attr_name][sys_idx]
                for attr_name in split_per_system
            }

            system_attrs = {
                # Create a system tensor with all zeros for this system
                "system_idx": torch.zeros(
                    system_sizes[sys_idx], device=state.device, dtype=torch.int64
                ),
                # Add the split per-atom attributes
                **{
                    attr_name: split_per_atom[attr_name][sys_idx]
                    for attr_name in split_per_atom
                },
                # Add the split per-system attributes (with unpadding applied)
                **per_system_dict,
                # Add the global attributes
                **global_attrs,
            }

            atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1])
            new_constraints = [
                new_constraint
                for constraint in state.constraints
                if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx))
            ]

            system_attrs["_constraints"] = new_constraints
            states.append(type(state)(**system_attrs))  # type: ignore[invalid-argument-type]

        return states


    @classmethod
    def from_sim_state(cls, state: SimState) -> "OldState":
        """Create OldState from a SimState."""
        return cls(
            positions=state.positions.clone(),
            masses=state.masses.clone(),
            cell=state.cell.clone(),
            pbc=state.pbc.clone() if isinstance(state.pbc, torch.Tensor) else state.pbc,
            atomic_numbers=state.atomic_numbers.clone(),
            charge=state.charge.clone() if state.charge is not None else None,
            spin=state.spin.clone() if state.spin is not None else None,
            system_idx=state.system_idx.clone() if state.system_idx is not None else None,
            _constraints=state.constraints.copy(),
        )


def run_torchsim_static(
    n_structures_list: list[int],
    base_structure,
    model,
    device: torch.device,
) -> list[float]:
    """Run static calculations for each n using batched path, return timings."""
    autobatcher = ts.BinningAutoBatcher(
        model=model,
        max_memory_scaler=MAX_MEMORY_SCALER,
        memory_scales_with=MEMORY_SCALES_WITH,
    )
    times: list[float] = []
    for n in n_structures_list:
        structures = [base_structure] * n
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        ts.static(structures, model, autobatcher=autobatcher)
        if device.type == "cuda":
            torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        times.append(elapsed)
        print(f"  n={n} static_time={elapsed:.6f}s")
    return times
  


if __name__ == "__main__":

    # Setup
    mgo_atoms = bulk("MgO", crystalstructure="rocksalt", a=4.21, cubic=True)

    print("Loading MACE model...")
    loaded_model = mace_mp(
        model=MaceUrls.mace_mpa_medium,
        return_raw_model=True,
        default_dtype="float64",
        device=str(DEVICE),
    )
    mace_model = MaceModel(
        model=typing.cast("torch.nn.Module", loaded_model),
        device=DEVICE,
        compute_forces=True,
        compute_stress=True,
        dtype=DTYPE,
        enable_cueq=False,
    )

    print("\n=== Static Benchmark Comparison: Eager vs Lazy Split ===\n")

    # Run with eager split (OldState behavior)
    print("Running with eager split (OldState):")
    with patch.object(SimState, "split", lambda self: OldState._split_state(self)):
        eager_times = run_torchsim_static(
            N_STRUCTURES, mgo_atoms, mace_model, DEVICE
        )

    # Run with lazy split (default State behavior)
    print("\nRunning with lazy split (State):")
    lazy_times = run_torchsim_static(
        N_STRUCTURES, mgo_atoms, mace_model, DEVICE
    )

    # Print comparison
    print("\n=== Comparison ===")
    for i, n in enumerate(N_STRUCTURES):
        speedup = eager_times[i] / max(lazy_times[i], 1e-9)
        print(
            f"  n={n:4d}: eager={eager_times[i]:.6f}s, "
            f"lazy={lazy_times[i]:.6f}s, speedup={speedup:.2f}x"
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argmin runs over metric_values, which is just a list of floats computed from the batched state tensors before we split anything. This gives us integer indices, and only when we call state_list[idx]we do actually create the 2 needed states.

Comment on lines 937 to 920
all relevant properties.
their requested order (not natural 0,1,2 order).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That refers to the fact that calling _slice_state(state, [3, 1, 4]) returns a new state where the systems appear in the exact order requested (3→0, 1→1, 4→2), instead of being sorted in ascending order. I updated the docstring to better reflect this point.

Comment on lines -1114 to +1128
device: torch.device | None = None,
dtype: torch.dtype | None = None,
device: torch.device,
dtype: torch.dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor fix, but I think making device and dtype required improves type checking and forces callers to be explicit. With None defaults, .to(None, None) is a no-op, so the state silently stays where it was, potentially causing device mismatches with the model later.

return keep_state, pop_states


def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor) -> T:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rewrite of this function adds a lot of code and it's unclear it makes anything faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above

@falletta
Copy link
Contributor Author

falletta commented Feb 5, 2026

Thanks Orion for reviewing. I applied a few cosmetic changes to address your comments. Regarding your points:

  1. Great!
  2. The changes to the state manipulations are essential, as described in my comments. The scaling scripts provide a benchmark that highlights the performance speedup.
  3. All of the changes were intentional and came out of carefully profiling many functions in the code—from the major bottlenecks in memory_scaler and split to smaller optimizations in the state initialization.

Co-authored-by: Cursor <cursoragent@cursor.com>
falletta and others added 2 commits February 5, 2026 08:20
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants