fix(solar.make_chromaticdelay): NaN's in gradient#91
fix(solar.make_chromaticdelay): NaN's in gradient#91davecwright3 wants to merge 3 commits intonanograv:mainfrom
Conversation
Rewrite some portions of the method in log-space, and add a NaN mask.
|
Here's a minimal reproducible example >>> j1713 = ds.Pulsar.read_feather("./data/v1p1_de440_pint_bipm2019-J1713+0747.feather")
>>> j1713_likelihood = ds.PulsarLikelihood(
[
j1713.residuals,
ds.makenoise_measurement(j1713, j1713.noisedict),
ds.makegp_ecorr(j1713, j1713.noisedict),
ds.makegp_timing(j1713, svd=True),
ds.makedelay( j1713, ds.make_chromaticdelay(j1713), name="dmexp_1")
]
)
>>> j1713_params = {
'J1713+0747_dmexp_1_idx': 1.5,
'J1713+0747_dmexp_1_log10_Amp': -6.0,
'J1713+0747_dmexp_1_log10_tau': 1.5,
'J1713+0747_dmexp_1_t0': 57509.0,
}
>>> jax.grad(j1713_likelihood.logL)(j1713_params)
{'J1713+0747_dmexp_1_idx': Array(-89.7052646, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_Amp': Array(101.21448275, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_tau': Array(94.12016904, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_t0': Array(1.39003896, dtype=float64, weak_type=True)}
>>> jax.grad(j1713_likelihood.logL)(j1713_params | {'J1713+0747_dmexp_1_log10_tau':0.7 })
{'J1713+0747_dmexp_1_idx': Array(nan, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_Amp': Array(nan, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_tau': Array(nan, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_t0': Array(nan, dtype=float64, weak_type=True)}
>>> jax.grad(j1713_likelihood.logL)(j1713_params | {'J1713+0747_dmexp_1_log10_tau':0.8 })
{'J1713+0747_dmexp_1_idx': Array(-20.85039117, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_Amp': Array(23.70269581, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_log10_tau': Array(35.77951153, dtype=float64, weak_type=True),
'J1713+0747_dmexp_1_t0': Array(1.63148113, dtype=float64, weak_type=True)}For some reason, 0.8 is the limit below which the gradients become NaN's. |
The jnp.where nan mask is actually not needed. The separate calculation of the exponent is what fixes the nans.
|
The last call to |
|
Apologies for the back and forth, I've finally found the root issue. This explains why ~0.8 was the magic number that produced >>> j1713 = ds.Pulsar.read_feather("./data/v1p1_de440_pint_bipm2019-J1713+0747.feather")
>>> j1713_params = {
'J1713+0747_dmexp_1_idx': 1.5,
'J1713+0747_dmexp_1_log10_Amp': -6.0,
'J1713+0747_dmexp_1_log10_tau': 1.5,
'J1713+0747_dmexp_1_t0': 57509.0,
}
>>> toadays = jnp.array(j1713.toas / ds.const.day)
>>> dt = toadays - j1713_params["J1713+0747_dmexp_1_t0"]
>>> dt.min(), dt.max()
(Array(-4115.44121685, dtype=float64), Array(1562.00791291, dtype=float64))
>>> tau = 10** 0.8
>>> jnp.exp(-dt.min() / tau)
Array(1.86245505e+283, dtype=float64)
>>> tau = 10** 0.7633
>>> jnp.exp(-dt.min() / tau)
Array(1.77140153e+308, dtype=float64)
>>> tau = 10** 0.76
>>> jnp.exp(-dt.min() / tau)
Array(inf, dtype=float64)Which of course makes perfect sense because we've reached the maximum float that can be represented with 64 bits >>> import sys
>>> sys.float_info.max
1.7976931348623157e+308 The solutionJAX docs (and other autodiff resources) recommend the "double where" trick, which I accidentally implemented. The first This could probably use a squash commit or a squash merge? If you'd like me to squash and open a new PR just let me know. |
meyers-academic
left a comment
There was a problem hiding this comment.
Could you add a unit tests for this? Just a simple one -- create tests/test_solar.py and basically just run the MWE you gave and check that it's not a NaN?
| 0.0, | ||
| ) | ||
|
|
||
| return matrix.jnp.where(dt_mask, -1.0 * matrix.jnp.exp(vals), vals) |
There was a problem hiding this comment.
I realize it's a bit hypocritical, given our current lack of tests, but do you think you could add a unit test to make sure this doesn't happen? This seems like a prime example of where this would be needed. Perhaps you can create a tests/test_solar.py and just give the example in the PR and check that it's not infinite?
There was a problem hiding this comment.
I can do that!
I found$$\tau$$ is
NaN's coming from this function in the gradients of both single-pulsar and full-PTA likelihoods. For very large negative numbers in the exponent, the gradient wrt the timescalewhere$\Delta t$ is the time after the start of the event. $\Delta t$ can be very large, which creates a product of a very large number and a number very close to zero. Somehow, this produces
NaN's (the Jax traceback indicated that the NaN came from the multiplication with the exponential).My first attempt at fixing this was to rewrite the dip in the form
which I hoped would stabilize the computation. It also can push the exponent towards more positive values for$f <$ 1400 MHz.
There's also an updated docstring included in this PR.
The quote below was the case when I first opened the PR, but the issue is actually with how
jax.gradinteracts withjax.numpy.where. See 683873e a6ce480 and my comments below