From 880a57aa3a483320462eb7d682137536e23bf8a2 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 19 May 2024 12:13:04 +0530 Subject: [PATCH] remove list_of_nodes in favor of similar applys_between --- pytensor/graph/basic.py | 32 -------------------------------- pytensor/scalar/basic.py | 4 ++-- tests/graph/test_basic.py | 12 ------------ 3 files changed, 2 insertions(+), 46 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index ac1716ee06..6dc10afcbc 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1789,38 +1789,6 @@ def view_roots(node: Variable) -> list[Variable]: return [node] -def list_of_nodes( - inputs: Collection[Variable], outputs: Iterable[Variable] -) -> list[Apply]: - r"""Return the `Apply` nodes of the graph between `inputs` and `outputs`. - - Parameters - ---------- - inputs : list of Variable - Input `Variable`\s. - outputs : list of Variable - Output `Variable`\s. - - """ - - def expand(o: Apply) -> list[Apply]: - return [ - inp.owner - for inp in o.inputs - if inp.owner and not any(i in inp.owner.outputs for i in inputs) - ] - - return list( - cast( - Iterable[Apply], - walk( - [o.owner for o in outputs if o.owner], - expand, - ), - ) - ) - - def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool: """Determine if any `depends_on` is in the graph given by ``apply``. diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 5d7ba66748..56a3629dc5 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -24,7 +24,7 @@ from pytensor import printing from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType, grad_undefined -from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes +from pytensor.graph.basic import Apply, Constant, Variable, applys_between, clone from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph from pytensor.graph.rewriting.basic import MergeOptimizer @@ -4125,7 +4125,7 @@ def c_support_code_apply(self, node, name): def prepare_node(self, node, storage_map, compute_map, impl): if impl not in self.prepare_node_called: - for n in list_of_nodes(self.inputs, self.outputs): + for n in applys_between(self.inputs, self.outputs): n.op.prepare_node(n, None, None, impl) self.prepare_node_called.add(impl) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 5dc9789727..08c352ab71 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -23,7 +23,6 @@ get_var_by_name, graph_inputs, io_toposort, - list_of_nodes, orphans_between, truncated_graph_inputs, variable_depends_on, @@ -567,17 +566,6 @@ def test_ops(): assert res_list == [o3.owner, o2.owner, o1.owner] -def test_list_of_nodes(): - r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) - o1.name = "o1" - o2 = MyOp(r3, o1) - o2.name = "o2" - - res = list_of_nodes([r1, r2], [o2]) - assert res == [o2.owner, o1.owner] - - def test_apply_depends_on(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) o1 = MyOp(r1, r2)