diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index 5b494bc0..53a7e0bd 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -28,6 +28,24 @@ from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs +def _compute_atom_idx(system_idx: torch.Tensor, n_systems: int) -> torch.Tensor: + """Compute per-system atom indices, vectorized. + + Args: + system_idx: System index for each atom [N] + n_systems: Number of systems S + + Returns: + Tensor [N] with per-system atom indices + """ + device = system_idx.device + counts = torch.bincount(system_idx, minlength=n_systems) + offsets = torch.zeros(n_systems, device=device, dtype=torch.long) + if n_systems > 1: + offsets[1:] = counts[:-1].cumsum(0) + return torch.arange(len(system_idx), device=device) - offsets[system_idx] + + def _atoms_to_padded( x: torch.Tensor, system_idx: torch.Tensor, @@ -47,11 +65,7 @@ def _atoms_to_padded( """ device, dtype = x.device, x.dtype out = torch.zeros((n_systems, max_atoms, 3), device=device, dtype=dtype) - # Create atom index within each system - atom_idx = torch.zeros_like(system_idx) - for sys in range(n_systems): - mask = system_idx == sys - atom_idx[mask] = torch.arange(mask.sum(), device=device) + atom_idx = _compute_atom_idx(system_idx, n_systems) out[system_idx, atom_idx] = x return out @@ -59,25 +73,18 @@ def _atoms_to_padded( def _padded_to_atoms( x: torch.Tensor, system_idx: torch.Tensor, - n_atoms: int, ) -> torch.Tensor: """Convert padded per-system [S, M, 3] to atom-indexed [N, 3]. Args: x: Tensor of shape [S, M, 3] system_idx: System index for each atom [N] - n_atoms: Total number of atoms N Returns: Tensor of shape [N, 3] """ n_systems = x.shape[0] - device = x.device - # Create atom index within each system - atom_idx = torch.zeros(n_atoms, device=device, dtype=torch.long) - for sys in range(n_systems): - mask = system_idx == sys - atom_idx[mask] = torch.arange(mask.sum(), device=device) + atom_idx = _compute_atom_idx(system_idx, n_systems) return x[system_idx, atom_idx] # [N, 3] @@ -315,7 +322,6 @@ def lbfgs_step( # noqa: PLR0915, C901 device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 n_systems = state.n_systems # S - n_atoms = state.n_atoms # N # Derive max_atoms from history shape: [S, H, M, 3] or [S, H, M_ext, 3] history_dim = state.s_history.shape[2] # M or M_ext @@ -326,17 +332,8 @@ def lbfgs_step( # noqa: PLR0915, C901 max_atoms = history_dim # M max_atoms_ext = max_atoms - # Create atom index within each system for padding/unpadding - atom_idx_in_sys = torch.zeros(n_atoms, device=device, dtype=torch.long) - for sys in range(n_systems): - mask = state.system_idx == sys - atom_idx_in_sys[mask] = torch.arange(mask.sum(), device=device) - # Create valid atom mask for per-system operations: [S, M] - atom_mask = torch.zeros((n_systems, max_atoms), device=device, dtype=torch.bool) - for sys in range(n_systems): - n_atoms_sys = int(state.max_atoms[sys].item()) - atom_mask[sys, :n_atoms_sys] = True + atom_mask = torch.arange(max_atoms, device=device)[None] < state.max_atoms[:, None] # Extended mask including cell DOFs: [S, M_ext] if is_cell_state: @@ -471,10 +468,10 @@ def lbfgs_step( # noqa: PLR0915, C901 step_padded = step[:, :max_atoms] # [S, M, 3] step_cell = step[:, max_atoms:] # [S, 3, 3] # Convert padded step to atom-level - step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + step_positions = _padded_to_atoms(step_padded, state.system_idx) else: step_padded = step # [S, M, 3] - step_positions = _padded_to_atoms(step_padded, state.system_idx, n_atoms) + step_positions = _padded_to_atoms(step_padded, state.system_idx) # Save previous state for history update # For cell state: store fractional positions and scaled forces (ASE convention)