-
Notifications
You must be signed in to change notification settings - Fork 166
Open
Labels
Description
Description
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode
mode = Mode()#.excluding("scan_save_mem")
n = pt.tensor("n", shape=(), dtype=int)
init_state = pt.tensor("init_state", shape=(3,))
final_state = pytensor.scan(
fn=lambda xtm1: xtm1 * 2,
outputs_info=[init_state],
n_steps=n,
return_updates=False,
)
res = final_state.owner.inputs[0][-1] # Access the last state of the Scan output buffer (which includes the initial state)
fn = pytensor.function([init_state, n], res, mode=mode, on_unused_input="ignore")
np.testing.assert_allclose(fn(init_state=np.ones((3,)), n=0), np.ones((3,)) # FailsThis gives the wrong result because scan_save_mem always wants to give a buffer with at least one extra entry, but in this case it ends up growing the buffer and making the index operation point to unitialized memory.
Seen in pymc-devs/pymc#7380
Reactions are currently unavailable