Made trajectory reporter optional in static function#441
Made trajectory reporter optional in static function#441falletta wants to merge 7 commits intoTorchSim:mainfrom
Conversation
torch_sim/runners.py
Outdated
| else: | ||
| # Collect base properties for each system when no reporter | ||
| for sys_idx in range(sub_state.n_systems): | ||
| atom_mask = sub_state.system_idx == sys_idx | ||
| base_props: dict[str, torch.Tensor] = { | ||
| "potential_energy": static_state.energy[sys_idx], | ||
| } | ||
| if model.compute_forces: | ||
| base_props["forces"] = static_state.forces[atom_mask] | ||
| if model.compute_stress: | ||
| base_props["stress"] = static_state.stress[sys_idx] | ||
| all_props.append(base_props) |
There was a problem hiding this comment.
Are you sure this is faster?
The profiling data doesn't really say anything about how long the report call takes, and this isn't benchmarked against the current code.
There was a problem hiding this comment.
Yes, I confirm the speedup. Below are the results from running the run_torchsim_static function in 8.scaling.py (see the static PR) before and after the fix. We observe a speedup of up to 24.3%, and it continues to increase with system size. In addition, I verified that the cost associated to the trajectory reporter disappears from the profiling analysis.
Previous results:
=== Static benchmark ===
n=1 static_time=1.928943s
n=1 static_time=0.272846s
n=1 static_time=0.683335s
n=1 static_time=0.026675s
n=10 static_time=0.281990s
n=100 static_time=0.705871s
n=500 static_time=1.510273s
n=1000 static_time=1.528872s
n=2500 static_time=3.809000s
n=5000 static_time=7.890238s
New results:
n=1 static_time=2.165601s
n=1 static_time=0.271961s
n=1 static_time=0.665016s
n=1 static_time=0.022899s
n=10 static_time=0.295468s
n=100 static_time=0.692651s
n=500 static_time=1.411905s
n=1000 static_time=1.291772s -> 18.3% speedup
n=2500 static_time=3.175455s -> 19.9% speedup
n=5000 static_time=6.348887s -> 24.3% speedup
torch_sim/runners.py
Outdated
|
|
||
|
|
||
| def static( | ||
| def static( # noqa: C901 |
There was a problem hiding this comment.
I don't favor disabling the complexity limit here. A lot of effort has been put into keeping these functions minimal and readable.
There was a problem hiding this comment.
Ok, I introduced helper functions to reduce the complexity of static. However, note that optimize has a # noqa: C901, PLR0915, which disables the complexity limit, similarly to my original implementation. We might consider introducing helper functions in optimize as well, for consistency and style with static.
By making the trajectory reporter optional, we avoid additional computational overhead (see profiling plot in this PR).