From 679588f322d33aef668f74bfe9fbed5825b5cdb0 Mon Sep 17 00:00:00 2001 From: WHOIM1205 Date: Thu, 12 Feb 2026 11:57:55 -0800 Subject: [PATCH 1/3] Fix incorrect Solve gradient for assume_a=sym/her/pos Signed-off-by: WHOIM1205 --- pytensor/tensor/slinalg.py | 16 +++++++++ tests/tensor/test_slinalg.py | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 1695ade729..ab9e794efa 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1014,6 +1014,22 @@ def perform(self, node, inputs, outputs): except np.linalg.LinAlgError: outputs[0][0] = np.full(a.shape, np.nan, dtype=a.dtype) + def L_op(self, inputs, outputs, output_gradients): + res = super().L_op(inputs, outputs, output_gradients) + + if self.assume_a in ("sym", "her", "pos"): + A_bar = res[0] + # When assume_a is sym/her/pos, the solver only reads one triangle + # of A and symmetrizes internally. Off-diagonal elements in the read + # triangle contribute to both (i,j) and (j,i) of the effective matrix, + # so we must accumulate the symmetric contribution and zero the unread triangle. + if self.lower: + res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1) + else: + res[0] = ptb.triu(A_bar) + ptb.triu(A_bar.mT, 1) + + return res + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": if not allowed_inplace_inputs: return self diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index b2b08e95ae..c35924385e 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -399,6 +399,69 @@ def test_solve_gradient( lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps ) + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)], ids=["b_col_vec", "b_vec"]) + @pytest.mark.parametrize( + "assume_a, lower", + [ + ("sym", False), + ("sym", True), + ("pos", False), + ("pos", True), + ], + ids=["sym_upper", "sym_lower", "pos_upper", "pos_lower"], + ) + @pytest.mark.skipif( + config.floatX == "float32", + reason="Gradients not numerically stable in float32", + ) + def test_solve_symmetric_gradient_direct( + self, b_shape: tuple[int], assume_a: str, lower: bool + ): + """Test that the gradient of Solve is correct when a pre-structured + matrix is passed directly, without composing with a symmetrization + wrapper. This catches bugs where L_op doesn't account for the solver + only reading one triangle of A.""" + rng = np.random.default_rng(utt.fetch_seed()) + + A_raw = rng.normal(size=(5, 5)).astype(config.floatX) + if assume_a == "pos": + A_val = (A_raw @ A_raw.T + 5 * np.eye(5)).astype(config.floatX) + else: + A_val = ((A_raw + A_raw.T) / 2).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) + + A = pt.tensor("A", shape=(5, 5)) + b = pt.tensor("b", shape=b_shape) + x = solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) + loss = x.sum() + g_A = grad(loss, A) + f = function([A, b], g_A) + + analytic = f(A_val, b_val) + + # Numerical gradient: perturb only the read triangle + eps = 1e-7 + numerical = np.zeros_like(A_val) + for i in range(5): + for j in range(5): + if lower and j > i: + continue + if not lower and j < i: + continue + A_plus = A_val.copy() + A_plus[i, j] += eps + A_minus = A_val.copy() + A_minus[i, j] -= eps + x_plus = scipy_linalg.solve( + A_plus, b_val, assume_a=assume_a, lower=lower + ) + x_minus = scipy_linalg.solve( + A_minus, b_val, assume_a=assume_a, lower=lower + ) + numerical[i, j] = (x_plus.sum() - x_minus.sum()) / (2 * eps) + + np.testing.assert_allclose(analytic, numerical, atol=1e-5, rtol=1e-5) + def test_solve_tringular_indirection(self): a = pt.matrix("a") b = pt.vector("b") From 9cfbc38464d1370e78d6cb3333232954235125fa Mon Sep 17 00:00:00 2001 From: WHOIM1205 Date: Thu, 12 Feb 2026 14:04:59 -0800 Subject: [PATCH 2/3] refactor: use verify_grad in test_solve_symmetric_gradient_direct Replace manual finite-difference loop with verify_grad using triangular parameterization. Add explicit zero-gradient assertion for unread triangle. Signed-off-by: WHOIM1205 --- tests/tensor/test_slinalg.py | 69 ++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index c35924385e..4960436a5c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -421,46 +421,61 @@ def test_solve_symmetric_gradient_direct( matrix is passed directly, without composing with a symmetrization wrapper. This catches bugs where L_op doesn't account for the solver only reading one triangle of A.""" + n = 5 rng = np.random.default_rng(utt.fetch_seed()) - A_raw = rng.normal(size=(5, 5)).astype(config.floatX) + # Build a valid symmetric (or pos-def) matrix and extract the read triangle + A_raw = rng.normal(size=(n, n)).astype(config.floatX) if assume_a == "pos": - A_val = (A_raw @ A_raw.T + 5 * np.eye(5)).astype(config.floatX) + A_val = (A_raw @ A_raw.T + 5 * np.eye(n)).astype(config.floatX) else: A_val = ((A_raw + A_raw.T) / 2).astype(config.floatX) b_val = rng.normal(size=b_shape).astype(config.floatX) - A = pt.tensor("A", shape=(5, 5)) + if lower: + tri_idx = np.tril_indices(n) + else: + tri_idx = np.triu_indices(n) + tri_val = A_val[tri_idx].astype(config.floatX) + + # --- Part 1: verify_grad with only the read-triangle as free params --- + def solve_from_tri(tri, b): + A = pt.zeros((n, n)) + if lower: + A = pt.set_subtensor(A[pt.tril_indices(n)], tri) + else: + A = pt.set_subtensor(A[pt.triu_indices(n)], tri) + # Enforce symmetry so both triangles are consistent + A = A + A.T - pt.diag(pt.diag(A)) + if assume_a == "pos": + A = A @ A.T + 5 * pt.eye(n) + return solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) + + # Re-derive tri_val from the reconstruction so verify_grad perturbations are valid + if assume_a == "pos": + # For pos-def, parameterize from A_raw's triangle before the outer product + tri_val_raw = A_raw[tri_idx].astype(config.floatX) + utt.verify_grad( + solve_from_tri, [tri_val_raw, b_val], 3, rng, + ) + else: + utt.verify_grad( + solve_from_tri, [tri_val, b_val], 3, rng, + ) + + # --- Part 2: gradient w.r.t. full A has zeros in the unread triangle --- + A = pt.tensor("A", shape=(n, n)) b = pt.tensor("b", shape=b_shape) x = solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) loss = x.sum() g_A = grad(loss, A) f = function([A, b], g_A) + g_val = f(A_val, b_val) - analytic = f(A_val, b_val) - - # Numerical gradient: perturb only the read triangle - eps = 1e-7 - numerical = np.zeros_like(A_val) - for i in range(5): - for j in range(5): - if lower and j > i: - continue - if not lower and j < i: - continue - A_plus = A_val.copy() - A_plus[i, j] += eps - A_minus = A_val.copy() - A_minus[i, j] -= eps - x_plus = scipy_linalg.solve( - A_plus, b_val, assume_a=assume_a, lower=lower - ) - x_minus = scipy_linalg.solve( - A_minus, b_val, assume_a=assume_a, lower=lower - ) - numerical[i, j] = (x_plus.sum() - x_minus.sum()) / (2 * eps) - - np.testing.assert_allclose(analytic, numerical, atol=1e-5, rtol=1e-5) + if lower: + assert np.allclose(g_val[np.triu_indices(n, k=1)], 0) + else: + assert np.allclose(g_val[np.tril_indices(n, k=-1)], 0) def test_solve_tringular_indirection(self): a = pt.matrix("a") From 5aa8edab3d0bb32e140861eff776b6f4d0a29c79 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Feb 2026 22:06:04 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tensor/test_slinalg.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 4960436a5c..e0dad0b9ee 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -456,11 +456,17 @@ def solve_from_tri(tri, b): # For pos-def, parameterize from A_raw's triangle before the outer product tri_val_raw = A_raw[tri_idx].astype(config.floatX) utt.verify_grad( - solve_from_tri, [tri_val_raw, b_val], 3, rng, + solve_from_tri, + [tri_val_raw, b_val], + 3, + rng, ) else: utt.verify_grad( - solve_from_tri, [tri_val, b_val], 3, rng, + solve_from_tri, + [tri_val, b_val], + 3, + rng, ) # --- Part 2: gradient w.r.t. full A has zeros in the unread triangle ---