diff --git a/examples/scripts/8_scaling.py b/examples/scripts/8_scaling.py new file mode 100644 index 000000000..cea561a87 --- /dev/null +++ b/examples/scripts/8_scaling.py @@ -0,0 +1,208 @@ +"""Scaling benchmarks for static, relax, NVE, and NVT.""" + +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace,test]" +# ] +# /// + +import time +import typing + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +# Shared constants +N_STRUCTURES_STATIC = [1, 1, 1, 1, 10, 100, 500, 1000, 2500, 5000] +N_STRUCTURES_RELAX = [1, 10, 100, 500] +N_STRUCTURES_NVE = [1, 10, 100, 500] +N_STRUCTURES_NVT = [1, 10, 100, 500] +RELAX_STEPS = 10 +MD_STEPS = 10 +MAX_MEMORY_SCALER = 400_000 +MEMORY_SCALES_WITH = "n_atoms_x_density" + + +def load_mace_model(device: torch.device) -> MaceModel: + """Load MACE model for benchmarking.""" + loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype="float64", + device=str(device), + ) + return MaceModel( + model=typing.cast("torch.nn.Module", loaded_model), + device=device, + compute_forces=True, + compute_stress=True, + dtype=torch.float64, + enable_cueq=False, + ) + + +def run_torchsim_static( + n_structures_list: list[int], + base_structure: typing.Any, + model: MaceModel, + device: torch.device, +) -> list[float]: + """Run static calculations for each n using batched path, return timings.""" + autobatcher = ts.BinningAutoBatcher( + model=model, + max_memory_scaler=MAX_MEMORY_SCALER, + memory_scales_with=MEMORY_SCALES_WITH, + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + ts.static(structures, model, autobatcher=autobatcher) + if device.type == "cuda": + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} static_time={elapsed:.6f}s") + return times + + +def run_torchsim_relax( + n_structures_list: list[int], + base_structure: typing.Any, + model: MaceModel, + device: torch.device, +) -> list[float]: + """Run relaxation with ts.optimize for each n; return timings.""" + autobatcher = ts.InFlightAutoBatcher( + model=model, + max_memory_scaler=MAX_MEMORY_SCALER, + memory_scales_with=MEMORY_SCALES_WITH, + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + ts.optimize( + system=structures, + model=model, + optimizer=ts.optimizers.Optimizer.fire, + init_kwargs={ + "cell_filter": ts.optimizers.cell_filters.CellFilter.frechet, + "constant_volume": False, + "hydrostatic_strain": True, + }, + max_steps=RELAX_STEPS, + convergence_fn=ts.runners.generate_force_convergence_fn( + force_tol=1e-3, + include_cell_forces=True, + ), + autobatcher=autobatcher, + ) + if device.type == "cuda": + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} relax_{RELAX_STEPS}_time={elapsed:.6f}s") + return times + + +def run_torchsim_nve( + n_structures_list: list[int], + base_structure: typing.Any, + model: MaceModel, + device: torch.device, +) -> list[float]: + """Run NVE MD for MD_STEPS per n; return times.""" + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + ts.integrate( + system=structures, + model=model, + integrator=ts.Integrator.nve, + n_steps=MD_STEPS, + temperature=300.0, + timestep=0.002, + autobatcher=ts.BinningAutoBatcher( + model=model, + max_memory_scaler=MAX_MEMORY_SCALER, + memory_scales_with=MEMORY_SCALES_WITH, + ), + ) + if device.type == "cuda": + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} nve_time={elapsed:.6f}s") + return times + + +def run_torchsim_nvt( + n_structures_list: list[int], + base_structure: typing.Any, + model: MaceModel, + device: torch.device, +) -> list[float]: + """Run NVT (Nose-Hoover) MD for MD_STEPS per n; return times.""" + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + if device.type == "cuda": + torch.cuda.synchronize() + t0 = time.perf_counter() + ts.integrate( + system=structures, + model=model, + integrator=ts.Integrator.nvt_nose_hoover, + n_steps=MD_STEPS, + temperature=300.0, + timestep=0.002, + autobatcher=ts.BinningAutoBatcher( + model=model, + max_memory_scaler=MAX_MEMORY_SCALER, + memory_scales_with=MEMORY_SCALES_WITH, + ), + ) + if device.type == "cuda": + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} nvt_time={elapsed:.6f}s") + return times + + +if __name__ == "__main__": + # Setup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True) + base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase) + + # Load model once + model = load_mace_model(device) + + # Run all benchmarks + print("=== Static benchmark ===") + static_times = run_torchsim_static(N_STRUCTURES_STATIC, base_structure, model, device) + + print("\n=== Relax benchmark ===") + relax_times = run_torchsim_relax(N_STRUCTURES_RELAX, base_structure, model, device) + + print("\n=== NVE benchmark ===") + nve_times = run_torchsim_nve(N_STRUCTURES_NVE, base_structure, model, device) + + print("\n=== NVT benchmark ===") + nvt_times = run_torchsim_nvt(N_STRUCTURES_NVT, base_structure, model, device) diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index 1b2c473bd..07e55354b 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -50,7 +50,7 @@ # %% import torch -from torch_sim.autobatching import calculate_memory_scaler +from torch_sim.autobatching import calculate_memory_scalers from ase.build import bulk @@ -63,14 +63,14 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) -# Calculate memory scaling factor based on atom count -atom_metric = calculate_memory_scaler(state, memory_scales_with="n_atoms") +# Calculate memory scaling factor based on atom count (returns list, one per system) +atom_metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms") # Calculate memory scaling based on atom count and density -density_metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density") +density_metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density") -print(f"Atom-based memory metric: {atom_metric}") -print(f"Density-based memory metric: {density_metric:.2f}") +print(f"Atom-based memory metrics: {atom_metrics}") +print(f"Density-based memory metrics: {[f'{m:.2f}' for m in density_metrics]}") # %% [markdown] @@ -95,11 +95,11 @@ mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel(model=mace, device=device) -state_list = state.split() -memory_metric_values = [ - calculate_memory_scaler(s, memory_scales_with="n_atoms") for s in state_list -] +# calculate_memory_scalers returns a list with one value per system in the state +memory_metric_values = calculate_memory_scalers(state, memory_scales_with="n_atoms") +# estimate_max_memory_scaler needs a list of individual states +state_list = state.split() max_memory_metric = estimate_max_memory_scaler( state_list, mace_model, metric_values=memory_metric_values ) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index c3e9aea23..953eb0a65 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -7,7 +7,7 @@ from torch_sim.autobatching import ( BinningAutoBatcher, InFlightAutoBatcher, - calculate_memory_scaler, + calculate_memory_scalers, determine_max_batch_size, to_constant_volume_bins, ) @@ -93,31 +93,39 @@ def test_bounds_and_tuples(): def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None: """Test calculation of scaling metrics for a state.""" # Test n_atoms metric - n_atoms_metric = calculate_memory_scaler(si_sim_state, "n_atoms") - assert n_atoms_metric == si_sim_state.n_atoms + n_atoms_metric = calculate_memory_scalers(si_sim_state, "n_atoms") + assert n_atoms_metric == [si_sim_state.n_atoms] # Test n_atoms_x_density metric - density_metric = calculate_memory_scaler(si_sim_state, "n_atoms_x_density") + density_metric = calculate_memory_scalers(si_sim_state, "n_atoms_x_density") volume = torch.abs(torch.linalg.det(si_sim_state.cell[0])) / 1000 expected = si_sim_state.n_atoms * (si_sim_state.n_atoms / volume.item()) - assert pytest.approx(density_metric, rel=1e-5) == expected + assert pytest.approx(density_metric[0], rel=1e-5) == expected # Test invalid metric with pytest.raises(ValueError, match="Invalid metric"): - calculate_memory_scaler(si_sim_state, "invalid_metric") + calculate_memory_scalers(si_sim_state, "invalid_metric") def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None: """Test calculation of scaling metrics for a non-periodic state.""" - # Test that calculate passes - n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms") - assert n_atoms_metric == benzene_sim_state.n_atoms + n_atoms_metric = calculate_memory_scalers(benzene_sim_state, "n_atoms") + assert n_atoms_metric == [benzene_sim_state.n_atoms] - # Test n_atoms_x_density metric works for non-periodic systems - n_atoms_x_density_metric = calculate_memory_scaler( + n_atoms_x_density_metric = calculate_memory_scalers( benzene_sim_state, "n_atoms_x_density" ) - assert n_atoms_x_density_metric > 0 + assert n_atoms_x_density_metric[0] > 0 + bbox = ( + benzene_sim_state.positions.max(dim=0).values + - benzene_sim_state.positions.min(dim=0).values + ).clone() + for i, p in enumerate(benzene_sim_state.pbc): + if not p: + bbox[i] += 2.0 + assert pytest.approx(n_atoms_x_density_metric[0], rel=1e-5) == ( + benzene_sim_state.n_atoms**2 / (bbox.prod().item() / 1000) + ) def test_split_state(si_double_sim_state: ts.SimState) -> None: diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index a0cf6aabb..077a72ce3 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -322,58 +322,41 @@ def determine_max_batch_size( return sizes[-1] -def calculate_memory_scaler( +def calculate_memory_scalers( state: SimState, memory_scales_with: MemoryScaling = "n_atoms_x_density", -) -> float: - """Calculate a metric that estimates memory requirements for a state. +) -> list[float]: + """Calculate memory scaling metric for each system in a state. - Provides different scaling metrics that correlate with memory usage. - Models with radial neighbor cutoffs generally scale with "n_atoms_x_density", - while models with a fixed number of neighbors scale with "n_atoms". - The choice of metric can significantly impact the accuracy of memory requirement - estimations for different types of simulation systems. + Uses vectorized operations for batched periodic states. Args: - state (SimState): State to calculate metric for, with shape information - specific to the SimState instance. - memory_scales_with ("n_atoms_x_density" | "n_atoms"): Type of metric - to use. "n_atoms" uses only atom count and is suitable for models that - have a fixed number of neighbors. "n_atoms_x_density" uses atom count - multiplied by number density and is better for models with radial cutoffs - Defaults to "n_atoms_x_density". + state (SimState): State to calculate metric for. + memory_scales_with ("n_atoms_x_density" | "n_atoms"): Metric type. Returns: - float: Calculated metric value. - - Raises: - ValueError: If state has multiple batches or if an invalid metric type is - provided. - - Example:: - - # Calculate memory scaling factor based on atom count - metric = calculate_memory_scaler(state, memory_scales_with="n_atoms") - - # Calculate memory scaling factor based on atom count and density - metric = calculate_memory_scaler(state, memory_scales_with="n_atoms_x_density") + list[float]: One scaler per system. """ - if state.n_systems > 1: - return sum(calculate_memory_scaler(s, memory_scales_with) for s in state.split()) if memory_scales_with == "n_atoms": - return state.n_atoms + return state.n_atoms_per_system.tolist() if memory_scales_with == "n_atoms_x_density": - if all(state.pbc): - volume = torch.abs(torch.linalg.det(state.cell[0])) / 1000 - else: - bbox = state.positions.max(dim=0).values - state.positions.min(dim=0).values - # add 2 A in non-periodic directions to account for 2D systems and slabs - for i, periodic in enumerate(state.pbc): - if not periodic: - bbox[i] += 2.0 - volume = bbox.prod() / 1000 # convert A^3 to nm^3 - number_density = state.n_atoms / volume.item() - return state.n_atoms * number_density + if state.n_systems > 1 and state.pbc.all().item(): # vectorized path + n_atoms = state.n_atoms_per_system.to(state.volume.dtype) + volume = torch.abs(state.volume) / 1000 # A^3 -> nm^3 + return torch.where(volume > 0, n_atoms * n_atoms / volume, n_atoms).tolist() + # per-system path (non-periodic or single system) + scalers = [] + for s in state.split(): + if all(s.pbc): + volume = torch.abs(torch.linalg.det(s.cell[0])) / 1000 + else: + bbox = s.positions.max(dim=0).values - s.positions.min(dim=0).values + for i, periodic in enumerate(s.pbc): + if not periodic: + bbox[i] += 2.0 + volume = bbox.prod() / 1000 + scalers.append(s.n_atoms * (s.n_atoms / volume.item())) + return scalers raise ValueError( f"Invalid metric: {memory_scales_with}, must be one of {get_args(MemoryScaling)}" ) @@ -397,7 +380,7 @@ def estimate_max_memory_scaler( state_list (list[SimState]): States to test, each with shape information specific to the SimState instance. metric_values (list[float]): Corresponding metric values for each state, - as calculated by calculate_memory_scaler(). + as calculated by calculate_memory_scalers(). **kwargs: Additional keyword arguments passed to determine_max_batch_size. Returns: @@ -406,7 +389,7 @@ def estimate_max_memory_scaler( Example:: # Calculate metrics for a set of states - metrics = [calculate_memory_scaler(state) for state in states] + metrics = [calculate_memory_scalers(state) for state in states] # Estimate maximum safe metric value max_metric = estimate_max_memory_scaler(model, states, metrics) @@ -556,11 +539,11 @@ def load_states(self, states: T | Sequence[T]) -> float: This method resets the current state bin index, so any ongoing iteration will be restarted when this method is called. """ - self.state_slices = states.split() if isinstance(states, SimState) else states - self.memory_scalers = [ - calculate_memory_scaler(state_slice, self.memory_scales_with) - for state_slice in self.state_slices - ] + batched = ( + states if isinstance(states, SimState) else ts.concatenate_states(states) + ) + self.memory_scalers = calculate_memory_scalers(batched, self.memory_scales_with) + self.state_slices = batched.split() if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( self.state_slices, @@ -589,9 +572,8 @@ def load_states(self, states: T | Sequence[T]) -> float: ) # list[dict[original_index: int, memory_scale:float]] # Convert to list of lists of indices self.index_bins = [list(batch.keys()) for batch in self.index_bins] - self.batched_states = [] - for index_bin in self.index_bins: - self.batched_states.append([self.state_slices[idx] for idx in index_bin]) + # Build batches: one sliced state per bin + self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] self.current_state_bin = 0 return self.max_memory_scaler @@ -846,7 +828,7 @@ def load_states(self, states: Sequence[T] | Iterator[T] | T) -> None: """ if isinstance(states, SimState): states = states.split() - if isinstance(states, list | tuple): + if not isinstance(states, Iterator): states = iter(states) self.states_iterator = states @@ -875,7 +857,7 @@ def _get_next_states(self) -> list[T]: new_idx: list[int] = [] new_states: list[T] = [] for state in self.states_iterator: - metric = calculate_memory_scaler(state, self.memory_scales_with) + metric = calculate_memory_scalers(state, self.memory_scales_with)[0] if metric > self.max_memory_scaler: raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" @@ -931,7 +913,7 @@ def _get_first_batch(self) -> T: # we need to sample a state and use it to estimate the max metric # for the first batch first_state = next(self.states_iterator) - first_metric = calculate_memory_scaler(first_state, self.memory_scales_with) + first_metric = calculate_memory_scalers(first_state, self.memory_scales_with)[0] self.current_scalers += [first_metric] self.current_idx += [0] self.iteration_count.append(0) # Initialize attempt counter for first state diff --git a/torch_sim/state.py b/torch_sim/state.py index 063e9b585..4b882da0c 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -137,8 +137,10 @@ def __post_init__(self) -> None: # noqa: C901 self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) + n_systems_val = 1 else: # assert that system indices are unique consecutive integers _, counts = torch.unique_consecutive(initial_system_idx, return_counts=True) + n_systems_val = len(counts) if not torch.all(counts == torch.bincount(initial_system_idx)): raise ValueError("System indices must be unique consecutive integers") @@ -146,15 +148,13 @@ def __post_init__(self) -> None: # noqa: C901 validate_constraints(self.constraints, state=self) if self.charge is None: - self.charge = torch.zeros( - self.n_systems, device=self.device, dtype=self.dtype - ) - elif self.charge.shape[0] != self.n_systems: - raise ValueError(f"Charge must have shape (n_systems={self.n_systems},)") + self.charge = torch.zeros(n_systems_val, device=self.device, dtype=self.dtype) + elif self.charge.shape[0] != n_systems_val: + raise ValueError(f"Charge must have shape (n_systems={n_systems_val},)") if self.spin is None: - self.spin = torch.zeros(self.n_systems, device=self.device, dtype=self.dtype) - elif self.spin.shape[0] != self.n_systems: - raise ValueError(f"Spin must have shape (n_systems={self.n_systems},)") + self.spin = torch.zeros(n_systems_val, device=self.device, dtype=self.dtype) + elif self.spin.shape[0] != n_systems_val: + raise ValueError(f"Spin must have shape (n_systems={n_systems_val},)") if self.cell.ndim != 3 and initial_system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -162,9 +162,10 @@ def __post_init__(self) -> None: # noqa: C901 if self.cell.shape[-2:] != (3, 3): raise ValueError("Cell must have shape (n_systems, 3, 3)") - if self.cell.shape[0] != self.n_systems: + if self.cell.shape[0] != n_systems_val: raise ValueError( - f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" + f"Cell must have shape (n_systems={n_systems_val}, 3, 3), " + f"got {self.cell.shape}" ) # if devices aren't all the same, raise an error, in a clean way @@ -450,14 +451,14 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: - """Split the SimState into a list of single-system SimStates. + def split(self) -> Sequence[Self]: + """Split the SimState into a sequence of single-system SimStates (O(1)). - Divides the current state into separate states, each containing a single system, - preserving all properties appropriately for each system. + Each single-system state is created on first access (index or iteration), + so the call itself is O(1). Use like a list: len(s), s[i], for x in s. Returns: - list[SimState]: A list of SimState objects, one per system + Sequence[SimState]: A sequence of SimState objects, one per system """ return _split_state(self) @@ -664,6 +665,58 @@ def deform_grad(self) -> torch.Tensor: return self._deform_grad(self.reference_row_vector_cell, self.row_vector_cell) +def _split_state[T: SimState](state: T) -> Sequence[T]: + """Return a lazy Sequence view of state split into single-system states. + + Each single-system state is created on first access, so the call is O(1). + + Args: + state: The SimState to split. + + Returns: + A Sequence of SimState objects, one per system. + """ + cumsum = torch.cat( + (state.n_atoms_per_system.new_zeros(1), state.n_atoms_per_system.cumsum(0)) + ) + n_systems = state.n_systems + + def get_single(idx: int) -> T: + start, end = int(cumsum[idx]), int(cumsum[idx + 1]) + attrs: dict[str, Any] = { + "system_idx": torch.zeros( + end - start, device=state.device, dtype=torch.int64 + ), + **dict(get_attrs_for_scope(state, "global")), + } + for name, val in get_attrs_for_scope(state, "per-atom"): + if name != "system_idx": + attrs[name] = val[start:end] + for name, val in get_attrs_for_scope(state, "per-system"): + attrs[name] = val[idx : idx + 1] if isinstance(val, torch.Tensor) else val + atom_idx = torch.arange(start, end, device=state.device) + attrs["_constraints"] = [ + c + for con in state.constraints + if (c := con.select_sub_constraint(atom_idx, idx)) + ] + return type(state)(**attrs) + + def _len(_: object) -> int: + return n_systems + + def _getitem(_: object, idx: int | slice) -> T | list[T]: + if isinstance(idx, slice): + return [get_single(i) for i in range(*idx.indices(n_systems))] + if idx < 0: + idx += n_systems + if not 0 <= idx < n_systems: + raise IndexError(f"index {idx} out of range [0, {n_systems})") + return get_single(idx) + + return type("SplitSeq", (Sequence,), {"__len__": _len, "__getitem__": _getitem})() + + def _normalize_system_indices( system_indices: int | Sequence[int] | slice | torch.Tensor, n_systems: int, @@ -833,76 +886,6 @@ def _filter_attrs_by_mask( return filtered_attrs -def _split_state[T: SimState](state: T) -> list[T]: - """Split a SimState into a list of states, each containing a single system. - - Divides a multi-system state into individual single-system states, preserving - appropriate properties for each system. - - Args: - state (SimState): The SimState to split - - Returns: - list[SimState]: A list of SimState objects, each containing a single - system - """ - system_sizes = state.n_atoms_per_system.tolist() - - split_per_atom = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - - split_per_system = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - if isinstance(attr_value, torch.Tensor): - split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) - else: # Non-tensor attributes are replicated for each split - split_per_system[attr_name] = [attr_value] * state.n_systems - - global_attrs = dict(get_attrs_for_scope(state, "global")) - - # Create a state for each system - states: list[T] = [] - n_systems = len(system_sizes) - zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) - cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) - for sys_idx in range(n_systems): - # Build per-system attributes (padded attributes stay padded for consistency) - per_system_dict = { - attr_name: split_per_system[attr_name][sys_idx] - for attr_name in split_per_system - } - - system_attrs = { - # Create a system tensor with all zeros for this system - "system_idx": torch.zeros( - system_sizes[sys_idx], device=state.device, dtype=torch.int64 - ), - # Add the split per-atom attributes - **{ - attr_name: split_per_atom[attr_name][sys_idx] - for attr_name in split_per_atom - }, - # Add the split per-system attributes (with unpadding applied) - **per_system_dict, - # Add the global attributes - **global_attrs, - } - - atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) - new_constraints = [ - new_constraint - for constraint in state.constraints - if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) - ] - - system_attrs["_constraints"] = new_constraints - states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] - - return states - - def _pop_states[T: SimState]( state: T, pop_indices: list[int] | torch.Tensor ) -> tuple[T, list[T]]: @@ -945,7 +928,7 @@ def _pop_states[T: SimState]( # Create and split the pop state pop_state: T = type(state)(**pop_attrs) # type: ignore[assignment] - pop_states = _split_state(pop_state) + pop_states = list(pop_state.split()) return keep_state, pop_states @@ -954,12 +937,13 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving - all relevant properties. + the requested order. E.g., system_indices=[3, 1, 4] results in original + systems 3, 1, 4 becoming new systems 0, 1, 2. Args: state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the - sliced state + sliced state (order preserved in the result) Returns: SimState: A new SimState object containing only the specified systems @@ -975,15 +959,45 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - # Create masks for the atoms and systems to include - system_range = torch.arange(state.n_systems, device=state.device) - system_mask = torch.isin(system_range, system_indices) - atom_mask = torch.isin(state.system_idx, system_indices) + # Build atom indices in requested order (preserves system_indices order) + system_indices = system_indices.reshape(-1) + cumsum = torch.cat( + (state.n_atoms_per_system.new_zeros(1), state.n_atoms_per_system.cumsum(0)) + ) + atom_indices = torch.cat( + [ + torch.arange(cumsum[i].item(), cumsum[i + 1].item(), device=state.device) + for i in system_indices + ] + ) + + # Create masks for constraint selection + atom_mask = torch.zeros(state.n_atoms, dtype=torch.bool, device=state.device) + atom_mask[atom_indices] = True + system_mask = torch.zeros(state.n_systems, dtype=torch.bool, device=state.device) + system_mask[system_indices] = True - # Filter attributes - filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) + # Build inverse map for system_idx remapping + inv = torch.empty( + system_indices.max().item() + 1, device=state.device, dtype=torch.long + ) + inv[system_indices] = torch.arange(len(system_indices), device=state.device) - # Create the sliced state + # Filter attributes preserving requested order + filtered_attrs = dict(get_attrs_for_scope(state, "global")) + filtered_attrs["_constraints"] = [ + c + for con in copy.deepcopy(state.constraints) + if (c := con.select_constraint(atom_mask, system_mask)) + ] + for name, val in get_attrs_for_scope(state, "per-atom"): + filtered_attrs[name] = ( + inv[val[atom_indices]] if name == "system_idx" else val[atom_indices] + ) + for name, val in get_attrs_for_scope(state, "per-system"): + filtered_attrs[name] = ( + val[system_indices] if isinstance(val, torch.Tensor) else val + ) return type(state)(**filtered_attrs) # type: ignore[invalid-return-type] @@ -1134,8 +1148,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 def initialize_state( system: StateLike, - device: torch.device | None = None, - dtype: torch.dtype | None = None, + device: torch.device, + dtype: torch.dtype, ) -> SimState: """Initialize state tensors from a atomistic system representation.