diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 5dbd0f39..a49ba07d 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -748,22 +748,7 @@ def static( state: SimState = ts.initialize_state(system, model.device, model.dtype) batch_iterator = _configure_batches_iterator(state, model, autobatcher=autobatcher) - properties = ["potential_energy"] - if model.compute_forces: - properties.append("forces") - if model.compute_stress: - properties.append("stress") - if isinstance(trajectory_reporter, dict): - trajectory_reporter = copy.deepcopy(trajectory_reporter) - trajectory_reporter["state_kwargs"] = { - "variable_atomic_numbers": True, - "variable_masses": True, - "save_forces": model.compute_forces, - } - trajectory_reporter = _configure_reporter( - trajectory_reporter or dict(filenames=None), - properties=properties, - ) + trajectory_reporter, _ = _configure_static_reporter(trajectory_reporter, model) @dataclass(kw_only=True) class StaticState(SimState): @@ -778,7 +763,7 @@ class StaticState(SimState): } all_props: list[dict[str, torch.Tensor]] = [] - og_filenames = trajectory_reporter.filenames + og_filenames = trajectory_reporter.filenames if trajectory_reporter else None tqdm_pbar = None if pbar and autobatcher: @@ -812,13 +797,17 @@ class StaticState(SimState): ), ) - props = trajectory_reporter.report(static_state, 0, model=model) - all_props.extend(props) + if trajectory_reporter: + props = trajectory_reporter.report(static_state, 0, model=model) + all_props.extend(props) + else: + all_props.extend(_collect_base_properties(sub_state, static_state, model)) if tqdm_pbar: tqdm_pbar.update(static_state.n_systems) - trajectory_reporter.finish() + if trajectory_reporter: + trajectory_reporter.finish() if isinstance(batch_iterator, BinningAutoBatcher): # reorder properties to match original order of states @@ -827,3 +816,65 @@ class StaticState(SimState): return [prop for _, prop in sorted(indexed_props, key=lambda x: x[0])] return all_props + + +def _configure_static_reporter( + trajectory_reporter: TrajectoryReporter | dict | None, + model: ModelInterface, +) -> tuple[TrajectoryReporter | None, list[str]]: + """Configure trajectory reporter for static calculations. + + Args: + trajectory_reporter: Optional reporter or dict config + model: The model interface to check which properties are computed + + Returns: + Tuple of (configured reporter or None, list of properties) + """ + properties = ["potential_energy"] + if model.compute_forces: + properties.append("forces") + if model.compute_stress: + properties.append("stress") + + if trajectory_reporter is None: + return None, properties + + if isinstance(trajectory_reporter, dict): + trajectory_reporter = copy.deepcopy(trajectory_reporter) + trajectory_reporter["state_kwargs"] = { + "variable_atomic_numbers": True, + "variable_masses": True, + "save_forces": model.compute_forces, + } + + return _configure_reporter(trajectory_reporter, properties=properties), properties + + +def _collect_base_properties( + sub_state: SimState, + static_state: SimState, + model: ModelInterface, +) -> list[dict[str, torch.Tensor]]: + """Collect base properties for each system when no trajectory reporter is used. + + Args: + sub_state: The sub-state being processed + static_state: State containing energy, forces, and stress from model + model: The model interface to check which properties are computed + + Returns: + List of property dictionaries, one per system + """ + props_list: list[dict[str, torch.Tensor]] = [] + for sys_idx in range(sub_state.n_systems): + atom_mask = sub_state.system_idx == sys_idx + base_props: dict[str, torch.Tensor] = { + "potential_energy": static_state.energy[sys_idx], + } + if model.compute_forces: + base_props["forces"] = static_state.forces[atom_mask] + if model.compute_stress: + base_props["stress"] = static_state.stress[sys_idx] + props_list.append(base_props) + return props_list