diff --git a/tests/test_math.py b/tests/test_math.py index 21f42f11..f2f527a5 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -3,8 +3,6 @@ # ruff: noqa: SLF001 -import itertools - import numpy as np import scipy import torch @@ -18,119 +16,70 @@ class TestExpmFrechet: - """Test suite for expm_frechet using numpy arrays converted to torch tensors.""" + """Test suite for expm_frechet for 3x3 matrices.""" def test_expm_frechet(self): """Test basic functionality of expm_frechet against scipy implementation.""" - M = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], dtype=np.float64 - ) - A = np.array([[1, 2], [5, 6]], dtype=np.float64) - E = np.array([[3, 4], [7, 8]], dtype=np.float64) + A = np.array([[1, 2, 0], [5, 6, 0], [0, 0, 1]], dtype=np.float64) + E = np.array([[3, 4, 0], [7, 8, 0], [0, 0, 0]], dtype=np.float64) expected_expm = scipy.linalg.expm(A) - expected_frechet = scipy.linalg.expm(M)[:2, 2:] + expected_frechet = scipy.linalg.expm_frechet(A, E)[1] A = torch.from_numpy(A).to(device=device) E = torch.from_numpy(E).to(device=device) - for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}): - # Convert it to numpy arrays before passing it to the function - observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs) - assert_allclose(expected_expm, observed_expm.cpu().numpy()) - assert_allclose(expected_frechet, observed_frechet.cpu().numpy()) + observed_expm, observed_frechet = fm.expm_frechet(A, E) + assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14) + assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14) def test_small_norm_expm_frechet(self): - """Test matrices with a range of norms for better coverage.""" - M_original = np.array( - [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], dtype=np.float64 - ) - A_original = np.array([[1, 2], [5, 6]], dtype=np.float64) - E_original = np.array([[3, 4], [7, 8]], dtype=np.float64) - A_original_norm_1 = scipy.linalg.norm(A_original, 1) - selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15] - - m_neighbor_pairs = itertools.pairwise(selected_m_list) - for ma, mb in m_neighbor_pairs: - ell_a = scipy.linalg._expm_frechet.ell_table_61[ma] - ell_b = scipy.linalg._expm_frechet.ell_table_61[mb] - target_norm_1 = 0.5 * (ell_a + ell_b) - scale = target_norm_1 / A_original_norm_1 - M = scale * M_original - A = scale * A_original - E = scale * E_original - expected_expm = scipy.linalg.expm(A) - expected_frechet = scipy.linalg.expm(M)[:2, 2:] - A = torch.from_numpy(A).to(device=device, dtype=DTYPE) - E = torch.from_numpy(E).to(device=device, dtype=DTYPE) - # Convert it to numpy arrays before passing it to the function - observed_expm, observed_frechet = fm.expm_frechet(A, E) - assert_allclose(expected_expm, observed_expm.cpu().numpy()) - assert_allclose(expected_frechet, observed_frechet.cpu().numpy()) + """Test matrices with small norms.""" + A = np.array([[0.1, 0.2, 0], [0.5, 0.6, 0], [0, 0, 0.1]], dtype=np.float64) + E = np.array([[0.3, 0.4, 0], [0.7, 0.8, 0], [0, 0, 0]], dtype=np.float64) + expected_expm = scipy.linalg.expm(A) + expected_frechet = scipy.linalg.expm_frechet(A, E)[1] + + A = torch.from_numpy(A).to(device=device, dtype=DTYPE) + E = torch.from_numpy(E).to(device=device, dtype=DTYPE) + observed_expm, observed_frechet = fm.expm_frechet(A, E) + assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14) + assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14) def test_fuzz(self): - """Test with a variety of random inputs to ensure robustness.""" + """Test with a variety of random 3x3 inputs to ensure robustness.""" rng = np.random.default_rng(1726500908359153) - # try a bunch of crazy inputs - rfuncs = ( - np.random.uniform, - np.random.normal, - np.random.standard_cauchy, - np.random.exponential, - ) - ntests = 100 + ntests = 20 for _ in range(ntests): - rfunc = rfuncs[rng.choice(4)] target_norm_1 = rng.exponential() - n = rng.integers(2, 16) - A_original = rfunc(size=(n, n)) - E_original = rfunc(size=(n, n)) + A_original = rng.standard_normal((3, 3)) + E_original = rng.standard_normal((3, 3)) A_original_norm_1 = scipy.linalg.norm(A_original, 1) scale = target_norm_1 / A_original_norm_1 A = scale * A_original E = scale * E_original - M = np.vstack([np.hstack([A, E]), np.hstack([np.zeros_like(A), A])]) expected_expm = scipy.linalg.expm(A) - expected_frechet = scipy.linalg.expm(M)[:n, n:] + expected_frechet = scipy.linalg.expm_frechet(A, E)[1] A = torch.from_numpy(A).to(device=device, dtype=DTYPE) E = torch.from_numpy(E).to(device=device, dtype=DTYPE) - # Convert it to numpy arrays before passing it to the function observed_expm, observed_frechet = fm.expm_frechet(A, E) assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=5e-8) assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-7) - def test_problematic_matrix(self): - """Test a specific matrix that previously uncovered a bug.""" + def test_large_norm_matrices(self): + """Test expm_frechet with larger norm matrices requiring scaling.""" A = np.array( - [[1.50591997, 1.93537998], [0.41203263, 0.23443516]], dtype=np.float64 + [[1.5, 0.8, 0.3], [0.6, 1.2, 0.5], [0.4, 0.7, 1.8]], dtype=np.float64 ) E = np.array( - [[1.87864034, 2.07055038], [1.34102727, 0.67341123]], dtype=np.float64 - ) - A = torch.from_numpy(A).to(device=device, dtype=DTYPE) - E = torch.from_numpy(E).to(device=device, dtype=DTYPE) - # Convert it to numpy arrays before passing it to the function - sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( - A, E, method="blockEnlarge" + [[0.1, 0.2, 0.1], [0.2, 0.15, 0.1], [0.1, 0.1, 0.2]], dtype=np.float64 ) - assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy()) - assert_allclose(sps_frechet.cpu().numpy(), blockEnlarge_frechet.cpu().numpy()) - - def test_medium_matrix(self): - """Test with a medium-sized matrix to compare performance between methods.""" - n = 1000 - rng = np.random.default_rng() - A = rng.exponential(size=(n, n)) - E = rng.exponential(size=(n, n)) + expected_expm = scipy.linalg.expm(A) + expected_frechet = scipy.linalg.expm_frechet(A, E)[1] A = torch.from_numpy(A).to(device=device, dtype=DTYPE) E = torch.from_numpy(E).to(device=device, dtype=DTYPE) - # Convert it to numpy arrays before passing it to the function - sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( - A, E, method="blockEnlarge" - ) - assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy()) - assert_allclose(sps_frechet.cpu().numpy(), blockEnlarge_frechet.cpu().numpy()) + observed_expm, observed_frechet = fm.expm_frechet(A, E) + assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-12) + assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-12) class TestExpmFrechetTorch: @@ -138,149 +87,28 @@ class TestExpmFrechetTorch: def test_expm_frechet(self): """Test basic functionality of expm_frechet against torch.linalg.matrix_exp.""" - M = torch.tensor( - [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], - dtype=DTYPE, - device=device, - ) - A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) - E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) + A = torch.tensor([[1, 2, 0], [5, 6, 0], [0, 0, 1]], dtype=DTYPE, device=device) + E = torch.tensor([[3, 4, 0], [7, 8, 0], [0, 0, 0]], dtype=DTYPE, device=device) expected_expm = torch.linalg.matrix_exp(A) - expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] - for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}): - observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs) - torch.testing.assert_close(expected_expm, observed_expm) - torch.testing.assert_close(expected_frechet, observed_frechet) - - def test_small_norm_expm_frechet(self): - """Test matrices with a range of norms for better coverage using torch tensors.""" - M_original = torch.tensor( - [ - [1, 2, 3, 4], - [5, 6, 7, 8], - [0, 0, 1, 2], - [0, 0, 5, 6], - ], - dtype=DTYPE, - device=device, - ) - A_original = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) - E_original = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) - A_original_norm_1 = torch.linalg.norm(A_original, 1) - selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15] - m_neighbor_pairs = itertools.pairwise(selected_m_list) - for ma, mb in m_neighbor_pairs: - ell_a = fm.ell_table_61[ma] - ell_b = fm.ell_table_61[mb] - target_norm_1 = 0.5 * (ell_a + ell_b) - scale = target_norm_1 / A_original_norm_1 - M = scale * M_original - A = scale * A_original - E = scale * E_original - expected_expm = torch.linalg.matrix_exp(A) - expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] - observed_expm, observed_frechet = fm.expm_frechet(A, E) - torch.testing.assert_close(expected_expm, observed_expm) - torch.testing.assert_close(expected_frechet, observed_frechet) + observed_expm, _ = fm.expm_frechet(A, E) + torch.testing.assert_close(expected_expm, observed_expm) def test_fuzz(self): - """Test with a variety of random inputs using torch tensors.""" + """Test with a variety of random 3x3 inputs using torch tensors.""" rng = np.random.default_rng(1726500908359153) - # try a bunch of crazy inputs - # Convert random functions to tensor-generating functions - tensor_rfuncs = ( - lambda size, device="cpu": torch.tensor( - rng.uniform(size=size), device=device - ), - lambda size, device="cpu": torch.tensor(rng.normal(size=size), device=device), - lambda size, device="cpu": torch.tensor( - rng.standard_cauchy(size=size), device=device - ), - lambda size, device="cpu": torch.tensor( - rng.exponential(size=size), device=device - ), - ) - ntests = 100 + ntests = 20 for _ in range(ntests): - rfunc = tensor_rfuncs[torch.tensor(rng.choice(4))] - target_norm_1 = torch.tensor(rng.exponential()) - n = torch.tensor(rng.integers(2, 16)) - A_original = rfunc(size=(n, n)) - E_original = rfunc(size=(n, n)) + target_norm_1 = rng.exponential() + A_original = torch.tensor(rng.standard_normal((3, 3)), device=device) + E_original = torch.tensor(rng.standard_normal((3, 3)), device=device) A_original_norm_1 = torch.linalg.norm(A_original, 1) scale = target_norm_1 / A_original_norm_1 A = scale * A_original E = scale * E_original - M = torch.vstack( - [torch.hstack([A, E]), torch.hstack([torch.zeros_like(A), A])] - ) expected_expm = torch.linalg.matrix_exp(A) - expected_frechet = torch.linalg.matrix_exp(M)[:n, n:] - observed_expm, observed_frechet = fm.expm_frechet(A, E) + observed_expm, _ = fm.expm_frechet(A, E) torch.testing.assert_close(expected_expm, observed_expm, atol=5e-8, rtol=1e-5) - torch.testing.assert_close( - expected_frechet, observed_frechet, atol=1e-7, rtol=1e-5 - ) - - def test_problematic_matrix(self): - """Test a specific matrix that previously uncovered a bug using torch tensors.""" - A = torch.tensor( - [[1.50591997, 1.93537998], [0.41203263, 0.23443516]], - dtype=DTYPE, - device=device, - ) - E = torch.tensor( - [[1.87864034, 2.07055038], [1.34102727, 0.67341123]], - dtype=DTYPE, - device=device, - ) - sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( - A, E, method="blockEnlarge" - ) - torch.testing.assert_close(sps_expm, blockEnlarge_expm) - torch.testing.assert_close(sps_frechet, blockEnlarge_frechet) - - def test_medium_matrix(self): - """Test with a medium-sized matrix to compare performance - between methods using torch tensors. - """ - n = 1000 - rng = np.random.default_rng() - A = torch.tensor(rng.exponential(size=(n, n))) - E = torch.tensor(rng.exponential(size=(n, n))) - - sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS") - blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet( - A, E, method="blockEnlarge" - ) - torch.testing.assert_close(sps_expm, blockEnlarge_expm) - torch.testing.assert_close(sps_frechet, blockEnlarge_frechet) - - -class TestExpmFrechetTorchGrad: - """Test suite for gradient computation with expm and its Frechet derivative.""" - - def test_expm_frechet(self): - """Test gradient computation for matrix exponential and its Frechet derivative.""" - M = torch.tensor( - [[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], - dtype=DTYPE, - device=device, - ) - A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device) - E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device) - expected_expm = torch.linalg.matrix_exp(A) - expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:] - # expm will use the SPS method as default - observed_expm = fm.expm.apply(A) - torch.testing.assert_close(expected_expm, observed_expm) - # Compute the Frechet derivative in the direction of grad_output - A.requires_grad = True - observed_expm = fm.expm.apply(A) - (observed_frechet,) = torch.autograd.grad(observed_expm, A, E, retain_graph=True) - torch.testing.assert_close(expected_frechet, observed_frechet) class TestLogM33: @@ -469,3 +297,53 @@ def test_nearly_degenerate(self): torch.testing.assert_close( M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device) ) + + def test_batched_positive_definite(self): + """Test batched matrix logarithm with positive definite matrices.""" + batch_size = 3 + rng = np.random.default_rng(42) + L = rng.standard_normal((batch_size, 3, 3)) + M_np = np.array([L[i] @ L[i].T + 0.5 * np.eye(3) for i in range(batch_size)]) + M_torch = torch.tensor(M_np, dtype=torch.float64) + + log_torch = fm.matrix_log_33(M_torch) + + for i in range(batch_size): + log_scipy = scipy.linalg.logm(M_np[i]).real + assert_allclose(log_torch[i].numpy(), log_scipy, atol=1e-12) + # Verify round-trip: exp(log(M)) = M + M_recovered = torch.matrix_exp(log_torch[i]) + assert_allclose(M_recovered.numpy(), M_np[i], atol=1e-10) + + +class TestFrechetCellFilterIntegration: + """Integration tests for the Frechet cell filter pipeline.""" + + def test_frechet_derivatives_vs_scipy(self): + """Test Frechet derivative computation matches scipy.""" + n_systems = 2 + torch.manual_seed(42) + + # Create small deformations + deform_log = torch.randn(n_systems, 3, 3, dtype=torch.float64) * 0.01 + + # Compute Frechet derivatives for all 9 directions + idx_flat = torch.arange(9) + i_idx, j_idx = idx_flat // 3, idx_flat % 3 + directions = torch.zeros((9, 3, 3), dtype=torch.float64) + directions[idx_flat, i_idx, j_idx] = 1.0 + + A_batch = deform_log.unsqueeze(1).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3) + E_batch = directions.unsqueeze(0).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3) + _, frechet_torch = fm.expm_frechet(A_batch, E_batch) + frechet_torch = frechet_torch.reshape(n_systems, 9, 3, 3) + + # Compare with scipy + for sys_idx in range(n_systems): + for dir_idx in range(9): + A_np = deform_log[sys_idx].numpy() + E_np = directions[dir_idx].numpy() + _, frechet_scipy = scipy.linalg.expm_frechet(A_np, E_np) + assert_allclose( + frechet_torch[sys_idx, dir_idx].numpy(), frechet_scipy, atol=1e-12 + ) diff --git a/torch_sim/math.py b/torch_sim/math.py index 737422dc..58e83b1f 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -2,10 +2,9 @@ # ruff: noqa: FBT001, FBT002, RUF002, RUF003 -from typing import Any, Final +from typing import Final import torch -from torch.autograd import Function @torch.jit.script @@ -26,20 +25,20 @@ def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch. return d, m -def expm_frechet( # noqa: C901 +def expm_frechet( # noqa: PLR0915, C901 A: torch.Tensor, E: torch.Tensor, - method: str | None = None, check_finite: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Frechet derivative of the matrix exponential of A in the direction E. + Optimized for batched 3x3 matrices. Also handles single 3x3 matrices by + auto-adding a batch dimension. + Args: - A: (N, N) array_like. Matrix of which to take the matrix exponential. - E: (N, N) array_like. Matrix direction in which to take the Frechet derivative. - method: str, optional. Choice of algorithm. Should be one of - - `SPS` (default) - - `blockEnlarge` + A: (B, 3, 3) or (3, 3) tensor. Matrix of which to take the matrix exponential. + E: (B, 3, 3) or (3, 3) tensor. Matrix direction in which to take the Frechet + derivative. Must have same shape as A. check_finite: bool, optional. Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain @@ -63,50 +62,80 @@ def expm_frechet( # noqa: C901 if not isinstance(E, torch.Tensor): E = torch.tensor(E, dtype=torch.float64) - if A.dim() != 2 or A.shape[0] != A.shape[1]: - raise ValueError("expected A to be a square matrix") - if E.dim() != 2 or E.shape[0] != E.shape[1]: - raise ValueError("expected E to be a square matrix") - if A.shape != E.shape: - raise ValueError("expected A and E to be the same shape") - - if method is None: - method = "SPS" - - if method == "SPS": - expm_A, expm_frechet_AE = expm_frechet_algo_64(A, E) - elif method == "blockEnlarge": - expm_A, expm_frechet_AE = expm_frechet_block_enlarge(A, E) - else: - raise ValueError(f"Unknown {method=}") - - return expm_A, expm_frechet_AE + # Handle unbatched 3x3 input by adding batch dimension + unbatched = A.dim() == 2 + if unbatched: + if A.shape != (3, 3): + raise ValueError("expected A to be (3, 3) or (B, 3, 3)") + if E.shape != (3, 3): + raise ValueError("expected E to be (3, 3) or (B, 3, 3)") + A = A.unsqueeze(0) + E = E.unsqueeze(0) + + if E.dim() != 3 or A.shape != E.shape or A.shape[1:] != (3, 3): + raise ValueError("expected A, E to be (B, 3, 3) with same shape") + + batch_size = A.shape[0] + device, dtype = A.device, A.dtype + ident = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, 3, 3) + + A_norm_1 = torch.norm(A, p=1, dim=(-2, -1)) + scale_val = torch.log2( + torch.clamp(A_norm_1.max() / ell_table_61[13], min=1.0, max=2.0**64) + ) + s = max(0, min(int(torch.ceil(scale_val).item()), 64)) + A = A * 2.0**-s + E = E * 2.0**-s + A2 = torch.matmul(A, A) + M2 = torch.matmul(A, E) + torch.matmul(E, A) + A4 = torch.matmul(A2, A2) + M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) + A6 = torch.matmul(A2, A4) + M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2) -def expm_frechet_block_enlarge( - A: torch.Tensor, E: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Helper function for testing and profiling. + b = ( + 64764752532480000.0, + 32382376266240000.0, + 7771770303897600.0, + 1187353796428800.0, + 129060195264000.0, + 10559470521600.0, + 670442572800.0, + 33522128640.0, + 1323241920.0, + 40840800.0, + 960960.0, + 16380.0, + 182.0, + 1.0, + ) + W1 = b[13] * A6 + b[11] * A4 + b[9] * A2 + W2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident + Z1 = b[12] * A6 + b[10] * A4 + b[8] * A2 + Z2 = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident + W = torch.matmul(A6, W1) + W2 + U = torch.matmul(A, W) + V = torch.matmul(A6, Z1) + Z2 + + Lw1 = b[13] * M6 + b[11] * M4 + b[9] * M2 + Lw2 = b[7] * M6 + b[5] * M4 + b[3] * M2 + Lz1 = b[12] * M6 + b[10] * M4 + b[8] * M2 + Lz2 = b[6] * M6 + b[4] * M4 + b[2] * M2 + Lw = torch.matmul(A6, Lw1) + torch.matmul(M6, W1) + Lw2 + Lu = torch.matmul(A, Lw) + torch.matmul(E, W) + Lv = torch.matmul(A6, Lz1) + torch.matmul(M6, Z1) + Lz2 - Args: - A: Input matrix - E: Direction matrix + R = torch.linalg.solve(-U + V, U + V) + L = torch.linalg.solve(-U + V, Lu + Lv + torch.matmul(Lu - Lv, R)) - Returns: - expm_A: Matrix exponential of A - expm_frechet_AE: torch.Tensor - Frechet derivative of the matrix exponential of A in the direction E - """ - n = A.shape[0] - # Create block matrix M = [[A, E], [0, A]] - M = torch.zeros((2 * n, 2 * n), dtype=A.dtype, device=A.device) - M[:n, :n] = A - M[:n, n:] = E - M[n:, n:] = A + for _ in range(s): + L = torch.matmul(R, L) + torch.matmul(L, R) + R = torch.matmul(R, R) - # Use matrix exponential - expm_M = matrix_exp(M) - return expm_M[:n, :n], expm_M[:n, n:] + if unbatched: + return R.squeeze(0), L.squeeze(0) + return R, L # Maximal values ell_m of ||2**-s A|| such that the backward error bound @@ -138,464 +167,21 @@ def expm_frechet_block_enlarge( ) -def _diff_pade3( - A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute Padé approximation of order 3 for matrix exponential and - its Frechet derivative. - - Args: - A: Input matrix - E: Direction matrix - ident: Identity matrix of same shape as A - - Returns: - U, V, Lu, Lv: Components needed for computing the matrix exponential and - its Frechet derivative - """ - b = (120.0, 60.0, 12.0, 1.0) - A2 = torch.matmul(A, A) - M2 = torch.matmul(A, E) + torch.matmul(E, A) - U = torch.matmul(A, b[3] * A2 + b[1] * ident) - V = b[2] * A2 + b[0] * ident - Lu = torch.matmul(A, b[3] * M2) + torch.matmul(E, b[3] * A2 + b[1] * ident) - Lv = b[2] * M2 - return U, V, Lu, Lv - - -def _diff_pade5( - A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute Padé approximation of order 5 for matrix exponential and - its Frechet derivative. - - Args: - A: Input matrix - E: Direction matrix - ident: Identity matrix of same shape as A - - Returns: - U, V, Lu, Lv: Components needed for computing the matrix exponential and - its Frechet derivative - """ - b = (30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0) - A2 = torch.matmul(A, A) - M2 = torch.matmul(A, E) + torch.matmul(E, A) - A4 = torch.matmul(A2, A2) - M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) - U = torch.matmul(A, b[5] * A4 + b[3] * A2 + b[1] * ident) - V = b[4] * A4 + b[2] * A2 + b[0] * ident - Lu = torch.matmul(A, b[5] * M4 + b[3] * M2) + torch.matmul( - E, b[5] * A4 + b[3] * A2 + b[1] * ident - ) - Lv = b[4] * M4 + b[2] * M2 - return U, V, Lu, Lv - - -def _diff_pade7( - A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute Padé approximation of order 7 for matrix exponential and - its Frechet derivative. - - Args: - A: Input matrix - E: Direction matrix - ident: Identity matrix of same shape as A - - Returns: - U, V, Lu, Lv: Components needed for computing the matrix exponential and - its Frechet derivative - """ - b = (17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0) - A2 = torch.matmul(A, A) - M2 = torch.matmul(A, E) + torch.matmul(E, A) - A4 = torch.matmul(A2, A2) - M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) - A6 = torch.matmul(A2, A4) - M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2) - U = torch.matmul(A, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident) - V = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident - Lu = torch.matmul(A, b[7] * M6 + b[5] * M4 + b[3] * M2) + torch.matmul( - E, b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident - ) - Lv = b[6] * M6 + b[4] * M4 + b[2] * M2 - return U, V, Lu, Lv - - -def _diff_pade9( - A: torch.Tensor, E: torch.Tensor, ident: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute Padé approximation of order 9 for matrix exponential and - its Frechet derivative. - - Args: - A: Input matrix - E: Direction matrix - ident: Identity matrix of same shape as A - - Returns: - U, V, Lu, Lv: Components needed for computing the matrix exponential and - its Frechet derivative - """ - b = ( - 17643225600.0, - 8821612800.0, - 2075673600.0, - 302702400.0, - 30270240.0, - 2162160.0, - 110880.0, - 3960.0, - 90.0, - 1.0, - ) - A2 = torch.matmul(A, A) - M2 = torch.matmul(A, E) + torch.matmul(E, A) - A4 = torch.matmul(A2, A2) - M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) - A6 = torch.matmul(A2, A4) - M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2) - A8 = torch.matmul(A4, A4) - M8 = torch.matmul(A4, M4) + torch.matmul(M4, A4) - U = torch.matmul(A, b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident) - V = b[8] * A8 + b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident - Lu = torch.matmul(A, b[9] * M8 + b[7] * M6 + b[5] * M4 + b[3] * M2) + torch.matmul( - E, b[9] * A8 + b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident - ) - Lv = b[8] * M8 + b[6] * M6 + b[4] * M4 + b[2] * M2 - return U, V, Lu, Lv - - -def expm_frechet_algo_64( - A: torch.Tensor, E: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute matrix exponential and its Frechet derivative using Algorithm 6.4. - - This implementation follows Al-Mohy and Higham's Algorithm 6.4 from - "Computing the Frechet Derivative of the Matrix Exponential, with an - application to Condition Number Estimation". - - Args: - A: Input matrix - E: Direction matrix - - Returns: - R: Matrix exponential of A - L: Frechet derivative of the matrix exponential in the direction E - """ - n = A.shape[0] - s = None - ident = torch.eye(n, dtype=A.dtype, device=A.device) - A_norm_1 = torch.norm(A, p=1) - m_pade_pairs = ( - (3, _diff_pade3), - (5, _diff_pade5), - (7, _diff_pade7), - (9, _diff_pade9), - ) - - for m, pade in m_pade_pairs: - if A_norm_1 <= ell_table_61[m]: - U, V, Lu, Lv = pade(A, E, ident) - s = 0 - break - - if s is None: - # scaling - s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13])))) - A = A * 2.0**-s - E = E * 2.0**-s - # pade order 13 - A2 = torch.matmul(A, A) - M2 = torch.matmul(A, E) + torch.matmul(E, A) - A4 = torch.matmul(A2, A2) - M4 = torch.matmul(A2, M2) + torch.matmul(M2, A2) - A6 = torch.matmul(A2, A4) - M6 = torch.matmul(A4, M2) + torch.matmul(M4, A2) - b = ( - 64764752532480000.0, - 32382376266240000.0, - 7771770303897600.0, - 1187353796428800.0, - 129060195264000.0, - 10559470521600.0, - 670442572800.0, - 33522128640.0, - 1323241920.0, - 40840800.0, - 960960.0, - 16380.0, - 182.0, - 1.0, - ) - W1 = b[13] * A6 + b[11] * A4 + b[9] * A2 - W2 = b[7] * A6 + b[5] * A4 + b[3] * A2 + b[1] * ident - Z1 = b[12] * A6 + b[10] * A4 + b[8] * A2 - Z2 = b[6] * A6 + b[4] * A4 + b[2] * A2 + b[0] * ident - W = torch.matmul(A6, W1) + W2 - U = torch.matmul(A, W) - V = torch.matmul(A6, Z1) + Z2 - Lw1 = b[13] * M6 + b[11] * M4 + b[9] * M2 - Lw2 = b[7] * M6 + b[5] * M4 + b[3] * M2 - Lz1 = b[12] * M6 + b[10] * M4 + b[8] * M2 - Lz2 = b[6] * M6 + b[4] * M4 + b[2] * M2 - Lw = torch.matmul(A6, Lw1) + torch.matmul(M6, W1) + Lw2 - Lu = torch.matmul(A, Lw) + torch.matmul(E, W) - Lv = torch.matmul(A6, Lz1) + torch.matmul(M6, Z1) + Lz2 - - # Solve the system (-U + V)X = (U + V) for R - R = torch.linalg.solve(-U + V, U + V) - - # Solve the system (-U + V)X = (Lu + Lv + (Lu - Lv)R) for L - L = torch.linalg.solve(-U + V, Lu + Lv + torch.matmul(Lu - Lv, R)) - - # squaring - for _ in range(s): - L = torch.matmul(R, L) + torch.matmul(L, R) - R = torch.matmul(R, R) - - return R, L - - -def matrix_exp(A: torch.Tensor) -> torch.Tensor: - """Compute the matrix exponential of A using PyTorch's matrix_exp. - - Args: - A: Input matrix - - Returns: - torch.Tensor: Matrix exponential of A - """ - return torch.matrix_exp(A) - - -def vec(M: torch.Tensor) -> torch.Tensor: - """Stack columns of M to construct a single vector. - - This is somewhat standard notation in linear algebra. - - Args: - M: Input matrix - - Returns: - torch.Tensor: Output vector - """ - return M.t().reshape(-1) - - -def expm_frechet_kronform( - A: torch.Tensor, method: str | None = None, check_finite: bool = True +def _identity_for_t( + T: torch.Tensor, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: - """Construct the Kronecker form of the Frechet derivative of expm. - - Args: - A: Square matrix tensor with shape (N, N) - method: Optional extra keyword to be passed to expm_frechet - check_finite: Whether to check that the input matrix contains only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - - Returns: - torch.Tensor: Kronecker form of the Frechet derivative of the matrix exponential - with shape (N*N, N*N) - """ - if check_finite and not torch.isfinite(A).all(): - raise ValueError("Matrix A contains non-finite values") - - # Convert input to torch tensor if it isn't already - if not isinstance(A, torch.Tensor): - A = torch.tensor(A, dtype=torch.float64) - - if A.dim() != 2 or A.shape[0] != A.shape[1]: - raise ValueError("expected a square matrix") - - n = A.shape[0] - ident = torch.eye(n, dtype=A.dtype, device=A.device) - cols = [] - - for i in range(n): - for j in range(n): - E = torch.outer(ident[i], ident[j]) - _, F = expm_frechet(A, E, method=method, check_finite=False) - cols.append(vec(F)) - - return torch.stack(cols, dim=1) - - -def expm_cond(A: torch.Tensor, check_finite: bool = True) -> torch.Tensor: - """Relative condition number of the matrix exponential in the Frobenius norm. - - Args: - A: Square input matrix with shape (N, N) - check_finite: Whether to check that the input matrix contains only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - - Returns: - kappa: The relative condition number of the matrix exponential - in the Frobenius norm - """ - if check_finite and not torch.isfinite(A).all(): - raise ValueError("Matrix A contains non-finite values") - - # Convert input to torch tensor if it isn't already - if not isinstance(A, torch.Tensor): - A = torch.tensor(A, dtype=torch.float64) - - if A.dim() != 2 or A.shape[0] != A.shape[1]: - raise ValueError("expected a square matrix") - - X = matrix_exp(A) - K = expm_frechet_kronform(A, check_finite=False) - - # The following norm choices are deliberate. - # norms of A and X are Frobenius norms, and norm of K is the induced 2-norm. - norm_p = "fro" # codespell:ignore - A_norm = torch.norm(A, p=norm_p) - X_norm = torch.norm(X, p=norm_p) - K_norm = torch.linalg.matrix_norm(K, ord=2) - - return (K_norm * A_norm) / X_norm # kappa - - -class expm(Function): # noqa: N801 - """Compute the matrix exponential of a matrix or batch of matrices.""" - - @staticmethod - def forward(ctx: Any, A: torch.Tensor) -> torch.Tensor: - """Compute the matrix exponential of A. - - Args: - ctx: ctx - A: Input matrix or batch of matrices - - Returns: - Matrix exponential of A - """ - # Save A for backward pass - ctx.save_for_backward(A) - # Use the matrix_exp function we already have - return matrix_exp(A) - - @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: - """Compute the gradient of matrix exponential. - - Args: - ctx: ctx - grad_output: Gradient with respect to the output - - Returns: - Gradient with respect to the input - """ - # Retrieve saved tensor - (A,) = ctx.saved_tensors - - # Compute the Frechet derivative in the direction of grad_output - _, frechet_deriv = expm_frechet(A, grad_output, method="SPS", check_finite=False) - return frechet_deriv - - -def _is_valid_matrix(T: torch.Tensor, n: int = 3) -> bool: - """Check if T is a valid nxn matrix. - - Args: - T: The matrix to check - n: The expected dimension of the matrix, default=3 - - Returns: - bool: True if T is a valid nxn tensor, False otherwise - """ - return isinstance(T, torch.Tensor) and T.shape == (n, n) - - -def _determine_eigenvalue_case( # noqa: C901 - T: torch.Tensor, eigenvalues: torch.Tensor, num_tol: float = 1e-16 -) -> str: - """Determine the eigenvalue structure case of matrix T. - - Args: - T: The 3x3 matrix to analyze - eigenvalues: The eigenvalues of T - num_tol: Numerical tolerance for comparing eigenvalues, default=1e-16 - - Returns: - The case identifier ("case1a", "case1b", etc.) - - Raises: - ValueError: If the eigenvalue structure cannot be determined - """ - # Get unique values and their counts directly with one call - uniq_vals, counts = torch.unique(eigenvalues, return_counts=True) - - # Use np.isclose to group eigenvalues that are numerically close - # We can create a mask for each unique value to see if other values are close to it - if len(uniq_vals) > 1: - # Check if some "unique" values should actually be considered the same - i = 0 - while i < len(uniq_vals): - # Find all values close to the current one - close_mask = torch.isclose(uniq_vals, uniq_vals[i], rtol=0, atol=num_tol) - close_count = torch.sum(close_mask) - - if close_count > 1: # If there are other close values - # Merge them (keep the first one, remove the others) - counts[i] = torch.sum(counts[close_mask]) - uniq_vals = uniq_vals[~(close_mask & torch.arange(len(close_mask)) != i)] - counts = counts[~(close_mask & torch.arange(len(counts)) != i)] - else: - i += 1 - - # Now determine the case based on the number of unique eigenvalues - if len(uniq_vals) == 1: - # Case 1: All eigenvalues are equal (λ, λ, λ) - lambda_val = uniq_vals[0] - Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device) - T_minus_lambdaI = T - lambda_val * Identity - - rank1 = torch.linalg.matrix_rank(T_minus_lambdaI) - if rank1 == 0: - return "case1a" # q(T) = (T - λI) - - rank2 = torch.linalg.matrix_rank(T_minus_lambdaI @ T_minus_lambdaI) - if rank2 == 0: - return "case1b" # q(T) = (T - λI)² - - return "case1c" # q(T) = (T - λI)³ - - if len(uniq_vals) == 2: - # Case 2: Two distinct eigenvalues - # The counts array already tells us which eigenvalue is repeated - if counts.max() != 2 or counts.min() != 1: - raise ValueError("Unexpected eigenvalue pattern for Case 2") - - mu = uniq_vals[torch.argmin(counts)] # The non-repeated eigenvalue - lambda_val = uniq_vals[torch.argmax(counts)] # The repeated eigenvalue - - Identity = torch.eye(3, dtype=lambda_val.dtype, device=lambda_val.device) - T_minus_muI = T - mu * Identity - T_minus_lambdaI = T - lambda_val * Identity - - # Check if (T - μI)(T - λI) annihilates T - if torch.allclose( - T_minus_muI @ T_minus_lambdaI @ T, - torch.zeros((3, 3), dtype=lambda_val.dtype, device=lambda_val.device), - ): - return "case2a" # q(T) = (T - λI)(T - μI) - return "case2b" # q(T) = (T - μI)(T - λI)² - - if len(uniq_vals) == 3: - # Case 3: Three distinct eigenvalues (λ, μ, ν) - return "case3" # q(T) = (T - λI)(T - μI)(T - νI) - - raise ValueError("Could not determine eigenvalue structure") + """Return identity (3, 3) or (n, 3, 3) matching T's batch shape.""" + if T.dim() == 3: + n = T.shape[0] + return torch.eye(3, dtype=dtype, device=device).unsqueeze(0).expand(n, -1, -1) + return torch.eye(3, dtype=dtype, device=device) def _matrix_log_case1a(T: torch.Tensor, lambda_val: torch.Tensor) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI). This is the case where T is a scalar multiple of the identity matrix. + T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed @@ -604,9 +190,9 @@ def _matrix_log_case1a(T: torch.Tensor, lambda_val: torch.Tensor) -> torch.Tenso Returns: The logarithm of T, which is log(λ)·I """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) - return torch.log(lambda_val) * Identity + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) + return torch.log(lambda_val) * identity def _matrix_log_case1b( @@ -615,6 +201,7 @@ def _matrix_log_case1b( """Compute log(T) when q(T) = (T - λI)². This is the case where T has a Jordan block of size 2. + T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed @@ -624,16 +211,32 @@ def _matrix_log_case1b( Returns: The logarithm of T """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) - T_minus_lambdaI = T - lambda_val * Identity + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) + T_minus_lambdaI = T - lambda_val * identity + denom = torch.clamp(lambda_val.abs(), min=num_tol) + scale = torch.where(lambda_val.abs() > 1, lambda_val, denom) + return torch.log(lambda_val) * identity + T_minus_lambdaI / scale - # For numerical stability, scale appropriately - if abs(lambda_val) > 1: - scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val - return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI - # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) + +def _ensure_batched( + T: torch.Tensor, *eigenvalues: torch.Tensor +) -> tuple[bool, torch.Tensor, tuple[torch.Tensor, ...]]: + """Ensure T and eigenvalues are in batched form for matrix log computation. + + Args: + T: Matrix of shape (3, 3) or (n, 3, 3) + *eigenvalues: Scalar or (n, 1, 1) shaped eigenvalue tensors + + Returns: + Tuple of (unbatched, T, eigenvalues) where unbatched is True if input was + unbatched, T has shape (n, 3, 3), and eigenvalues have shape (n, 1, 1) + """ + unbatched = T.dim() == 2 + if unbatched: + T = T.unsqueeze(0) + eigenvalues = tuple(ev.view(1, 1, 1) for ev in eigenvalues) + return unbatched, T, eigenvalues def _matrix_log_case1c( @@ -642,6 +245,7 @@ def _matrix_log_case1c( """Compute log(T) when q(T) = (T - λI)³. This is the case where T has a Jordan block of size 3. + T may be (3, 3) or (n, 3, 3); lambda_val scalar or (n, 1, 1). Args: T: The matrix whose logarithm is to be computed @@ -651,21 +255,17 @@ def _matrix_log_case1c( Returns: The logarithm of T """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) - T_minus_lambdaI = T - lambda_val * Identity - - # Compute (T - λI)² with better numerical stability - T_minus_lambdaI_squared = T_minus_lambdaI @ T_minus_lambdaI - - # For numerical stability + unbatched, T, (lambda_val,) = _ensure_batched(T, lambda_val) + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) + T_minus_lambdaI = T - lambda_val * identity + T_minus_lambdaI_squared = torch.bmm(T_minus_lambdaI, T_minus_lambdaI) lambda_squared = lambda_val * lambda_val - - term1 = torch.log(lambda_val) * Identity - term2 = T_minus_lambdaI / max(lambda_val, num_tol) - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) - - return term1 + term2 - term3 + term1 = torch.log(lambda_val) * identity + term2 = T_minus_lambdaI / torch.clamp(lambda_val.abs(), min=num_tol) + term3 = T_minus_lambdaI_squared / torch.clamp(2 * lambda_squared, min=num_tol) + result = term1 + term2 - term3 + return result.squeeze(0) if unbatched else result def _matrix_log_case2a( @@ -674,6 +274,8 @@ def _matrix_log_case2a( """Compute log(T) when q(T) = (T - λI)(T - μI) with λ≠μ. This is the case with two distinct eigenvalues. + T may be (3, 3) or (n, 3, 3); lambda_val, mu scalar or (n, 1, 1). + Formula: log T = log μ((T - λI)/(μ - λ)) + log λ((T - μI)/(λ - μ)) Args: @@ -686,24 +288,19 @@ def _matrix_log_case2a( The logarithm of T Raises: - ValueError: If λ and μ are too close for numerical stability + ValueError: If λ and μ are too close """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) - lambda_minus_mu = lambda_val - mu - - # Check for numerical stability - if torch.abs(lambda_minus_mu) < num_tol: + unbatched, T, (lambda_val, mu) = _ensure_batched(T, lambda_val, mu) + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) + if (torch.abs(lambda_val - mu) < num_tol).any(): raise ValueError("λ and μ are too close, computation may be unstable") - - T_minus_lambdaI = T - lambda_val * Identity - T_minus_muI = T - mu * Identity - - # Compute each term separately for better numerical stability + T_minus_lambdaI = T - lambda_val * identity + T_minus_muI = T - mu * identity term1 = torch.log(mu) * (T_minus_lambdaI / (mu - lambda_val)) term2 = torch.log(lambda_val) * (T_minus_muI / (lambda_val - mu)) - - return term1 + term2 + result = term1 + term2 + return result.squeeze(0) if unbatched else result def _matrix_log_case2b( @@ -711,7 +308,9 @@ def _matrix_log_case2b( ) -> torch.Tensor: """Compute log(T) when q(T) = (T - μI)(T - λI)² with λ≠μ. - This is the case with one eigenvalue of multiplicity 2 and one distinct eigenvalue. + This is the case with one eigenvalue of multiplicity 2 and one distinct. + T may be (3, 3) or (n, 3, 3); lambda_val, mu scalar or (n, 1, 1). + Formula: log T = log μ((T - λI)²/(λ - μ)²) - log λ((T - μI)(T - (2λ - μ)I)/(λ - μ)²) + ((T - λI)(T - μI)/(λ(λ - μ))) @@ -726,36 +325,28 @@ def _matrix_log_case2b( The logarithm of T Raises: - ValueError: If λ and μ are too close for numerical stability or - if λ is too close to zero + ValueError: If λ and μ are too close or λ≈0 """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) + unbatched, T, (lambda_val, mu) = _ensure_batched(T, lambda_val, mu) + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) lambda_minus_mu = lambda_val - mu - lambda_minus_mu_squared = lambda_minus_mu * lambda_minus_mu - - # Check for numerical stability - if torch.abs(lambda_minus_mu) < num_tol: + if (torch.abs(lambda_minus_mu) < num_tol).any(): raise ValueError("λ and μ are too close, computation may be unstable") - - if torch.abs(lambda_val) < num_tol: + if (torch.abs(lambda_val) < num_tol).any(): raise ValueError("λ is too close to zero, computation may be unstable") - - T_minus_lambdaI = T - lambda_val * Identity - T_minus_muI = T - mu * Identity - T_minus_lambdaI_squared = T_minus_lambdaI @ T_minus_lambdaI - - # The term (T - (2λ - μ)I) - T_minus_2lambda_plus_muI = T - (2 * lambda_val - mu) * Identity - - # Compute each term separately for better numerical stability + lambda_minus_mu_squared = lambda_minus_mu * lambda_minus_mu + T_minus_lambdaI = T - lambda_val * identity + T_minus_muI = T - mu * identity + T_minus_lambdaI_squared = torch.bmm(T_minus_lambdaI, T_minus_lambdaI) + T_minus_2lambda_plus_muI = T - (2 * lambda_val - mu) * identity + term2_mat = torch.bmm(T_minus_muI, T_minus_2lambda_plus_muI) term1 = torch.log(mu) * (T_minus_lambdaI_squared / lambda_minus_mu_squared) - term2 = -torch.log(lambda_val) * ( - (T_minus_muI @ T_minus_2lambda_plus_muI) / lambda_minus_mu_squared - ) - term3 = (T_minus_lambdaI @ T_minus_muI) / (lambda_val * lambda_minus_mu) - - return term1 + term2 + term3 + term2 = -torch.log(lambda_val) * (term2_mat / lambda_minus_mu_squared) + term3_mat = torch.bmm(T_minus_lambdaI, T_minus_muI) + term3 = term3_mat / (lambda_val * lambda_minus_mu) + result = term1 + term2 + term3 + return result.squeeze(0) if unbatched else result def _matrix_log_case3( @@ -766,7 +357,9 @@ def _matrix_log_case3( num_tol: float = 1e-16, ) -> torch.Tensor: """Compute log(T) when q(T) = (T - λI)(T - μI)(T - νI) with λ≠μ≠ν≠λ. + This is the case with three distinct eigenvalues. + T may be (3, 3) or (n, 3, 3); lambda_val, mu, nu scalar or (n, 1, 1). Formula: log T = log λ((T - μI)(T - νI)/((λ - μ)(λ - ν))) + log μ((T - λI)(T - νI)/((μ - λ)(μ - ν))) @@ -783,135 +376,213 @@ def _matrix_log_case3( The logarithm of T Raises: - ValueError: If any pair of eigenvalues are too close for numerical stability + ValueError: If eigenvalues are too close """ - n = T.shape[0] - Identity = torch.eye(n, dtype=lambda_val.dtype, device=lambda_val.device) - - # Check if eigenvalues are distinct enough for numerical stability - if ( - min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) - < num_tol - ): + unbatched, T, (lambda_val, mu, nu) = _ensure_batched(T, lambda_val, mu, nu) + dtype, device = lambda_val.dtype, lambda_val.device + identity = _identity_for_t(T, dtype, device) + min_diff = torch.minimum( + torch.minimum( + torch.abs(lambda_val - mu), + torch.abs(lambda_val - nu), + ), + torch.abs(mu - nu), + ) + if (min_diff < num_tol).any(): raise ValueError("Eigenvalues are too close, computation may be unstable") - - T_minus_lambdaI = T - lambda_val * Identity - T_minus_muI = T - mu * Identity - T_minus_nuI = T - nu * Identity - - # Compute the terms for λ - lambda_term_numerator = T_minus_muI @ T_minus_nuI - lambda_term_denominator = (lambda_val - mu) * (lambda_val - nu) + T_minus_lambdaI = T - lambda_val * identity + T_minus_muI = T - mu * identity + T_minus_nuI = T - nu * identity + lambda_term_num = torch.bmm(T_minus_muI, T_minus_nuI) lambda_term = torch.log(lambda_val) * ( - lambda_term_numerator / lambda_term_denominator + lambda_term_num / ((lambda_val - mu) * (lambda_val - nu)) ) + mu_term_num = torch.bmm(T_minus_lambdaI, T_minus_nuI) + mu_term = torch.log(mu) * (mu_term_num / ((mu - lambda_val) * (mu - nu))) + nu_term_num = torch.bmm(T_minus_lambdaI, T_minus_muI) + nu_term = torch.log(nu) * (nu_term_num / ((nu - lambda_val) * (nu - mu))) + result = lambda_term + mu_term + nu_term + return result.squeeze(0) if unbatched else result - # Compute the terms for μ - mu_term_numerator = T_minus_lambdaI @ T_minus_nuI - mu_term_denominator = (mu - lambda_val) * (mu - nu) - mu_term = torch.log(mu) * (mu_term_numerator / mu_term_denominator) - # Compute the terms for ν - nu_term_numerator = T_minus_lambdaI @ T_minus_muI - nu_term_denominator = (nu - lambda_val) * (nu - mu) - nu_term = torch.log(nu) * (nu_term_numerator / nu_term_denominator) - - return lambda_term + mu_term + nu_term +def _determine_matrix_log_cases( + T: torch.Tensor, + sorted_eig: torch.Tensor, + diff: torch.Tensor, + n_unique: torch.Tensor, + valid: torch.Tensor, + num_tol: float, +) -> torch.Tensor: + """Determine which matrix log case applies to each system. + Args: + T: Input matrices of shape (n_systems, 3, 3) + sorted_eig: Sorted eigenvalues of shape (n_systems, 3) + diff: Differences between consecutive eigenvalues (n_systems, 2) + n_unique: Number of unique eigenvalues per system (n_systems,) + valid: Boolean mask of valid systems (n_systems,) + num_tol: Numerical tolerance -def _matrix_log_33( # noqa: C901 - T: torch.Tensor, case: str = "auto", dtype: torch.dtype = torch.float64 + Returns: + Case indices: 0=case1a, 1=case1b, 2=case1c, 3=case2a, 4=case2b, 5=case3, + -1=fallback + """ + n_systems = T.shape[0] + device, dtype_out = T.device, T.dtype + case_indices = torch.full((n_systems,), -1, dtype=torch.long, device=device) + + if not valid.any(): + return case_indices + + eye3 = torch.eye(3, dtype=dtype_out, device=device).unsqueeze(0) + + # Case 1: all eigenvalues equal + m1 = valid & (n_unique == 1) + if m1.any(): + lam = sorted_eig[:, 0:1].unsqueeze(-1) + T_lam = T - lam * eye3 + rank1 = torch.linalg.matrix_rank(T_lam) + rank2 = torch.linalg.matrix_rank(torch.bmm(T_lam, T_lam)) + case_indices.masked_fill_(m1 & (rank1 == 0), 0) + case_indices.masked_fill_(m1 & (rank1 > 0) & (rank2 == 0), 1) + case_indices.masked_fill_(m1 & (rank1 > 0) & (rank2 > 0), 2) + + # Case 2: two distinct eigenvalues + m2 = valid & (n_unique == 2) + if m2.any(): + lam_rep = torch.where( + diff[:, 0:1] <= num_tol, sorted_eig[:, 0:1], sorted_eig[:, 2:3] + ).unsqueeze(-1) + mu_val = torch.where( + diff[:, 0:1] <= num_tol, sorted_eig[:, 2:3], sorted_eig[:, 0:1] + ).unsqueeze(-1) + M = torch.bmm(T - mu_val * eye3, torch.bmm(T - lam_rep * eye3, T)) + case2a = m2 & (torch.linalg.norm(M, dim=(-2, -1)) < num_tol) + case_indices.masked_fill_(case2a, 3) + case_indices.masked_fill_(m2 & ~case2a, 4) + + # Case 3: three distinct eigenvalues + case_indices.masked_fill_(valid & (n_unique == 3), 5) + + return case_indices + + +def _process_matrix_log_case( + case_int: int, + idx_t: torch.Tensor, + T_sub: torch.Tensor, + sorted_sub: torch.Tensor, + dtype_out: torch.dtype, + device: torch.device, + num_tol: float, ) -> torch.Tensor: + """Process a single matrix log case for the given indices. + + Args: + case_int: Case identifier (-1 to 5) + idx_t: Indices of systems belonging to this case + T_sub: Subset of matrices for this case + sorted_sub: Sorted eigenvalues for this case + dtype_out: Output dtype + device: Device for computation + num_tol: Numerical tolerance + + Returns: + Computed matrix logarithms for the subset + """ + if case_int == -1: # Fallback to scipy for complex eigenvalues + n_sub = idx_t.numel() + result = torch.zeros_like(T_sub) + for i in range(n_sub): + result[i] = matrix_log_scipy(T_sub[i].cpu()).to(device) + elif case_int <= 2: # Cases 1a, 1b, 1c + lam = sorted_sub[:, 0:1].unsqueeze(-1).to(dtype_out) + case1_funcs = { + 0: lambda: _matrix_log_case1a(T_sub, lam), + 1: lambda: _matrix_log_case1b(T_sub, lam, num_tol), + 2: lambda: _matrix_log_case1c(T_sub, lam, num_tol), + } + result = case1_funcs[case_int]() + elif case_int <= 4: # Cases 2a, 2b + d = sorted_sub[:, 1:2] - sorted_sub[:, 0:1] + lam_rep = ( + torch.where(d <= num_tol, sorted_sub[:, 0:1], sorted_sub[:, 2:3]) + .unsqueeze(-1) + .to(dtype_out) + ) + mu_val = ( + torch.where(d <= num_tol, sorted_sub[:, 2:3], sorted_sub[:, 0:1]) + .unsqueeze(-1) + .to(dtype_out) + ) + case2_func = _matrix_log_case2a if case_int == 3 else _matrix_log_case2b + result = case2_func(T_sub, lam_rep, mu_val, num_tol) + else: # Case 3: three distinct eigenvalues + lam = sorted_sub[:, 0:1].unsqueeze(-1).to(dtype_out) + mu_val = sorted_sub[:, 1:2].unsqueeze(-1).to(dtype_out) + nu_val = sorted_sub[:, 2:3].unsqueeze(-1).to(dtype_out) + result = _matrix_log_case3(T_sub, lam, mu_val, nu_val, num_tol) + return result + + +def _matrix_log_33(T: torch.Tensor, dtype: torch.dtype = torch.float64) -> torch.Tensor: """Compute the logarithm of 3x3 matrix T based on its eigenvalue structure. + The logarithm of this matrix is known exactly as given the in the references. + Supports both single matrix (3, 3) and batched input (n_systems, 3, 3). Args: - T: The matrix whose logarithm is to be computed - case: One of "auto", "case1a", "case1b", "case1c", "case2a", "case2b", "case3" - - "auto": Automatically determine the structure - - "case1a": All eigenvalues are equal, q(T) = (T - λI) - - "case1b": All eigenvalues are equal, q(T) = (T - λI)² - - "case1c": All eigenvalues are equal, q(T) = (T - λI)³ - - "case2a": Two distinct eigenvalues, q(T) = (T - λI)(T - μI) - - "case2b": Two distinct eigenvalues, q(T) = (T - μI)(T - λI)² - - "case3": Three distinct eigenvalues, q(T) = (T - λI)(T - μI)(T - νI) + T: The matrix whose logarithm is to be computed, shape (3, 3) or (n_systems, 3, 3) dtype: The data type to use for numerical tolerance, default=torch.float64 Returns: - The logarithm of T + The logarithm of T, same shape as input References: - https://link.springer.com/article/10.1007/s10659-008-9169-x """ num_tol = 1e-16 if dtype == torch.float64 else 1e-8 - if not _is_valid_matrix(T): - raise ValueError("Input must be a 3x3 matrix") + # Handle unbatched input by adding batch dimension + unbatched = T.dim() == 2 + if unbatched: + if T.shape != (3, 3): + raise ValueError("Input must be a 3x3 matrix") + T = T.unsqueeze(0) + elif T.shape[1:] != (3, 3): + raise ValueError("Batched input must have shape (n_systems, 3, 3)") - # Compute eigenvalues + device, dtype_out = T.device, T.dtype eigenvalues = torch.linalg.eigvals(T) - # Convert eigenvalues to real if they're complex but with tiny imaginary parts - eigenvalues = ( - torch.real(eigenvalues) - if torch.allclose( - torch.imag(eigenvalues), - torch.zeros_like(torch.imag(eigenvalues)), - atol=num_tol, - ) - else eigenvalues - ) - # If automatic detection, determine the structure - if case == "auto": - case = _determine_eigenvalue_case(T, eigenvalues, num_tol) - - # Case 1: All eigenvalues are equal (λ, λ, λ) - if case in ("case1a", "case1b", "case1c"): - lambda_val = eigenvalues[0] - - # Check for numerical stability - if torch.abs(lambda_val) < num_tol: - raise ValueError("Eigenvalue too close to zero, computation may be unstable") - - if case == "case1a": - return _matrix_log_case1a(T, lambda_val) - if case == "case1b": - return _matrix_log_case1b(T, lambda_val, num_tol) - if case == "case1c": - return _matrix_log_case1c(T, lambda_val, num_tol) - - # Case 2: Two distinct eigenvalues (μ, λ, λ) - elif case in ("case2a", "case2b"): - # Find the unique eigenvalue (μ) and the repeated eigenvalue (λ) - uniq_vals, counts = torch.unique( - torch.round(eigenvalues, decimals=10), return_counts=True - ) - if len(uniq_vals) != 2 or counts.max() != 2: - raise ValueError( - "Case 2 requires exactly two distinct eigenvalues with one repeated" - ) - - mu = uniq_vals[torch.argmin(counts)] # The non-repeated eigenvalue - lambda_val = uniq_vals[torch.argmax(counts)] # The repeated eigenvalue - - if case == "case2a": - return _matrix_log_case2a(T, lambda_val, mu, num_tol) - if case == "case2b": - return _matrix_log_case2b(T, lambda_val, mu, num_tol) + # Check for complex eigenvalues - require scipy fallback + imag_magnitude = torch.abs(torch.imag(eigenvalues)) + has_complex_eig = (imag_magnitude > num_tol).any(dim=1) + eigenvalues_real = torch.real(eigenvalues) - # Case 3: Three distinct eigenvalues (λ, μ, ν) - elif case == "case3": - if len(torch.unique(torch.round(eigenvalues, decimals=10))) != 3: - raise ValueError("Case 3 requires three distinct eigenvalues") + # Sort eigenvalues once for all systems + sorted_eig, _ = torch.sort(eigenvalues_real, dim=1) + diff = sorted_eig[:, 1:] - sorted_eig[:, :-1] + n_unique = 1 + (diff > num_tol).sum(dim=1) + valid = ~has_complex_eig & torch.isfinite(eigenvalues_real).all(dim=1) - lambda_val, mu, nu = torch.sort(eigenvalues).values # Sort for consistency - return _matrix_log_case3(T, lambda_val, mu, nu, num_tol) + # Determine case for each system + case_indices = _determine_matrix_log_cases( + T, sorted_eig, diff, n_unique, valid, num_tol + ) - else: - raise ValueError(f"Unknown eigenvalue {case=}") + # Process each case + out = torch.zeros_like(T) + for case_int in range(-1, 6): + mask = case_indices == case_int + if not mask.any(): + continue + idx_t = mask.nonzero(as_tuple=True)[0] + out[idx_t] = _process_matrix_log_case( + case_int, idx_t, T[idx_t], sorted_eig[idx_t], dtype_out, device, num_tol + ) - # should never be reached, just for type checker - raise RuntimeError("Unexpected code path in _matrix_log_33") + return out.squeeze(0) if unbatched else out def matrix_log_scipy(matrix: torch.Tensor) -> torch.Tensor: @@ -953,8 +624,10 @@ def matrix_log_33( ) -> torch.Tensor: """Compute the matrix logarithm of a square 3x3 matrix. + Also supports batched input of shape (n_systems, 3, 3). + Args: - matrix: A square 3x3 matrix tensor + matrix: A square 3x3 matrix tensor, or batch of shape (n_systems, 3, 3) sim_dtype: Simulation dtype, default=torch.float64 fallback_warning: Whether to print a warning when falling back to scipy, default=False @@ -977,6 +650,11 @@ def matrix_log_33( if fallback_warning: print(msg) # noqa: T201 # Fall back to scipy implementation + if matrix.dim() == 3: + out = torch.zeros_like(matrix, dtype=sim_dtype) + for i in range(matrix.shape[0]): + out[i] = matrix_log_scipy(matrix[i].cpu()).to(matrix.device).to(sim_dtype) + return out return matrix_log_scipy(matrix).to(sim_dtype) diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 60ac5a1c..bb013f5e 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -303,25 +303,20 @@ def compute_cell_forces[T: AnyCellState]( for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): directions[idx, mu, nu] = 1.0 - # Compute deformation gradient log - deform_grad_log = torch.zeros_like(cur_deform_grad) - for sys_idx in range(n_systems): - deform_grad_log[sys_idx] = fm.matrix_log_33(cur_deform_grad[sys_idx]) - - # Compute Frechet derivatives - cell_forces = torch.zeros_like(ucf_cell_grad) - for sys_idx in range(n_systems): - expm_derivs = torch.stack( - [ - fm.expm_frechet(deform_grad_log[sys_idx], direction)[1] - for direction in directions - ] - ) - forces_flat = torch.sum( - expm_derivs * ucf_cell_grad[sys_idx].unsqueeze(0), dim=(1, 2) - ) - cell_forces[sys_idx] = forces_flat.reshape(3, 3) + # Compute deformation gradient log (batched for parallelism) + deform_grad_log = fm.matrix_log_33( + cur_deform_grad, sim_dtype=cur_deform_grad.dtype + ) + # Compute Frechet derivatives (batched over systems and directions) + A_batch = ( + deform_grad_log.unsqueeze(1).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3) + ) + E_batch = directions.unsqueeze(0).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3) + _, expm_derivs_batch = fm.expm_frechet(A_batch, E_batch) + expm_derivs = expm_derivs_batch.reshape(n_systems, 9, 3, 3) + forces_flat = (expm_derivs * ucf_cell_grad.unsqueeze(1)).sum(dim=(2, 3)) + cell_forces = forces_flat.reshape(n_systems, 3, 3) state.cell_forces = cell_forces / state.cell_factor else: # Unit cell force computation # Note (AG): ASE transforms virial as: