Skip to content

scan_save_mem gives wrong results for empty scan #1878

@ricardoV94

Description

@ricardoV94

Description

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode

mode = Mode()#.excluding("scan_save_mem")

n = pt.tensor("n", shape=(), dtype=int)
init_state = pt.tensor("init_state", shape=(3,))
final_state = pytensor.scan(
    fn=lambda xtm1: xtm1 * 2,
    outputs_info=[init_state],
    n_steps=n,
    return_updates=False,
)
res = final_state.owner.inputs[0][-1]  # Access the last state of the Scan output buffer (which includes the initial state)

fn = pytensor.function([init_state, n], res, mode=mode, on_unused_input="ignore")
np.testing.assert_allclose(fn(init_state=np.ones((3,)), n=0), np.ones((3,))  # Fails

This gives the wrong result because scan_save_mem always wants to give a buffer with at least one extra entry, but in this case it ends up growing the buffer and making the index operation point to unitialized memory.

Seen in pymc-devs/pymc#7380

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions