Skip to content

fix(solar.make_chromaticdelay): NaN's in gradient#91

Open
davecwright3 wants to merge 3 commits intonanograv:mainfrom
davecwright3:fix/chromatic-delay-nan-gradients
Open

fix(solar.make_chromaticdelay): NaN's in gradient#91
davecwright3 wants to merge 3 commits intonanograv:mainfrom
davecwright3:fix/chromatic-delay-nan-gradients

Conversation

@davecwright3
Copy link
Collaborator

@davecwright3 davecwright3 commented Apr 30, 2025

I found NaN's coming from this function in the gradients of both single-pulsar and full-PTA likelihoods. For very large negative numbers in the exponent, the gradient wrt the timescale $$\tau$$ is

$$\propto \frac{\Delta t}{\tau^2} e^{\tfrac{-\Delta t}{\tau}},$$

where $\Delta t$ is the time after the start of the event. $\Delta t$ can be very large, which creates a product of a very large number and a number very close to zero. Somehow, this produces NaN's (the Jax traceback indicated that the NaN came from the multiplication with the exponential).

My first attempt at fixing this was to rewrite the dip in the form

$$10^{\log_{10}(A)} e^{-\Delta t / 10^{\tau}} f^{-\gamma} = \textrm{exp}\left(\log_{10}(A)\ln{(10)} + -\Delta t / 10^{\tau} + -\gamma\ln{f}\right ),$$

which I hoped would stabilize the computation. It also can push the exponent towards more positive values for $f <$ 1400 MHz.

There's also an updated docstring included in this PR.

The quote below was the case when I first opened the PR, but the issue is actually with how jax.grad interacts with jax.numpy.where. See 683873e a6ce480 and my comments below

It fixes some cases, but $$\Delta t$$ can still become large enough to make NaNs. I then tried a few things like clipping, scaling, etc., but the solution that worked ended up being the most simple one.
I added a jnp.where to filter NaN's out of the exponent. If someone has a more intelligent solution to this issue, I'd be happy to change this PR.

Rewrite some portions of the method in log-space, and add a NaN mask.
@davecwright3 davecwright3 requested a review from vallis April 30, 2025 20:25
@davecwright3 davecwright3 added the bug Something isn't working label Apr 30, 2025
@davecwright3
Copy link
Collaborator Author

davecwright3 commented May 1, 2025

Here's a minimal reproducible example

>>> j1713 = ds.Pulsar.read_feather("./data/v1p1_de440_pint_bipm2019-J1713+0747.feather")
>>> j1713_likelihood =  ds.PulsarLikelihood(
    [
        j1713.residuals,
        ds.makenoise_measurement(j1713, j1713.noisedict),
        ds.makegp_ecorr(j1713, j1713.noisedict),
        ds.makegp_timing(j1713, svd=True),
        ds.makedelay( j1713, ds.make_chromaticdelay(j1713), name="dmexp_1")
    ]
)
>>> j1713_params = {
    'J1713+0747_dmexp_1_idx': 1.5,
    'J1713+0747_dmexp_1_log10_Amp': -6.0,
    'J1713+0747_dmexp_1_log10_tau': 1.5,
    'J1713+0747_dmexp_1_t0': 57509.0,
}


>>> jax.grad(j1713_likelihood.logL)(j1713_params)

{'J1713+0747_dmexp_1_idx': Array(-89.7052646, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_Amp': Array(101.21448275, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_tau': Array(94.12016904, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_t0': Array(1.39003896, dtype=float64, weak_type=True)}

>>> jax.grad(j1713_likelihood.logL)(j1713_params | {'J1713+0747_dmexp_1_log10_tau':0.7 })

{'J1713+0747_dmexp_1_idx': Array(nan, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_Amp': Array(nan, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_tau': Array(nan, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_t0': Array(nan, dtype=float64, weak_type=True)}

>>> jax.grad(j1713_likelihood.logL)(j1713_params | {'J1713+0747_dmexp_1_log10_tau':0.8 })

{'J1713+0747_dmexp_1_idx': Array(-20.85039117, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_Amp': Array(23.70269581, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_log10_tau': Array(35.77951153, dtype=float64, weak_type=True),
 'J1713+0747_dmexp_1_t0': Array(1.63148113, dtype=float64, weak_type=True)}

For some reason, 0.8 is the limit below which the gradients become NaN's.

The jnp.where nan mask is actually not needed. The separate calculation
of the exponent is what fixes the nans.
@davecwright3
Copy link
Collaborator Author

The last call to jnp.where for the NaN mask is not needed. I made the usual mistake of turning two knobs at once, and it turns out that calculating the exponent outside of the call to jnp.exp is what fixes the issues. I'd bet that it's some issue with the different jaxpr's that JAX's autodiff creates, but I haven't inspected it in detail.

@davecwright3 davecwright3 marked this pull request as draft May 1, 2025 05:49
@davecwright3
Copy link
Collaborator Author

davecwright3 commented May 1, 2025

Apologies for the back and forth, I've finally found the root issue. JAX's reverse-mode autodiff still evaluates the gradient at the False positions for the expression in a jnp.where https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where. For our case, this means we evaluate at all dt. For dt negative enough and log10_tau small enough, the expression jnp.exp(-dt / 10**log10_tau) overflows and returns an inf, giving a NaN gradient.

This explains why ~0.8 was the magic number that produced NaN's.

>>> j1713 = ds.Pulsar.read_feather("./data/v1p1_de440_pint_bipm2019-J1713+0747.feather")
>>> j1713_params = {
    'J1713+0747_dmexp_1_idx': 1.5,
    'J1713+0747_dmexp_1_log10_Amp': -6.0,
    'J1713+0747_dmexp_1_log10_tau': 1.5,
    'J1713+0747_dmexp_1_t0': 57509.0,
}
>>> toadays = jnp.array(j1713.toas / ds.const.day)
>>> dt = toadays - j1713_params["J1713+0747_dmexp_1_t0"]
>>> dt.min(), dt.max()
(Array(-4115.44121685, dtype=float64), Array(1562.00791291, dtype=float64))

>>> tau = 10** 0.8
>>> jnp.exp(-dt.min() / tau)
Array(1.86245505e+283, dtype=float64)

>>> tau = 10** 0.7633
>>> jnp.exp(-dt.min() / tau)
Array(1.77140153e+308, dtype=float64)

>>> tau = 10** 0.76
>>> jnp.exp(-dt.min() / tau)
Array(inf, dtype=float64)

Which of course makes perfect sense because we've reached the maximum float that can be represented with 64 bits

>>> import sys
>>> sys.float_info.max
1.7976931348623157e+308 

The solution

JAX docs (and other autodiff resources) recommend the "double where" trick, which I accidentally implemented. The first jnp.where in this PR keeps the values that can overflow the exponential from ever entering it. A very simple alternative solution I came up with is to create an "overflow-proof" exponential by moving the jnp.where inside the exponential and return a -inf for dt < 0.0. However, in my bench-marking this was actually ~100us slower than what I had already written.

This could probably use a squash commit or a squash merge? If you'd like me to squash and open a new PR just let me know.

@davecwright3 davecwright3 marked this pull request as ready for review May 1, 2025 06:54
Copy link
Collaborator

@meyers-academic meyers-academic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit tests for this? Just a simple one -- create tests/test_solar.py and basically just run the MWE you gave and check that it's not a NaN?

0.0,
)

return matrix.jnp.where(dt_mask, -1.0 * matrix.jnp.exp(vals), vals)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it's a bit hypocritical, given our current lack of tests, but do you think you could add a unit test to make sure this doesn't happen? This seems like a prime example of where this would be needed. Perhaps you can create a tests/test_solar.py and just give the example in the PR and check that it's not infinite?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants