From 91d3b7c01668ec88b606f9ac7cc3f011272d653c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 22 Nov 2023 11:59:52 +0100 Subject: [PATCH] Do not merge while scans with different until condition The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables --- pytensor/scan/rewriting.py | 53 ++++++++--- tests/scan/test_rewriting.py | 165 +++++++++++++++++++++++++++-------- 2 files changed, 170 insertions(+), 48 deletions(-) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index d4a59c5b17..b84bcf7bf7 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -17,7 +17,9 @@ from pytensor.graph.basic import ( Apply, Constant, + NominalVariable, Variable, + ancestors, apply_depends_on, equal_computations, graph_inputs, @@ -1950,11 +1952,13 @@ def belongs_to_set(self, node, set_nodes): Questionable, we should also consider profile ? """ - rep = set_nodes[0] + op = node.op + rep_node = set_nodes[0] + rep_op = rep_node.op if ( - rep.op.info.as_while != node.op.info.as_while - or node.op.truncate_gradient != rep.op.truncate_gradient - or node.op.mode != rep.op.mode + op.info.as_while != rep_op.info.as_while + or op.truncate_gradient != rep_op.truncate_gradient + or op.mode != rep_op.mode ): return False @@ -1964,7 +1968,7 @@ def belongs_to_set(self, node, set_nodes): except NotScalarConstantError: pass - rep_nsteps = rep.inputs[0] + rep_nsteps = rep_node.inputs[0] try: rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: @@ -1978,13 +1982,40 @@ def belongs_to_set(self, node, set_nodes): if apply_depends_on(node, nd) or apply_depends_on(nd, node): return False - if not node.op.info.as_while: + if not op.info.as_while: return True - cond = node.op.inner_outputs[-1] - rep_cond = rep.op.inner_outputs[-1] - return equal_computations( - [cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs - ) + + # We need to check the while conditions are identical + conds = [op.inner_outputs[-1]] + rep_conds = [rep_op.inner_outputs[-1]] + if not equal_computations( + conds, rep_conds, op.inner_inputs, rep_op.inner_inputs + ): + return False + + # If they depend on inner inputs we need to check for equivalence on the respective outer inputs + nominal_inputs = [a for a in ancestors(conds) if isinstance(a, NominalVariable)] + if not nominal_inputs: + return True + rep_nominal_inputs = [ + a for a in ancestors(rep_conds) if isinstance(a, NominalVariable) + ] + + conds = [] + rep_conds = [] + mapping = op.get_oinp_iinp_iout_oout_mappings()["outer_inp_from_inner_inp"] + rep_mapping = rep_op.get_oinp_iinp_iout_oout_mappings()[ + "outer_inp_from_inner_inp" + ] + inner_inputs = op.inner_inputs + rep_inner_inputs = rep_op.inner_inputs + for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs): + conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]]) + rep_conds.append( + rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]] + ) + + return equal_computations(conds, rep_conds) def apply(self, fgraph): # Collect all scan nodes ordered according to toposort diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 8f362b4e50..9dc6e698cf 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -15,6 +15,7 @@ from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.utils import until +from pytensor.tensor import stack from pytensor.tensor.blas import Dot22 from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, sigmoid @@ -796,7 +797,13 @@ def inner_fct(seq1, seq2, seq3, previous_output): class TestScanMerge: - mode = get_default_mode().including("scan") + mode = get_default_mode().including("scan").excluding("scan_pushout_seqs_ops") + + @staticmethod + def count_scans(fn): + nodes = fn.maker.fgraph.apply_nodes + scans = [node for node in nodes if isinstance(node.op, Scan)] + return len(scans) def test_basic(self): x = vector() @@ -808,56 +815,38 @@ def sum(s): sx, upx = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[y]) - f = function( - [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") - ) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 2 + f = function([x, y], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 sx, upx = scan(sum, sequences=[x], n_steps=2) sy, upy = scan(sum, sequences=[y], n_steps=3) - f = function( - [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") - ) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 2 + f = function([x, y], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 sx, upx = scan(sum, sequences=[x], n_steps=4) sy, upy = scan(sum, sequences=[y], n_steps=4) - f = function( - [x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") - ) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 1 + f = function([x, y], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 1 sx, upx = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[x]) - f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 1 + f = function([x], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 1 sx, upx = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE") - f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 1 + f = function([x], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 1 sx, upx = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[x], truncate_gradient=1) - f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 2 + f = function([x], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 def test_three_scans(self): r""" @@ -877,12 +866,8 @@ def sum(s): sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y") sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z") - f = function( - [x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops") - ) - topo = f.maker.fgraph.toposort() - scans = [n for n in topo if isinstance(n.op, Scan)] - assert len(scans) == 2 + f = function([x, y], [sy, sz], mode=self.mode) + assert self.count_scans(f) == 2 rng = np.random.default_rng(utt.fetch_seed()) x_val = rng.uniform(size=(4,)).astype(config.floatX) @@ -913,6 +898,112 @@ def test_belongs_to_set(self): assert not opt_obj.belongs_to_set(scan_node1, [scan_node2]) assert not opt_obj.belongs_to_set(scan_node2, [scan_node1]) + @config.change_flags(cxx="") # Just for faster compilation + def test_while_scan(self): + x = vector("x") + y = vector("y") + + def add(s): + return s + 1, until(s > 5) + + def sub(s): + return s - 1, until(s > 5) + + def sub_alt(s): + return s - 1, until(s > 4) + + sx, upx = scan(add, sequences=[x]) + sy, upy = scan(sub, sequences=[y]) + + f = function([x, y], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 + + sx, upx = scan(add, sequences=[x]) + sy, upy = scan(sub, sequences=[x]) + + f = function([x], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 1 + + sx, upx = scan(add, sequences=[x]) + sy, upy = scan(sub_alt, sequences=[x]) + + f = function([x], [sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 + + @config.change_flags(cxx="") # Just for faster compilation + def test_while_scan_nominal_dependency(self): + """Test case where condition depends on nominal variables. + + This is a regression test for #509 + """ + c1 = scalar("c1") + c2 = scalar("c2") + x = vector("x", shape=(5,)) + y = vector("y", shape=(5,)) + z = vector("z", shape=(5,)) + + def add(s1, s2, const): + return s1 + 1, until(s2 > const) + + def sub(s1, s2, const): + return s1 - 1, until(s2 > const) + + sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) + sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1]) + + f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) + assert self.count_scans(f) == 2 + res_sx, res_sy = f( + x=[0, 0, 0, 0, 0], + y=[0, 0, 0, 0, 0], + z=[0, 1, 2, 3, 4], + c1=0, + ) + np.testing.assert_array_equal(res_sx, [1, 1]) + np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1]) + + sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) + sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2]) + + f = pytensor.function( + inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode + ) + assert self.count_scans(f) == 2 + res_sx, res_sy = f( + x=[0, 0, 0, 0, 0], + y=[0, 0, 0, 0, 0], + z=[0, 1, 2, 3, 4], + c1=3, + c2=1, + ) + np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1]) + np.testing.assert_array_equal(res_sy, [-1, -1, -1]) + + sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) + sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1]) + + f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) + assert self.count_scans(f) == 1 + + def nested_scan(c, x, z): + sx, _ = scan(add, sequences=[x, z], non_sequences=[c]) + sy, _ = scan(sub, sequences=[x, z], non_sequences=[c]) + return sx.sum() + sy.sum() + + sz, _ = scan( + nested_scan, + sequences=[stack([c1, c2])], + non_sequences=[x, z], + mode=self.mode, + ) + + f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode) + [scan_node] = [ + node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + inner_f = scan_node.op.fn + assert self.count_scans(inner_f) == 1 + class TestScanInplaceOptimizer: mode = get_default_mode().including("scan_make_inplace", "inplace")