Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix symbol substitution for classical operations #1538

Merged
merged 5 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def package(self):
cmake.install()

def requirements(self):
self.requires("tket/1.3.17@tket/stable")
self.requires("tket/1.3.18@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tkassert/0.3.4@tket/stable")
Expand Down
6 changes: 6 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

Unreleased
----------

* Fix symbol substitution for classical operations.


1.31.1 (August 2024)
--------------------

Expand Down
11 changes: 11 additions & 0 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@

from pytket.passes import DecomposeClassicalExp, FlattenRegisters

from sympy import Symbol

from strategies import reg_name_regex, binary_digits, uint32, uint64 # type: ignore

curr_file_path = Path(__file__).resolve().parent
Expand Down Expand Up @@ -1481,5 +1483,14 @@ def test_box_equality_check() -> None:
assert ceb1 == ClassicalExpBox(2, 0, 1, exp1)


def test_sym_sub_range_pred() -> None:
c = Circuit(1, 2)
c.H(0, condition=reg_eq(BitRegister("c", 2), 3))
c1 = c.copy()
c.symbol_substitution({Symbol("a"): 0.5})

assert c == c1


if __name__ == "__main__":
test_wasm()
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.17"
version = "1.3.18"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
46 changes: 41 additions & 5 deletions tket/include/tket/Ops/ClassicalOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <memory>

#include "Op.hpp"
#include "OpPtr.hpp"
#include "tket/Utils/Json.hpp"

namespace tket {
Expand All @@ -47,11 +48,6 @@ class ClassicalOp : public Op {
OpType type, unsigned n_i, unsigned n_io, unsigned n_o,
const std::string &name = "");

// Trivial overrides
Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<ClassicalOp>(*this);
}
SymSet free_symbols() const override { return {}; }
unsigned n_qubits() const override { return 0; }

Expand Down Expand Up @@ -141,6 +137,11 @@ class ClassicalTransformOp : public ClassicalEvalOp {
unsigned n, const std::vector<uint64_t> &values,
const std::string &name = "ClassicalTransform");

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<ClassicalTransformOp>(*this);
}

std::vector<bool> eval(const std::vector<bool> &x) const override;

std::vector<uint64_t> get_values() const { return values_; }
Expand Down Expand Up @@ -171,6 +172,11 @@ class WASMOp : public ClassicalOp {
std::vector<unsigned> _width_o_parameter, const std::string &_func_name,
const std::string &_wasm_uid);

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<WASMOp>(*this);
}

/**
* return if the op is external
*/
Expand Down Expand Up @@ -284,6 +290,11 @@ class SetBitsOp : public ClassicalEvalOp {
: ClassicalEvalOp(OpType::SetBits, 0, 0, values.size(), "SetBits"),
values_(values) {}

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<SetBitsOp>(*this);
}

std::string get_name(bool latex) const override;

std::vector<bool> get_values() const { return values_; }
Expand All @@ -304,6 +315,11 @@ class CopyBitsOp : public ClassicalEvalOp {
explicit CopyBitsOp(unsigned n)
: ClassicalEvalOp(OpType::CopyBits, n, 0, n, "CopyBits") {}

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<CopyBitsOp>(*this);
}

std::vector<bool> eval(const std::vector<bool> &x) const override;
};

Expand Down Expand Up @@ -345,6 +361,11 @@ class RangePredicateOp : public PredicateOp {
uint64_t b = std::numeric_limits<uint64_t>::max())
: PredicateOp(OpType::RangePredicate, n, "RangePredicate"), a(a), b(b) {}

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<RangePredicateOp>(*this);
}

std::string get_name(bool latex) const override;

uint64_t upper() const { return b; }
Expand Down Expand Up @@ -384,6 +405,11 @@ class ExplicitPredicateOp : public PredicateOp {
unsigned n, const std::vector<bool> &values,
const std::string &name = "ExplicitPredicate");

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<ExplicitPredicateOp>(*this);
}

std::vector<bool> eval(const std::vector<bool> &x) const override;

std::vector<bool> get_values() const { return values_; }
Expand Down Expand Up @@ -430,6 +456,11 @@ class ExplicitModifierOp : public ModifyingOp {
unsigned n, const std::vector<bool> &values,
const std::string &name = "ExplicitModifier");

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<ExplicitModifierOp>(*this);
}

std::vector<bool> eval(const std::vector<bool> &x) const override;

std::vector<bool> get_values() const { return values_; }
Expand All @@ -448,6 +479,11 @@ class MultiBitOp : public ClassicalEvalOp {
public:
MultiBitOp(std::shared_ptr<const ClassicalEvalOp> op, unsigned n);

Op_ptr symbol_substitution(
const SymEngine::map_basic_basic &) const override {
return std::make_shared<MultiBitOp>(*this);
}

std::string get_name(bool latex) const override;

std::shared_ptr<const ClassicalEvalOp> get_op() const { return op_; }
Expand Down
24 changes: 24 additions & 0 deletions tket/test/src/Circuit/test_Symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <catch2/catch_test_macros.hpp>
#include <tket/Circuit/Circuit.hpp>
#include <tket/Ops/ClassicalOps.hpp>
#include <tket/Transformations/BasicOptimisation.hpp>
#include <tket/Transformations/CliffordOptimisation.hpp>
#include <tket/Transformations/OptimisationPass.hpp>
Expand Down Expand Up @@ -259,5 +260,28 @@ SCENARIO("Symbolic GPI, GPI2, AAMS") {
}
}

SCENARIO("Symbolic substitution for classical operations") {
std::vector<uint64_t> and_table = {0, 1, 2, 7, 0, 1, 2, 7};
std::shared_ptr<ClassicalTransformOp> and_ttop =
std::make_shared<ClassicalTransformOp>(3, and_table);
std::shared_ptr<RangePredicateOp> rpop =
std::make_shared<RangePredicateOp>(3, 2, 6);
Circuit circ(1, 4);
circ.add_op<unsigned>(OpType::H, {0});
circ.add_op<unsigned>(and_ttop, {0, 1, 2});
circ.add_op<unsigned>(and_ttop, {1, 2, 3});
circ.add_op<unsigned>(rpop, {0, 1, 2, 3});
circ.add_op<unsigned>(AndOp(), {2, 3, 0});
circ.add_op<unsigned>(OrOp(), {0, 1, 2});
circ.add_op<unsigned>(NotOp(), {2, 3});
circ.add_op<unsigned>(ClassicalX(), {1});
circ.add_op<unsigned>(ClassicalCX(), {0, 1});
circ.add_op<unsigned>(AndWithOp(), {2, 3});
Circuit circ1(circ);
symbol_map_t symmap;
circ1.symbol_substitution(symmap);
REQUIRE(circ == circ1);
}

} // namespace test_Symbolic
} // namespace tket
Loading