Skip to content

Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887

Open
WHOIM1205 wants to merge 3 commits intopymc-devs:mainfrom
WHOIM1205:fix-solve-symmetric-gradient
Open

Fix incorrect gradient in Solve for structured assume_a (sym/pos/her)#1887
WHOIM1205 wants to merge 3 commits intopymc-devs:mainfrom
WHOIM1205:fix-solve-symmetric-gradient

Conversation

@WHOIM1205
Copy link
Contributor

Fix gradient handling in Solve for structured assume_a cases


Summary

SolveBase.L_op computes gradients assuming all entries of A are independent. This is correct for assume_a="gen".

However, when using structured assumptions ("sym", "her", "pos"), the solver only reads one triangle of A. The backward pass did not account for this, resulting in incorrect gradients when a pre-structured matrix was passed directly into pt.linalg.solve.

Existing tests did not catch this because they wrapped the input matrix with a symmetrization transform, which masked the issue via the chain rule.


Fix

Updated pytensor/tensor/slinalg.py:

python

Before

Inherited from SolveBase

A_bar = -outer(b_bar, c)

  • all test cases are passed locally
image

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@jessegrabowski
Copy link
Member

Is this related to #1230 ?

@WHOIM1205
Copy link
Contributor Author

Is this related to #1230 ?

This PR is not directly related to #1230.
#1230 concerns solve_triangular gradients (specifically unit_diag and trans handling), whereas this pr addresses gradient handling in Solve when assume_a is structured (sym/pos/her).
They affect different operators and different gradient paths.

@ricardoV94
Copy link
Member

I think the question was whether it's the same nature of issue


analytic = f(A_val, b_val)

# Numerical gradient: perturb only the read triangle
Copy link
Member

@ricardoV94 ricardoV94 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we concot a graph that allows us to use verify_grad instead and still works as regression test?

Something whose input is just the triangular entries? I'm assuming they were being half counted?

You can still verify they came out as zeros on an explicit grad fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we concot a graph that allows us to use verify_grad instead and still works as regression test?

Something whose input is just the triangular entries? I'm assuming they were being half counted?

You can still verify they came out as zeros on an explicit grad fn

Thanks for the suggestion i've updated the test accordingly

The manual finite-difference loop has been replaced with utt.verify_grad using a triangular parameterization of the structured entries the symmetric matrix is reconstructed inside the graph and i’ve also added an explicit assertion that the unread triangle has zero gradients
all parametrized cases are passing locally

@ricardoV94
Copy link
Member

hey @ricardoV94
This fixes an error-path imbalance in post_open_standalone() where mutex_ghost was not released on connect() failure. Since the mutex is shared across restore tasks, this could lead to a cross-process deadlock during Unix socket restore. The change is limited to the failure path and does not affect the success flow.

I suppose this is a mistake comment from some other work

@WHOIM1205
Copy link
Contributor Author

WHOIM1205 commented Feb 12, 2026

I think the question was whether it's the same nature of issue

Ah thanks for clarifying
Yes it’s similar in nature in the sense that both issues stem from the gradient not fully respecting structural assumptions made in the forward solve
However they affect different operators and different logic paths:
#1230 concerns solve_triangular (specifically unit_diag and trans handling)
This PR addresses Solve when assume_a is structured (sym/pos/her)
So conceptually similar (structure-aware gradient handling) but technically independent fixes

@WHOIM1205
Copy link
Contributor Author

hey @ricardoV94
This fixes an error-path imbalance in post_open_standalone() where mutex_ghost was not released on connect() failure. Since the mutex is shared across restore tasks, this could lead to a cross-process deadlock during Unix socket restore. The change is limited to the failure path and does not affect the success flow.

I suppose this is a mistake comment from some other work

Apologies that was clearly pasted from another PR by mistake.
Please ignore that comment.
This PR only concerns the gradient handling in Solve for structured assume_a cases

Replace manual finite-difference loop with verify_grad using triangular
parameterization. Add explicit zero-gradient assertion for unread triangle.

Signed-off-by: WHOIM1205 <rathourprateek8@gmail.com>
@WHOIM1205
Copy link
Contributor Author

pre-commit.ci autofix

# triangle contribute to both (i,j) and (j,i) of the effective matrix,
# so we must accumulate the symmetric contribution and zero the unread triangle.
if self.lower:
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.mT, -1)
res[0] = ptb.tril(A_bar) + ptb.tril(A_bar.conj().mT, -1)

Otherwise the hermetian case is wrong

@jessegrabowski
Copy link
Member

I think it would be good to have a simple closed-form test. Given:

$$ A = \begin{bmatrix} 1 & \rho \\ \rho & 1 \end{bmatrix} $$

The inverse is analytically computable:

$$ A^{-1} = \frac{1}{1 - \rho^2} \begin{bmatrix} 1 & -\rho \\ -\rho & 1 \end{bmatrix} $$

Given loss = solve(A, b)[1, 0] with b = np.array([[1.], [0.]]), then the loss is:

$$ \mathcal{L} = \frac{\rho}{\rho^2 - 1} $$

And the gradient with respect to $\rho$ is:

$$ \frac{\partial \mathcal{L}}{\partial \rho} = -\frac{1 + \rho^2}{(1 - \rho^2)^2} $$

Checking pytensor:

import pytensor 
import pytensor.tensor as pt

rho = pt.dscalar('rho')
A = pt.stacklists([[1., 0.], [rho, 1.]])
b = pt.stacklists(([[1.], [0]]))
x = pt.linalg.solve(A, b, assume_a='pos')
loss = x[1, 0]

dL_dx = pt.grad(loss, rho)
fn = pytensor.function([rho], dL_dx)
fn(0.88) # array(-19.64815653)

# Analytical:
def expected_grad(rho):
    return -(1 + rho ** 2) / (1 - rho ** 2) ** 2
expected_grad(0.88) # -34.86368894924802

Note though that this only occurs because we passed a non-PSD matrix to solve. It is a numerical quirk that the underlying algorithm is able to treat this matrix as PSD, notwithstanding that it is not. If you instead pass A = pt.stacklists([[1., rho], [rho, 1.]]), you will get the right answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants