From ba0f7c08b9a55a12ac3b1f467f869a6698c0bb7d Mon Sep 17 00:00:00 2001 From: Silas Dilkes <36165522+sjdilkes@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:23:04 +0000 Subject: [PATCH] Add `BasePass.get_pre_conditions` and `BasePass.get_post_conditions` (#1689) * Add new methods for getting predicates from python passes * bump * Update predicates_test.py * Update pytket/binders/passes.cpp Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> * Update pytket/binders/passes.cpp Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> * add get gate_set, check tests * Update predicates_test.py * Update predicates_test.py * Update predicates_test.py * bump, fix ruff issue * regen stubs * Update predicates_test.py * Update predicates_test.py * Update predicates_test.py * addresss comments * Update predicates_test.py --------- Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> --- pytket/binders/passes.cpp | 57 ++++++++++++++++++++++++++++++++- pytket/binders/predicates.cpp | 2 +- pytket/docs/changelog.rst | 4 +++ pytket/pytket/_tket/passes.pyi | 15 +++++++++ pytket/tests/predicates_test.py | 21 ++++++++++++ 5 files changed, 97 insertions(+), 2 deletions(-) diff --git a/pytket/binders/passes.cpp b/pytket/binders/passes.cpp index ef398f8a0d..634d046845 100644 --- a/pytket/binders/passes.cpp +++ b/pytket/binders/passes.cpp @@ -145,6 +145,30 @@ const PassPtr &DecomposeClassicalExp() { return pp; } +std::optional get_gate_set(const BasePass &base_pass) { + std::optional allowed_ops; + for (const std::pair> + &p : base_pass.get_conditions().first) { + std::shared_ptr gsp_ptr = + std::dynamic_pointer_cast(p.second); + if (!gsp_ptr) { + continue; + } + OpTypeSet candidate_allowed_ops = gsp_ptr->get_allowed_types(); + if (!allowed_ops) { + allowed_ops = candidate_allowed_ops; + } else { + OpTypeSet intersection; + std::set_intersection( + candidate_allowed_ops.begin(), candidate_allowed_ops.end(), + allowed_ops->begin(), allowed_ops->end(), + std::inserter(intersection, intersection.begin())); + allowed_ops = intersection; + } + } + return allowed_ops; +} + PYBIND11_MODULE(passes, m) { py::module_::import("pytket._tket.predicates"); m.def( @@ -212,7 +236,6 @@ PYBIND11_MODULE(passes, m) { ); } }; - py::class_( m, "BasePass", "Base class for passes.") .def( @@ -268,6 +291,38 @@ PYBIND11_MODULE(passes, m) { return py::cast(serialise(base_pass)); }, ":return: A JSON serializable dictionary representation of the Pass.") + .def( + "get_preconditions", + [](const BasePass &base_pass) { + std::vector pre_conditions; + for (const std::pair< + const std::type_index, std::shared_ptr> + &p : base_pass.get_conditions().first) { + pre_conditions.push_back(p.second); + } + return pre_conditions; + }, + "Returns the precondition Predicates for the given pass." + "\n:return: A list of Predicate") + .def( + "get_postconditions", + [](const BasePass &base_pass) { + std::vector post_conditions; + for (const std::pair< + const std::type_index, std::shared_ptr> & + p : base_pass.get_conditions().second.specific_postcons_) { + post_conditions.push_back(p.second); + } + return post_conditions; + }, + "Returns the postcondition Predicates for the given pass." + "\n\n:return: A list of :py:class:`Predicate`") + .def( + "get_gate_set", &get_gate_set, + "Returns the intersection of all set of OpType for all " + "GateSetPredicate in the `BasePass` preconditions, or `None` " + "if there are no gate-set predicates.", + "\n\n:return: A set of allowed OpType") .def_static( "from_dict", [](const py::dict &base_pass_dict, diff --git a/pytket/binders/predicates.cpp b/pytket/binders/predicates.cpp index 127ef81c20..80bab6e426 100644 --- a/pytket/binders/predicates.cpp +++ b/pytket/binders/predicates.cpp @@ -83,7 +83,7 @@ PYBIND11_MODULE(predicates, m) { "implies", &Predicate::implies, ":return: True if predicate implies another one, else False", py::arg("other")) - .def("__str__", [](const Predicate &) { return ""; }) + .def("__str__", &Predicate::to_string) .def("__repr__", &Predicate::to_string) .def( "to_dict", diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index 5fd2aac7d4..b2fb2fb5e2 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog Unreleased ---------- +Features: + +* Add `BasePass.get_preconditions()` and `BasePass.getpost_conditions()`. + API changes: * Remove the deprecated methods `auto_rebase_pass()` and `auto_squash_pass()`. diff --git a/pytket/pytket/_tket/passes.pyi b/pytket/pytket/_tket/passes.pyi index 9a59fae388..c8a35e24e5 100644 --- a/pytket/pytket/_tket/passes.pyi +++ b/pytket/pytket/_tket/passes.pyi @@ -54,6 +54,21 @@ class BasePass: :param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback. :return: True if pass modified the circuit, else False """ + def get_gate_set(self) -> set[pytket._tket.circuit.OpType] | None: + """ + :return: A set of allowed OpType + """ + def get_postconditions(self) -> list[pytket._tket.predicates.Predicate]: + """ + Returns the postcondition Predicates for the given pass. + + :return: A list of :py:class:`Predicate` + """ + def get_preconditions(self) -> list[pytket._tket.predicates.Predicate]: + """ + Returns the precondition Predicates for the given pass. + :return: A list of Predicate + """ def to_dict(self) -> typing.Any: """ :return: A JSON serializable dictionary representation of the Pass. diff --git a/pytket/tests/predicates_test.py b/pytket/tests/predicates_test.py index c09098ffd1..806d3b8551 100644 --- a/pytket/tests/predicates_test.py +++ b/pytket/tests/predicates_test.py @@ -1092,6 +1092,27 @@ def test_greedy_pauli_synth() -> None: assert GreedyPauliSimp().apply(c) +def test_get_pre_conditions() -> None: + pre_cons = GreedyPauliSimp().get_preconditions() + gate_set = pre_cons[0].gate_set # type: ignore + assert OpType.CX in gate_set + assert OpType.Measure in gate_set + + +def test_get_post_conditions() -> None: + gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure} + post_cons = AutoRebase(gate_set).get_postconditions() + assert post_cons[0].gate_set == gate_set # type: ignore + + +def test_get_gate_set() -> None: + gate_set = GreedyPauliSimp().get_gate_set() + assert gate_set is not None + assert OpType.CX in gate_set + assert OpType.Measure in gate_set + assert CliffordPushThroughMeasures().get_gate_set() is None + + if __name__ == "__main__": test_predicate_generation() test_compilation_unit_generation()