From 189ba03a1444e2f9438789f01697851b2148c070 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 24 Nov 2023 15:30:12 +0100 Subject: [PATCH] Faster vectorize by walking sorted nodes --- pytensor/graph/replace.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 8bdf348ccd..8caee200c8 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -3,7 +3,13 @@ from functools import partial, singledispatch from typing import Optional, Union, cast, overload -from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs +from pytensor.graph.basic import ( + Apply, + Constant, + Variable, + io_toposort, + truncated_graph_inputs, +) from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op @@ -295,19 +301,14 @@ def vectorize_graph( inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys()) new_inputs = [replace.get(inp, inp) for inp in inputs] - def transform(var: Variable) -> Variable: - if var in inputs: - return new_inputs[inputs.index(var)] + vect_vars = dict(zip(inputs, new_inputs)) + for node in io_toposort(inputs, seq_outputs): + vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] + vect_node = vectorize_node(node, *vect_inputs) + for output, vect_output in zip(node.outputs, vect_node.outputs): + vect_vars[output] = vect_output - node = var.owner - batched_inputs = [transform(inp) for inp in node.inputs] - batched_node = vectorize_node(node, *batched_inputs) - batched_var = batched_node.outputs[var.owner.outputs.index(var)] - - return cast(Variable, batched_var) - - # TODO: MergeOptimization or node caching? - seq_vect_outputs = [transform(out) for out in seq_outputs] + seq_vect_outputs = [vect_vars[out] for out in seq_outputs] if isinstance(outputs, Sequence): return seq_vect_outputs