diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index c1f8cc14b95..f3093b0d316 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -185,6 +185,7 @@ AmplitudeDampingChannel, AnyIntegerPowerGateFamily, AnyUnitaryGateFamily, + ArithmeticGate, ArithmeticOperation, asymmetric_depolarize, AsymmetricDepolarizingChannel, diff --git a/cirq-core/cirq/interop/quirk/__init__.py b/cirq-core/cirq/interop/quirk/__init__.py index b2b97b58775..a3fa4bc5daf 100644 --- a/cirq-core/cirq/interop/quirk/__init__.py +++ b/cirq-core/cirq/interop/quirk/__init__.py @@ -20,6 +20,7 @@ # Imports from cells are only to ensure operation reprs work correctly. from cirq.interop.quirk.cells import ( + QuirkArithmeticGate, QuirkArithmeticOperation, QuirkInputRotationOperation, QuirkQubitPermutationGate, diff --git a/cirq-core/cirq/interop/quirk/cells/__init__.py b/cirq-core/cirq/interop/quirk/cells/__init__.py index 5f7cf8ccbb9..b03af2a493f 100644 --- a/cirq-core/cirq/interop/quirk/cells/__init__.py +++ b/cirq-core/cirq/interop/quirk/cells/__init__.py @@ -21,7 +21,7 @@ from cirq.interop.quirk.cells.qubit_permutation_cells import QuirkQubitPermutationGate -from cirq.interop.quirk.cells.arithmetic_cells import QuirkArithmeticOperation +from cirq.interop.quirk.cells.arithmetic_cells import QuirkArithmeticGate, QuirkArithmeticOperation from cirq.interop.quirk.cells.input_rotation_cells import QuirkInputRotationOperation diff --git a/cirq-core/cirq/interop/quirk/cells/arithmetic_cells.py b/cirq-core/cirq/interop/quirk/cells/arithmetic_cells.py index 8f89d6c99b8..0072686cf15 100644 --- a/cirq-core/cirq/interop/quirk/cells/arithmetic_cells.py +++ b/cirq-core/cirq/interop/quirk/cells/arithmetic_cells.py @@ -28,12 +28,14 @@ ) from cirq import ops, value +from cirq._compat import deprecated_class from cirq.interop.quirk.cells.cell import Cell, CellMaker, CELL_SIZES if TYPE_CHECKING: import cirq +@deprecated_class(deadline='v0.15', fix='Use cirq.QuirkArithmeticGate') @value.value_equality class QuirkArithmeticOperation(ops.ArithmeticOperation): """Applies arithmetic to a target and some inputs. @@ -148,6 +150,110 @@ def __repr__(self) -> str: ) +@value.value_equality +class QuirkArithmeticGate(ops.ArithmeticGate): + """Applies arithmetic to a target and some inputs. + + Implements Quirk-specific implicit effects like assuming that the presence + of an 'r' input implies modular arithmetic. + + In Quirk, modular operations have no effect on values larger than the + modulus. This convention is used because unitarity forces *some* convention + on out-of-range values (they cannot simply disappear or raise exceptions), + and the simplest is to do nothing. This call handles ensuring that happens, + and ensuring the new target register value is normalized modulo the modulus. + """ + + def __init__( + self, identifier: str, target: Sequence[int], inputs: Sequence[Union[Sequence[int], int]] + ): + """Inits QuirkArithmeticGate. + + Args: + identifier: The quirk identifier string for this operation. + target: The target qubit register. + inputs: Qubit registers, which correspond to the qid shape of the + qubits from which the input will be read, or classical + constants, that determine what happens to the target. + + Raises: + ValueError: If the target is too small for a modular operation with + too small modulus. + """ + self.identifier = identifier + self.target: Tuple[int, ...] = tuple(target) + self.inputs: Tuple[Union[Sequence[int], int], ...] = tuple( + e if isinstance(e, int) else tuple(e) for e in inputs + ) + + if self.operation.is_modular: + r = inputs[-1] + if isinstance(r, int): + over = r > 1 << len(target) + else: + over = len(cast(Sequence, r)) > len(target) + if over: + raise ValueError(f'Target too small for modulus.\nTarget: {target}\nModulus: {r}') + + @property + def operation(self) -> '_QuirkArithmeticCallable': + return ARITHMETIC_OP_TABLE[self.identifier] + + def _value_equality_values_(self) -> Any: + return self.identifier, self.target, self.inputs + + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + return [self.target, *self.inputs] + + def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> 'QuirkArithmeticGate': + if len(new_registers) != len(self.inputs) + 1: + raise ValueError( + 'Wrong number of registers.\n' + f'New registers: {repr(new_registers)}\n' + f'Operation: {repr(self)}' + ) + + if isinstance(new_registers[0], int): + raise ValueError( + 'The first register is the mutable target. ' + 'It must be a list of qubits, not the constant ' + f'{new_registers[0]}.' + ) + + return QuirkArithmeticGate(self.identifier, new_registers[0], new_registers[1:]) + + def apply(self, *registers: int) -> Union[int, Iterable[int]]: + return self.operation(*registers) + + def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> List[str]: + lettered_args = list(zip(self.operation.letters, self.inputs)) + + result: List[str] = [] + + # Target register labels. + consts = ''.join( + f',{letter}={reg}' for letter, reg in lettered_args if isinstance(reg, int) + ) + result.append(f'Quirk({self.identifier}{consts})') + result.extend(f'#{i}' for i in range(2, len(self.target) + 1)) + + # Input register labels. + for letter, reg in lettered_args: + if not isinstance(reg, int): + result.extend(f'{letter.upper()}{i}' for i in range(len(cast(Sequence, reg)))) + + return result + + def __repr__(self) -> str: + return ( + 'cirq.interop.quirk.QuirkArithmeticGate(\n' + f' {repr(self.identifier)},\n' + f' target={repr(self.target)},\n' + f' inputs={_indented_list_lines_repr(self.inputs)},\n' + ')' + ) + + _IntsToIntCallable = Union[ Callable[[int], int], Callable[[int, int], int], @@ -244,11 +350,13 @@ def operations(self) -> 'cirq.OP_TREE': if missing_inputs: raise ValueError(f'Missing input: {sorted(missing_inputs)}') - return QuirkArithmeticOperation( + inputs = cast(Sequence[Union[Sequence['cirq.Qid'], int]], self.inputs) + qubits = self.target + tuple(q for i in self.inputs if isinstance(i, Sequence) for q in i) + return QuirkArithmeticGate( self.identifier, - self.target, - cast(Sequence[Union[Sequence['cirq.Qid'], int]], self.inputs), - ) + [q.dimension for q in self.target], + [i if isinstance(i, int) else [q.dimension for q in i] for i in inputs], + ).on(*qubits) def _indented_list_lines_repr(items: Sequence[Any]) -> str: diff --git a/cirq-core/cirq/interop/quirk/cells/arithmetic_cells_test.py b/cirq-core/cirq/interop/quirk/cells/arithmetic_cells_test.py index ff93fef0f3b..2ed1f3522dc 100644 --- a/cirq-core/cirq/interop/quirk/cells/arithmetic_cells_test.py +++ b/cirq-core/cirq/interop/quirk/cells/arithmetic_cells_test.py @@ -344,7 +344,7 @@ def test_with_registers(): '["+=AB3",1,1,"inputB2"]' ']}' ) - op = cast(cirq.ArithmeticOperation, circuit[0].operations[0]) + op = cast(cirq.ArithmeticGate, circuit[0].operations[0].gate) with pytest.raises(ValueError, match='number of registers'): _ = op.with_registers() @@ -353,11 +353,11 @@ def test_with_registers(): _ = op.with_registers(1, 2, 3) op2 = op.with_registers([], 5, 5) - np.testing.assert_allclose(cirq.unitary(cirq.Circuit(op2)), np.array([[1]]), atol=1e-8) + np.testing.assert_allclose(cirq.unitary(cirq.Circuit(op2())), np.array([[1]]), atol=1e-8) - op2 = op.with_registers([*cirq.LineQubit.range(3)], 5, 5) + op2 = op.with_registers([2, 2, 2], 5, 5) np.testing.assert_allclose( - cirq.final_state_vector(cirq.Circuit(op2), initial_state=0), + cirq.final_state_vector(cirq.Circuit(op2(*cirq.LineQubit.range(3))), initial_state=0), cirq.one_hot(index=25 % 8, shape=8, dtype=np.complex64), atol=1e-8, ) diff --git a/cirq-core/cirq/interop/quirk/cells/composite_cell_test.py b/cirq-core/cirq/interop/quirk/cells/composite_cell_test.py index db13e29e8b1..3a68a9bf866 100644 --- a/cirq-core/cirq/interop/quirk/cells/composite_cell_test.py +++ b/cirq-core/cirq/interop/quirk/cells/composite_cell_test.py @@ -92,13 +92,17 @@ def test_custom_circuit_gate(): # With internal input. assert_url_to_circuit_returns( '{"cols":[["~a5ls"]],"gates":[{"id":"~a5ls","circuit":{"cols":[["inputA1","+=A1"]]}}]}', - cirq.Circuit(cirq.interop.quirk.QuirkArithmeticOperation('+=A1', target=[b], inputs=[[a]])), + cirq.Circuit( + cirq.interop.quirk.QuirkArithmeticGate('+=A1', target=[2], inputs=[[2]]).on(b, a) + ), ) # With external input. assert_url_to_circuit_returns( '{"cols":[["inputA1","~r79k"]],"gates":[{"id":"~r79k","circuit":{"cols":[["+=A1"]]}}]}', - cirq.Circuit(cirq.interop.quirk.QuirkArithmeticOperation('+=A1', target=[b], inputs=[[a]])), + cirq.Circuit( + cirq.interop.quirk.QuirkArithmeticGate('+=A1', target=[2], inputs=[[2]]).on(b, a) + ), ) # With external control. @@ -127,9 +131,15 @@ def test_custom_circuit_gate(): '{"cols":[["~q1fh",1,1,"inputA2"]],"gates":[{"id":"~q1fh",' '"circuit":{"cols":[["+=A2"],[1,"+=A2"],[1,"+=A2"]]}}]}', cirq.Circuit( - cirq.interop.quirk.QuirkArithmeticOperation('+=A2', target=[a, b], inputs=[[d, e]]), - cirq.interop.quirk.QuirkArithmeticOperation('+=A2', target=[b, c], inputs=[[d, e]]), - cirq.interop.quirk.QuirkArithmeticOperation('+=A2', target=[b, c], inputs=[[d, e]]), + cirq.interop.quirk.QuirkArithmeticGate('+=A2', target=[2, 2], inputs=[[2, 2]]).on( + a, b, d, e + ), + cirq.interop.quirk.QuirkArithmeticGate('+=A2', target=[2, 2], inputs=[[2, 2]]).on( + b, c, d, e + ), + cirq.interop.quirk.QuirkArithmeticGate('+=A2', target=[2, 2], inputs=[[2, 2]]).on( + b, c, d, e + ), ), ) diff --git a/cirq-core/cirq/interop/quirk/cells/input_cells_test.py b/cirq-core/cirq/interop/quirk/cells/input_cells_test.py index 2453d73fd84..23a72b5ea72 100644 --- a/cirq-core/cirq/interop/quirk/cells/input_cells_test.py +++ b/cirq-core/cirq/interop/quirk/cells/input_cells_test.py @@ -35,7 +35,7 @@ def test_input_cell(): ) # Overlaps with effect. - with pytest.raises(ValueError, match='Overlapping registers'): + with pytest.raises(ValueError, match='Duplicate qids'): _ = quirk_url_to_circuit( 'https://algassert.com/quirk#circuit={"cols":[["+=A3","inputA3"]]}' ) @@ -53,7 +53,7 @@ def test_reversed_input_cell(): ) # Overlaps with effect. - with pytest.raises(ValueError, match='Overlapping registers'): + with pytest.raises(ValueError, match='Duplicate qids'): _ = quirk_url_to_circuit( 'https://algassert.com/quirk#circuit={"cols":[["+=A3","revinputA3"]]}' ) diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index f3a649fd4aa..086a27d775c 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -14,7 +14,7 @@ """Gates (unitary and non-unitary), operations, base types, and gate sets. """ -from cirq.ops.arithmetic_operation import ArithmeticOperation +from cirq.ops.arithmetic_operation import ArithmeticGate, ArithmeticOperation from cirq.ops.clifford_gate import CliffordGate, PauliTransform, SingleQubitCliffordGate diff --git a/cirq-core/cirq/ops/arithmetic_operation.py b/cirq-core/cirq/ops/arithmetic_operation.py index 7f3b010b58b..57a6665d797 100644 --- a/cirq-core/cirq/ops/arithmetic_operation.py +++ b/cirq-core/cirq/ops/arithmetic_operation.py @@ -15,11 +15,12 @@ import abc import itertools -from typing import Union, Iterable, List, Sequence, cast, TypeVar, TYPE_CHECKING +from typing import Union, Iterable, List, Sequence, cast, Tuple, TypeVar, TYPE_CHECKING import numpy as np -from cirq.ops.raw_types import Operation +from cirq._compat import deprecated_class +from cirq.ops.raw_types import Operation, Gate if TYPE_CHECKING: import cirq @@ -28,6 +29,7 @@ TSelf = TypeVar('TSelf', bound='ArithmeticOperation') +@deprecated_class(deadline='v0.15', fix='Use cirq.ArithmeticGate') class ArithmeticOperation(Operation, metaclass=abc.ABCMeta): """A helper class for implementing reversible classical arithmetic. @@ -241,8 +243,217 @@ def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): return args.target_tensor +TSelfGate = TypeVar('TSelfGate', bound='ArithmeticGate') + + +class ArithmeticGate(Gate, metaclass=abc.ABCMeta): + """A helper gate for implementing reversible classical arithmetic. + + Child classes must override the `registers`, `with_registers`, and `apply` + methods. + + This class handles the details of ensuring that the scaling of implementing + the gate is O(2^n) instead of O(4^n) where n is the number of qubits + being acted on, by implementing an `_apply_unitary_` function in terms of + the registers and the apply function of the child class. + + Examples: + ``` + + >>> class Add(cirq.ArithmeticGate): + ... def __init__( + ... self, + ... target_register: [int, Sequence[int]], + ... input_register: Union[int, Sequence[int]], + ... ): + ... self.target_register = target_register + ... self.input_register = input_register + ... + ... def registers(self) -> Sequence[Union[int, Sequence[int]]]: + ... return self.target_register, self.input_register + ... + ... def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> TSelfGate: + ... return Add(*new_registers) + ... + ... def apply(self, *register_values: int) -> Union[int, Iterable[int]]: + ... return sum(register_values) + >>> cirq.unitary( + ... Add(target_register=[2, 2], + ... input_register=1).on(*cirq.LineQubit.range(2)) + ... ).astype(np.int32) + array([[0, 0, 0, 1], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]], dtype=int32) + >>> c = cirq.Circuit( + ... cirq.X(cirq.LineQubit(3)), + ... cirq.X(cirq.LineQubit(2)), + ... cirq.X(cirq.LineQubit(6)), + ... cirq.measure(*cirq.LineQubit.range(4, 8), key='before:in'), + ... cirq.measure(*cirq.LineQubit.range(4), key='before:out'), + ... + ... Add(target_register=[2] * 4, + ... input_register=[2] * 4).on(*cirq.LineQubit.range(8)), + ... + ... cirq.measure(*cirq.LineQubit.range(4, 8), key='after:in'), + ... cirq.measure(*cirq.LineQubit.range(4), key='after:out'), + ... ) + >>> cirq.sample(c).data + before:in before:out after:in after:out + 0 2 3 2 5 + + ``` + """ + + @abc.abstractmethod + def registers(self) -> Sequence[Union[int, Sequence[int]]]: + """The data acted upon by the arithmetic gate. + + Each register in the list can either be a classical constant (an `int`), + or else a list of qubit/qudit dimensions. Registers that are set to a + classical constant must not be mutated by the arithmetic gate + (their value must remain fixed when passed to `apply`). + + Registers are big endian. The first qubit is the most significant, the + last qubit is the 1s qubit, the before last qubit is the 2s qubit, etc. + + Returns: + A list of constants and qubit groups that the gate will act upon. + """ + raise NotImplementedError() + + @abc.abstractmethod + def with_registers(self: TSelfGate, *new_registers: Union[int, Sequence[int]]) -> TSelfGate: + """Returns the same fate targeting different registers. + + Args: + *new_registers: The new values that should be returned by the + `registers` method. + + Returns: + An instance of the same kind of gate, but acting on different + registers. + """ + raise NotImplementedError() + + @abc.abstractmethod + def apply(self, *register_values: int) -> Union[int, Iterable[int]]: + """Returns the result of the gate operating on classical values. + + For example, an addition takes two values (the target and the source), + adds the source into the target, then returns the target and source + as the new register values. + + The `apply` method is permitted to be sloppy in three ways: + + 1. The `apply` method is permitted to return values that have more bits + than the registers they will be stored into. The extra bits are + simply dropped. For example, if the value 5 is returned for a 2 + qubit register then 5 % 2**2 = 1 will be used instead. Negative + values are also permitted. For example, for a 3 qubit register the + value -2 becomes -2 % 2**3 = 6. + 2. When the value of the last `k` registers is not changed by the + gate, the `apply` method is permitted to omit these values + from the result. That is to say, when the length of the output is + less than the length of the input, it is padded up to the intended + length by copying from the same position in the input. + 3. When only the first register's value changes, the `apply` method is + permitted to return an `int` instead of a sequence of ints. + + The `apply` method *must* be reversible. Otherwise the gate will + not be unitary, and incorrect behavior will result. + + Examples: + + A fully detailed adder: + + ``` + def apply(self, target, offset): + return (target + offset) % 2**len(self.target_register), offset + ``` + + The same adder, with less boilerplate due to the details being + handled by the `ArithmeticGate` class: + + ``` + def apply(self, target, offset): + return target + offset + ``` + """ + raise NotImplementedError() + + def _qid_shape_(self) -> Tuple[int, ...]: + shape = [] + for r in self.registers(): + if isinstance(r, Sequence): + for i in r: + shape.append(i) + return tuple(shape) + + def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): + registers = self.registers() + input_ranges: List[Sequence[int]] = [] + shape: List[int] = [] + overflow_sizes: List[int] = [] + for register in registers: + if isinstance(register, int): + input_ranges.append([register]) + shape.append(1) + overflow_sizes.append(register + 1) + else: + size = int(np.prod([dim for dim in register], dtype=np.int64).item()) + shape.append(size) + input_ranges.append(range(size)) + overflow_sizes.append(size) + + leftover = args.target_tensor.size // np.prod(shape, dtype=np.int64).item() + new_shape = (*shape, leftover) + + transposed_args = args.with_axes_transposed_to_start() + src = transposed_args.target_tensor.reshape(new_shape) + dst = transposed_args.available_buffer.reshape(new_shape) + for input_seq in itertools.product(*input_ranges): + output = self.apply(*input_seq) + + # Wrap into list. + inputs: List[int] = list(input_seq) + outputs: List[int] = [output] if isinstance(output, int) else list(output) + + # Omitted tail values default to the corresponding input value. + if len(outputs) < len(inputs): + outputs += inputs[len(outputs) - len(inputs) :] + # Get indices into range. + for i in range(len(outputs)): + if isinstance(registers[i], int): + if outputs[i] != registers[i]: + raise ValueError( + _describe_bad_arithmetic_changed_const( + self.registers(), inputs, outputs + ) + ) + # Classical constants go to zero on a unit axe. + outputs[i] = 0 + inputs[i] = 0 + else: + # Quantum values get wrapped into range. + outputs[i] %= overflow_sizes[i] + + # Copy amplitude to new location. + cast(List[Union[int, slice]], outputs).append(slice(None)) + cast(List[Union[int, slice]], inputs).append(slice(None)) + dst[tuple(outputs)] = src[tuple(inputs)] + + # In case the reshaped arrays were copies instead of views. + dst.shape = transposed_args.available_buffer.shape + transposed_args.target_tensor[...] = dst + + return args.target_tensor + + def _describe_bad_arithmetic_changed_const( - registers: Sequence[Union[int, Sequence['cirq.Qid']]], inputs: List[int], outputs: List[int] + registers: Sequence[Union[int, Sequence[Union['cirq.Qid', int]]]], + inputs: List[int], + outputs: List[int], ) -> str: from cirq.circuits import TextDiagramDrawer @@ -258,7 +469,7 @@ def _describe_bad_arithmetic_changed_const( drawer.write(3, i + 1, str(outputs[i])) return ( "A register cannot be set to an int (a classical constant) unless its " - "value is not affected by the operation.\n" + "value is not affected by the gate.\n" "\nExample case where a constant changed:\n" + drawer.render(horizontal_spacing=1, vertical_spacing=0) ) diff --git a/cirq-core/cirq/ops/arithmetic_operation_test.py b/cirq-core/cirq/ops/arithmetic_operation_test.py index 4586c9582a7..7d93362d818 100644 --- a/cirq-core/cirq/ops/arithmetic_operation_test.py +++ b/cirq-core/cirq/ops/arithmetic_operation_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union, Sequence + import pytest import numpy as np @@ -98,29 +100,96 @@ def with_registers(self, *new_registers): def apply(self, target_value, input_value): return target_value + input_value - inc2 = Add(cirq.LineQubit.range(2), 1) + with cirq.testing.assert_deprecated(deadline='v0.15', count=8): + inc2 = Add(cirq.LineQubit.range(2), 1) + np.testing.assert_allclose(cirq.unitary(inc2), shift_matrix(4, 1), atol=1e-8) + + dec3 = Add(cirq.LineQubit.range(3), -1) + np.testing.assert_allclose(cirq.unitary(dec3), shift_matrix(8, -1), atol=1e-8) + + add3from2 = Add(cirq.LineQubit.range(3), cirq.LineQubit.range(2)) + np.testing.assert_allclose(cirq.unitary(add3from2), adder_matrix(8, 4), atol=1e-8) + + add2from3 = Add(cirq.LineQubit.range(2), cirq.LineQubit.range(3)) + np.testing.assert_allclose(cirq.unitary(add2from3), adder_matrix(4, 8), atol=1e-8) + + with pytest.raises(ValueError, match='affected by the gate'): + _ = cirq.unitary(Add(1, cirq.LineQubit.range(2))) + + with pytest.raises(ValueError, match='affected by the gate'): + _ = cirq.unitary(Add(1, 1)) + + np.testing.assert_allclose(cirq.unitary(Add(1, 0)), np.eye(1)) + + cirq.testing.assert_has_consistent_apply_unitary( + Add(cirq.LineQubit.range(2), cirq.LineQubit.range(2)) + ) + + +def test_arithmetic_gate_apply_unitary(): + class Add(cirq.ArithmeticGate): + def __init__( + self, + target_register: Union[int, Sequence[int]], + input_register: Union[int, Sequence[int]], + ): + self.target_register = target_register + self.input_register = input_register + + def registers(self): + return self.target_register, self.input_register + + def with_registers(self, *new_registers): + raise NotImplementedError() + + def apply(self, target_value, input_value): + return target_value + input_value + + qubits = [cirq.LineQubit.range(i) for i in range(6)] + + inc2 = Add([2, 2], 1) np.testing.assert_allclose(cirq.unitary(inc2), shift_matrix(4, 1), atol=1e-8) + np.testing.assert_allclose(cirq.unitary(inc2.on(*qubits[2])), shift_matrix(4, 1), atol=1e-8) - dec3 = Add(cirq.LineQubit.range(3), -1) + dec3 = Add([2, 2, 2], -1) np.testing.assert_allclose(cirq.unitary(dec3), shift_matrix(8, -1), atol=1e-8) + np.testing.assert_allclose(cirq.unitary(dec3.on(*qubits[3])), shift_matrix(8, -1), atol=1e-8) - add3from2 = Add(cirq.LineQubit.range(3), cirq.LineQubit.range(2)) + add3from2 = Add([2, 2, 2], [2, 2]) np.testing.assert_allclose(cirq.unitary(add3from2), adder_matrix(8, 4), atol=1e-8) + np.testing.assert_allclose( + cirq.unitary(add3from2.on(*qubits[5])), adder_matrix(8, 4), atol=1e-8 + ) - add2from3 = Add(cirq.LineQubit.range(2), cirq.LineQubit.range(3)) + add2from3 = Add([2, 2], [2, 2, 2]) np.testing.assert_allclose(cirq.unitary(add2from3), adder_matrix(4, 8), atol=1e-8) + np.testing.assert_allclose( + cirq.unitary(add2from3.on(*qubits[5])), adder_matrix(4, 8), atol=1e-8 + ) + + with pytest.raises(ValueError, match='affected by the gate'): + _ = cirq.unitary(Add(1, [2, 2])) - with pytest.raises(ValueError, match='affected by the operation'): - _ = cirq.unitary(Add(1, cirq.LineQubit.range(2))) + with pytest.raises(ValueError, match='affected by the gate'): + _ = cirq.unitary(Add(1, [2, 2]).on(*qubits[2])) - with pytest.raises(ValueError, match='affected by the operation'): + with pytest.raises(ValueError, match='affected by the gate'): _ = cirq.unitary(Add(1, 1)) + with pytest.raises(ValueError, match='affected by the gate'): + _ = cirq.unitary(Add(1, 1).on()) + np.testing.assert_allclose(cirq.unitary(Add(1, 0)), np.eye(1)) + np.testing.assert_allclose(cirq.unitary(Add(1, 0).on()), np.eye(1)) - cirq.testing.assert_has_consistent_apply_unitary( - Add(cirq.LineQubit.range(2), cirq.LineQubit.range(2)) - ) + cirq.testing.assert_has_consistent_apply_unitary(Add([2, 2], [2, 2])) + cirq.testing.assert_has_consistent_apply_unitary(Add([2, 2], [2, 2]).on(*qubits[4])) + + with pytest.raises(ValueError, match='Wrong number of qubits'): + _ = Add(1, [2, 2]).on(*qubits[3]) + + with pytest.raises(ValueError, match='Wrong shape of qids'): + _ = Add(1, [2, 3]).on(*qubits[2]) def test_arithmetic_operation_qubits(): @@ -139,22 +208,23 @@ def with_registers(self, *new_registers): def apply(self, target_value, input_value): raise NotImplementedError() - q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6) - op = Three([q0], [], [q4, q5]) - assert op.qubits == (q0, q4, q5) - assert op.registers() == ([q0], [], [q4, q5]) + with cirq.testing.assert_deprecated(deadline='v0.15', count=4): + q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6) + op = Three([q0], [], [q4, q5]) + assert op.qubits == (q0, q4, q5) + assert op.registers() == ([q0], [], [q4, q5]) - op2 = op.with_qubits(q2, q4, q1) - assert op2.qubits == (q2, q4, q1) - assert op2.registers() == ([q2], [], [q4, q1]) + op2 = op.with_qubits(q2, q4, q1) + assert op2.qubits == (q2, q4, q1) + assert op2.registers() == ([q2], [], [q4, q1]) - op3 = op.with_registers([q0, q1, q3], [q5], 1) - assert op3.qubits == (q0, q1, q3, q5) - assert op3.registers() == ([q0, q1, q3], [q5], 1) + op3 = op.with_registers([q0, q1, q3], [q5], 1) + assert op3.qubits == (q0, q1, q3, q5) + assert op3.registers() == ([q0, q1, q3], [q5], 1) - op4 = op3.with_qubits(q0, q1, q2, q3) - assert op4.registers() == ([q0, q1, q2], [q3], 1) - assert op4.qubits == (q0, q1, q2, q3) + op4 = op3.with_qubits(q0, q1, q2, q3) + assert op4.registers() == ([q0, q1, q2], [q3], 1) + assert op4.qubits == (q0, q1, q2, q3) def test_reshape_referencing(): @@ -168,6 +238,7 @@ def registers(self): def with_registers(self, *new_registers): raise NotImplementedError() - state = np.ones(4, dtype=np.complex64) / 2 - output = cirq.final_state_vector(cirq.Circuit(Op1()), initial_state=state) - np.testing.assert_allclose(state, output) + with cirq.testing.assert_deprecated(deadline='v0.15'): + state = np.ones(4, dtype=np.complex64) / 2 + output = cirq.final_state_vector(cirq.Circuit(Op1()), initial_state=state) + np.testing.assert_allclose(state, output) diff --git a/examples/examples_test.py b/examples/examples_test.py index d5918e3f021..f7603438671 100644 --- a/examples/examples_test.py +++ b/examples/examples_test.py @@ -155,46 +155,38 @@ def test_example_noisy_simulation(): def test_example_shor_modular_exp_register_size(): with pytest.raises(ValueError): - _ = examples.shor.ModularExp( - target=cirq.LineQubit.range(2), exponent=cirq.LineQubit.range(2, 5), base=4, modulus=5 - ) + _ = examples.shor.ModularExp(target=[2, 2], exponent=[2, 2, 2], base=4, modulus=5) def test_example_shor_modular_exp_register_type(): - operation = examples.shor.ModularExp( - target=cirq.LineQubit.range(3), exponent=cirq.LineQubit.range(3, 5), base=4, modulus=5 - ) + operation = examples.shor.ModularExp(target=[2, 2, 2], exponent=[2, 2], base=4, modulus=5) with pytest.raises(ValueError): - _ = operation.with_registers(cirq.LineQubit.range(3)) + _ = operation.with_registers([2, 2, 2]) with pytest.raises(ValueError): - _ = operation.with_registers(1, cirq.LineQubit.range(3, 6), 4, 5) + _ = operation.with_registers(1, [2, 2, 2], 4, 5) with pytest.raises(ValueError): - _ = operation.with_registers( - cirq.LineQubit.range(3), cirq.LineQubit.range(3, 6), cirq.LineQubit.range(6, 9), 5 - ) + _ = operation.with_registers([2, 2, 2], [2, 2, 2], [2, 2, 2], 5) with pytest.raises(ValueError): - _ = operation.with_registers( - cirq.LineQubit.range(3), cirq.LineQubit.range(3, 6), 4, cirq.LineQubit.range(6, 9) - ) + _ = operation.with_registers([2, 2, 2], [2, 2, 2], 4, [2, 2, 2]) def test_example_shor_modular_exp_registers(): - target = cirq.LineQubit.range(3) - exponent = cirq.LineQubit.range(3, 5) + target = [2, 2, 2] + exponent = [2, 2] operation = examples.shor.ModularExp(target, exponent, 4, 5) assert operation.registers() == (target, exponent, 4, 5) - new_target = cirq.LineQubit.range(5, 8) - new_exponent = cirq.LineQubit.range(8, 12) + new_target = [2, 2, 2] + new_exponent = [2, 2, 2, 2] new_operation = operation.with_registers(new_target, new_exponent, 6, 7) assert new_operation.registers() == (new_target, new_exponent, 6, 7) def test_example_shor_modular_exp_diagram(): - target = cirq.LineQubit.range(3) - exponent = cirq.LineQubit.range(3, 5) - operation = examples.shor.ModularExp(target, exponent, 4, 5) - circuit = cirq.Circuit(operation) + target = [2, 2, 2] + exponent = [2, 2] + gate = examples.shor.ModularExp(target, exponent, 4, 5) + circuit = cirq.Circuit(gate.on(*cirq.LineQubit.range(5))) cirq.testing.assert_has_diagram( circuit, """ @@ -210,8 +202,8 @@ def test_example_shor_modular_exp_diagram(): """, ) - operation = operation.with_registers(target, 2, 4, 5) - circuit = cirq.Circuit(operation) + gate = gate.with_registers(target, 2, 4, 5) + circuit = cirq.Circuit(gate.on(*cirq.LineQubit.range(3))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/examples/shor.py b/examples/shor.py index 50283ff1c4d..bde92625dd3 100644 --- a/examples/shor.py +++ b/examples/shor.py @@ -51,8 +51,7 @@ import fractions import math import random - -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union import sympy @@ -101,7 +100,7 @@ def naive_order_finder(x: int, n: int) -> Optional[int]: return r -class ModularExp(cirq.ArithmeticOperation): +class ModularExp(cirq.ArithmeticGate): """Quantum modular exponentiation. This class represents the unitary which multiplies base raised to exponent @@ -129,11 +128,7 @@ class ModularExp(cirq.ArithmeticOperation): """ def __init__( - self, - target: Sequence[cirq.Qid], - exponent: Union[int, Sequence[cirq.Qid]], - base: int, - modulus: int, + self, target: Sequence[int], exponent: Union[int, Sequence[int]], base: int, modulus: int ) -> None: if len(target) < modulus.bit_length(): raise ValueError( @@ -144,10 +139,10 @@ def __init__( self.base = base self.modulus = modulus - def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]: + def registers(self) -> Sequence[Union[int, Sequence[int]]]: return self.target, self.exponent, self.base, self.modulus - def with_registers(self, *new_registers: Union[int, Sequence['cirq.Qid']]) -> 'ModularExp': + def with_registers(self, *new_registers: Union[int, Sequence[int]]) -> 'ModularExp': if len(new_registers) != 4: raise ValueError( f'Expected 4 registers (target, exponent, base, ' @@ -171,22 +166,12 @@ def apply(self, *register_values: int) -> int: def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: assert args.known_qubits is not None - wire_symbols: List[str] = [] - t, e = 0, 0 - for qubit in args.known_qubits: - if qubit in self.target: - if t == 0: - if isinstance(self.exponent, Sequence): - e_str = 'e' - else: - e_str = str(self.exponent) - wire_symbols.append(f'ModularExp(t*{self.base}**{e_str} % {self.modulus})') - else: - wire_symbols.append('t' + str(t)) - t += 1 - if isinstance(self.exponent, Sequence) and qubit in self.exponent: - wire_symbols.append('e' + str(e)) - e += 1 + wire_symbols = [f't{i}' for i in range(len(self.target))] + e_str = str(self.exponent) + if isinstance(self.exponent, Sequence): + e_str = 'e' + wire_symbols += [f'e{i}' for i in range(len(self.exponent))] + wire_symbols[0] = f'ModularExp(t*{self.base}**{e_str} % {self.modulus})' return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols)) @@ -224,7 +209,7 @@ def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit: return cirq.Circuit( cirq.X(target[L - 1]), cirq.H.on_each(*exponent), - ModularExp(target, exponent, x, n), + ModularExp([2] * len(target), [2] * len(exponent), x, n).on(*target + exponent), cirq.qft(*exponent, inverse=True), cirq.measure(*exponent, key='exponent'), )