Skip to content

Commit

Permalink
Add BasePass.get_pre_conditions and BasePass.get_post_conditions (#…
Browse files Browse the repository at this point in the history
…1689)

* Add new methods for getting predicates from python passes

* bump

* Update predicates_test.py

* Update pytket/binders/passes.cpp

Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>

* Update pytket/binders/passes.cpp

Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>

* add get gate_set, check tests

* Update predicates_test.py

* Update predicates_test.py

* Update predicates_test.py

* bump, fix ruff issue

* regen stubs

* Update predicates_test.py

* Update predicates_test.py

* Update predicates_test.py

* addresss comments

* Update predicates_test.py

---------

Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>
  • Loading branch information
sjdilkes and cqc-alec authored Nov 25, 2024
1 parent 57e10dc commit ba0f7c0
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 2 deletions.
57 changes: 56 additions & 1 deletion pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ const PassPtr &DecomposeClassicalExp() {
return pp;
}

std::optional<OpTypeSet> get_gate_set(const BasePass &base_pass) {
std::optional<OpTypeSet> allowed_ops;
for (const std::pair<const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
std::shared_ptr<GateSetPredicate> gsp_ptr =
std::dynamic_pointer_cast<GateSetPredicate>(p.second);
if (!gsp_ptr) {
continue;
}
OpTypeSet candidate_allowed_ops = gsp_ptr->get_allowed_types();
if (!allowed_ops) {
allowed_ops = candidate_allowed_ops;
} else {
OpTypeSet intersection;
std::set_intersection(
candidate_allowed_ops.begin(), candidate_allowed_ops.end(),
allowed_ops->begin(), allowed_ops->end(),
std::inserter(intersection, intersection.begin()));
allowed_ops = intersection;
}
}
return allowed_ops;
}

PYBIND11_MODULE(passes, m) {
py::module_::import("pytket._tket.predicates");
m.def(
Expand Down Expand Up @@ -212,7 +236,6 @@ PYBIND11_MODULE(passes, m) {
);
}
};

py::class_<BasePass, PassPtr, PyBasePass>(
m, "BasePass", "Base class for passes.")
.def(
Expand Down Expand Up @@ -268,6 +291,38 @@ PYBIND11_MODULE(passes, m) {
return py::cast(serialise(base_pass));
},
":return: A JSON serializable dictionary representation of the Pass.")
.def(
"get_preconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> pre_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
pre_conditions.push_back(p.second);
}
return pre_conditions;
},
"Returns the precondition Predicates for the given pass."
"\n:return: A list of Predicate")
.def(
"get_postconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> post_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>> &
p : base_pass.get_conditions().second.specific_postcons_) {
post_conditions.push_back(p.second);
}
return post_conditions;
},
"Returns the postcondition Predicates for the given pass."
"\n\n:return: A list of :py:class:`Predicate`")
.def(
"get_gate_set", &get_gate_set,
"Returns the intersection of all set of OpType for all "
"GateSetPredicate in the `BasePass` preconditions, or `None` "
"if there are no gate-set predicates.",
"\n\n:return: A set of allowed OpType")
.def_static(
"from_dict",
[](const py::dict &base_pass_dict,
Expand Down
2 changes: 1 addition & 1 deletion pytket/binders/predicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ PYBIND11_MODULE(predicates, m) {
"implies", &Predicate::implies,
":return: True if predicate implies another one, else False",
py::arg("other"))
.def("__str__", [](const Predicate &) { return "<tket::Predicate>"; })
.def("__str__", &Predicate::to_string)
.def("__repr__", &Predicate::to_string)
.def(
"to_dict",
Expand Down
4 changes: 4 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
Unreleased
----------

Features:

* Add `BasePass.get_preconditions()` and `BasePass.getpost_conditions()`.

API changes:

* Remove the deprecated methods `auto_rebase_pass()` and `auto_squash_pass()`.
Expand Down
15 changes: 15 additions & 0 deletions pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ class BasePass:
:param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback.
:return: True if pass modified the circuit, else False
"""
def get_gate_set(self) -> set[pytket._tket.circuit.OpType] | None:
"""
:return: A set of allowed OpType
"""
def get_postconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the postcondition Predicates for the given pass.
:return: A list of :py:class:`Predicate`
"""
def get_preconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the precondition Predicates for the given pass.
:return: A list of Predicate
"""
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the Pass.
Expand Down
21 changes: 21 additions & 0 deletions pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,27 @@ def test_greedy_pauli_synth() -> None:
assert GreedyPauliSimp().apply(c)


def test_get_pre_conditions() -> None:
pre_cons = GreedyPauliSimp().get_preconditions()
gate_set = pre_cons[0].gate_set # type: ignore
assert OpType.CX in gate_set
assert OpType.Measure in gate_set


def test_get_post_conditions() -> None:
gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure}
post_cons = AutoRebase(gate_set).get_postconditions()
assert post_cons[0].gate_set == gate_set # type: ignore


def test_get_gate_set() -> None:
gate_set = GreedyPauliSimp().get_gate_set()
assert gate_set is not None
assert OpType.CX in gate_set
assert OpType.Measure in gate_set
assert CliffordPushThroughMeasures().get_gate_set() is None


if __name__ == "__main__":
test_predicate_generation()
test_compilation_unit_generation()
Expand Down

0 comments on commit ba0f7c0

Please sign in to comment.