diff --git a/pytket/CMakeLists.txt b/pytket/CMakeLists.txt index a99d09ef13..4c58fc2a6c 100644 --- a/pytket/CMakeLists.txt +++ b/pytket/CMakeLists.txt @@ -113,6 +113,7 @@ pybind11_add_module(circuit binders/circuit/Circuit/add_op.cpp binders/circuit/Circuit/main.cpp binders/circuit/classical.cpp + binders/circuit/clexpr.cpp binders/circuit/main.cpp ${HEADER_FILES}) target_include_directories(circuit PRIVATE binders/include) diff --git a/pytket/binders/circuit/Circuit/add_op.cpp b/pytket/binders/circuit/Circuit/add_op.cpp index 20447d429d..c8c952d2f3 100644 --- a/pytket/binders/circuit/Circuit/add_op.cpp +++ b/pytket/binders/circuit/Circuit/add_op.cpp @@ -13,8 +13,10 @@ // limitations under the License. #include +#include #include +#include #include #include @@ -33,6 +35,7 @@ #include "tket/Circuit/ToffoliBox.hpp" #include "tket/Converters/PhasePoly.hpp" #include "tket/Gate/OpPtrFunctions.hpp" +#include "tket/Ops/ClExpr.hpp" #include "tket/Utils/UnitID.hpp" #include "typecast.hpp" namespace py = pybind11; @@ -483,6 +486,19 @@ void init_circuit_add_op(py::class_> &c) { ":param args: Indices of the qubits to append the box to" "\n:return: the new :py:class:`Circuit`", py::arg("expression"), py::arg("target")) + .def( + "add_clexpr", + [](Circuit *circ, const WiredClExpr &expr, + const py::tket_custom::SequenceVec &args, + const py::kwargs &kwargs) { + Op_ptr op = std::make_shared(expr); + return add_gate_method(circ, op, args, kwargs); + }, + "Append a :py:class:`WiredClExpr` to the circuit.\n\n" + ":param expr: The expression to append\n" + ":param args: The bits to apply the expression to\n" + ":return: the new :py:class:`Circuit`", + py::arg("expr"), py::arg("args")) .def( "add_custom_gate", [](Circuit *circ, const composite_def_ptr_t &definition, diff --git a/pytket/binders/circuit/clexpr.cpp b/pytket/binders/circuit/clexpr.cpp new file mode 100644 index 0000000000..f45559f585 --- /dev/null +++ b/pytket/binders/circuit/clexpr.cpp @@ -0,0 +1,554 @@ +// Copyright 2019-2024 Cambridge Quantum Computing +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tket/Ops/ClExpr.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +#include "UnitRegister.hpp" +#include "binder_json.hpp" +#include "deleted_hash.hpp" +#include "py_operators.hpp" + +namespace py = pybind11; + +namespace tket { + +static std::string qasm_bit_repr( + const ClExprTerm &term, const std::map &input_bits) { + if (const int *n = std::get_if(&term)) { + switch (*n) { + case 0: + return "0"; + case 1: + return "1"; + default: + throw std::logic_error("Invalid integer in bit operation"); + } + } else { + ClExprVar var = std::get(term); + if (const ClBitVar *bvar = std::get_if(&var)) { + const Bit b = input_bits.at(bvar->index); + return b.repr(); + } else { + throw std::logic_error("Expected bit variable, found register variable"); + } + } +} + +static std::string qasm_reg_repr( + const ClExprTerm &term, const std::map &input_regs) { + if (const int *n = std::get_if(&term)) { + std::stringstream ss; + ss << *n; + return ss.str(); + } else { + ClExprVar var = std::get(term); + if (const ClRegVar *rvar = std::get_if(&var)) { + const BitRegister r = input_regs.at(rvar->index); + return r.name(); + } else { + throw std::logic_error("Expected register variable, found bit variable"); + } + } +} + +enum class ArgValueType { Bit, Reg }; + +static std::string qasm_expr_repr( + const ClExpr &expr, const std::map &input_bits, + const std::map &input_regs); + +static std::string qasm_arg_repr( + const ClExprArg &arg, const std::map &input_bits, + const std::map &input_regs, const ArgValueType typ) { + if (const ClExpr *expr = std::get_if(&arg)) { + return qasm_expr_repr(*expr, input_bits, input_regs); + } else { + if (typ == ArgValueType::Bit) { + return qasm_bit_repr(std::get(arg), input_bits); + } else { + return qasm_reg_repr(std::get(arg), input_regs); + } + } +} + +static std::string qasm_expr_repr( + const ClExpr &expr, const std::map &input_bits, + const std::map &input_regs) { + const ClOp op = expr.get_op(); + const std::vector args = expr.get_args(); + const unsigned n_args = args.size(); + std::stringstream ss; + ss << "("; + switch (op) { + case ClOp::INVALID: + throw std::logic_error("Invalid expression."); + + case ClOp::BitAnd: + if (n_args == 0) { + ss << "1"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Bit); + if (i + 1 < n_args) { + ss << " & "; + } + } + } + break; + + case ClOp::BitOr: + if (n_args == 0) { + ss << "0"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Bit); + if (i + 1 < n_args) { + ss << " | "; + } + } + } + break; + + case ClOp::BitXor: + if (n_args == 0) { + ss << "0"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Bit); + if (i + 1 < n_args) { + ss << " ^ "; + } + } + } + break; + + case ClOp::BitEq: + if (n_args != 2) { + throw std::logic_error("BitEq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Bit); + ss << " == "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Bit); + break; + + case ClOp::BitNeq: + if (n_args != 2) { + throw std::logic_error("BitNeq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Bit) + << " != " + << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Bit); + break; + + case ClOp::BitNot: + if (n_args != 1) { + throw std::logic_error("BitNot with != 1 argument"); + } + ss << "~" + << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Bit); + break; + + case ClOp::BitZero: + if (n_args != 0) { + throw std::logic_error("BitZero with != 0 arguments"); + } + ss << "0"; + break; + + case ClOp::BitOne: + if (n_args != 0) { + throw std::logic_error("BitOne with != 0 arguments"); + } + ss << "1"; + break; + + case ClOp::RegAnd: + if (n_args == 0) { + ss << "-1"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Reg); + if (i + 1 < n_args) { + ss << " & "; + } + } + } + break; + + case ClOp::RegOr: + if (n_args == 0) { + ss << "0"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Reg); + if (i + 1 < n_args) { + ss << " | "; + } + } + } + break; + + case ClOp::RegXor: + if (n_args == 0) { + ss << "0"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Reg); + if (i + 1 < n_args) { + ss << " ^ "; + } + } + } + break; + + case ClOp::RegEq: + if (n_args != 2) { + throw std::logic_error("RegEq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " == "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegNeq: + if (n_args != 2) { + throw std::logic_error("RegNeq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg) + << " != " + << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegNot: + if (n_args != 1) { + throw std::logic_error("RegNot with != 1 argument"); + } + ss << "~" + << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegZero: + if (n_args != 0) { + throw std::logic_error("RegZero with != 0 arguments"); + } + ss << "0"; + break; + + case ClOp::RegOne: + if (n_args != 0) { + throw std::logic_error("RegOne with != 0 arguments"); + } + ss << "-1"; + break; + + case ClOp::RegLt: + if (n_args != 2) { + throw std::logic_error("RegLt with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " < "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegGt: + if (n_args != 2) { + throw std::logic_error("RegGt with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " > "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegLeq: + if (n_args != 2) { + throw std::logic_error("RegLeq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " <= "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegGeq: + if (n_args != 2) { + throw std::logic_error("RegGeq with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " >= "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegAdd: + if (n_args == 0) { + ss << "0"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Reg); + if (i + 1 < n_args) { + ss << " + "; + } + } + } + break; + + case ClOp::RegSub: + if (n_args != 2) { + throw std::logic_error("RegSub with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " - "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegMul: + if (n_args == 0) { + ss << "1"; + } else { + for (unsigned i = 0; i < n_args; i++) { + ss << qasm_arg_repr( + args[i], input_bits, input_regs, ArgValueType::Reg); + if (i + 1 < n_args) { + ss << " * "; + } + } + } + break; + + case ClOp::RegDiv: + if (n_args != 2) { + throw std::logic_error("RegDiv with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " / "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegPow: + if (n_args != 2) { + throw std::logic_error("RegPow with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " ** "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegLsh: + if (n_args != 2) { + throw std::logic_error("RegLsh with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " << "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegRsh: + if (n_args != 2) { + throw std::logic_error("RegRsh with != 2 arguments"); + } + ss << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + ss << " >> "; + ss << qasm_arg_repr(args[1], input_bits, input_regs, ArgValueType::Reg); + break; + + case ClOp::RegNeg: + if (n_args != 1) { + throw std::logic_error("RegNeg with != 1 argument"); + } + ss << "-" + << qasm_arg_repr(args[0], input_bits, input_regs, ArgValueType::Reg); + break; + } + ss << ")"; + return ss.str(); +} + +void init_clexpr(py::module &m) { + py::enum_(m, "ClOp", "A classical operation", py::arithmetic()) + .value("INVALID", ClOp::INVALID, "Invalid") + .value("BitAnd", ClOp::BitAnd, "Bitwise AND") + .value("BitOr", ClOp::BitOr, "Bitwise OR") + .value("BitXor", ClOp::BitXor, "Bitwise XOR") + .value("BitEq", ClOp::BitEq, "Bitwise equality") + .value("BitNeq", ClOp::BitNeq, "Bitwise inequality") + .value("BitNot", ClOp::BitNot, "Bitwise NOT") + .value("BitZero", ClOp::BitZero, "Constant zero bit") + .value("BitOne", ClOp::BitOne, "Constant one bit") + .value("RegAnd", ClOp::RegAnd, "Registerwise AND") + .value("RegOr", ClOp::RegOr, "Registerwise OR") + .value("RegXor", ClOp::RegXor, "Registerwise XOR") + .value("RegEq", ClOp::RegEq, "Registerwise equality") + .value("RegNeq", ClOp::RegNeq, "Registerwise inequality") + .value("RegNot", ClOp::RegNot, "Registerwise NOT") + .value("RegZero", ClOp::RegZero, "Constant all-zeros register") + .value("RegOne", ClOp::RegOne, "Constant all-ones register") + .value("RegLt", ClOp::RegLt, "Integer less-than comparison") + .value("RegGt", ClOp::RegGt, "Integer greater-than comparison") + .value("RegLeq", ClOp::RegLeq, "Integer less-than-or-equal comparison") + .value("RegGeq", ClOp::RegGeq, "Integer greater-than-or-equal comparison") + .value("RegAdd", ClOp::RegAdd, "Integer addition") + .value("RegSub", ClOp::RegSub, "Integer subtraction") + .value("RegMul", ClOp::RegMul, "Integer multiplication") + .value("RegDiv", ClOp::RegDiv, "Integer division") + .value("RegPow", ClOp::RegPow, "Integer exponentiation") + .value("RegLsh", ClOp::RegLsh, "Left shift") + .value("RegRsh", ClOp::RegRsh, "Right shift") + .value("RegNeg", ClOp::RegNeg, "Integer negation"); + + py::class_>( + m, "ClBitVar", "A bit variable within an expression") + .def( + py::init(), "Construct from an integer identifier", + py::arg("i")) + .def("__eq__", &py_equals) + .def( + "__str__", + [](const ClBitVar &var) { + std::stringstream ss; + ss << var; + return ss.str(); + }) + .def( + "__repr__", + [](const ClBitVar &var) { + std::stringstream ss; + ss << "ClBitVar(" << var.index << ")"; + return ss.str(); + }) + .def("__hash__", [](const ClBitVar &var) { return var.index; }) + .def_property_readonly( + "index", [](const ClBitVar &var) { return var.index; }, + ":return: integer identifier for the variable"); + + py::class_>( + m, "ClRegVar", "A register variable within an expression") + .def( + py::init(), "Construct from an integer identifier", + py::arg("i")) + .def("__eq__", &py_equals) + .def( + "__str__", + [](const ClRegVar &var) { + std::stringstream ss; + ss << var; + return ss.str(); + }) + .def( + "__repr__", + [](const ClRegVar &var) { + std::stringstream ss; + ss << "ClRegVar(" << var.index << ")"; + return ss.str(); + }) + .def("__hash__", [](const ClRegVar &var) { return var.index; }) + .def_property_readonly( + "index", [](const ClRegVar &var) { return var.index; }, + ":return: integer identifier for the variable"); + + py::class_>( + m, "ClExpr", "A classical expression") + .def( + py::init>(), + "Construct from an operation and a list of arguments", py::arg("op"), + py::arg("args")) + .def("__eq__", &py_equals) + .def( + "__str__", + [](const ClExpr &expr) { + std::stringstream ss; + ss << expr; + return ss.str(); + }) + .def("__hash__", &deletedHash, deletedHashDocstring) + .def_property_readonly("op", &ClExpr::get_op, ":return: main operation") + .def_property_readonly("args", &ClExpr::get_args, ":return: arguments") + .def( + "as_qasm", + [](const ClExpr &expr, const std::map input_bits, + const std::map input_regs) -> std::string { + return qasm_expr_repr(expr, input_bits, input_regs); + }, + "QASM-style string representation given corresponding bits and " + "registers", + py::arg("input_bits"), py::arg("input_regs")); + + py::class_>( + m, "WiredClExpr", + "A classical expression defined over a sequence of bits") + .def( + py::init< + ClExpr, std::map, + std::map>, + std::vector>(), + "Construct from an expression with bit and register positions", + py::arg("expr"), py::arg("bit_posn") = std::map(), + py::arg("reg_posn") = std::map>(), + py::arg("output_posn") = std::vector()) + .def("__eq__", &py_equals) + .def( + "__str__", + [](const WiredClExpr &expr) { + std::stringstream ss; + ss << expr; + return ss.str(); + }) + .def("__hash__", &deletedHash, deletedHashDocstring) + .def_property_readonly( + "expr", &WiredClExpr::get_expr, ":return: expression") + .def_property_readonly( + "bit_posn", &WiredClExpr::get_bit_posn, ":return: bit positions") + .def_property_readonly( + "reg_posn", &WiredClExpr::get_reg_posn, ":return: register positions") + .def_property_readonly( + "output_posn", &WiredClExpr::get_output_posn, + ":return: output positions") + .def( + "to_dict", + [](const WiredClExpr &wexpr) { + return py::object(nlohmann::json(wexpr)).cast(); + }, + ":return: JSON-serializable dict representation") + .def_static( + "from_dict", + [](const py::dict &wexpr_dict) { + return nlohmann::json(wexpr_dict).get(); + }, + "Construct from JSON-serializable dict representation"); + + py::class_, Op>( + m, "ClExprOp", "An operation defined by a classical expression") + .def( + py::init(), + "Construct from a wired classical expression") + .def_property_readonly( + "type", &ClExprOp::get_type, ":return: operation type") + .def_property_readonly( + "expr", &ClExprOp::get_wired_expr, ":return: wired expression"); +} + +} // namespace tket diff --git a/pytket/binders/circuit/main.cpp b/pytket/binders/circuit/main.cpp index 4ada6280bf..ba4d2853f3 100644 --- a/pytket/binders/circuit/main.cpp +++ b/pytket/binders/circuit/main.cpp @@ -25,6 +25,7 @@ #include "tket/Circuit/Command.hpp" #include "tket/Gate/OpPtrFunctions.hpp" #include "tket/Gate/SymTable.hpp" +#include "tket/OpType/OpType.hpp" #include "tket/Ops/BarrierOp.hpp" #include "tket/Ops/MetaOp.hpp" #include "tket/Ops/Op.hpp" @@ -42,6 +43,7 @@ typedef py::tket_custom::SequenceVec py_unit_vector_t; void def_circuit(py::class_> &); void init_classical(py::module &m); void init_boxes(py::module &m); +void init_clexpr(py::module &m); PYBIND11_MODULE(circuit, m) { py::module::import("pytket._tket.unit_id"); @@ -545,6 +547,7 @@ PYBIND11_MODULE(circuit, m) { "DiagonalBox", OpType::DiagonalBox, "A box for synthesising a diagonal unitary matrix into a sequence of " "multiplexed-Rz gates") + .value("ClExpr", OpType::ClExpr, "A classical expression") .def_static( "from_name", [](const py::str &name) { return json(name).get(); }, @@ -688,6 +691,7 @@ PYBIND11_MODULE(circuit, m) { "result in bit 0"); init_boxes(m); init_classical(m); + init_clexpr(m); def_circuit(pyCircuit); m.def( diff --git a/pytket/conanfile.py b/pytket/conanfile.py index bbba711fbf..8674f27a3d 100644 --- a/pytket/conanfile.py +++ b/pytket/conanfile.py @@ -38,7 +38,7 @@ def requirements(self): self.requires("pybind11_json/0.2.14") self.requires("symengine/0.12.0") self.requires("tkassert/0.3.4@tket/stable") - self.requires("tket/1.3.33@tket/stable") + self.requires("tket/1.3.34@tket/stable") self.requires("tklog/0.3.3@tket/stable") self.requires("tkrng/0.3.3@tket/stable") self.requires("tktokenswap/0.3.9@tket/stable") diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index b49720e1b1..88e3db78c8 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -4,9 +4,19 @@ Changelog 1.33.2 (Unreleased) ------------------- -* Support Python 3.13. +Features: + +* Add new `ClExprOp` operation type as an alternative to `ClassicalExpBox`; add + option to use this when converting from QASM. + +Fixes: + * Fix small default display screen for circuit renderer. +General: + +* Support Python 3.13. + 1.33.1 (October 2024) --------------------- diff --git a/pytket/pytket/_tket/circuit.pyi b/pytket/pytket/_tket/circuit.pyi index 30e5106e1c..b89daa701e 100644 --- a/pytket/pytket/_tket/circuit.pyi +++ b/pytket/pytket/_tket/circuit.pyi @@ -11,7 +11,7 @@ import pytket.circuit.logic_exp import pytket.wasm.wasm import sympy import typing -__all__ = ['BarrierOp', 'BasisOrder', 'CXConfigType', 'CircBox', 'Circuit', 'ClassicalEvalOp', 'ClassicalExpBox', 'ClassicalOp', 'Command', 'Conditional', 'ConjugationBox', 'CopyBitsOp', 'CustomGate', 'CustomGateDef', 'DiagonalBox', 'DummyBox', 'EdgeType', 'ExpBox', 'MetaOp', 'MultiBitOp', 'MultiplexedRotationBox', 'MultiplexedTensoredU2Box', 'MultiplexedU2Box', 'MultiplexorBox', 'Op', 'OpType', 'PauliExpBox', 'PauliExpCommutingSetBox', 'PauliExpPairBox', 'PhasePolyBox', 'ProjectorAssertionBox', 'QControlBox', 'RangePredicateOp', 'ResourceBounds', 'ResourceData', 'SetBitsOp', 'StabiliserAssertionBox', 'StatePreparationBox', 'TermSequenceBox', 'ToffoliBox', 'ToffoliBoxSynthStrat', 'Unitary1qBox', 'Unitary2qBox', 'Unitary3qBox', 'WASMOp', 'fresh_symbol'] +__all__ = ['BarrierOp', 'BasisOrder', 'CXConfigType', 'CircBox', 'Circuit', 'ClBitVar', 'ClExpr', 'ClExprOp', 'ClOp', 'ClRegVar', 'ClassicalEvalOp', 'ClassicalExpBox', 'ClassicalOp', 'Command', 'Conditional', 'ConjugationBox', 'CopyBitsOp', 'CustomGate', 'CustomGateDef', 'DiagonalBox', 'DummyBox', 'EdgeType', 'ExpBox', 'MetaOp', 'MultiBitOp', 'MultiplexedRotationBox', 'MultiplexedTensoredU2Box', 'MultiplexedU2Box', 'MultiplexorBox', 'Op', 'OpType', 'PauliExpBox', 'PauliExpCommutingSetBox', 'PauliExpPairBox', 'PhasePolyBox', 'ProjectorAssertionBox', 'QControlBox', 'RangePredicateOp', 'ResourceBounds', 'ResourceData', 'SetBitsOp', 'StabiliserAssertionBox', 'StatePreparationBox', 'TermSequenceBox', 'ToffoliBox', 'ToffoliBoxSynthStrat', 'Unitary1qBox', 'Unitary2qBox', 'Unitary3qBox', 'WASMOp', 'WiredClExpr', 'fresh_symbol'] class BarrierOp(Op): """ Barrier operations. @@ -1451,6 +1451,14 @@ class Circuit: :param args: Indices of the qubits to append the box to :return: the new :py:class:`Circuit` """ + def add_clexpr(self, expr: WiredClExpr, args: typing.Sequence[pytket._tket.unit_id.Bit], **kwargs: Any) -> Circuit: + """ + Append a :py:class:`WiredClExpr` to the circuit. + + :param expr: The expression to append + :param args: The bits to apply the expression to + :return: the new :py:class:`Circuit` + """ @typing.overload def add_conditional_barrier(self, barrier_qubits: typing.Sequence[int], barrier_bits: typing.Sequence[int], condition_bits: typing.Sequence[int], value: int, data: str = '') -> Circuit: """ @@ -2540,6 +2548,239 @@ class Circuit: """ A list of all qubit ids in the circuit """ +class ClBitVar: + """ + A bit variable within an expression + """ + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + def __eq__(self, arg0: typing.Any) -> bool: + ... + def __hash__(self) -> int: + ... + def __init__(self, i: int) -> None: + """ + Construct from an integer identifier + """ + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... + @property + def index(self) -> int: + """ + :return: integer identifier for the variable + """ +class ClExpr: + """ + A classical expression + """ + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + def __eq__(self, arg0: typing.Any) -> bool: + ... + def __hash__(self) -> int: + """ + Hashing is not implemented for this class, attempting to hash an object will raise a type error + """ + def __init__(self, op: ClOp, args: list[int | ClBitVar | ClRegVar | ClExpr]) -> None: + """ + Construct from an operation and a list of arguments + """ + def __str__(self) -> str: + ... + def as_qasm(self, input_bits: dict[int, pytket._tket.unit_id.Bit], input_regs: dict[int, pytket._tket.unit_id.BitRegister]) -> str: + """ + QASM-style string representation given corresponding bits and registers + """ + @property + def args(self) -> list[int | ClBitVar | ClRegVar | ClExpr]: + """ + :return: arguments + """ + @property + def op(self) -> ClOp: + """ + :return: main operation + """ +class ClExprOp(Op): + """ + An operation defined by a classical expression + """ + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + def __init__(self, arg0: WiredClExpr) -> None: + """ + Construct from a wired classical expression + """ + @property + def expr(self) -> WiredClExpr: + """ + :return: wired expression + """ + @property + def type(self) -> OpType: + """ + :return: operation type + """ +class ClOp: + """ + A classical operation + + Members: + + INVALID : Invalid + + BitAnd : Bitwise AND + + BitOr : Bitwise OR + + BitXor : Bitwise XOR + + BitEq : Bitwise equality + + BitNeq : Bitwise inequality + + BitNot : Bitwise NOT + + BitZero : Constant zero bit + + BitOne : Constant one bit + + RegAnd : Registerwise AND + + RegOr : Registerwise OR + + RegXor : Registerwise XOR + + RegEq : Registerwise equality + + RegNeq : Registerwise inequality + + RegNot : Registerwise NOT + + RegZero : Constant all-zeros register + + RegOne : Constant all-ones register + + RegLt : Integer less-than comparison + + RegGt : Integer greater-than comparison + + RegLeq : Integer less-than-or-equal comparison + + RegGeq : Integer greater-than-or-equal comparison + + RegAdd : Integer addition + + RegSub : Integer subtraction + + RegMul : Integer multiplication + + RegDiv : Integer division + + RegPow : Integer exponentiation + + RegLsh : Left shift + + RegRsh : Right shift + + RegNeg : Integer negation + """ + BitAnd: typing.ClassVar[ClOp] # value = + BitEq: typing.ClassVar[ClOp] # value = + BitNeq: typing.ClassVar[ClOp] # value = + BitNot: typing.ClassVar[ClOp] # value = + BitOne: typing.ClassVar[ClOp] # value = + BitOr: typing.ClassVar[ClOp] # value = + BitXor: typing.ClassVar[ClOp] # value = + BitZero: typing.ClassVar[ClOp] # value = + INVALID: typing.ClassVar[ClOp] # value = + RegAdd: typing.ClassVar[ClOp] # value = + RegAnd: typing.ClassVar[ClOp] # value = + RegDiv: typing.ClassVar[ClOp] # value = + RegEq: typing.ClassVar[ClOp] # value = + RegGeq: typing.ClassVar[ClOp] # value = + RegGt: typing.ClassVar[ClOp] # value = + RegLeq: typing.ClassVar[ClOp] # value = + RegLsh: typing.ClassVar[ClOp] # value = + RegLt: typing.ClassVar[ClOp] # value = + RegMul: typing.ClassVar[ClOp] # value = + RegNeg: typing.ClassVar[ClOp] # value = + RegNeq: typing.ClassVar[ClOp] # value = + RegNot: typing.ClassVar[ClOp] # value = + RegOne: typing.ClassVar[ClOp] # value = + RegOr: typing.ClassVar[ClOp] # value = + RegPow: typing.ClassVar[ClOp] # value = + RegRsh: typing.ClassVar[ClOp] # value = + RegSub: typing.ClassVar[ClOp] # value = + RegXor: typing.ClassVar[ClOp] # value = + RegZero: typing.ClassVar[ClOp] # value = + __members__: typing.ClassVar[dict[str, ClOp]] # value = {'INVALID': , 'BitAnd': , 'BitOr': , 'BitXor': , 'BitEq': , 'BitNeq': , 'BitNot': , 'BitZero': , 'BitOne': , 'RegAnd': , 'RegOr': , 'RegXor': , 'RegEq': , 'RegNeq': , 'RegNot': , 'RegZero': , 'RegOne': , 'RegLt': , 'RegGt': , 'RegLeq': , 'RegGeq': , 'RegAdd': , 'RegSub': , 'RegMul': , 'RegDiv': , 'RegPow': , 'RegLsh': , 'RegRsh': , 'RegNeg': } + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + def __eq__(self, other: typing.Any) -> bool: + ... + def __ge__(self, other: typing.Any) -> bool: + ... + def __getstate__(self) -> int: + ... + def __gt__(self, other: typing.Any) -> bool: + ... + def __hash__(self) -> int: + ... + def __index__(self) -> int: + ... + def __init__(self, value: int) -> None: + ... + def __int__(self) -> int: + ... + def __le__(self, other: typing.Any) -> bool: + ... + def __lt__(self, other: typing.Any) -> bool: + ... + def __ne__(self, other: typing.Any) -> bool: + ... + def __repr__(self) -> str: + ... + def __setstate__(self, state: int) -> None: + ... + def __str__(self) -> str: + ... + @property + def name(self) -> str: + ... + @property + def value(self) -> int: + ... +class ClRegVar: + """ + A register variable within an expression + """ + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + def __eq__(self, arg0: typing.Any) -> bool: + ... + def __hash__(self) -> int: + ... + def __init__(self, i: int) -> None: + """ + Construct from an integer identifier + """ + def __repr__(self) -> str: + ... + def __str__(self) -> str: + ... + @property + def index(self) -> int: + """ + :return: integer identifier for the variable + """ class ClassicalEvalOp(ClassicalOp): """ Evaluatable classical operation. @@ -3378,6 +3619,8 @@ class OpType: StatePreparationBox : A box for preparing quantum states using multiplexed-Ry and multiplexed-Rz gates DiagonalBox : A box for synthesising a diagonal unitary matrix into a sequence of multiplexed-Rz gates + + ClExpr : A classical expression """ AAMS: typing.ClassVar[OpType] # value = BRIDGE: typing.ClassVar[OpType] # value = @@ -3401,6 +3644,7 @@ class OpType: CY: typing.ClassVar[OpType] # value = CZ: typing.ClassVar[OpType] # value = CircBox: typing.ClassVar[OpType] # value = + ClExpr: typing.ClassVar[OpType] # value = ClassicalExpBox: typing.ClassVar[OpType] # value = ClassicalTransform: typing.ClassVar[OpType] # value = CnRx: typing.ClassVar[OpType] # value = @@ -3480,7 +3724,7 @@ class OpType: Z: typing.ClassVar[OpType] # value = ZZMax: typing.ClassVar[OpType] # value = ZZPhase: typing.ClassVar[OpType] # value = - __members__: typing.ClassVar[dict[str, OpType]] # value = {'Phase': , 'Z': , 'X': , 'Y': , 'S': , 'Sdg': , 'T': , 'Tdg': , 'V': , 'Vdg': , 'SX': , 'SXdg': , 'H': , 'Rx': , 'Ry': , 'Rz': , 'U1': , 'U2': , 'U3': , 'GPI': , 'GPI2': , 'AAMS': , 'TK1': , 'TK2': , 'CX': , 'CY': , 'CZ': , 'CH': , 'CV': , 'CVdg': , 'CSX': , 'CSXdg': , 'CS': , 'CSdg': , 'CRz': , 'CRx': , 'CRy': , 'CU1': , 'CU3': , 'CCX': , 'ECR': , 'SWAP': , 'CSWAP': , 'noop': , 'Barrier': , 'Label': , 'Branch': , 'Goto': , 'Stop': , 'BRIDGE': , 'Measure': , 'Reset': , 'CircBox': , 'PhasePolyBox': , 'Unitary1qBox': , 'Unitary2qBox': , 'Unitary3qBox': , 'ExpBox': , 'PauliExpBox': , 'PauliExpPairBox': , 'PauliExpCommutingSetBox': , 'TermSequenceBox': , 'QControlBox': , 'ToffoliBox': , 'ConjugationBox': , 'DummyBox': , 'CustomGate': , 'Conditional': , 'ISWAP': , 'PhasedISWAP': , 'XXPhase': , 'YYPhase': , 'ZZPhase': , 'XXPhase3': , 'PhasedX': , 'NPhasedX': , 'CnRx': , 'CnRy': , 'CnRz': , 'CnX': , 'CnY': , 'CnZ': , 'ZZMax': , 'ESWAP': , 'FSim': , 'Sycamore': , 'ISWAPMax': , 'ClassicalTransform': , 'WASM': , 'SetBits': , 'CopyBits': , 'RangePredicate': , 'ExplicitPredicate': , 'ExplicitModifier': , 'MultiBit': , 'ClassicalExpBox': , 'MultiplexorBox': , 'MultiplexedRotationBox': , 'MultiplexedU2Box': , 'MultiplexedTensoredU2Box': , 'StatePreparationBox': , 'DiagonalBox': } + __members__: typing.ClassVar[dict[str, OpType]] # value = {'Phase': , 'Z': , 'X': , 'Y': , 'S': , 'Sdg': , 'T': , 'Tdg': , 'V': , 'Vdg': , 'SX': , 'SXdg': , 'H': , 'Rx': , 'Ry': , 'Rz': , 'U1': , 'U2': , 'U3': , 'GPI': , 'GPI2': , 'AAMS': , 'TK1': , 'TK2': , 'CX': , 'CY': , 'CZ': , 'CH': , 'CV': , 'CVdg': , 'CSX': , 'CSXdg': , 'CS': , 'CSdg': , 'CRz': , 'CRx': , 'CRy': , 'CU1': , 'CU3': , 'CCX': , 'ECR': , 'SWAP': , 'CSWAP': , 'noop': , 'Barrier': , 'Label': , 'Branch': , 'Goto': , 'Stop': , 'BRIDGE': , 'Measure': , 'Reset': , 'CircBox': , 'PhasePolyBox': , 'Unitary1qBox': , 'Unitary2qBox': , 'Unitary3qBox': , 'ExpBox': , 'PauliExpBox': , 'PauliExpPairBox': , 'PauliExpCommutingSetBox': , 'TermSequenceBox': , 'QControlBox': , 'ToffoliBox': , 'ConjugationBox': , 'DummyBox': , 'CustomGate': , 'Conditional': , 'ISWAP': , 'PhasedISWAP': , 'XXPhase': , 'YYPhase': , 'ZZPhase': , 'XXPhase3': , 'PhasedX': , 'NPhasedX': , 'CnRx': , 'CnRy': , 'CnRz': , 'CnX': , 'CnY': , 'CnZ': , 'ZZMax': , 'ESWAP': , 'FSim': , 'Sycamore': , 'ISWAPMax': , 'ClassicalTransform': , 'WASM': , 'SetBits': , 'CopyBits': , 'RangePredicate': , 'ExplicitPredicate': , 'ExplicitModifier': , 'MultiBit': , 'ClassicalExpBox': , 'MultiplexorBox': , 'MultiplexedRotationBox': , 'MultiplexedU2Box': , 'MultiplexedTensoredU2Box': , 'StatePreparationBox': , 'DiagonalBox': , 'ClExpr': } noop: typing.ClassVar[OpType] # value = @staticmethod def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore @@ -4143,6 +4387,54 @@ class WASMOp(ClassicalOp): """ Wasm module id. """ +class WiredClExpr: + """ + A classical expression defined over a sequence of bits + """ + @staticmethod + def _pybind11_conduit_v1_(*args, **kwargs): # type: ignore + ... + @staticmethod + def from_dict(arg0: dict) -> WiredClExpr: + """ + Construct from JSON-serializable dict representation + """ + def __eq__(self, arg0: typing.Any) -> bool: + ... + def __hash__(self) -> int: + """ + Hashing is not implemented for this class, attempting to hash an object will raise a type error + """ + def __init__(self, expr: ClExpr, bit_posn: dict[int, int] = {}, reg_posn: dict[int, list[int]] = {}, output_posn: list[int] = []) -> None: + """ + Construct from an expression with bit and register positions + """ + def __str__(self) -> str: + ... + def to_dict(self) -> dict: + """ + :return: JSON-serializable dict representation + """ + @property + def bit_posn(self) -> dict[int, int]: + """ + :return: bit positions + """ + @property + def expr(self) -> ClExpr: + """ + :return: expression + """ + @property + def output_posn(self) -> list[int]: + """ + :return: output positions + """ + @property + def reg_posn(self) -> dict[int, list[int]]: + """ + :return: register positions + """ def fresh_symbol(preferred: str = 'a') -> sympy.Symbol: """ Given some preferred symbol, this finds an appropriate suffix that will guarantee it has not yet been used in the current python session. diff --git a/pytket/pytket/circuit/clexpr.py b/pytket/pytket/circuit/clexpr.py new file mode 100644 index 0000000000..b6059c560c --- /dev/null +++ b/pytket/pytket/circuit/clexpr.py @@ -0,0 +1,163 @@ +# Copyright 2019-2024 Cambridge Quantum Computing +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pytket.circuit import ( + Bit, + BitRegister, + ClBitVar, + ClExpr, + ClOp, + ClRegVar, + WiredClExpr, +) +from pytket.circuit.logic_exp import Ops, BitWiseOp, RegWiseOp, LogicExp + +_reg_output_clops = set( + [ + ClOp.RegAnd, + ClOp.RegOr, + ClOp.RegXor, + ClOp.RegNot, + ClOp.RegZero, + ClOp.RegOne, + ClOp.RegAdd, + ClOp.RegSub, + ClOp.RegMul, + ClOp.RegDiv, + ClOp.RegPow, + ClOp.RegLsh, + ClOp.RegRsh, + ClOp.RegNeg, + ] +) + + +def has_reg_output(op: ClOp) -> bool: + return op in _reg_output_clops + + +def clop_from_ops(op: Ops) -> ClOp: + match op: + case BitWiseOp.AND: + return ClOp.BitAnd + case BitWiseOp.OR: + return ClOp.BitOr + case BitWiseOp.XOR: + return ClOp.BitXor + case BitWiseOp.EQ: + return ClOp.BitEq + case BitWiseOp.NEQ: + return ClOp.BitNeq + case BitWiseOp.NOT: + return ClOp.BitNot + case BitWiseOp.ZERO: + return ClOp.BitZero + case BitWiseOp.ONE: + return ClOp.BitOne + case RegWiseOp.AND: + return ClOp.RegAnd + case RegWiseOp.OR: + return ClOp.RegOr + case RegWiseOp.XOR: + return ClOp.RegXor + case RegWiseOp.EQ: + return ClOp.RegEq + case RegWiseOp.NEQ: + return ClOp.RegNeq + case RegWiseOp.LT: + return ClOp.RegLt + case RegWiseOp.GT: + return ClOp.RegGt + case RegWiseOp.LEQ: + return ClOp.RegLeq + case RegWiseOp.GEQ: + return ClOp.RegGeq + case RegWiseOp.ADD: + return ClOp.RegAdd + case RegWiseOp.SUB: + return ClOp.RegSub + case RegWiseOp.MUL: + return ClOp.RegMul + case RegWiseOp.DIV: + return ClOp.RegDiv + case RegWiseOp.POW: + return ClOp.RegPow + case RegWiseOp.LSH: + return ClOp.RegLsh + case RegWiseOp.RSH: + return ClOp.RegRsh + case RegWiseOp.NOT: + return ClOp.RegNot + case RegWiseOp.NEG: + return ClOp.RegNeg + + +@dataclass +class ExpressionConverter: + bit_indices: dict[Bit, int] + reg_indices: dict[BitRegister, int] + + def convert(self, exp: LogicExp) -> ClExpr: + op: ClOp = clop_from_ops(exp.op) + args: list[int | ClBitVar | ClRegVar | ClExpr] = [] + for arg in exp.args: + if isinstance(arg, LogicExp): + args.append(self.convert(arg)) + elif isinstance(arg, Bit): + args.append(ClBitVar(self.bit_indices[arg])) + elif isinstance(arg, BitRegister): + args.append(ClRegVar(self.reg_indices[arg])) + else: + assert isinstance(arg, int) + args.append(arg) + return ClExpr(op, args) + + +def wired_clexpr_from_logic_exp( + exp: LogicExp, output_bits: list[Bit] +) -> tuple[WiredClExpr, list[Bit]]: + """Convert a :py:class:`LogicExp` to a :py:class:`WiredClExpr` + + :param exp: the LogicExp + :param output_bits: list of output bits of the LogicExp + :return: the WiredClExpr and its full list of arguments + """ + # 1. Construct lists of input bits and registers (where the positions of the items + # in each list will be the indices of the corresponding variables in the ClExpr): + all_vars = exp.all_inputs_ordered() + input_bits: list[Bit] = [var for var in all_vars if isinstance(var, Bit)] + input_regs: list[BitRegister] = [ + var for var in all_vars if isinstance(var, BitRegister) + ] + # 2. Order the arguments: first the input bits, then all the bits in the input + # registers then any remaining output bits: + args = [] + args.extend(input_bits) + for r in input_regs: + args.extend(r.to_list()) + args.extend(b for b in output_bits if b not in args) + # 3. Construct the WiredClExpr and return it with the argument list: + return ( + WiredClExpr( + ExpressionConverter( + {b: i for i, b in enumerate(input_bits)}, + {r: i for i, r in enumerate(input_regs)}, + ).convert(exp), + {i: args.index(b) for i, b in enumerate(input_bits)}, + {i: [args.index(b) for b in r.to_list()] for i, r in enumerate(input_regs)}, + [args.index(b) for b in output_bits], + ), + args, + ) diff --git a/pytket/pytket/qasm/qasm.py b/pytket/pytket/qasm/qasm.py index 9a01d6e167..86cba07ed1 100644 --- a/pytket/pytket/qasm/qasm.py +++ b/pytket/pytket/qasm/qasm.py @@ -54,6 +54,9 @@ MultiBitOp, WASMOp, BarrierOp, + ClExprOp, + WiredClExpr, + ClExpr, ) from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE from pytket.circuit import ( @@ -66,6 +69,7 @@ QubitRegister, UnitID, ) +from pytket.circuit.clexpr import has_reg_output, wired_clexpr_from_logic_exp from pytket.circuit.decompose_classical import int_to_bools from pytket.circuit.logic_exp import ( BitLogicExp, @@ -298,7 +302,12 @@ def __iter__(self) -> Iterable[str]: class CircuitTransformer(Transformer): - def __init__(self, return_gate_dict: bool = False, maxwidth: int = 32) -> None: + def __init__( + self, + return_gate_dict: bool = False, + maxwidth: int = 32, + use_clexpr: bool = False, + ) -> None: super().__init__() self.q_registers: Dict[str, int] = {} self.c_registers: Dict[str, int] = {} @@ -307,6 +316,7 @@ def __init__(self, return_gate_dict: bool = False, maxwidth: int = 32) -> None: self.include = "" self.return_gate_dict = return_gate_dict self.maxwidth = maxwidth + self.use_clexpr = use_clexpr def _fresh_temp_bit(self) -> List: if _TEMP_BIT_NAME in self.c_registers: @@ -715,6 +725,29 @@ def _cexpbox_dict(self, exp: LogicExp, args: List[List]) -> CommandDict: }, } + def _clexpr_dict(self, exp: LogicExp, out_args: List[List]) -> CommandDict: + # Convert the LogicExp to a serialization of a command containing the + # corresponding ClExprOp. + wexpr, args = wired_clexpr_from_logic_exp( + exp, [Bit.from_list(arg) for arg in out_args] + ) + return { + "op": { + "type": "ClExpr", + "expr": wexpr.to_dict(), + }, + "args": [arg.to_list() for arg in args], + } + + def _logic_exp_as_cmd_dict( + self, exp: LogicExp, out_args: List[List] + ) -> CommandDict: + return ( + self._clexpr_dict(exp, out_args) + if self.use_clexpr + else self._cexpbox_dict(exp, out_args) + ) + def assign(self, tree: List) -> Iterable[CommandDict]: child_iter = iter(tree) out_args = list(next(child_iter)) @@ -752,7 +785,7 @@ def assign(self, tree: List) -> Iterable[CommandDict]: args = args_uids[0] if isinstance(out_arg, List): if isinstance(exp, LogicExp): - yield self._cexpbox_dict(exp, args) + yield self._logic_exp_as_cmd_dict(exp, args) elif isinstance(exp, (int, bool)): assert exp in (0, 1, True, False) yield { @@ -769,9 +802,9 @@ def assign(self, tree: List) -> Iterable[CommandDict]: else: reg = out_arg if isinstance(exp, RegLogicExp): - yield self._cexpbox_dict(exp, args) + yield self._logic_exp_as_cmd_dict(exp, args) elif isinstance(exp, BitLogicExp): - yield self._cexpbox_dict(exp, args[:1]) + yield self._logic_exp_as_cmd_dict(exp, args[:1]) elif isinstance(exp, int): yield { "args": args, @@ -926,38 +959,42 @@ def prog(self, tree: Iterable) -> Dict[str, Any]: return outdict -def parser(maxwidth: int) -> Lark: +def parser(maxwidth: int, use_clexpr: bool) -> Lark: return Lark( grammar, start="prog", debug=False, parser="lalr", cache=True, - transformer=CircuitTransformer(maxwidth=maxwidth), + transformer=CircuitTransformer(maxwidth=maxwidth, use_clexpr=use_clexpr), ) g_parser = None g_maxwidth = 32 +g_use_clexpr = False -def set_parser(maxwidth: int) -> None: - global g_parser, g_maxwidth - if (g_parser is None) or (g_maxwidth != maxwidth): # type: ignore - g_parser = parser(maxwidth=maxwidth) +def set_parser(maxwidth: int, use_clexpr: bool) -> None: + global g_parser, g_maxwidth, g_use_clexpr + if (g_parser is None) or (g_maxwidth != maxwidth) or g_use_clexpr != use_clexpr: # type: ignore + g_parser = parser(maxwidth=maxwidth, use_clexpr=use_clexpr) g_maxwidth = maxwidth + g_use_clexpr = use_clexpr def circuit_from_qasm( input_file: Union[str, "os.PathLike[Any]"], encoding: str = "utf-8", maxwidth: int = 32, + use_clexpr: bool = False, ) -> Circuit: """A method to generate a tket Circuit from a qasm file. :param input_file: path to qasm file; filename must have ``.qasm`` extension :param encoding: file encoding (default utf-8) :param maxwidth: maximum allowed width of classical registers (default 32) + :param use_clexpr: whether to use ClExprOp to represent classical expressions :return: pytket circuit """ ext = os.path.splitext(input_file)[-1] @@ -965,21 +1002,24 @@ def circuit_from_qasm( raise TypeError("Can only convert .qasm files") with open(input_file, "r", encoding=encoding) as f: try: - circ = circuit_from_qasm_io(f, maxwidth=maxwidth) + circ = circuit_from_qasm_io(f, maxwidth=maxwidth, use_clexpr=use_clexpr) except QASMParseError as e: raise QASMParseError(e.msg, e.line, str(input_file)) return circ -def circuit_from_qasm_str(qasm_str: str, maxwidth: int = 32) -> Circuit: +def circuit_from_qasm_str( + qasm_str: str, maxwidth: int = 32, use_clexpr: bool = False +) -> Circuit: """A method to generate a tket Circuit from a qasm string. :param qasm_str: qasm string :param maxwidth: maximum allowed width of classical registers (default 32) + :param use_clexpr: whether to use ClExprOp to represent classical expressions :return: pytket circuit """ global g_parser - set_parser(maxwidth=maxwidth) + set_parser(maxwidth=maxwidth, use_clexpr=use_clexpr) assert g_parser is not None cast(CircuitTransformer, g_parser.options.transformer)._reset_context( reset_wasm=False @@ -987,9 +1027,13 @@ def circuit_from_qasm_str(qasm_str: str, maxwidth: int = 32) -> Circuit: return Circuit.from_dict(g_parser.parse(qasm_str)) # type: ignore[arg-type] -def circuit_from_qasm_io(stream_in: TextIO, maxwidth: int = 32) -> Circuit: +def circuit_from_qasm_io( + stream_in: TextIO, maxwidth: int = 32, use_clexpr: bool = False +) -> Circuit: """A method to generate a tket Circuit from a qasm text stream""" - return circuit_from_qasm_str(stream_in.read(), maxwidth=maxwidth) + return circuit_from_qasm_str( + stream_in.read(), maxwidth=maxwidth, use_clexpr=use_clexpr + ) def circuit_from_qasm_wasm( @@ -997,6 +1041,7 @@ def circuit_from_qasm_wasm( wasm_file: Union[str, "os.PathLike[Any]"], encoding: str = "utf-8", maxwidth: int = 32, + use_clexpr: bool = False, ) -> Circuit: """A method to generate a tket Circuit from a qasm string and external WASM module. @@ -1008,10 +1053,12 @@ def circuit_from_qasm_wasm( """ global g_parser wasm_module = WasmFileHandler(str(wasm_file)) - set_parser(maxwidth=maxwidth) + set_parser(maxwidth=maxwidth, use_clexpr=use_clexpr) assert g_parser is not None cast(CircuitTransformer, g_parser.options.transformer).wasm = wasm_module - return circuit_from_qasm(input_file, encoding=encoding, maxwidth=maxwidth) + return circuit_from_qasm( + input_file, encoding=encoding, maxwidth=maxwidth, use_clexpr=use_clexpr + ) def circuit_to_qasm( @@ -1718,6 +1765,51 @@ def add_classical_exp_box(self, op: ClassicalExpBox, args: List[Bit]) -> None: " for writing to a single bit or whole registers." ) + def add_wired_clexpr(self, op: ClExprOp, args: List[Bit]) -> None: + wexpr: WiredClExpr = op.expr + # 1. Determine the mappings from bit variables to bits and from register + # variables to registers. + expr: ClExpr = wexpr.expr + bit_posn: dict[int, int] = wexpr.bit_posn + reg_posn: dict[int, list[int]] = wexpr.reg_posn + output_posn: list[int] = wexpr.output_posn + input_bits: dict[int, Bit] = {i: args[j] for i, j in bit_posn.items()} + input_regs: dict[int, BitRegister] = {} + all_cregs = set(self.cregs.values()) + for i, posns in reg_posn.items(): + reg_args = [args[j] for j in posns] + for creg in all_cregs: + if creg.to_list() == reg_args: + input_regs[i] = creg + break + else: + raise QASMUnsupportedError( + f"ClExprOp ({wexpr}) contains a register variable (r{i}) " + "that is not wired to any BitRegister in the circuit." + ) + # 2. Write the left-hand side of the assignment. + output_repr: Optional[str] = None + output_args: list[Bit] = [args[j] for j in output_posn] + n_output_args = len(output_args) + expect_reg_output = has_reg_output(expr.op) + if n_output_args == 0: + raise QASMUnsupportedError("Expression has no output.") + elif n_output_args == 1: + output_arg = output_args[0] + output_repr = output_arg.reg_name if expect_reg_output else str(output_arg) + else: + if not expect_reg_output: + raise QASMUnsupportedError("Unexpected output for operation.") + for creg in all_cregs: + if creg.to_list() == output_args: + output_repr = creg.name + self.strings.add_string(f"{output_repr} = ") + # 3. Write the right-hand side of the assignment. + self.strings.add_string( + expr.as_qasm(input_bits=input_bits, input_regs=input_regs) + ) + self.strings.add_string(";\n") + def add_wasm(self, op: WASMOp, args: List[Bit]) -> None: inputs: List[str] = [] outputs: List[str] = [] @@ -1821,6 +1913,9 @@ def add_op(self, op: Op, args: Sequence[UnitID]) -> None: elif optype == OpType.ClassicalExpBox: assert isinstance(op, ClassicalExpBox) self.add_classical_exp_box(op, cast(List[Bit], args)) + elif optype == OpType.ClExpr: + assert isinstance(op, ClExprOp) + self.add_wired_clexpr(op, cast(List[Bit], args)) elif optype == OpType.WASM: assert isinstance(op, WASMOp) self.add_wasm(op, cast(List[Bit], args)) diff --git a/pytket/tests/clexpr_test.py b/pytket/tests/clexpr_test.py new file mode 100644 index 0000000000..e078224aa0 --- /dev/null +++ b/pytket/tests/clexpr_test.py @@ -0,0 +1,219 @@ +# Copyright 2019-2024 Cambridge Quantum Computing +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytket.circuit import ( + Bit, + CircBox, + Circuit, + ClBitVar, + ClExpr, + ClExprOp, + ClOp, + ClRegVar, + OpType, + WiredClExpr, +) +from pytket.qasm import circuit_to_qasm_str, circuit_from_qasm_str + + +def test_op() -> None: + reg_add = ClOp.RegAdd + assert str(reg_add) == "ClOp.RegAdd" + reg_sub = ClOp.RegSub + assert reg_add != reg_sub + + +def test_vars() -> None: + bvar3 = ClBitVar(3) + assert bvar3.index == 3 + assert str(bvar3) == "b3" + bvar4 = ClBitVar(4) + rvar3 = ClRegVar(3) + assert rvar3.index == 3 + assert str(rvar3) == "r3" + rvar3a = ClRegVar(3) + assert bvar3 != bvar4 + assert bvar3 != rvar3 + assert rvar3 == rvar3a + + +def test_expr() -> None: + b0 = ClBitVar(0) + r0 = ClRegVar(0) + three = 3 + expr0 = ClExpr(op=ClOp.RegAdd, args=[r0, three]) + expr = ClExpr(op=ClOp.BitXor, args=[expr0, b0]) + assert str(expr) == "xor(add(r0, 3), b0)" + assert expr.op == ClOp.BitXor + args = expr.args + assert len(args) == 2 + assert args[0] == expr0 + assert args[1] == b0 + + +def test_wexpr() -> None: + expr = ClExpr( + op=ClOp.RegDiv, + args=[ClRegVar(0), ClExpr(op=ClOp.RegAdd, args=[2, ClBitVar(0)])], + ) + wexpr = WiredClExpr( + expr=expr, bit_posn={0: 1}, reg_posn={0: [2, 0]}, output_posn=[2, 0] + ) + assert str(wexpr) == "div(r0, add(2, b0)) [b0:1, r0:(2,0) --> (2,0)]" + assert wexpr.expr == expr + assert wexpr.bit_posn == {0: 1} + assert wexpr.reg_posn == {0: [2, 0]} + assert wexpr.output_posn == [2, 0] + wexpr_dict = wexpr.to_dict() + assert wexpr_dict == { + "bit_posn": [[0, 1]], + "expr": { + "args": [ + { + "input": { + "term": {"type": "reg", "var": {"index": 0}}, + "type": "var", + }, + "type": "term", + }, + { + "input": { + "args": [ + {"input": {"term": 2, "type": "int"}, "type": "term"}, + { + "input": { + "term": {"type": "bit", "var": {"index": 0}}, + "type": "var", + }, + "type": "term", + }, + ], + "op": "RegAdd", + }, + "type": "expr", + }, + ], + "op": "RegDiv", + }, + "output_posn": [2, 0], + "reg_posn": [[0, [2, 0]]], + } + wexpr1 = WiredClExpr.from_dict(wexpr_dict) + assert wexpr == wexpr1 + + +def test_adding_to_circuit() -> None: + expr = ClExpr(op=ClOp.BitXor, args=[ClBitVar(0), ClBitVar(1)]) + wexpr = WiredClExpr(expr=expr, bit_posn={0: 0, 1: 1}, output_posn=[2]) + c = Circuit(0, 3) + c.add_clexpr(wexpr, c.bits) + cmds = c.get_commands() + assert len(cmds) == 1 + op = cmds[0].op + assert isinstance(op, ClExprOp) + assert op.expr == wexpr + d = c.to_dict() + c1 = Circuit.from_dict(d) + assert c == c1 + d1 = c1.to_dict() + assert d == d1 + c2 = c.copy() + assert c2 == c + + +def test_qasm_conversion() -> None: + c = Circuit() + c.add_c_register("a", 3) + c.add_c_register("b", 3) + c.add_c_register("c", 3) + c.add_c_register("d", 3) + expr = ClExpr( + op=ClOp.RegSub, + args=[ + ClExpr( + op=ClOp.RegDiv, + args=[ClExpr(ClOp.RegAdd, args=[ClRegVar(0), ClRegVar(1)]), 2], + ), + ClRegVar(2), + ], + ) + wexpr = WiredClExpr( + expr=expr, + reg_posn={0: [0, 1, 2], 1: [3, 4, 5], 2: [6, 7, 8]}, + output_posn=[9, 10, 11], + ) + c.add_clexpr(wexpr, c.bits) + qasm = circuit_to_qasm_str(c, header="hqslib1") + assert ( + qasm + == """OPENQASM 2.0; +include "hqslib1.inc"; + +creg a[3]; +creg b[3]; +creg c[3]; +creg d[3]; +d = (((a + b) / 2) - c); +""" + ) + c1 = circuit_from_qasm_str(qasm, use_clexpr=True) + assert c == c1 + + +def make_circ() -> Circuit: + c = Circuit() + c.add_bit(Bit("x", 0)) + c.add_bit(Bit("x", 1)) + c.add_bit(Bit("y", 0)) + c.add_clexpr( + WiredClExpr( + expr=ClExpr(op=ClOp.BitXor, args=[ClBitVar(0), ClBitVar(1)]), + bit_posn={0: 0, 1: 1}, + output_posn=[2], + ), + [Bit("x", 0), Bit("x", 1), Bit("y", 0)], + ) + return c + + +def test_copy_and_flatten() -> None: + # See https://github.com/CQCL/tket/issues/1544 + c0 = make_circ() + c1 = make_circ() + assert c0 == c1 + c2 = c1.copy() + c2.flatten_registers() + assert c0 == c1 + assert c2.get_commands()[0].op == c0.get_commands()[0].op + qasm = circuit_to_qasm_str(c2, header="hqslib1") + assert ( + qasm + == """OPENQASM 2.0; +include "hqslib1.inc"; + +creg c[3]; +c[2] = (c[0] ^ c[1]); +""" + ) + + +def test_circbox() -> None: + # See https://github.com/CQCL/tket/issues/1544 + c0 = make_circ() + cbox = CircBox(c0) + c1 = Circuit(0, 3) + c1.add_circbox(cbox, [0, 1, 2]) + c2 = c1.copy() + c2.flatten_registers() + assert c1 == c2 diff --git a/pytket/tests/qasm_test.py b/pytket/tests/qasm_test.py index 85b0a9a5db..8ed36b7fb4 100644 --- a/pytket/tests/qasm_test.py +++ b/pytket/tests/qasm_test.py @@ -422,7 +422,7 @@ def test_h1_rzz() -> None: def test_extended_qasm() -> None: fname = str(curr_file_path / "qasm_test_files/test17.qasm") out_fname = str(curr_file_path / "qasm_test_files/test17_output.qasm") - c = circuit_from_qasm_wasm(fname, "testfile.wasm") + c = circuit_from_qasm_wasm(fname, "testfile.wasm", use_clexpr=True) out_qasm = circuit_to_qasm_str(c, "hqslib1") with open(out_fname) as f: @@ -432,15 +432,14 @@ def test_extended_qasm() -> None: assert circuit_to_qasm_str(c2, "hqslib1") - with pytest.raises(DecomposeClassicalError) as e: - DecomposeClassicalExp().apply(c) + assert not DecomposeClassicalExp().apply(c) def test_decomposable_extended() -> None: fname = str(curr_file_path / "qasm_test_files/test18.qasm") out_fname = str(curr_file_path / "qasm_test_files/test18_output.qasm") - c = circuit_from_qasm_wasm(fname, "testfile.wasm", maxwidth=64) + c = circuit_from_qasm_wasm(fname, "testfile.wasm", maxwidth=64, use_clexpr=True) DecomposeClassicalExp().apply(c) out_qasm = circuit_to_qasm_str(c, "hqslib1", maxwidth=64) @@ -654,10 +653,11 @@ def test_qasm_phase() -> None: assert c1 == c0 -def test_CopyBits() -> None: +@pytest.mark.parametrize("use_clexpr", [True, False]) +def test_CopyBits(use_clexpr: bool) -> None: input_qasm = """OPENQASM 2.0;\ninclude "hqslib1.inc";\n\ncreg c0[1]; creg c1[3];\nc0[0] = c1[1];\n""" - c = circuit_from_qasm_str(input_qasm) + c = circuit_from_qasm_str(input_qasm, use_clexpr=use_clexpr) result_circ_qasm = circuit_to_qasm_str(c, "hqslib1") assert input_qasm == result_circ_qasm @@ -831,7 +831,8 @@ def test_max_reg_width() -> None: assert len(circ_out.bits) == 33 -def test_classical_expbox_arg_order() -> None: +@pytest.mark.parametrize("use_clexpr", [True, False]) +def test_classical_expbox_arg_order(use_clexpr: bool) -> None: qasm = """ OPENQASM 2.0; include "hqslib1.inc"; @@ -846,7 +847,7 @@ def test_classical_expbox_arg_order() -> None: c = a ^ b | d; """ - circ = circuit_from_qasm_str(qasm) + circ = circuit_from_qasm_str(qasm, use_clexpr=use_clexpr) args = circ.get_commands()[0].args expected_symbol_order = ["a", "b", "d", "c"] expected_index_order = [0, 1, 2, 3] @@ -1158,5 +1159,6 @@ def test_multibitop() -> None: test_header_stops_gate_definition() test_tk2_definition() test_rxxyyzz_conversion() - test_classical_expbox_arg_order() + test_classical_expbox_arg_order(True) + test_classical_expbox_arg_order(False) test_register_name_check() diff --git a/pytket/tests/qasm_test_files/test17_output.qasm b/pytket/tests/qasm_test_files/test17_output.qasm index 7312e3de10..4e467334f0 100644 --- a/pytket/tests/qasm_test_files/test17_output.qasm +++ b/pytket/tests/qasm_test_files/test17_output.qasm @@ -7,19 +7,19 @@ creg b[3]; creg c[4]; creg d[1]; c = 2; -d[0] = (a << 1); c[0] = a[0]; c[1] = a[1]; if(b!=2) c[1] = ((b[1] & a[1]) | a[0]); c = (b & a); b = (a + b); -b[1] = (b[0] ^ (~ b[2])); +b[1] = (b[0] ^ (~b[2])); c = (a - (b ** c)); -d[0] = (c >> 2); +d = (a << 1); +d = (c >> 2); c[0] = 1; -d[0] = (a[0] ^ 1); -CCE1(c); b = ((a * c) / b); +CCE1(c); +d[0] = (a[0] ^ 1); a = CCE2(a, b); if(c>=2) h q[0]; if(d[0]==1) rx(1.0*pi) q[0]; diff --git a/pytket/tests/qasm_test_files/test18_output.qasm b/pytket/tests/qasm_test_files/test18_output.qasm index f2ab0b7d65..c635f280ab 100644 --- a/pytket/tests/qasm_test_files/test18_output.qasm +++ b/pytket/tests/qasm_test_files/test18_output.qasm @@ -6,22 +6,16 @@ creg a[2]; creg b[3]; creg c[4]; creg d[1]; -creg tk_SCRATCH_BIT[7]; -creg tk_SCRATCH_BITREG_0[64]; c = 2; -tk_SCRATCH_BITREG_0[0] = b[0] & a[0]; -tk_SCRATCH_BITREG_0[1] = b[1] & a[1]; c[0] = a[0]; c[1] = a[1]; -if(b!=2) tk_SCRATCH_BIT[6] = b[1] & a[1]; -c[0] = tk_SCRATCH_BITREG_0[0] | d[0]; -if(b!=2) c[1] = tk_SCRATCH_BIT[6] | a[0]; -tk_SCRATCH_BIT[6] = 1; -d[0] = a[0] ^ tk_SCRATCH_BIT[6]; -if(c>=2) h q[0]; +if(b!=2) c[1] = ((b[1] & a[1]) | a[0]); +c = ((b & a) | d); +d[0] = (a[0] ^ 1); a = CCE(a, b); -if(c<=2) h q[0]; +if(c>=2) h q[0]; CCE(c); +if(c<=2) h q[0]; if(c<=1) h q[0]; if(c>=3) h q[0]; if(c!=2) h q[0]; diff --git a/tket/CMakeLists.txt b/tket/CMakeLists.txt index 2b55de0d7c..13b4c7df97 100644 --- a/tket/CMakeLists.txt +++ b/tket/CMakeLists.txt @@ -218,6 +218,7 @@ target_sources(tket src/MeasurementSetup/MeasurementSetup.cpp src/Ops/BarrierOp.cpp src/Ops/ClassicalOps.cpp + src/Ops/ClExpr.cpp src/Ops/FlowOp.cpp src/Ops/MetaOp.cpp src/Ops/Op.cpp @@ -369,6 +370,7 @@ target_sources(tket include/tket/MeasurementSetup/MeasurementSetup.hpp include/tket/Ops/BarrierOp.hpp include/tket/Ops/ClassicalOps.hpp + include/tket/Ops/ClExpr.hpp include/tket/Ops/FlowOp.hpp include/tket/Ops/MetaOp.hpp include/tket/Ops/Op.hpp diff --git a/tket/conanfile.py b/tket/conanfile.py index 04deefbc46..2e7e0e9b43 100644 --- a/tket/conanfile.py +++ b/tket/conanfile.py @@ -23,7 +23,7 @@ class TketConan(ConanFile): name = "tket" - version = "1.3.33" + version = "1.3.34" package_type = "library" license = "Apache 2" homepage = "https://github.com/CQCL/tket" diff --git a/tket/include/tket/OpType/OpType.hpp b/tket/include/tket/OpType/OpType.hpp index e1b6c8fa5a..12a64a0d55 100644 --- a/tket/include/tket/OpType/OpType.hpp +++ b/tket/include/tket/OpType/OpType.hpp @@ -770,7 +770,12 @@ enum class OpType { /** * See \ref DummyBox */ - DummyBox + DummyBox, + + /** + * Function defined over bits and sequences of bits treated as integers + */ + ClExpr, }; JSON_DECL(OpType) diff --git a/tket/include/tket/OpType/OpTypeFunctions.hpp b/tket/include/tket/OpType/OpTypeFunctions.hpp index f23a544cdc..b07a82e4e2 100644 --- a/tket/include/tket/OpType/OpTypeFunctions.hpp +++ b/tket/include/tket/OpType/OpTypeFunctions.hpp @@ -111,7 +111,7 @@ bool is_clifford_type(OpType optype); /** Test for measurement and reset gates */ bool is_projective_type(OpType optype); -/** Test for purely classical gates derived from ClassicalOp */ +/** Test for purely classical gates */ bool is_classical_type(OpType optype); /** Test for controlled gates */ diff --git a/tket/include/tket/Ops/ClExpr.hpp b/tket/include/tket/Ops/ClExpr.hpp new file mode 100644 index 0000000000..1adf73d696 --- /dev/null +++ b/tket/include/tket/Ops/ClExpr.hpp @@ -0,0 +1,306 @@ +// Copyright 2019-2024 Cambridge Quantum Computing +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +/** + * @file + * @brief Classical expressions involving bits and registers + */ + +#include +#include +#include +#include +#include +#include + +#include "tket/Ops/Op.hpp" + +namespace tket { + +// TODO Use X or list macros to reduce boilerplate. + +/** + * An function acting on bits or bit registers + */ +enum class ClOp { + INVALID, /// Invalid + BitAnd, /// Bitwise AND + BitOr, /// Bitwise OR + BitXor, /// Bitwise XOR + BitEq, /// Bitwise equality + BitNeq, /// Bitwise inequality + BitNot, /// Bitwise NOT + BitZero, /// Constant zero bit + BitOne, /// Constant one bit + RegAnd, /// Registerwise AND + RegOr, /// Registerwise OR + RegXor, /// Registerwise XOR + RegEq, /// Registerwise equality + RegNeq, /// Registerwise inequality + RegNot, /// Registerwise NOT + RegZero, /// Constant all-zeros register + RegOne, /// Constant all-ones register + RegLt, /// Integer less-than comparison + RegGt, /// Integer greater-than comparison + RegLeq, /// Integer less-than-or-equal comparison + RegGeq, /// Integer greater-than-or-equal comparison + RegAdd, /// Integer addition + RegSub, /// Integer subtraction + RegMul, /// Integer multiplication + RegDiv, /// Integer division + RegPow, /// Integer exponentiation + RegLsh, /// Left shift + RegRsh, /// Right shift + RegNeg /// Integer negation +}; + +std::ostream& operator<<(std::ostream& os, ClOp fn); + +NLOHMANN_JSON_SERIALIZE_ENUM( + ClOp, { + {ClOp::INVALID, "INVALID"}, {ClOp::BitAnd, "BitAnd"}, + {ClOp::BitOr, "BitOr"}, {ClOp::BitXor, "BitXor"}, + {ClOp::BitEq, "BitEq"}, {ClOp::BitNeq, "BitNeq"}, + {ClOp::BitNot, "BitNot"}, {ClOp::BitZero, "BitZero"}, + {ClOp::BitOne, "BitOne"}, {ClOp::RegAnd, "RegAnd"}, + {ClOp::RegOr, "RegOr"}, {ClOp::RegXor, "RegXor"}, + {ClOp::RegEq, "RegEq"}, {ClOp::RegNeq, "RegNeq"}, + {ClOp::RegNot, "RegNot"}, {ClOp::RegZero, "RegZero"}, + {ClOp::RegOne, "RegOne"}, {ClOp::RegLt, "RegLt"}, + {ClOp::RegGt, "RegGt"}, {ClOp::RegLeq, "RegLeq"}, + {ClOp::RegGeq, "RegGeq"}, {ClOp::RegAdd, "RegAdd"}, + {ClOp::RegSub, "RegSub"}, {ClOp::RegMul, "RegMul"}, + {ClOp::RegDiv, "RegDiv"}, {ClOp::RegPow, "RegPow"}, + {ClOp::RegLsh, "RegLsh"}, {ClOp::RegRsh, "RegRsh"}, + {ClOp::RegNeg, "RegNeg"}, + }) + +/** + * A bit variable within an expression + */ +typedef struct ClBitVar { + unsigned index; /// Identifier for the variable within the expression + bool operator==(const ClBitVar& other) const { return index == other.index; } + friend std::ostream& operator<<(std::ostream& os, const ClBitVar& var); +} ClBitVar; + +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(ClBitVar, index) + +/** + * A register variable within an expression + */ +typedef struct ClRegVar { + unsigned index; /// Identifier for the variable within the expression + bool operator==(const ClRegVar& other) const { return index == other.index; } + friend std::ostream& operator<<(std::ostream& os, const ClRegVar& var); +} ClRegVar; + +NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(ClRegVar, index) + +/** + * A (bit or register) variable within an expression + */ +typedef std::variant ClExprVar; + +std::ostream& operator<<(std::ostream& os, const ClExprVar& var); + +void to_json(nlohmann::json& j, const ClExprVar& var); + +void from_json(const nlohmann::json& j, ClExprVar& var); + +/** + * A term in a classical expression (either a constant or a variable) + */ +typedef std::variant ClExprTerm; + +std::ostream& operator<<(std::ostream& os, const ClExprTerm& term); + +void to_json(nlohmann::json& j, const ClExprTerm& term); + +void from_json(const nlohmann::json& j, ClExprTerm& term); + +class ClExpr; + +/** + * An argument to a classical operation in an expression + */ +typedef std::variant ClExprArg; + +std::ostream& operator<<(std::ostream& os, const ClExprArg& arg); + +void to_json(nlohmann::json& j, const ClExprArg& arg); + +void from_json(const nlohmann::json& j, ClExprArg& arg); + +/** + * A classical expression + * + * It may be composed of subexpressions. + */ +class ClExpr { + public: + /** + * Default constructor + */ + ClExpr(); + + /** + * Construct a classical expression from an operation and its arguments + * + * @param op Operation + * @param args Arguments + */ + ClExpr(ClOp op, std::vector args); + + bool operator==(const ClExpr& other) const; + + friend std::ostream& operator<<(std::ostream& os, const ClExpr& expr); + + /** + * Main operation + */ + ClOp get_op() const; + + /** + * Arguments + */ + std::vector get_args() const; + + /** + * All bit variables occurring within the expression + */ + std::set all_bit_variables() const; + + /** + * All register variables occurring within the expression + */ + std::set all_reg_variables() const; + + private: + ClOp op; + std::vector args; + std::set all_bit_vars; + std::set all_reg_vars; +}; + +void to_json(nlohmann::json& j, const ClExpr& expr); + +void from_json(const nlohmann::json& j, ClExpr& expr); + +/** + * A classical expression defined over a sequence of bits + * + * This defines an operation on a finite number of bits. Bit variables within + * the expression are mapped to specific bit indices and register variables are + * mapped to specific (disjoint) sequences of bit indices. The output of the + * expression is also mapped to a specific bit index or sequence of bit indices. + * If the output is a register, it must either be disjoint from all of the input + * registers or exactly match one of them. + */ +class WiredClExpr { + public: + /** + * Default constructor + */ + WiredClExpr(); + + /** + * Construct by specifying the bit, register and output positions + * + * @param expr Expression + * @param bit_posn Map from identifiers of bit variables to bit positions + * @param reg_posn Map from identifiers of register variables to sequences of + * bit positions. + * @param output_posn Sequence of bit positions for the output + * @throws ClExprWiringError if wiring is not valid + */ + WiredClExpr( + const ClExpr& expr, const std::map& bit_posn, + const std::map>& reg_posn, + const std::vector output_posn); + + bool operator==(const WiredClExpr& other) const; + + friend std::ostream& operator<<(std::ostream& os, const WiredClExpr& expr); + + /** + * Expression + */ + ClExpr get_expr() const; + + /** + * Bit positions + */ + std::map get_bit_posn() const; + + /** + * Register positions + */ + std::map> get_reg_posn() const; + + /** + * Output positions + */ + std::vector get_output_posn() const; + + /** + * Total number of bits including bit and register inputs and output + */ + unsigned get_total_n_bits() const; + + private: + ClExpr expr; + std::map bit_posn; + std::map> reg_posn; + std::set all_bit_posns; + std::set> all_reg_posns; + std::vector output_posn; + unsigned total_n_bits; +}; + +void to_json(nlohmann::json& j, const WiredClExpr& expr); + +void from_json(const nlohmann::json& j, WiredClExpr& expr); + +class ClExprWiringError : public std::logic_error { + public: + explicit ClExprWiringError(const std::string& message) + : std::logic_error(message) {} +}; + +class ClExprOp : public Op { + public: + ClExprOp(const WiredClExpr& expr); + + Op_ptr symbol_substitution( + const SymEngine::map_basic_basic& sub_map) const override; + SymSet free_symbols() const override; + op_signature_t get_signature() const override; + + /** + * Wired classical expression + */ + WiredClExpr get_wired_expr() const; + + nlohmann::json serialize() const override; + + static Op_ptr deserialize(const nlohmann::json& j); + + private: + WiredClExpr expr; +}; + +} // namespace tket diff --git a/tket/src/Circuit/OpJson.cpp b/tket/src/Circuit/OpJson.cpp index 4774aecc9a..9a56098305 100644 --- a/tket/src/Circuit/OpJson.cpp +++ b/tket/src/Circuit/OpJson.cpp @@ -18,6 +18,7 @@ #include "tket/OpType/OpType.hpp" #include "tket/OpType/OpTypeFunctions.hpp" #include "tket/Ops/BarrierOp.hpp" +#include "tket/Ops/ClExpr.hpp" #include "tket/Ops/ClassicalOps.hpp" #include "tket/Ops/OpPtr.hpp" #include "tket/Utils/Json.hpp" @@ -34,6 +35,8 @@ void from_json(const nlohmann::json& j, Op_ptr& op) { op = Conditional::deserialize(j); } else if (optype == OpType::WASM) { op = WASMOp::deserialize(j); + } else if (optype == OpType::ClExpr) { + op = ClExprOp::deserialize(j); } else if (is_classical_type(optype)) { op = ClassicalOp::deserialize(j); } else if (is_gate_type(optype)) { diff --git a/tket/src/OpType/OpTypeFunctions.cpp b/tket/src/OpType/OpTypeFunctions.cpp index 5a564c3171..0c2a0c98a6 100644 --- a/tket/src/OpType/OpTypeFunctions.cpp +++ b/tket/src/OpType/OpTypeFunctions.cpp @@ -108,7 +108,7 @@ const OpTypeSet& all_classical_types() { OpType::CopyBits, OpType::RangePredicate, OpType::ExplicitPredicate, OpType::ExplicitModifier, OpType::MultiBit, OpType::WASM, - OpType::ClassicalExpBox, + OpType::ClassicalExpBox, OpType::ClExpr, }; static std::unique_ptr gates = std::make_unique(optypes); diff --git a/tket/src/OpType/OpTypeInfo.cpp b/tket/src/OpType/OpTypeInfo.cpp index 19e292bb70..b708903fa0 100644 --- a/tket/src/OpType/OpTypeInfo.cpp +++ b/tket/src/OpType/OpTypeInfo.cpp @@ -15,6 +15,7 @@ #include "tket/OpType/OpTypeInfo.hpp" #include +#include #include "tket/OpType/OpType.hpp" @@ -180,7 +181,8 @@ const std::map& optypeinfo() { {"ClassicalExpBox", "ClassicalExpBox", {}, std::nullopt}}, {OpType::MultiBit, {"MultiBit", "MultiBit", {}, std::nullopt}}, {OpType::UnitaryTableauBox, - {"UnitaryTableauBox", "UnitaryTableauBox", {}, std::nullopt}}}; + {"UnitaryTableauBox", "UnitaryTableauBox", {}, std::nullopt}}, + {OpType::ClExpr, {"ClExpr", "ClExpr", {}, std::nullopt}}}; static std::unique_ptr> opinfo = std::make_unique>(typeinfo); return *opinfo; diff --git a/tket/src/Ops/ClExpr.cpp b/tket/src/Ops/ClExpr.cpp new file mode 100644 index 0000000000..0e0f1fdd5a --- /dev/null +++ b/tket/src/Ops/ClExpr.cpp @@ -0,0 +1,445 @@ +// Copyright 2019-2024 Cambridge Quantum Computing +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tket/Ops/ClExpr.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include "tket/OpType/OpType.hpp" + +namespace tket { + +std::ostream& operator<<(std::ostream& os, ClOp fn) { + switch (fn) { + case ClOp::INVALID: + return os << "INVALID"; + case ClOp::BitAnd: + return os << "and"; + case ClOp::BitOr: + return os << "or"; + case ClOp::BitXor: + return os << "xor"; + case ClOp::BitEq: + return os << "eq"; + case ClOp::BitNeq: + return os << "neq"; + case ClOp::BitNot: + return os << "not"; + case ClOp::BitZero: + return os << "zero"; + case ClOp::BitOne: + return os << "one"; + case ClOp::RegAnd: + return os << "and"; + case ClOp::RegOr: + return os << "or"; + case ClOp::RegXor: + return os << "xor"; + case ClOp::RegEq: + return os << "eq"; + case ClOp::RegNeq: + return os << "neq"; + case ClOp::RegNot: + return os << "not"; + case ClOp::RegZero: + return os << "zero"; + case ClOp::RegOne: + return os << "one"; + case ClOp::RegLt: + return os << "lt"; + case ClOp::RegGt: + return os << "gt"; + case ClOp::RegLeq: + return os << "leq"; + case ClOp::RegGeq: + return os << "geq"; + case ClOp::RegAdd: + return os << "add"; + case ClOp::RegSub: + return os << "sub"; + case ClOp::RegMul: + return os << "mul"; + case ClOp::RegDiv: + return os << "div"; + case ClOp::RegPow: + return os << "pow"; + case ClOp::RegLsh: + return os << "lsh"; + case ClOp::RegRsh: + return os << "rsh"; + case ClOp::RegNeg: + return os << "neg"; + } + throw std::logic_error("Invalid data"); +} + +std::ostream& operator<<(std::ostream& os, const ClBitVar& var) { + return os << "b" << var.index; +} + +std::ostream& operator<<(std::ostream& os, const ClRegVar& var) { + return os << "r" << var.index; +} + +std::ostream& operator<<(std::ostream& os, const ClExprVar& var) { + if (const ClBitVar* bvar = std::get_if(&var)) { + return os << *bvar; + } else { + ClRegVar rvar = std::get(var); + return os << rvar; + } +} + +void to_json(nlohmann::json& j, const ClExprVar& var) { + nlohmann::json inner_j; + if (const ClBitVar* bvar = std::get_if(&var)) { + j["type"] = "bit"; + to_json(inner_j, *bvar); + } else { + j["type"] = "reg"; + ClRegVar rvar = std::get(var); + to_json(inner_j, rvar); + } + j["var"] = inner_j; +} + +void from_json(const nlohmann::json& j, ClExprVar& var) { + const std::string vartype = j.at("type").get(); + if (vartype == "bit") { + var = j.at("var").get(); + } else { + TKET_ASSERT(vartype == "reg"); + var = j.at("var").get(); + } +} + +std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) { + if (const int* n = std::get_if(&term)) { + return os << *n; + } else { + ClExprVar var = std::get(term); + return os << var; + } +} + +void to_json(nlohmann::json& j, const ClExprTerm& term) { + nlohmann::json inner_j; + if (const int* n = std::get_if(&term)) { + j["type"] = "int"; + inner_j = *n; + } else { + j["type"] = "var"; + ClExprVar var = std::get(term); + to_json(inner_j, var); + } + j["term"] = inner_j; +} + +void from_json(const nlohmann::json& j, ClExprTerm& term) { + const std::string termtype = j.at("type").get(); + if (termtype == "int") { + term = j.at("term").get(); + } else { + TKET_ASSERT(termtype == "var"); + term = j.at("term").get(); + } +} + +std::ostream& operator<<(std::ostream& os, const ClExprArg& arg) { + if (const ClExprTerm* term = std::get_if(&arg)) { + return os << *term; + } else { + ClExpr expr = std::get(arg); + return os << expr; + } +} + +void to_json(nlohmann::json& j, const ClExprArg& arg) { + nlohmann::json inner_j; + if (const ClExprTerm* term = std::get_if(&arg)) { + j["type"] = "term"; + to_json(inner_j, *term); + } else { + j["type"] = "expr"; + ClExpr expr = std::get(arg); + to_json(inner_j, expr); + } + j["input"] = inner_j; +} + +void from_json(const nlohmann::json& j, ClExprArg& arg) { + const std::string inputtype = j.at("type").get(); + if (inputtype == "term") { + arg = j.at("input").get(); + } else { + TKET_ASSERT(inputtype == "expr"); + ClExpr expr; + from_json(j.at("input"), expr); + arg = expr; + } +} + +ClExpr::ClExpr() : ClExpr(ClOp::INVALID, {}) {} + +ClExpr::ClExpr(ClOp op, std::vector args) + : op(op), args(args), all_bit_vars(), all_reg_vars() { + for (const ClExprArg& input : args) { + if (std::holds_alternative(input)) { + ClExprTerm basic_input = std::get(input); + if (std::holds_alternative(basic_input)) { + ClExprVar var = std::get(basic_input); + if (std::holds_alternative(var)) { + ClBitVar bit_var = std::get(var); + all_bit_vars.insert(bit_var.index); + } else { + ClRegVar reg_var = std::get(var); + all_reg_vars.insert(reg_var.index); + } + } + } else { + ClExpr expr = std::get(input); + std::set expr_bit_vars = expr.all_bit_variables(); + std::set expr_reg_vars = expr.all_reg_variables(); + all_bit_vars.insert(expr_bit_vars.begin(), expr_bit_vars.end()); + all_reg_vars.insert(expr_reg_vars.begin(), expr_reg_vars.end()); + } + } +} + +bool ClExpr::operator==(const ClExpr& other) const { + return op == other.op && args == other.args; +} + +std::ostream& operator<<(std::ostream& os, const ClExpr& expr) { + os << expr.get_op() << "("; + const std::vector& args = expr.get_args(); + unsigned n_args = args.size(); + for (unsigned i = 0; i < n_args; i++) { + os << args[i]; + if (i + 1 < n_args) { + os << ", "; + } + } + os << ")"; + return os; +} + +ClOp ClExpr::get_op() const { return op; } + +std::vector ClExpr::get_args() const { return args; } + +std::set ClExpr::all_bit_variables() const { return all_bit_vars; } + +std::set ClExpr::all_reg_variables() const { return all_reg_vars; } + +void to_json(nlohmann::json& j, const ClExpr& expr) { + nlohmann::json j_op = expr.get_op(); + nlohmann::json j_args = expr.get_args(); + j["op"] = j_op; + j["args"] = j_args; +} + +void from_json(const nlohmann::json& j, ClExpr& expr) { + ClOp op = j.at("op").get(); + std::vector args = j.at("args").get>(); + expr = ClExpr(op, args); +} + +WiredClExpr::WiredClExpr() : WiredClExpr({}, {}, {}, {}) {} + +WiredClExpr::WiredClExpr( + const ClExpr& expr, const std::map& bit_posn, + const std::map>& reg_posn, + const std::vector output_posn) + : expr(expr), + bit_posn(bit_posn), + reg_posn(reg_posn), + output_posn(output_posn) { + std::set b; + std::set r; + std::set posns; + for (const auto& pair : bit_posn) { + b.insert(pair.first); + unsigned bit_pos = pair.second; + if (posns.contains(bit_pos)) { + throw ClExprWiringError("Invalid maps constructing WiredClExpr"); + } + posns.insert(bit_pos); + all_bit_posns.insert(bit_pos); + } + for (const auto& pair : reg_posn) { + r.insert(pair.first); + for (unsigned bit_pos : pair.second) { + if (posns.contains(bit_pos)) { + throw ClExprWiringError("Invalid maps constructing WiredClExpr"); + } + posns.insert(bit_pos); + } + all_reg_posns.insert(pair.second); + } + total_n_bits = posns.size(); + for (const unsigned& posn : output_posn) { + if (!posns.contains(posn)) { + total_n_bits++; + } + } + if (output_posn.size() == 1) { + // It mustn't be one of an input register of size > 1 + unsigned i = output_posn[0]; + for (const std::vector& reg : all_reg_posns) { + if (reg.size() > 1 && std::any_of( + reg.begin(), reg.end(), + [&i](const unsigned& j) { return i == j; })) { + throw ClExprWiringError( + "Output bit contained in a larger input register"); + } + } + } else { + // It must either be disjoint from everything or match one of the registers + if (std::any_of( + output_posn.begin(), output_posn.end(), + [&posns](const unsigned& j) { return posns.contains(j); })) { + if (!std::any_of( + all_reg_posns.begin(), all_reg_posns.end(), + [&output_posn](const std::vector& reg) { + return output_posn == reg; + })) { + throw ClExprWiringError("Output register inconsistent with inputs"); + } + } + } + if (b != expr.all_bit_variables()) { + throw ClExprWiringError( + "Mismatch of bit variables constructing WiredClExpr"); + } + if (r != expr.all_reg_variables()) { + throw ClExprWiringError( + "Mismatch of register variables constructing WiredClExpr"); + } +} + +bool WiredClExpr::operator==(const WiredClExpr& other) const { + return expr == other.expr && bit_posn == other.bit_posn && + reg_posn == other.reg_posn && output_posn == other.output_posn; +} + +std::ostream& operator<<(std::ostream& os, const WiredClExpr& expr) { + os << expr.expr << " ["; + unsigned n_vars = expr.bit_posn.size() + expr.reg_posn.size(); + unsigned i = 0; + for (const std::pair pair : expr.bit_posn) { + os << "b" << pair.first << ":" << pair.second; + i++; + if (i < n_vars) { + os << ", "; + } + } + for (const std::pair> pair : expr.reg_posn) { + os << "r" << pair.first << ":("; + unsigned reg_size = pair.second.size(); + for (unsigned j = 0; j < reg_size; j++) { + os << pair.second[j]; + if (j + 1 < reg_size) { + os << ","; + } + } + os << ")"; + i++; + if (i < n_vars) { + os << ", "; + } + } + os << " --> ("; + unsigned n_outs = expr.output_posn.size(); + for (unsigned i = 0; i < n_outs; i++) { + os << expr.output_posn[i]; + if (i + 1 < n_outs) { + os << ","; + } + } + os << ")]"; + return os; +} + +ClExpr WiredClExpr::get_expr() const { return expr; } + +std::map WiredClExpr::get_bit_posn() const { + return bit_posn; +} + +std::map> WiredClExpr::get_reg_posn() const { + return reg_posn; +} + +std::vector WiredClExpr::get_output_posn() const { + return output_posn; +} + +unsigned WiredClExpr::get_total_n_bits() const { return total_n_bits; } + +void to_json(nlohmann::json& j, const WiredClExpr& expr) { + nlohmann::json j_expr = expr.get_expr(); + nlohmann::json j_bit_posn = expr.get_bit_posn(); + nlohmann::json j_reg_posn = expr.get_reg_posn(); + nlohmann::json j_output_posn = expr.get_output_posn(); + j["expr"] = j_expr; + j["bit_posn"] = j_bit_posn; + j["reg_posn"] = j_reg_posn; + j["output_posn"] = j_output_posn; +} + +void from_json(const nlohmann::json& j, WiredClExpr& expr) { + ClExpr e = j.at("expr").get(); + std::map bit_posn = + j.at("bit_posn").get>(); + std::map> reg_posn = + j.at("reg_posn").get>>(); + std::vector output_posn = + j.at("output_posn").get>(); + expr = WiredClExpr(e, bit_posn, reg_posn, output_posn); +} + +ClExprOp::ClExprOp(const WiredClExpr& expr) : Op(OpType::ClExpr), expr(expr) {} + +Op_ptr ClExprOp::symbol_substitution(const SymEngine::map_basic_basic&) const { + return std::make_shared(*this); +} + +SymSet ClExprOp::free_symbols() const { return SymSet(); } + +op_signature_t ClExprOp::get_signature() const { + return op_signature_t(expr.get_total_n_bits(), EdgeType::Classical); +} + +WiredClExpr ClExprOp::get_wired_expr() const { return expr; } + +nlohmann::json ClExprOp::serialize() const { + nlohmann::json j; + j["type"] = get_type(); + j["expr"] = get_wired_expr(); + return j; +} + +Op_ptr ClExprOp::deserialize(const nlohmann::json& j) { + ClExprOp exprop{j.at("expr").get()}; + return std::make_shared(exprop); +} + +} // namespace tket diff --git a/tket/test/CMakeLists.txt b/tket/test/CMakeLists.txt index f7d895fb00..4e020c4d16 100644 --- a/tket/test/CMakeLists.txt +++ b/tket/test/CMakeLists.txt @@ -107,6 +107,7 @@ add_executable(test-tket src/test_Assertion.cpp src/test_BoxDecompRoutingMethod.cpp src/test_ChoiMixTableau.cpp + src/test_ClExpr.cpp src/test_Clifford.cpp src/test_Combinators.cpp src/test_CompilerPass.cpp diff --git a/tket/test/src/test_ClExpr.cpp b/tket/test/src/test_ClExpr.cpp new file mode 100644 index 0000000000..a49888ce5e --- /dev/null +++ b/tket/test/src/test_ClExpr.cpp @@ -0,0 +1,217 @@ +// Copyright 2019-2024 Cambridge Quantum Computing +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "tket/Circuit/Circuit.hpp" +#include "tket/Circuit/Command.hpp" +#include "tket/OpType/EdgeType.hpp" +#include "tket/Ops/ClExpr.hpp" +#include "tket/Ops/OpPtr.hpp" +#include "tket/Utils/UnitID.hpp" + +namespace tket { + +SCENARIO("Circuit containing a ClExprOp") { + GIVEN("A simple classical expression") { + // AND of two bits: + ClExpr expr(ClOp::BitAnd, {ClBitVar{0}, ClBitVar{1}}); + // First two bits are inputs; last bit is output: + WiredClExpr wexpr(expr, {{0, 0}, {1, 1}}, {}, {2}); + Op_ptr op = std::make_shared(wexpr); + REQUIRE(op->get_signature() == op_signature_t(3, EdgeType::Classical)); + Circuit circ(0, 3); + circ.add_op(op, {0, 1, 2}); + std::vector cmds = circ.get_commands(); + REQUIRE(cmds.size() == 1); + } + GIVEN("A complicated classical expression") { + // d[0,1,2] <-- (a[2,1,0] + b[2,3,4]) / (c[1,0,3] * d[0,1,2]) + ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}}); + ClExpr denom(ClOp::RegMul, {ClRegVar{2}, ClRegVar{3}}); + ClExpr expr(ClOp::RegDiv, {numer, denom}); + std::vector a_pos{0, 3, 4}; + std::vector b_pos{1, 11, 5}; + std::vector c_pos{10, 2, 7}; + std::vector d_pos{8, 9, 6}; + WiredClExpr wexpr( + expr, {}, {{0, a_pos}, {1, b_pos}, {2, c_pos}, {3, d_pos}}, d_pos); + std::vector e_pos{0, 1, 2}; + REQUIRE_THROWS_AS( + WiredClExpr( + expr, {}, {{0, a_pos}, {1, b_pos}, {2, e_pos}, {3, d_pos}}, d_pos), + ClExprWiringError); + Op_ptr op = std::make_shared(wexpr); + Circuit circ; + register_t preg = circ.add_c_register("p", 6); + register_t qreg = circ.add_c_register("q", 6); + circ.add_op( + op, {Bit{"p", 2}, Bit{"q", 2}, Bit{"p", 1}, Bit{"q", 3}, Bit{"p", 0}, + Bit{"q", 4}, Bit{"p", 5}, Bit{"q", 5}, Bit{"p", 4}, Bit{"q", 0}, + Bit{"p", 3}, Bit{"q", 1}}); + std::vector cmds = circ.get_commands(); + REQUIRE(cmds.size() == 1); + } +} + +SCENARIO("Serialization and stringification") { + GIVEN("ClOp") { + ClOp op = ClOp::RegEq; + std::stringstream ss; + ss << op; + REQUIRE(ss.str() == "eq"); + nlohmann::json j = op; + ClOp op1 = j.get(); + REQUIRE(op1 == op); + } + GIVEN("All ClOps") { + std::stringstream ss; + ss << ClOp::INVALID << " " << ClOp::BitAnd << " " << ClOp::BitOr << " " + << ClOp::BitXor << " " << ClOp::BitEq << " " << ClOp::BitNeq << " " + << ClOp::BitNot << " " << ClOp::BitZero << " " << ClOp::BitOne << " " + << ClOp::RegAnd << " " << ClOp::RegOr << " " << ClOp::RegXor << " " + << ClOp::RegEq << " " << ClOp::RegNeq << " " << ClOp::RegNot << " " + << ClOp::RegZero << " " << ClOp::RegOne << " " << ClOp::RegLt << " " + << ClOp::RegGt << " " << ClOp::RegLeq << " " << ClOp::RegGeq << " " + << ClOp::RegAdd << " " << ClOp::RegSub << " " << ClOp::RegMul << " " + << ClOp::RegDiv << " " << ClOp::RegPow << " " << ClOp::RegLsh << " " + << ClOp::RegRsh << " " << ClOp::RegNeg; + REQUIRE( + ss.str() == + "INVALID and or xor eq neq not zero one and or xor eq neq not zero one " + "lt gt leq geq add sub mul div pow lsh rsh neg"); + } + GIVEN("ClBitVar") { + ClBitVar var{3}; + std::stringstream ss; + ss << var; + REQUIRE(ss.str() == "b3"); + nlohmann::json j = var; + ClBitVar var1 = j.get(); + REQUIRE(var1 == var); + } + GIVEN("ClRegVar") { + ClRegVar var{4}; + std::stringstream ss; + ss << var; + REQUIRE(ss.str() == "r4"); + nlohmann::json j = var; + ClRegVar var1 = j.get(); + REQUIRE(var1 == var); + } + GIVEN("ClExprVar") { + ClExprVar var_bit = ClBitVar{3}; + ClExprVar var_reg = ClRegVar{4}; + std::stringstream ss; + ss << var_bit << ", " << var_reg; + REQUIRE(ss.str() == "b3, r4"); + nlohmann::json j_bit = var_bit; + nlohmann::json j_reg = var_reg; + ClExprVar var_bit1 = j_bit.get(); + ClExprVar var_reg1 = j_reg.get(); + REQUIRE(var_bit1 == var_bit); + REQUIRE(var_reg1 == var_reg); + } + GIVEN("ClExprTerm") { + ClExprTerm term_int = 7; + ClExprTerm term_var = ClRegVar{5}; + std::stringstream ss; + ss << term_int << ", " << term_var; + REQUIRE(ss.str() == "7, r5"); + nlohmann::json j_int = term_int; + nlohmann::json j_var = term_var; + ClExprTerm term_int1 = j_int.get(); + ClExprTerm term_var1 = j_var.get(); + REQUIRE(term_int1 == term_int); + REQUIRE(term_var1 == term_var); + } + GIVEN("Vector of ClExprArg (1)") { + std::vector args{ClRegVar{2}, int{3}}; + nlohmann::json j = args; + std::vector args1 = j.get>(); + REQUIRE(args == args1); + } + GIVEN("ClExpr (1)") { + // r0 + 7 + ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{7}}); + std::stringstream ss; + ss << expr; + REQUIRE(ss.str() == "add(r0, 7)"); + nlohmann::json j = expr; + ClExpr expr1 = j.get(); + REQUIRE(expr1 == expr); + } + GIVEN("Vector of ClExprArg (2)") { + ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{8}}); + std::vector args{expr}; + nlohmann::json j = args; + std::vector args1 = j.get>(); + REQUIRE(args == args1); + } + GIVEN("ClExpr (2)") { + // (r0 + r1) / (r2 * 3) + ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}}); + ClExpr denom(ClOp::RegMul, {ClRegVar{2}, int{3}}); + ClExpr expr(ClOp::RegDiv, {numer, denom}); + std::stringstream ss; + ss << expr; + REQUIRE(ss.str() == "div(add(r0, r1), mul(r2, 3))"); + nlohmann::json j = expr; + ClExpr expr1 = j.get(); + REQUIRE(expr1 == expr); + } + GIVEN("WiredClExpr") { + ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}}); + ClExpr denom(ClOp::RegMul, {ClRegVar{2}, ClRegVar{3}}); + ClExpr expr(ClOp::RegDiv, {numer, denom}); + std::vector a_pos{0, 3, 4}; + std::vector b_pos{1, 11, 5}; + std::vector c_pos{10, 2, 7}; + std::vector d_pos{8, 9, 6}; + WiredClExpr wexpr( + expr, {}, {{0, a_pos}, {1, b_pos}, {2, c_pos}, {3, d_pos}}, d_pos); + std::stringstream ss; + ss << wexpr; + REQUIRE( + ss.str() == + "div(add(r0, r1), mul(r2, r3)) [r0:(0,3,4), r1:(1,11,5), r2:(10,2,7), " + "r3:(8,9,6) --> (8,9,6)]"); + nlohmann::json j = wexpr; + WiredClExpr wexpr1 = j.get(); + REQUIRE(wexpr1 == wexpr); + } + GIVEN("ClExprOp") { + ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}}); + ClExpr denom(ClOp::RegMul, {ClRegVar{2}, ClRegVar{3}}); + ClExpr expr(ClOp::RegDiv, {numer, denom}); + std::vector a_pos{0, 3, 4}; + std::vector b_pos{1, 11, 5}; + std::vector c_pos{10, 2, 7}; + std::vector d_pos{8, 9, 6}; + WiredClExpr wexpr( + expr, {}, {{0, a_pos}, {1, b_pos}, {2, c_pos}, {3, d_pos}}, d_pos); + Op_ptr op = std::make_shared(wexpr); + nlohmann::json j = op; + Op_ptr op1 = j.get(); + const ClExprOp& exprop = static_cast(*op1); + REQUIRE(exprop.get_wired_expr() == wexpr); + Op_ptr op2 = op->symbol_substitution({}); + REQUIRE(op2->free_symbols().empty()); + } +} + +} // namespace tket