Skip to content

Commit

Permalink
Do not merge while scans with different until condition
Browse files Browse the repository at this point in the history
The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables
  • Loading branch information
ricardoV94 committed Nov 23, 2023
1 parent eb552ee commit 91d3b7c
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 48 deletions.
53 changes: 42 additions & 11 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from pytensor.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
ancestors,
apply_depends_on,
equal_computations,
graph_inputs,
Expand Down Expand Up @@ -1950,11 +1952,13 @@ def belongs_to_set(self, node, set_nodes):
Questionable, we should also consider profile ?
"""
rep = set_nodes[0]
op = node.op
rep_node = set_nodes[0]
rep_op = rep_node.op
if (
rep.op.info.as_while != node.op.info.as_while
or node.op.truncate_gradient != rep.op.truncate_gradient
or node.op.mode != rep.op.mode
op.info.as_while != rep_op.info.as_while
or op.truncate_gradient != rep_op.truncate_gradient
or op.mode != rep_op.mode
):
return False

Expand All @@ -1964,7 +1968,7 @@ def belongs_to_set(self, node, set_nodes):
except NotScalarConstantError:
pass

rep_nsteps = rep.inputs[0]
rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
Expand All @@ -1978,13 +1982,40 @@ def belongs_to_set(self, node, set_nodes):
if apply_depends_on(node, nd) or apply_depends_on(nd, node):
return False

if not node.op.info.as_while:
if not op.info.as_while:
return True
cond = node.op.inner_outputs[-1]
rep_cond = rep.op.inner_outputs[-1]
return equal_computations(
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs
)

# We need to check the while conditions are identical
conds = [op.inner_outputs[-1]]
rep_conds = [rep_op.inner_outputs[-1]]
if not equal_computations(
conds, rep_conds, op.inner_inputs, rep_op.inner_inputs
):
return False

# If they depend on inner inputs we need to check for equivalence on the respective outer inputs
nominal_inputs = [a for a in ancestors(conds) if isinstance(a, NominalVariable)]
if not nominal_inputs:
return True
rep_nominal_inputs = [
a for a in ancestors(rep_conds) if isinstance(a, NominalVariable)
]

conds = []
rep_conds = []
mapping = op.get_oinp_iinp_iout_oout_mappings()["outer_inp_from_inner_inp"]
rep_mapping = rep_op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
]
inner_inputs = op.inner_inputs
rep_inner_inputs = rep_op.inner_inputs
for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs):
conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]])
rep_conds.append(
rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]]
)

return equal_computations(conds, rep_conds)

def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort
Expand Down
165 changes: 128 additions & 37 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid
Expand Down Expand Up @@ -796,7 +797,13 @@ def inner_fct(seq1, seq2, seq3, previous_output):


class TestScanMerge:
mode = get_default_mode().including("scan")
mode = get_default_mode().including("scan").excluding("scan_pushout_seqs_ops")

@staticmethod
def count_scans(fn):
nodes = fn.maker.fgraph.apply_nodes
scans = [node for node in nodes if isinstance(node.op, Scan)]
return len(scans)

def test_basic(self):
x = vector()
Expand All @@ -808,56 +815,38 @@ def sum(s):
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[y])

f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2

sx, upx = scan(sum, sequences=[x], n_steps=2)
sy, upy = scan(sum, sequences=[y], n_steps=3)

f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2

sx, upx = scan(sum, sequences=[x], n_steps=4)
sy, upy = scan(sum, sequences=[y], n_steps=4)

f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1

sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x])

f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1

sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")

f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1

sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], truncate_gradient=1)

f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2

def test_three_scans(self):
r"""
Expand All @@ -877,12 +866,8 @@ def sum(s):
sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")

f = function(
[x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sy, sz], mode=self.mode)
assert self.count_scans(f) == 2

rng = np.random.default_rng(utt.fetch_seed())
x_val = rng.uniform(size=(4,)).astype(config.floatX)
Expand Down Expand Up @@ -913,6 +898,112 @@ def test_belongs_to_set(self):
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])

@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan(self):
x = vector("x")
y = vector("y")

def add(s):
return s + 1, until(s > 5)

def sub(s):
return s - 1, until(s > 5)

def sub_alt(s):
return s - 1, until(s > 4)

sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[y])

f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2

sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[x])

f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1

sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub_alt, sequences=[x])

f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2

@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan_nominal_dependency(self):
"""Test case where condition depends on nominal variables.
This is a regression test for #509
"""
c1 = scalar("c1")
c2 = scalar("c2")
x = vector("x", shape=(5,))
y = vector("y", shape=(5,))
z = vector("z", shape=(5,))

def add(s1, s2, const):
return s1 + 1, until(s2 > const)

def sub(s1, s2, const):
return s1 - 1, until(s2 > const)

sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1])

f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=0,
)
np.testing.assert_array_equal(res_sx, [1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])

sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2])

f = pytensor.function(
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=3,
c2=1,
)
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1])

sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1])

f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 1

def nested_scan(c, x, z):
sx, _ = scan(add, sequences=[x, z], non_sequences=[c])
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c])
return sx.sum() + sy.sum()

sz, _ = scan(
nested_scan,
sequences=[stack([c1, c2])],
non_sequences=[x, z],
mode=self.mode,
)

f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
inner_f = scan_node.op.fn
assert self.count_scans(inner_f) == 1


class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace")
Expand Down

0 comments on commit 91d3b7c

Please sign in to comment.