diff --git a/qiskit/circuit/__init__.py b/qiskit/circuit/__init__.py index bf51754a458e..52848e2e77b1 100644 --- a/qiskit/circuit/__init__.py +++ b/qiskit/circuit/__init__.py @@ -290,9 +290,16 @@ IfElseOp WhileLoopOp ForLoopOp + SwitchCaseOp BreakLoopOp ContinueLoopOp +The :class:`.SwitchCaseOp` also understands a special value: + +.. py:data: CASE_DEFAULT + Used as a possible "label" in a :class:`.SwitchCaseOp` to represent the default case. This will + always match, if it is tried. + Parametric Quantum Circuits --------------------------- @@ -340,6 +347,8 @@ WhileLoopOp, ForLoopOp, IfElseOp, + SwitchCaseOp, + CASE_DEFAULT, BreakLoopOp, ContinueLoopOp, ) diff --git a/qiskit/circuit/controlflow/__init__.py b/qiskit/circuit/controlflow/__init__.py index 89483aacf21b..60df9c2a370a 100644 --- a/qiskit/circuit/controlflow/__init__.py +++ b/qiskit/circuit/controlflow/__init__.py @@ -20,3 +20,4 @@ from .if_else import IfElseOp from .while_loop import WhileLoopOp from .for_loop import ForLoopOp +from .switch_case import SwitchCaseOp, CASE_DEFAULT diff --git a/qiskit/circuit/controlflow/switch_case.py b/qiskit/circuit/controlflow/switch_case.py new file mode 100644 index 000000000000..7eec5474bec8 --- /dev/null +++ b/qiskit/circuit/controlflow/switch_case.py @@ -0,0 +1,170 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Circuit operation representing an ``switch/case`` statement.""" + +__all__ = ("SwitchCaseOp", "CASE_DEFAULT") + +from typing import Union, Iterable, Any, Tuple, Optional, List, Literal + +from qiskit.circuit import ClassicalRegister, Clbit, QuantumCircuit +from qiskit.circuit.exceptions import CircuitError + +from .control_flow import ControlFlowOp + + +class _DefaultCaseType: + """The type of the default-case singleton. This is used instead of just having + ``CASE_DEFAULT = object()`` so we can set the pretty-printing properties, which are class-level + only.""" + + def __repr__(self): + return "" + + +CASE_DEFAULT = _DefaultCaseType() +"""A special object that represents the "default" case of a switch statement. If you use this +as a case target, it must be the last case, and will match anything that wasn't already matched. +When using the builder interface of :meth:`.QuantumCircuit.switch`, this can also be accessed as the +``DEFAULT`` attribute of the bound case-builder object. +""" + + +class SwitchCaseOp(ControlFlowOp): + """A circuit operation that executes one particular circuit block based on matching a given + ``target`` against an ordered list of ``values``. The special value :data:`.CASE_DEFAULT` can + be used to represent a default condition. + + This is the low-level interface for creating a switch-case statement; in general, the circuit + method :meth:`.QuantumCircuit.switch_case` should be used as a context manager to access the + builder interface. At the low level, you must ensure that all the circuit blocks contain equal + numbers of qubits and clbits, and that the order the virtual bits of the containing circuit + should be bound is the same for all blocks. This will likely mean that each circuit block is + wider than its natural width, as each block must span the space covered by _any_ of the blocks. + + Args: + target: the runtime value to switch on. + cases: an ordered iterable of the corresponding value of the ``target`` and the circuit + block that should be executed if this is matched. There is no fall-through between + blocks, and the order matters. + """ + + def __init__( + self, + target: Union[Clbit, ClassicalRegister], + cases: Iterable[Tuple[Any, QuantumCircuit]], + *, + label: Optional[str] = None, + ): + if not isinstance(target, (Clbit, ClassicalRegister)): + raise CircuitError("the switch target must be a classical bit or register") + + target_bits = 1 if isinstance(target, Clbit) else len(target) + target_max = (1 << target_bits) - 1 + + case_ids = set() + num_qubits, num_clbits = None, None + self.target = target + self._case_map = {} + """Mapping of individual jump values to block indices. This level of indirection is to let + us more easily track the case of multiple labels pointing to the same circuit object, so + it's easier for things like `assign_parameters`, which need to touch each circuit object + exactly once, to function.""" + self._label_spec: List[Tuple[Union[int, Literal[CASE_DEFAULT]], ...]] = [] + """List of the normalised jump value specifiers. This is a list of tuples, where each tuple + contains the values, and the indexing is the same as the values of `_case_map` and + `_params`.""" + self._params = [] + """List of the circuit bodies used. This form makes it simpler for things like + :meth:`.replace_blocks` and :class:`.QuantumCircuit.assign_parameters` to do their jobs + without accidentally mutating the same circuit instance more than once.""" + for i, (value_spec, case_) in enumerate(cases): + values = tuple(value_spec) if isinstance(value_spec, (tuple, list)) else (value_spec,) + for value in values: + if value in self._case_map: + raise CircuitError(f"duplicate case value {value}") + if CASE_DEFAULT in self._case_map: + raise CircuitError("cases after the default are unreachable") + if value is not CASE_DEFAULT: + if not isinstance(value, int) or value < 0: + raise CircuitError("case values must be Booleans or non-negative integers") + if value > target_max: + raise CircuitError( + f"switch target '{target}' has {target_bits} bit(s) of precision," + f" but case {value} is larger than the maximum of {target_max}." + ) + self._case_map[value] = i + self._label_spec.append(values) + if not isinstance(case_, QuantumCircuit): + raise CircuitError("case blocks must be QuantumCircuit instances") + if id(case_) in case_ids: + raise CircuitError("separate cases cannot point to the same block") + case_ids.add(id(case_)) + if num_qubits is None: + num_qubits, num_clbits = case_.num_qubits, case_.num_clbits + if case_.num_qubits != num_qubits or case_.num_clbits != num_clbits: + raise CircuitError("incompatible bits between cases") + self._params.append(case_) + if not self._params: + # This condition also implies that `num_qubits` and `num_clbits` must be non-None. + raise CircuitError("must have at least one case to run") + + super().__init__("switch_case", num_qubits, num_clbits, self._params, label=label) + + def __eq__(self, other): + # The general __eq__ will compare the blocks in the right order, so we just need to ensure + # that all the labels point the right way as well. + return super().__eq__(other) and all( + set(labels_self) == set(labels_other) + for (labels_self, _), (labels_other, _) in zip(self._label_spec, other._label_spec) + ) + + def cases_specifier(self) -> Iterable[Tuple[Tuple, QuantumCircuit]]: + """Return an iterable where each element is a 2-tuple whose first element is a tuple of the + all the jump values that are associated with the circuit block in the second element. + + This is an abstract specification of the jump table suitable for creating new + :class:`.SwitchCaseOp` instances. + + .. seealso:: + :meth:`.SwitchCaseOp.cases` + Create a lookup table that you can use for your own purposes to jump from values to + the circuit that would be executed.""" + return zip(self._label_spec, self._params) + + def cases(self): + """Return a lookup table from case labels to the circuit that would be executed in that + case. This object is not generally suitable for creating a new :class:`.SwitchCaseOp` + because any keys that point to the same object will not be grouped. + + .. seealso:: + :meth:`.SwitchCaseOp.cases_specifier` + An alternate method that produces its output in a suitable format for creating new + :class:`.SwitchCaseOp` instances. + """ + return {key: self._params[index] for key, index in self._case_map.items()} + + @property + def blocks(self): + return tuple(self._params) + + def replace_blocks(self, blocks: Iterable[QuantumCircuit]) -> "SwitchCaseOp": + blocks = tuple(blocks) + if len(blocks) != len(self._params): + raise CircuitError(f"needed {len(self._case_map)} blocks but received {len(blocks)}") + return SwitchCaseOp(self.target, zip(self._label_spec, blocks)) + + def c_if(self, classical, val): + raise NotImplementedError( + "SwitchCaseOp cannot be classically controlled through Instruction.c_if. " + "Please nest it in an IfElseOp instead." + ) diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index 7c304f63af4a..7a032025e703 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -4591,6 +4591,47 @@ def if_else( condition = (self._resolve_classical_resource(condition[0]), condition[1]) return self.append(IfElseOp(condition, true_body, false_body, label), qubits, clbits) + def switch( + self, + target: Union[ClbitSpecifier, ClassicalRegister], + cases: Iterable[Tuple[typing.Any, QuantumCircuit]], + qubits: Sequence[QubitSpecifier], + clbits: Sequence[ClbitSpecifier], + *, + label: Optional[str] = None, + ) -> InstructionSet: + """Create a ``switch``/``case`` structure on this circuit. + + There are two forms for calling this function. If called with all its arguments (with the + possible exception of ``label``), it will create a + :class:`.SwitchCaseOp` with the given case structure. If ``cases`` (and + ``qubits`` and ``clbits``) are *not* passed, then this acts as a context manager, which + will automatically build a :class:`.SwitchCaseOp` when the scope finishes. In this form, + you do not need to keep track of the qubits or clbits you are using, because the scope will + handle it for you. + + Args: + target (Union[ClassicalRegister, Clbit]): The classical value to switch one. This must + be integer valued. + cases (Iterable[Tuple[typing.Any, QuantumCircuit]]): A sequence of case specifiers. Each + tuple defines one case body (the second item). The first item of the tuple can be + either a single integer value, the special value :data:`.CASE_DEFAULT`, or a tuple + of several integer values. Each of the integer values will be tried in turn; control + will then pass to the body corresponding to the first match. :data:`.CASE_DEFAULT` + matches all possible values. + qubits (Sequence[Qubit]): The circuit qubits over which all case bodies execute. + clbits (Sequence[Clbit]): The circuit clbits over which all case bodies execute. + label (Optional[str]): The string label of the instruction in the circuit. + + Returns: + InstructionSet: A handle to the instruction created. + """ + # pylint: disable=cyclic-import + from qiskit.circuit.controlflow.switch_case import SwitchCaseOp + + target = self._resolve_classical_resource(target) + return self.append(SwitchCaseOp(target, cases, label=label), qubits, clbits) + def break_loop(self) -> InstructionSet: """Apply :class:`~qiskit.circuit.BreakLoopOp`. diff --git a/releasenotes/notes/switch-case-9b6611d0603d36c0.yaml b/releasenotes/notes/switch-case-9b6611d0603d36c0.yaml new file mode 100644 index 000000000000..28f91c5d32ac --- /dev/null +++ b/releasenotes/notes/switch-case-9b6611d0603d36c0.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Qiskit now supports the representation of ``switch`` statements, using the new :class:`.SwitchCaseOp` + instruction and the :meth:`.QuantumCircuit.switch` method. This allows switching on a numeric + input (such as a classical register or bit) and executing the circuit that corresponds to the + matching value. Multiple values can point to the same circuit, and :data:`.CASE_DEFAULT` can be + used as an always-matching label. diff --git a/test/python/circuit/test_control_flow.py b/test/python/circuit/test_control_flow.py index f6e2372eb50d..84e9caca5e52 100644 --- a/test/python/circuit/test_control_flow.py +++ b/test/python/circuit/test_control_flow.py @@ -17,7 +17,8 @@ from ddt import ddt, data from qiskit.test import QiskitTestCase -from qiskit.circuit import Clbit, ClassicalRegister, Instruction, Parameter, QuantumCircuit +from qiskit.circuit import Clbit, ClassicalRegister, Instruction, Parameter, QuantumCircuit, Qubit +from qiskit.circuit.controlflow import CASE_DEFAULT from qiskit.circuit.library import XGate, RXGate from qiskit.circuit.exceptions import CircuitError @@ -28,6 +29,7 @@ IfElseOp, ContinueLoopOp, BreakLoopOp, + SwitchCaseOp, ) @@ -292,6 +294,138 @@ def test_break_loop_instantiation(self): self.assertEqual(op.num_clbits, 1) self.assertEqual(op.params, []) + def test_switch_clbit(self): + """Test that a switch statement can be constructed with a bit as a condition.""" + qubit = Qubit() + clbit = Clbit() + case1 = QuantumCircuit([qubit, clbit]) + case1.x(0) + case2 = QuantumCircuit([qubit, clbit]) + case2.z(0) + + op = SwitchCaseOp(clbit, [(True, case1), (False, case2)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 1) + self.assertEqual(op.target, clbit) + self.assertEqual(op.cases(), {True: case1, False: case2}) + self.assertEqual(list(op.blocks), [case1, case2]) + + def test_switch_register(self): + """Test that a switch statement can be constructed with a register as a condition.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + case3 = QuantumCircuit([qubit], creg) + case3.z(0) + + op = SwitchCaseOp(creg, [(0, case1), (1, case2), (2, case3)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 2) + self.assertEqual(op.target, creg) + self.assertEqual(op.cases(), {0: case1, 1: case2, 2: case3}) + self.assertEqual(list(op.blocks), [case1, case2, case3]) + + def test_switch_with_default(self): + """Test that a switch statement can be constructed with a default case at the end.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + case3 = QuantumCircuit([qubit], creg) + case3.z(0) + + op = SwitchCaseOp(creg, [(0, case1), (1, case2), (CASE_DEFAULT, case3)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 2) + self.assertEqual(op.target, creg) + self.assertEqual(op.cases(), {0: case1, 1: case2, CASE_DEFAULT: case3}) + self.assertEqual(list(op.blocks), [case1, case2, case3]) + + def test_switch_multiple_cases_to_same_block(self): + """Test that it is possible to add multiple cases that apply to the same block, if they are + given as a compound value. This is an allowed special case of block fall-through.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + + op = SwitchCaseOp(creg, [(0, case1), ((1, 2), case2)]) + self.assertIsInstance(op, Instruction) + self.assertEqual(op.name, "switch_case") + self.assertEqual(op.num_qubits, 1) + self.assertEqual(op.num_clbits, 2) + self.assertEqual(op.target, creg) + self.assertEqual(op.cases(), {0: case1, 1: case2, 2: case2}) + self.assertEqual(list(op.blocks), [case1, case2]) + + def test_switch_rejects_separate_cases_to_same_block(self): + """Test that the switch statement rejects cases that are supplied separately, but point to + the same QuantumCircuit.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + + with self.assertRaisesRegex(CircuitError, "separate cases cannot point to the same block"): + SwitchCaseOp(creg, [(0, case1), (1, case2), (2, case1)]) + + def test_switch_rejects_cases_over_different_bits(self): + """Test that a switch statement fails to build if its individual cases are not all defined + over the same numbers of bits.""" + qubits = [Qubit() for _ in [None] * 3] + clbits = [Clbit(), Clbit()] + case1 = QuantumCircuit(qubits, clbits) + case2 = QuantumCircuit(qubits[1:], clbits) + + for case in (case1, case2): + case.h(1) + case.cx(1, 0) + case.measure(0, 0) + + with self.assertRaisesRegex(CircuitError, r"incompatible bits between cases"): + SwitchCaseOp(Clbit(), [(True, case1), (False, case2)]) + + def test_switch_rejects_cases_with_bad_types(self): + """Test that a switch statement will fail to build if it contains cases whose types are not + matched to the switch expression.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + + with self.assertRaisesRegex(CircuitError, "case values must be"): + SwitchCaseOp(creg, [(1.3, case1), (4.5, case2)]) + + def test_switch_rejects_cases_after_default(self): + """Test that a switch statement will fail to build if there are cases after the default + case.""" + qubit = Qubit() + creg = ClassicalRegister(2) + case1 = QuantumCircuit([qubit], creg) + case1.x(0) + case2 = QuantumCircuit([qubit], creg) + case2.y(0) + + with self.assertRaisesRegex(CircuitError, "cases after the default are unreachable"): + SwitchCaseOp(creg, [(CASE_DEFAULT, case1), (1, case2)]) + @ddt class TestAddingControlFlowOperations(QiskitTestCase):