diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 1695ade729..ab9e794efa 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1014,6 +1014,22 @@ def perform(self, node, inputs, outputs): except np.linalg.LinAlgError: outputs[0][0] = np.full(a.shape, np.nan, dtype=a.dtype) + def L_op(self, inputs, outputs, output_gradients): + res = super().L_op(inputs, outputs, output_gradients) + + if self.assume_a in ("sym", "her", "pos"): + A_bar = res[0] + # When assume_a is sym/her/pos, the solver only reads one triangle + # of A and symmetrizes internally. Off-diagonal elements in the read + # triangle contribute to both (i,j) and (j,i) of the effective matrix, + # so we must accumulate the symmetric contribution and zero the unread triangle. + if self.lower: + res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1) + else: + res[0] = ptb.triu(A_bar) + ptb.triu(A_bar.mT, 1) + + return res + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": if not allowed_inplace_inputs: return self diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index b2b08e95ae..e0dad0b9ee 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -399,6 +399,90 @@ def test_solve_gradient( lambda A, b: solve_op(A_func(A), b), [A_val, b_val], 3, rng, eps=eps ) + @pytest.mark.parametrize("b_shape", [(5, 1), (5,)], ids=["b_col_vec", "b_vec"]) + @pytest.mark.parametrize( + "assume_a, lower", + [ + ("sym", False), + ("sym", True), + ("pos", False), + ("pos", True), + ], + ids=["sym_upper", "sym_lower", "pos_upper", "pos_lower"], + ) + @pytest.mark.skipif( + config.floatX == "float32", + reason="Gradients not numerically stable in float32", + ) + def test_solve_symmetric_gradient_direct( + self, b_shape: tuple[int], assume_a: str, lower: bool + ): + """Test that the gradient of Solve is correct when a pre-structured + matrix is passed directly, without composing with a symmetrization + wrapper. This catches bugs where L_op doesn't account for the solver + only reading one triangle of A.""" + n = 5 + rng = np.random.default_rng(utt.fetch_seed()) + + # Build a valid symmetric (or pos-def) matrix and extract the read triangle + A_raw = rng.normal(size=(n, n)).astype(config.floatX) + if assume_a == "pos": + A_val = (A_raw @ A_raw.T + 5 * np.eye(n)).astype(config.floatX) + else: + A_val = ((A_raw + A_raw.T) / 2).astype(config.floatX) + b_val = rng.normal(size=b_shape).astype(config.floatX) + + if lower: + tri_idx = np.tril_indices(n) + else: + tri_idx = np.triu_indices(n) + tri_val = A_val[tri_idx].astype(config.floatX) + + # --- Part 1: verify_grad with only the read-triangle as free params --- + def solve_from_tri(tri, b): + A = pt.zeros((n, n)) + if lower: + A = pt.set_subtensor(A[pt.tril_indices(n)], tri) + else: + A = pt.set_subtensor(A[pt.triu_indices(n)], tri) + # Enforce symmetry so both triangles are consistent + A = A + A.T - pt.diag(pt.diag(A)) + if assume_a == "pos": + A = A @ A.T + 5 * pt.eye(n) + return solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) + + # Re-derive tri_val from the reconstruction so verify_grad perturbations are valid + if assume_a == "pos": + # For pos-def, parameterize from A_raw's triangle before the outer product + tri_val_raw = A_raw[tri_idx].astype(config.floatX) + utt.verify_grad( + solve_from_tri, + [tri_val_raw, b_val], + 3, + rng, + ) + else: + utt.verify_grad( + solve_from_tri, + [tri_val, b_val], + 3, + rng, + ) + + # --- Part 2: gradient w.r.t. full A has zeros in the unread triangle --- + A = pt.tensor("A", shape=(n, n)) + b = pt.tensor("b", shape=b_shape) + x = solve(A, b, assume_a=assume_a, lower=lower, b_ndim=len(b_shape)) + loss = x.sum() + g_A = grad(loss, A) + f = function([A, b], g_A) + g_val = f(A_val, b_val) + + if lower: + assert np.allclose(g_val[np.triu_indices(n, k=1)], 0) + else: + assert np.allclose(g_val[np.tril_indices(n, k=-1)], 0) + def test_solve_tringular_indirection(self): a = pt.matrix("a") b = pt.vector("b")