Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 71 additions & 20 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,22 +748,7 @@ def static(
state: SimState = ts.initialize_state(system, model.device, model.dtype)

batch_iterator = _configure_batches_iterator(state, model, autobatcher=autobatcher)
properties = ["potential_energy"]
if model.compute_forces:
properties.append("forces")
if model.compute_stress:
properties.append("stress")
if isinstance(trajectory_reporter, dict):
trajectory_reporter = copy.deepcopy(trajectory_reporter)
trajectory_reporter["state_kwargs"] = {
"variable_atomic_numbers": True,
"variable_masses": True,
"save_forces": model.compute_forces,
}
trajectory_reporter = _configure_reporter(
trajectory_reporter or dict(filenames=None),
properties=properties,
)
trajectory_reporter, _ = _configure_static_reporter(trajectory_reporter, model)

@dataclass(kw_only=True)
class StaticState(SimState):
Expand All @@ -778,7 +763,7 @@ class StaticState(SimState):
}

all_props: list[dict[str, torch.Tensor]] = []
og_filenames = trajectory_reporter.filenames
og_filenames = trajectory_reporter.filenames if trajectory_reporter else None

tqdm_pbar = None
if pbar and autobatcher:
Expand Down Expand Up @@ -812,13 +797,17 @@ class StaticState(SimState):
),
)

props = trajectory_reporter.report(static_state, 0, model=model)
all_props.extend(props)
if trajectory_reporter:
props = trajectory_reporter.report(static_state, 0, model=model)
all_props.extend(props)
else:
all_props.extend(_collect_base_properties(sub_state, static_state, model))

if tqdm_pbar:
tqdm_pbar.update(static_state.n_systems)

trajectory_reporter.finish()
if trajectory_reporter:
trajectory_reporter.finish()

if isinstance(batch_iterator, BinningAutoBatcher):
# reorder properties to match original order of states
Expand All @@ -827,3 +816,65 @@ class StaticState(SimState):
return [prop for _, prop in sorted(indexed_props, key=lambda x: x[0])]

return all_props


def _configure_static_reporter(
trajectory_reporter: TrajectoryReporter | dict | None,
model: ModelInterface,
) -> tuple[TrajectoryReporter | None, list[str]]:
"""Configure trajectory reporter for static calculations.

Args:
trajectory_reporter: Optional reporter or dict config
model: The model interface to check which properties are computed

Returns:
Tuple of (configured reporter or None, list of properties)
"""
properties = ["potential_energy"]
if model.compute_forces:
properties.append("forces")
if model.compute_stress:
properties.append("stress")

if trajectory_reporter is None:
return None, properties

if isinstance(trajectory_reporter, dict):
trajectory_reporter = copy.deepcopy(trajectory_reporter)
trajectory_reporter["state_kwargs"] = {
"variable_atomic_numbers": True,
"variable_masses": True,
"save_forces": model.compute_forces,
}

return _configure_reporter(trajectory_reporter, properties=properties), properties


def _collect_base_properties(
sub_state: SimState,
static_state: SimState,
model: ModelInterface,
) -> list[dict[str, torch.Tensor]]:
"""Collect base properties for each system when no trajectory reporter is used.

Args:
sub_state: The sub-state being processed
static_state: State containing energy, forces, and stress from model
model: The model interface to check which properties are computed

Returns:
List of property dictionaries, one per system
"""
props_list: list[dict[str, torch.Tensor]] = []
for sys_idx in range(sub_state.n_systems):
atom_mask = sub_state.system_idx == sys_idx
base_props: dict[str, torch.Tensor] = {
"potential_energy": static_state.energy[sys_idx],
}
if model.compute_forces:
base_props["forces"] = static_state.forces[atom_mask]
if model.compute_stress:
base_props["stress"] = static_state.stress[sys_idx]
props_list.append(base_props)
return props_list
Loading