From 38933d2b565074a9826a7e42fe1643d40df241c5 Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 4 Feb 2026 13:09:02 -0800 Subject: [PATCH 1/4] trajectory reporter None in static if not specified --- torch_sim/runners.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 6b7204dee..96652c193 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -709,7 +709,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 return state # type: ignore[return-value] -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -753,17 +753,18 @@ def static( properties.append("forces") if model.compute_stress: properties.append("stress") - if isinstance(trajectory_reporter, dict): - trajectory_reporter = copy.deepcopy(trajectory_reporter) - trajectory_reporter["state_kwargs"] = { - "variable_atomic_numbers": True, - "variable_masses": True, - "save_forces": model.compute_forces, - } - trajectory_reporter = _configure_reporter( - trajectory_reporter or dict(filenames=None), - properties=properties, - ) + if trajectory_reporter is not None: + if isinstance(trajectory_reporter, dict): + trajectory_reporter = copy.deepcopy(trajectory_reporter) + trajectory_reporter["state_kwargs"] = { + "variable_atomic_numbers": True, + "variable_masses": True, + "save_forces": model.compute_forces, + } + trajectory_reporter = _configure_reporter( + trajectory_reporter, + properties=properties, + ) @dataclass(kw_only=True) class StaticState(SimState): @@ -778,7 +779,7 @@ class StaticState(SimState): } all_props: list[dict[str, torch.Tensor]] = [] - og_filenames = trajectory_reporter.filenames + og_filenames = trajectory_reporter.filenames if trajectory_reporter else None tqdm_pbar = None if pbar and autobatcher: @@ -812,13 +813,15 @@ class StaticState(SimState): ), ) - props = trajectory_reporter.report(static_state, 0, model=model) - all_props.extend(props) + if trajectory_reporter: + props = trajectory_reporter.report(static_state, 0, model=model) + all_props.extend(props) if tqdm_pbar: tqdm_pbar.update(static_state.n_systems) - trajectory_reporter.finish() + if trajectory_reporter: + trajectory_reporter.finish() if isinstance(batch_iterator, BinningAutoBatcher): # reorder properties to match original order of states From 71f356d7b29f18b8a4feca47c63e47930c43a52e Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 4 Feb 2026 13:16:59 -0800 Subject: [PATCH 2/4] populate all_props --- torch_sim/runners.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 96652c193..6d73b2b22 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -813,9 +813,24 @@ class StaticState(SimState): ), ) + # Collect base properties for each system + for sys_idx in range(sub_state.n_systems): + base_props: dict[str, torch.Tensor] = { + "potential_energy": static_state.energy[sys_idx], + } + atom_mask = sub_state.system_idx == sys_idx + if model.compute_forces: + base_props["forces"] = static_state.forces[atom_mask] + if model.compute_stress: + base_props["stress"] = static_state.stress[sys_idx] + all_props.append(base_props) + if trajectory_reporter: props = trajectory_reporter.report(static_state, 0, model=model) - all_props.extend(props) + # Merge any additional properties from reporter into base props + start_idx = len(all_props) - static_state.n_systems + for i, prop in enumerate(props): + all_props[start_idx + i].update(prop) if tqdm_pbar: tqdm_pbar.update(static_state.n_systems) From 21f2a7fe37258c4d7b1bf416fdee5aa2ad321115 Mon Sep 17 00:00:00 2001 From: falletta Date: Wed, 4 Feb 2026 13:21:15 -0800 Subject: [PATCH 3/4] collecting props when no reporter --- torch_sim/runners.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 6d73b2b22..39e952dbc 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -813,24 +813,21 @@ class StaticState(SimState): ), ) - # Collect base properties for each system - for sys_idx in range(sub_state.n_systems): - base_props: dict[str, torch.Tensor] = { - "potential_energy": static_state.energy[sys_idx], - } - atom_mask = sub_state.system_idx == sys_idx - if model.compute_forces: - base_props["forces"] = static_state.forces[atom_mask] - if model.compute_stress: - base_props["stress"] = static_state.stress[sys_idx] - all_props.append(base_props) - if trajectory_reporter: props = trajectory_reporter.report(static_state, 0, model=model) - # Merge any additional properties from reporter into base props - start_idx = len(all_props) - static_state.n_systems - for i, prop in enumerate(props): - all_props[start_idx + i].update(prop) + all_props.extend(props) + else: + # Collect base properties for each system when no reporter + for sys_idx in range(sub_state.n_systems): + atom_mask = sub_state.system_idx == sys_idx + base_props: dict[str, torch.Tensor] = { + "potential_energy": static_state.energy[sys_idx], + } + if model.compute_forces: + base_props["forces"] = static_state.forces[atom_mask] + if model.compute_stress: + base_props["stress"] = static_state.stress[sys_idx] + all_props.append(base_props) if tqdm_pbar: tqdm_pbar.update(static_state.n_systems) From af5c899d0aebad7e490d5a16afe3b1344bf7ace0 Mon Sep 17 00:00:00 2001 From: falletta Date: Thu, 5 Feb 2026 13:45:35 -0800 Subject: [PATCH 4/4] added helper functions --- torch_sim/runners.py | 94 ++++++++++++++++++++++++++++++-------------- 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 39e952dbc..516ae62d2 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -709,7 +709,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 return state # type: ignore[return-value] -def static( # noqa: C901 +def static( system: StateLike, model: ModelInterface, *, @@ -748,23 +748,7 @@ def static( # noqa: C901 state: SimState = ts.initialize_state(system, model.device, model.dtype) batch_iterator = _configure_batches_iterator(state, model, autobatcher=autobatcher) - properties = ["potential_energy"] - if model.compute_forces: - properties.append("forces") - if model.compute_stress: - properties.append("stress") - if trajectory_reporter is not None: - if isinstance(trajectory_reporter, dict): - trajectory_reporter = copy.deepcopy(trajectory_reporter) - trajectory_reporter["state_kwargs"] = { - "variable_atomic_numbers": True, - "variable_masses": True, - "save_forces": model.compute_forces, - } - trajectory_reporter = _configure_reporter( - trajectory_reporter, - properties=properties, - ) + trajectory_reporter, _ = _configure_static_reporter(trajectory_reporter, model) @dataclass(kw_only=True) class StaticState(SimState): @@ -817,17 +801,7 @@ class StaticState(SimState): props = trajectory_reporter.report(static_state, 0, model=model) all_props.extend(props) else: - # Collect base properties for each system when no reporter - for sys_idx in range(sub_state.n_systems): - atom_mask = sub_state.system_idx == sys_idx - base_props: dict[str, torch.Tensor] = { - "potential_energy": static_state.energy[sys_idx], - } - if model.compute_forces: - base_props["forces"] = static_state.forces[atom_mask] - if model.compute_stress: - base_props["stress"] = static_state.stress[sys_idx] - all_props.append(base_props) + all_props.extend(_collect_base_properties(sub_state, static_state, model)) if tqdm_pbar: tqdm_pbar.update(static_state.n_systems) @@ -842,3 +816,65 @@ class StaticState(SimState): return [prop for _, prop in sorted(indexed_props, key=lambda x: x[0])] return all_props + + +def _configure_static_reporter( + trajectory_reporter: TrajectoryReporter | dict | None, + model: ModelInterface, +) -> tuple[TrajectoryReporter | None, list[str]]: + """Configure trajectory reporter for static calculations. + + Args: + trajectory_reporter: Optional reporter or dict config + model: The model interface to check which properties are computed + + Returns: + Tuple of (configured reporter or None, list of properties) + """ + properties = ["potential_energy"] + if model.compute_forces: + properties.append("forces") + if model.compute_stress: + properties.append("stress") + + if trajectory_reporter is None: + return None, properties + + if isinstance(trajectory_reporter, dict): + trajectory_reporter = copy.deepcopy(trajectory_reporter) + trajectory_reporter["state_kwargs"] = { + "variable_atomic_numbers": True, + "variable_masses": True, + "save_forces": model.compute_forces, + } + + return _configure_reporter(trajectory_reporter, properties=properties), properties + + +def _collect_base_properties( + sub_state: SimState, + static_state: SimState, + model: ModelInterface, +) -> list[dict[str, torch.Tensor]]: + """Collect base properties for each system when no trajectory reporter is used. + + Args: + sub_state: The sub-state being processed + static_state: State containing energy, forces, and stress from model + model: The model interface to check which properties are computed + + Returns: + List of property dictionaries, one per system + """ + props_list: list[dict[str, torch.Tensor]] = [] + for sys_idx in range(sub_state.n_systems): + atom_mask = sub_state.system_idx == sys_idx + base_props: dict[str, torch.Tensor] = { + "potential_energy": static_state.energy[sys_idx], + } + if model.compute_forces: + base_props["forces"] = static_state.forces[atom_mask] + if model.compute_stress: + base_props["stress"] = static_state.stress[sys_idx] + props_list.append(base_props) + return props_list