Skip to content

Commit

Permalink
Faster vectorize by walking sorted nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 27, 2023
1 parent e2d0751 commit 189ba03
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from functools import partial, singledispatch
from typing import Optional, Union, cast, overload

from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.basic import (
Apply,
Constant,
Variable,
io_toposort,
truncated_graph_inputs,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op

Expand Down Expand Up @@ -295,19 +301,14 @@ def vectorize_graph(
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]

def transform(var: Variable) -> Variable:
if var in inputs:
return new_inputs[inputs.index(var)]
vect_vars = dict(zip(inputs, new_inputs))
for node in io_toposort(inputs, seq_outputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
vect_vars[output] = vect_output

node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]

return cast(Variable, batched_var)

# TODO: MergeOptimization or node caching?
seq_vect_outputs = [transform(out) for out in seq_outputs]
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]

if isinstance(outputs, Sequence):
return seq_vect_outputs
Expand Down

0 comments on commit 189ba03

Please sign in to comment.