diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index a9ac8b71..5b494bc0 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -537,6 +537,11 @@ def lbfgs_step( # noqa: PLR0915, C901 if is_cell_state: cell_filters.compute_cell_forces(model_output, state) + # Update state + state.set_constrained_forces(new_forces) # [N, 3] + state.energy = new_energy # [S] + state.stress = new_stress # [S, 3, 3] or None + # Build new (s, y) for history in per-system format [S, M_ext, 3] or [S, M, 3] # s = position difference, y = gradient difference if is_cell_state: @@ -603,11 +608,6 @@ def lbfgs_step( # noqa: PLR0915, C901 s_hist = s_hist[:, -max_history:] # [S, max_history, ...] y_hist = y_hist[:, -max_history:] - # Update state - state.set_constrained_forces(new_forces) # [N, 3] - state.energy = new_energy # [S] - state.stress = new_stress # [S, 3, 3] or None - if is_cell_state: # Store fractional/scaled for next iteration state.prev_positions = new_frac_positions.clone() # [N, 3] (fractional)