Skip to content

Accelerated ts.optimize by batching Frechet Cell Filter #439

Open
falletta wants to merge 9 commits intoTorchSim:mainfrom
falletta:speedup_relax
Open

Accelerated ts.optimize by batching Frechet Cell Filter #439
falletta wants to merge 9 commits intoTorchSim:mainfrom
falletta:speedup_relax

Conversation

@falletta
Copy link
Contributor

@falletta falletta commented Feb 3, 2026

Update Summary


1. torch_sim/math.py

Removed unbatched/legacy code:

  • expm_frechet_block_enlarge (helper function for block enlargement method)
  • _diff_pade3, _diff_pade5, _diff_pade7, _diff_pade9 (Padé approximation helpers)
  • expm_frechet_algo_64 (original algorithm implementation)
  • matrix_exp (custom matrix exponential function)
  • vec, expm_frechet_kronform (Kronecker form helpers)
  • expm_cond (condition number estimation)
  • class expm (autograd Function class)
  • _is_valid_matrix, _determine_eigenvalue_case (unbatched helpers)

Refactored expm_frechet:

  • Now optimized specifically for batched 3x3 matrices (common case for cell operations)
  • Handles both (B, 3, 3) batched input and (3, 3) single matrix input (auto-adds batch dim)
  • Removed method parameter (was SPS"or blockEnlarge)
  • Inlined the algorithm directly instead of calling helper functions

Refactored matrix_log_33:

  • Added _ensure_batched, _determine_matrix_log_cases, _process_matrix_log_case helpers
  • Made the matrix log computation work in batched mode

2. torch_sim/optimizers/cell_filters.py

Vectorized compute_cell_forces:

  • Before: used nested loops over systems and directions (9 iterations per system)
  • After: uses batched matrix operations across all systems and all 9 directions simultaneously
  • Key optimization: expm_frechet(A_batch, E_batch) is now called once with all n_systems * 9 matrices batched together

3. tests/test_math.py

Refactored tests:

  • TestExpmFrechet: test_expm_frechet, test_small_norm_expm_frechet, test_fuzz
  • TestExpmFrechetTorch: test_expm_frechet, test_fuzz

All updated to use 3x3 matrices and simplified by removing method parameter testing. Fuzz tests streamlined with fewer iterations.

Removed tests:

  • test_problematic_matrix, test_medium_matrix (both numpy and torch versions)
  • TestExpmFrechetTorchGrad class

Tests for comparing computation methods and large matrix performance no longer apply to the 3x3-specialized implementation.

Added tests:

  • TestExpmFrechet.test_large_norm_matrices - Tests scaling behavior for larger norm matrices
  • TestLogM33.test_batched_positive_definite - Tests batched matrix logarithm with round-trip verification
  • TestFrechetCellFilterIntegration - Integration tests for the cell filter pipeline
  • test_wrap_positions_* - Tests for the new wrap_positions property

Results

The figure below shows the speedup achieved for 10-step atomic relaxation. The test is performed for a 8-atom cubic supercell of MgO using the mace-mpa model. Prior results are shown in blue, while new results are shown in red. The speedup is calculated as speedup (%) = (baseline_time / current_time − 1) × 100. We observe a speedup up to 564% for large batches.
Screenshot 2026-02-03 at 2 57 00 PM

Co-authored-by: Cursor <cursoragent@cursor.com>
@orionarcher
Copy link
Collaborator

Could we get some tests verifying identical numerical behavior of the old and new versions? Can be deleted before merging when we get rid of the unbatched version.

@falletta
Copy link
Contributor Author

falletta commented Feb 4, 2026

@orionarcher I added test_math_frechet.py to compare the batched code against the SciPy implementations. Please have a look—happy to revise it as needed, and if it looks good, we can include it directly in test_math.py.

num_tol = 1e-16 if dtype == torch.float64 else 1e-8
batched = T.dim() == 3

if batched:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why support both batched and unbatched versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now removed all unbatched code

class TestExpmFrechet:
"""Tests for expm_frechet against scipy.linalg.expm_frechet."""

def test_small_matrix(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we testing the batched or unbatched versions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I incorporated the batched tests only in test_math.py

@thomasloux
Copy link
Collaborator

Haven't looked carefully on the implementation but I would potentially support to have 2 separate functions for batched (B, 3, 3) and unbatched (3,3) algorithms. This would also prevent graph breaks in the future, be easier to read, and in practice a state.cell is always (B, 3, 3), potentially with B=1. So we would always use the batched version anyway.

Co-authored-by: Cursor <cursoragent@cursor.com>
@falletta
Copy link
Contributor Author

falletta commented Feb 5, 2026

@orionarcher I removed all unbatched and unused code while preserving the new performance speedups. Please see the PR description for a detailed list of changes.

@thomasloux It’s indeed a good point, but for now it’s probably better to keep things clean and stick to the batched implementation only. By keeping only the batched implementation, we can remove quite a few lines of dead code.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement batched versions of math operations used in Frechet and all treatment of complex eigs in logM

4 participants