Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions torch_sim/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@
from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs


def _compute_atom_idx(system_idx: torch.Tensor, n_systems: int) -> torch.Tensor:
"""Compute per-system atom indices, vectorized.

Args:
system_idx: System index for each atom [N]
n_systems: Number of systems S

Returns:
Tensor [N] with per-system atom indices
"""
device = system_idx.device
counts = torch.bincount(system_idx, minlength=n_systems)
offsets = torch.zeros(n_systems, device=device, dtype=torch.long)
if n_systems > 1:
offsets[1:] = counts[:-1].cumsum(0)
return torch.arange(len(system_idx), device=device) - offsets[system_idx]


def _atoms_to_padded(
x: torch.Tensor,
system_idx: torch.Tensor,
Expand All @@ -47,37 +65,26 @@ def _atoms_to_padded(
"""
device, dtype = x.device, x.dtype
out = torch.zeros((n_systems, max_atoms, 3), device=device, dtype=dtype)
# Create atom index within each system
atom_idx = torch.zeros_like(system_idx)
for sys in range(n_systems):
mask = system_idx == sys
atom_idx[mask] = torch.arange(mask.sum(), device=device)
atom_idx = _compute_atom_idx(system_idx, n_systems)
out[system_idx, atom_idx] = x
return out


def _padded_to_atoms(
x: torch.Tensor,
system_idx: torch.Tensor,
n_atoms: int,
) -> torch.Tensor:
"""Convert padded per-system [S, M, 3] to atom-indexed [N, 3].

Args:
x: Tensor of shape [S, M, 3]
system_idx: System index for each atom [N]
n_atoms: Total number of atoms N

Returns:
Tensor of shape [N, 3]
"""
n_systems = x.shape[0]
device = x.device
# Create atom index within each system
atom_idx = torch.zeros(n_atoms, device=device, dtype=torch.long)
for sys in range(n_systems):
mask = system_idx == sys
atom_idx[mask] = torch.arange(mask.sum(), device=device)
atom_idx = _compute_atom_idx(system_idx, n_systems)
return x[system_idx, atom_idx] # [N, 3]


Expand Down Expand Up @@ -315,7 +322,6 @@ def lbfgs_step( # noqa: PLR0915, C901
device, dtype = model.device, model.dtype
eps = 1e-8 if dtype == torch.float32 else 1e-16
n_systems = state.n_systems # S
n_atoms = state.n_atoms # N

# Derive max_atoms from history shape: [S, H, M, 3] or [S, H, M_ext, 3]
history_dim = state.s_history.shape[2] # M or M_ext
Expand All @@ -326,17 +332,8 @@ def lbfgs_step( # noqa: PLR0915, C901
max_atoms = history_dim # M
max_atoms_ext = max_atoms

# Create atom index within each system for padding/unpadding
atom_idx_in_sys = torch.zeros(n_atoms, device=device, dtype=torch.long)
for sys in range(n_systems):
mask = state.system_idx == sys
atom_idx_in_sys[mask] = torch.arange(mask.sum(), device=device)

# Create valid atom mask for per-system operations: [S, M]
atom_mask = torch.zeros((n_systems, max_atoms), device=device, dtype=torch.bool)
for sys in range(n_systems):
n_atoms_sys = int(state.max_atoms[sys].item())
atom_mask[sys, :n_atoms_sys] = True
atom_mask = torch.arange(max_atoms, device=device)[None] < state.max_atoms[:, None]

# Extended mask including cell DOFs: [S, M_ext]
if is_cell_state:
Expand Down Expand Up @@ -471,10 +468,10 @@ def lbfgs_step( # noqa: PLR0915, C901
step_padded = step[:, :max_atoms] # [S, M, 3]
step_cell = step[:, max_atoms:] # [S, 3, 3]
# Convert padded step to atom-level
step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms)
step_positions = _padded_to_atoms(step_padded, state.system_idx)
else:
step_padded = step # [S, M, 3]
step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms)
step_positions = _padded_to_atoms(step_padded, state.system_idx)

# Save previous state for history update
# For cell state: store fractional positions and scaled forces (ASE convention)
Expand Down