-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Labels
Description
Description
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)
out = xs[-1] # Invalid when nsteps=0
fn = pytensor.function([n, x0], out)
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("shape_unsafe"))
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("scan_save_mem"))
print(fn(n=0, x0=[0, 1])) # IndexError: index out of bounds
I suspect from this hack:
pytensor/pytensor/scan/rewriting.py
Lines 1438 to 1445 in 8454c3b
# FIXME: This is not correct. Scan with 0 steps seems to be supported | |
# Make sure the ScanSaveMem optimization never makes the new | |
# number of steps to be 0 (this could happen, for instance, if | |
# the optimization detects that the outputs of the Scan go through | |
# subtensor nodes that end up taking no elements) because Scan with | |
# 0 iterations are not supported. Make sure the new number of steps | |
# is at least 1. | |
nw_steps = select_max(nw_steps, 1) |
But removing this hack leads to some tests failing, so other stuff may be doing wrong assumptions downstream of it (or perhaps inside the rewrite)