Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 92 additions & 214 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

# ruff: noqa: SLF001

import itertools

import numpy as np
import scipy
import torch
Expand All @@ -18,269 +16,99 @@


class TestExpmFrechet:
"""Test suite for expm_frechet using numpy arrays converted to torch tensors."""
"""Test suite for expm_frechet for 3x3 matrices."""

def test_expm_frechet(self):
"""Test basic functionality of expm_frechet against scipy implementation."""
M = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], dtype=np.float64
)
A = np.array([[1, 2], [5, 6]], dtype=np.float64)
E = np.array([[3, 4], [7, 8]], dtype=np.float64)
A = np.array([[1, 2, 0], [5, 6, 0], [0, 0, 1]], dtype=np.float64)
E = np.array([[3, 4, 0], [7, 8, 0], [0, 0, 0]], dtype=np.float64)
expected_expm = scipy.linalg.expm(A)
expected_frechet = scipy.linalg.expm(M)[:2, 2:]
expected_frechet = scipy.linalg.expm_frechet(A, E)[1]

A = torch.from_numpy(A).to(device=device)
E = torch.from_numpy(E).to(device=device)
for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}):
# Convert it to numpy arrays before passing it to the function
observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs)
assert_allclose(expected_expm, observed_expm.cpu().numpy())
assert_allclose(expected_frechet, observed_frechet.cpu().numpy())
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14)

def test_small_norm_expm_frechet(self):
"""Test matrices with a range of norms for better coverage."""
M_original = np.array(
[[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]], dtype=np.float64
)
A_original = np.array([[1, 2], [5, 6]], dtype=np.float64)
E_original = np.array([[3, 4], [7, 8]], dtype=np.float64)
A_original_norm_1 = scipy.linalg.norm(A_original, 1)
selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15]

m_neighbor_pairs = itertools.pairwise(selected_m_list)
for ma, mb in m_neighbor_pairs:
ell_a = scipy.linalg._expm_frechet.ell_table_61[ma]
ell_b = scipy.linalg._expm_frechet.ell_table_61[mb]
target_norm_1 = 0.5 * (ell_a + ell_b)
scale = target_norm_1 / A_original_norm_1
M = scale * M_original
A = scale * A_original
E = scale * E_original
expected_expm = scipy.linalg.expm(A)
expected_frechet = scipy.linalg.expm(M)[:2, 2:]
A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
# Convert it to numpy arrays before passing it to the function
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy())
assert_allclose(expected_frechet, observed_frechet.cpu().numpy())
"""Test matrices with small norms."""
A = np.array([[0.1, 0.2, 0], [0.5, 0.6, 0], [0, 0, 0.1]], dtype=np.float64)
E = np.array([[0.3, 0.4, 0], [0.7, 0.8, 0], [0, 0, 0]], dtype=np.float64)
expected_expm = scipy.linalg.expm(A)
expected_frechet = scipy.linalg.expm_frechet(A, E)[1]

A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-14)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-14)

def test_fuzz(self):
"""Test with a variety of random inputs to ensure robustness."""
"""Test with a variety of random 3x3 inputs to ensure robustness."""
rng = np.random.default_rng(1726500908359153)
# try a bunch of crazy inputs
rfuncs = (
np.random.uniform,
np.random.normal,
np.random.standard_cauchy,
np.random.exponential,
)
ntests = 100
ntests = 20
for _ in range(ntests):
rfunc = rfuncs[rng.choice(4)]
target_norm_1 = rng.exponential()
n = rng.integers(2, 16)
A_original = rfunc(size=(n, n))
E_original = rfunc(size=(n, n))
A_original = rng.standard_normal((3, 3))
E_original = rng.standard_normal((3, 3))
A_original_norm_1 = scipy.linalg.norm(A_original, 1)
scale = target_norm_1 / A_original_norm_1
A = scale * A_original
E = scale * E_original
M = np.vstack([np.hstack([A, E]), np.hstack([np.zeros_like(A), A])])
expected_expm = scipy.linalg.expm(A)
expected_frechet = scipy.linalg.expm(M)[:n, n:]
expected_frechet = scipy.linalg.expm_frechet(A, E)[1]
A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
# Convert it to numpy arrays before passing it to the function
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=5e-8)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-7)

def test_problematic_matrix(self):
"""Test a specific matrix that previously uncovered a bug."""
def test_large_norm_matrices(self):
"""Test expm_frechet with larger norm matrices requiring scaling."""
A = np.array(
[[1.50591997, 1.93537998], [0.41203263, 0.23443516]], dtype=np.float64
[[1.5, 0.8, 0.3], [0.6, 1.2, 0.5], [0.4, 0.7, 1.8]], dtype=np.float64
)
E = np.array(
[[1.87864034, 2.07055038], [1.34102727, 0.67341123]], dtype=np.float64
)
A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
# Convert it to numpy arrays before passing it to the function
sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS")
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A, E, method="blockEnlarge"
[[0.1, 0.2, 0.1], [0.2, 0.15, 0.1], [0.1, 0.1, 0.2]], dtype=np.float64
)
assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy())
assert_allclose(sps_frechet.cpu().numpy(), blockEnlarge_frechet.cpu().numpy())

def test_medium_matrix(self):
"""Test with a medium-sized matrix to compare performance between methods."""
n = 1000
rng = np.random.default_rng()
A = rng.exponential(size=(n, n))
E = rng.exponential(size=(n, n))
expected_expm = scipy.linalg.expm(A)
expected_frechet = scipy.linalg.expm_frechet(A, E)[1]

A = torch.from_numpy(A).to(device=device, dtype=DTYPE)
E = torch.from_numpy(E).to(device=device, dtype=DTYPE)
# Convert it to numpy arrays before passing it to the function
sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS")
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A, E, method="blockEnlarge"
)
assert_allclose(sps_expm.cpu().numpy(), blockEnlarge_expm.cpu().numpy())
assert_allclose(sps_frechet.cpu().numpy(), blockEnlarge_frechet.cpu().numpy())
observed_expm, observed_frechet = fm.expm_frechet(A, E)
assert_allclose(expected_expm, observed_expm.cpu().numpy(), atol=1e-12)
assert_allclose(expected_frechet, observed_frechet.cpu().numpy(), atol=1e-12)


class TestExpmFrechetTorch:
"""Test suite for expm_frechet using native torch tensors."""

def test_expm_frechet(self):
"""Test basic functionality of expm_frechet against torch.linalg.matrix_exp."""
M = torch.tensor(
[[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]],
dtype=DTYPE,
device=device,
)
A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device)
E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device)
A = torch.tensor([[1, 2, 0], [5, 6, 0], [0, 0, 1]], dtype=DTYPE, device=device)
E = torch.tensor([[3, 4, 0], [7, 8, 0], [0, 0, 0]], dtype=DTYPE, device=device)
expected_expm = torch.linalg.matrix_exp(A)
expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:]

for kwargs in ({}, {"method": "SPS"}, {"method": "blockEnlarge"}):
observed_expm, observed_frechet = fm.expm_frechet(A, E, **kwargs)
torch.testing.assert_close(expected_expm, observed_expm)
torch.testing.assert_close(expected_frechet, observed_frechet)

def test_small_norm_expm_frechet(self):
"""Test matrices with a range of norms for better coverage using torch tensors."""
M_original = torch.tensor(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[0, 0, 1, 2],
[0, 0, 5, 6],
],
dtype=DTYPE,
device=device,
)
A_original = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device)
E_original = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device)
A_original_norm_1 = torch.linalg.norm(A_original, 1)
selected_m_list = [1, 3, 5, 7, 9, 11, 13, 15]
m_neighbor_pairs = itertools.pairwise(selected_m_list)
for ma, mb in m_neighbor_pairs:
ell_a = fm.ell_table_61[ma]
ell_b = fm.ell_table_61[mb]
target_norm_1 = 0.5 * (ell_a + ell_b)
scale = target_norm_1 / A_original_norm_1
M = scale * M_original
A = scale * A_original
E = scale * E_original
expected_expm = torch.linalg.matrix_exp(A)
expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:]
observed_expm, observed_frechet = fm.expm_frechet(A, E)
torch.testing.assert_close(expected_expm, observed_expm)
torch.testing.assert_close(expected_frechet, observed_frechet)
observed_expm, _ = fm.expm_frechet(A, E)
torch.testing.assert_close(expected_expm, observed_expm)

def test_fuzz(self):
"""Test with a variety of random inputs using torch tensors."""
"""Test with a variety of random 3x3 inputs using torch tensors."""
rng = np.random.default_rng(1726500908359153)
# try a bunch of crazy inputs
# Convert random functions to tensor-generating functions
tensor_rfuncs = (
lambda size, device="cpu": torch.tensor(
rng.uniform(size=size), device=device
),
lambda size, device="cpu": torch.tensor(rng.normal(size=size), device=device),
lambda size, device="cpu": torch.tensor(
rng.standard_cauchy(size=size), device=device
),
lambda size, device="cpu": torch.tensor(
rng.exponential(size=size), device=device
),
)
ntests = 100
ntests = 20
for _ in range(ntests):
rfunc = tensor_rfuncs[torch.tensor(rng.choice(4))]
target_norm_1 = torch.tensor(rng.exponential())
n = torch.tensor(rng.integers(2, 16))
A_original = rfunc(size=(n, n))
E_original = rfunc(size=(n, n))
target_norm_1 = rng.exponential()
A_original = torch.tensor(rng.standard_normal((3, 3)), device=device)
E_original = torch.tensor(rng.standard_normal((3, 3)), device=device)
A_original_norm_1 = torch.linalg.norm(A_original, 1)
scale = target_norm_1 / A_original_norm_1
A = scale * A_original
E = scale * E_original
M = torch.vstack(
[torch.hstack([A, E]), torch.hstack([torch.zeros_like(A), A])]
)
expected_expm = torch.linalg.matrix_exp(A)
expected_frechet = torch.linalg.matrix_exp(M)[:n, n:]
observed_expm, observed_frechet = fm.expm_frechet(A, E)
observed_expm, _ = fm.expm_frechet(A, E)
torch.testing.assert_close(expected_expm, observed_expm, atol=5e-8, rtol=1e-5)
torch.testing.assert_close(
expected_frechet, observed_frechet, atol=1e-7, rtol=1e-5
)

def test_problematic_matrix(self):
"""Test a specific matrix that previously uncovered a bug using torch tensors."""
A = torch.tensor(
[[1.50591997, 1.93537998], [0.41203263, 0.23443516]],
dtype=DTYPE,
device=device,
)
E = torch.tensor(
[[1.87864034, 2.07055038], [1.34102727, 0.67341123]],
dtype=DTYPE,
device=device,
)
sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS")
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A, E, method="blockEnlarge"
)
torch.testing.assert_close(sps_expm, blockEnlarge_expm)
torch.testing.assert_close(sps_frechet, blockEnlarge_frechet)

def test_medium_matrix(self):
"""Test with a medium-sized matrix to compare performance
between methods using torch tensors.
"""
n = 1000
rng = np.random.default_rng()
A = torch.tensor(rng.exponential(size=(n, n)))
E = torch.tensor(rng.exponential(size=(n, n)))

sps_expm, sps_frechet = fm.expm_frechet(A, E, method="SPS")
blockEnlarge_expm, blockEnlarge_frechet = fm.expm_frechet(
A, E, method="blockEnlarge"
)
torch.testing.assert_close(sps_expm, blockEnlarge_expm)
torch.testing.assert_close(sps_frechet, blockEnlarge_frechet)


class TestExpmFrechetTorchGrad:
"""Test suite for gradient computation with expm and its Frechet derivative."""

def test_expm_frechet(self):
"""Test gradient computation for matrix exponential and its Frechet derivative."""
M = torch.tensor(
[[1, 2, 3, 4], [5, 6, 7, 8], [0, 0, 1, 2], [0, 0, 5, 6]],
dtype=DTYPE,
device=device,
)
A = torch.tensor([[1, 2], [5, 6]], dtype=DTYPE, device=device)
E = torch.tensor([[3, 4], [7, 8]], dtype=DTYPE, device=device)
expected_expm = torch.linalg.matrix_exp(A)
expected_frechet = torch.linalg.matrix_exp(M)[:2, 2:]
# expm will use the SPS method as default
observed_expm = fm.expm.apply(A)
torch.testing.assert_close(expected_expm, observed_expm)
# Compute the Frechet derivative in the direction of grad_output
A.requires_grad = True
observed_expm = fm.expm.apply(A)
(observed_frechet,) = torch.autograd.grad(observed_expm, A, E, retain_graph=True)
torch.testing.assert_close(expected_frechet, observed_frechet)


class TestLogM33:
Expand Down Expand Up @@ -469,3 +297,53 @@ def test_nearly_degenerate(self):
torch.testing.assert_close(
M_logm, torch.tensor(scipy_logm, dtype=DTYPE, device=device)
)

def test_batched_positive_definite(self):
"""Test batched matrix logarithm with positive definite matrices."""
batch_size = 3
rng = np.random.default_rng(42)
L = rng.standard_normal((batch_size, 3, 3))
M_np = np.array([L[i] @ L[i].T + 0.5 * np.eye(3) for i in range(batch_size)])
M_torch = torch.tensor(M_np, dtype=torch.float64)

log_torch = fm.matrix_log_33(M_torch)

for i in range(batch_size):
log_scipy = scipy.linalg.logm(M_np[i]).real
assert_allclose(log_torch[i].numpy(), log_scipy, atol=1e-12)
# Verify round-trip: exp(log(M)) = M
M_recovered = torch.matrix_exp(log_torch[i])
assert_allclose(M_recovered.numpy(), M_np[i], atol=1e-10)


class TestFrechetCellFilterIntegration:
"""Integration tests for the Frechet cell filter pipeline."""

def test_frechet_derivatives_vs_scipy(self):
"""Test Frechet derivative computation matches scipy."""
n_systems = 2
torch.manual_seed(42)

# Create small deformations
deform_log = torch.randn(n_systems, 3, 3, dtype=torch.float64) * 0.01

# Compute Frechet derivatives for all 9 directions
idx_flat = torch.arange(9)
i_idx, j_idx = idx_flat // 3, idx_flat % 3
directions = torch.zeros((9, 3, 3), dtype=torch.float64)
directions[idx_flat, i_idx, j_idx] = 1.0

A_batch = deform_log.unsqueeze(1).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3)
E_batch = directions.unsqueeze(0).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3)
_, frechet_torch = fm.expm_frechet(A_batch, E_batch)
frechet_torch = frechet_torch.reshape(n_systems, 9, 3, 3)

# Compare with scipy
for sys_idx in range(n_systems):
for dir_idx in range(9):
A_np = deform_log[sys_idx].numpy()
E_np = directions[dir_idx].numpy()
_, frechet_scipy = scipy.linalg.expm_frechet(A_np, E_np)
assert_allclose(
frechet_torch[sys_idx, dir_idx].numpy(), frechet_scipy, atol=1e-12
)
Loading
Loading