Skip to content

Made trajectory reporter optional in static function#441

Open
falletta wants to merge 7 commits intoTorchSim:mainfrom
falletta:trajreporter
Open

Made trajectory reporter optional in static function#441
falletta wants to merge 7 commits intoTorchSim:mainfrom
falletta:trajreporter

Conversation

@falletta
Copy link
Contributor

@falletta falletta commented Feb 4, 2026

By making the trajectory reporter optional, we avoid additional computational overhead (see profiling plot in this PR).

Comment on lines 819 to 830
else:
# Collect base properties for each system when no reporter
for sys_idx in range(sub_state.n_systems):
atom_mask = sub_state.system_idx == sys_idx
base_props: dict[str, torch.Tensor] = {
"potential_energy": static_state.energy[sys_idx],
}
if model.compute_forces:
base_props["forces"] = static_state.forces[atom_mask]
if model.compute_stress:
base_props["stress"] = static_state.stress[sys_idx]
all_props.append(base_props)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is faster?

The profiling data doesn't really say anything about how long the report call takes, and this isn't benchmarked against the current code.

Copy link
Contributor Author

@falletta falletta Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I confirm the speedup. Below are the results from running the run_torchsim_static function in 8.scaling.py (see the static PR) before and after the fix. We observe a speedup of up to 24.3%, and it continues to increase with system size. In addition, I verified that the cost associated to the trajectory reporter disappears from the profiling analysis.

Previous results:

=== Static benchmark ===
  n=1 static_time=1.928943s
  n=1 static_time=0.272846s
  n=1 static_time=0.683335s
  n=1 static_time=0.026675s
  n=10 static_time=0.281990s
  n=100 static_time=0.705871s
  n=500 static_time=1.510273s
  n=1000 static_time=1.528872s
  n=2500 static_time=3.809000s
  n=5000 static_time=7.890238s

New results:

  n=1 static_time=2.165601s
  n=1 static_time=0.271961s
  n=1 static_time=0.665016s
  n=1 static_time=0.022899s
  n=10 static_time=0.295468s
  n=100 static_time=0.692651s
  n=500 static_time=1.411905s
  n=1000 static_time=1.291772s -> 18.3% speedup
  n=2500 static_time=3.175455s -> 19.9% speedup
  n=5000 static_time=6.348887s -> 24.3% speedup



def static(
def static( # noqa: C901
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't favor disabling the complexity limit here. A lot of effort has been put into keeping these functions minimal and readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I introduced helper functions to reduce the complexity of static. However, note that optimize has a # noqa: C901, PLR0915, which disables the complexity limit, similarly to my original implementation. We might consider introducing helper functions in optimize as well, for consistency and style with static.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants