Skip to content

Commit

Permalink
remove list_of_nodes in favor of similar applys_between
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi authored and ricardoV94 committed May 20, 2024
1 parent 8c157a2 commit 880a57a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 46 deletions.
32 changes: 0 additions & 32 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,38 +1789,6 @@ def view_roots(node: Variable) -> list[Variable]:
return [node]


def list_of_nodes(
inputs: Collection[Variable], outputs: Iterable[Variable]
) -> list[Apply]:
r"""Return the `Apply` nodes of the graph between `inputs` and `outputs`.
Parameters
----------
inputs : list of Variable
Input `Variable`\s.
outputs : list of Variable
Output `Variable`\s.
"""

def expand(o: Apply) -> list[Apply]:
return [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
]

return list(
cast(
Iterable[Apply],
walk(
[o.owner for o in outputs if o.owner],
expand,
),
)
)


def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
Expand Down
4 changes: 2 additions & 2 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytensor import printing
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from pytensor.graph.basic import Apply, Constant, Variable, applys_between, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.graph.rewriting.basic import MergeOptimizer
Expand Down Expand Up @@ -4125,7 +4125,7 @@ def c_support_code_apply(self, node, name):

def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
for n in applys_between(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)

Expand Down
12 changes: 0 additions & 12 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
get_var_by_name,
graph_inputs,
io_toposort,
list_of_nodes,
orphans_between,
truncated_graph_inputs,
variable_depends_on,
Expand Down Expand Up @@ -567,17 +566,6 @@ def test_ops():
assert res_list == [o3.owner, o2.owner, o1.owner]


def test_list_of_nodes():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"

res = list_of_nodes([r1, r2], [o2])
assert res == [o2.owner, o1.owner]


def test_apply_depends_on():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
Expand Down

0 comments on commit 880a57a

Please sign in to comment.