From 10276006599898c1d4203c9739b30d9ac29665a0 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Feb 2022 17:19:08 -0800 Subject: [PATCH 1/6] Add convert_to_target_gateset transformer and CompilationTargetGateset interface --- cirq-core/cirq/__init__.py | 3 + .../cirq/protocols/decompose_protocol.py | 6 +- .../cirq/protocols/decompose_protocol_test.py | 3 +- .../cirq/protocols/json_test_data/spec.py | 2 + cirq-core/cirq/transformers/__init__.py | 9 + .../transformers/convert_to_target_gateset.py | 129 +++++++++++ .../convert_to_target_gateset_test.py | 200 ++++++++++++++++++ .../transformers/target_gatesets/__init__.py | 17 ++ .../compilation_target_gateset.py | 98 +++++++++ .../compilation_target_gateset_test.py | 53 +++++ 10 files changed, 517 insertions(+), 3 deletions(-) create mode 100644 cirq-core/cirq/transformers/convert_to_target_gateset.py create mode 100644 cirq-core/cirq/transformers/convert_to_target_gateset_test.py create mode 100644 cirq-core/cirq/transformers/target_gatesets/__init__.py create mode 100644 cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py create mode 100644 cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index b656f1c48b8..2c868e4454f 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -355,11 +355,14 @@ from cirq.transformers import ( align_left, align_right, + CompilationTargetGateset, compute_cphase_exponents_for_fsim_decomposition, + convert_to_target_gateset, decompose_clifford_tableau_to_operations, decompose_cphase_into_two_fsim, decompose_multi_controlled_x, decompose_multi_controlled_rotation, + decompose_operations_to_target_gateset, decompose_two_qubit_interaction_into_four_fsim_gates, defer_measurements, dephase_measurements, diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index a9e3ca93307..dda931d1772 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -180,7 +180,11 @@ def decompose( that doesn't satisfy the given `keep` predicate. """ - if on_stuck_raise is not _value_error_describing_bad_operation and keep is None: + if ( + on_stuck_raise is not _value_error_describing_bad_operation + and on_stuck_raise is not None + and keep is None + ): raise ValueError( "Must specify 'keep' if specifying 'on_stuck_raise', because it's " "not possible to get stuck if you don't have a criteria on what's " diff --git a/cirq-core/cirq/protocols/decompose_protocol_test.py b/cirq-core/cirq/protocols/decompose_protocol_test.py index aca655f828e..c4c9bce3631 100644 --- a/cirq-core/cirq/protocols/decompose_protocol_test.py +++ b/cirq-core/cirq/protocols/decompose_protocol_test.py @@ -182,6 +182,7 @@ def test_decompose_on_stuck_raise(): _ = cirq.decompose(NoMethod(), keep=lambda _: False) # Unless there's no operations to be unhappy about. assert cirq.decompose([], keep=lambda _: False) == [] + assert cirq.decompose([], on_stuck_raise=None) == [] # Or you say you're fine. assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=None) == [no_method] assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=lambda _: None) == [ @@ -198,8 +199,6 @@ def test_decompose_on_stuck_raise(): ) # There's a nice warning if you specify `on_stuck_raise` but not `keep`. - with pytest.raises(ValueError, match='on_stuck_raise'): - assert cirq.decompose([], on_stuck_raise=None) with pytest.raises(ValueError, match='on_stuck_raise'): assert cirq.decompose([], on_stuck_raise=TypeError('x')) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 9df089bc156..f18095db8db 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -92,6 +92,8 @@ 'ApplyMixtureArgs', 'ApplyUnitaryArgs', 'OperationTarget', + # Abstract base class for creating compilation targets. + 'CompilationTargetGateset', # Circuit optimizers are function-like. Only attributes # are ignore_failures, tolerance, and other feature flags 'AlignLeft', diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 8033c6315c2..cc7cf44af3e 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -41,6 +41,10 @@ two_qubit_gate_product_tabulation, ) +from cirq.transformers.target_gatesets import ( + CompilationTargetGateset, +) + from cirq.transformers.align import align_left, align_right from cirq.transformers.stratify import stratified_circuit @@ -49,6 +53,11 @@ from cirq.transformers.eject_phased_paulis import eject_phased_paulis +from cirq.transformers.convert_to_target_gateset import ( + convert_to_target_gateset, + decompose_operations_to_target_gateset, +) + from cirq.transformers.drop_empty_moments import drop_empty_moments from cirq.transformers.drop_negligible_operations import drop_negligible_operations diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset.py b/cirq-core/cirq/transformers/convert_to_target_gateset.py new file mode 100644 index 00000000000..ecc63ece184 --- /dev/null +++ b/cirq-core/cirq/transformers/convert_to_target_gateset.py @@ -0,0 +1,129 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Callable, TYPE_CHECKING + +from cirq import protocols +from cirq.transformers import transformer_api, transformer_primitives +from cirq.protocols.decompose_protocol import DecomposeResult + +if TYPE_CHECKING: + import cirq + + +def _create_on_stuck_raise_error(gateset: 'cirq.Gateset'): + def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: + return ValueError(f"Unable to convert {op} to target gateset {gateset!r}") + + return _value_error_describing_bad_operation + + +@transformer_api.transformer +def decompose_operations_to_target_gateset( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + gateset: Optional['cirq.Gateset'] = None, + decomposer: Callable[['cirq.Operation', int], DecomposeResult] = lambda *_: NotImplemented, + ignore_failures=True, +) -> 'cirq.Circuit': + """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`. + + This transformer attempts to decompose every operation `op` in the given circuit to `gateset` + using `cirq.decompose` protocol with `decomposer` used as an intercepting decomposer. This + ensures that `op` is recursively decomposed using implicitly defined known decompositions + (eg: in `_decompose_` magic method on the gaet class) till either `decomposer` knows how to + decompose the given operation or the given operation belongs to `gateset`. + + Args: + circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + gateset: Target gateset, which the decomposed operations should belong to. + decomposer: A callable type which accepts an (operation, moment_index) and returns + - An equivalent `cirq.OP_TREE` implementing `op` using gates from `gateset`. + - `None` or `NotImplemented` if does not know how to decompose a given `op`. + ignore_failures: If set, operations that fail to convert are left unchanged. If not set, + conversion failures raise a TypeError. + + Returns: + An equivalent circuit containing gates accepted by `gateset`. + + Raises: + TypeError: If any input operation fails to convert and `ignore_failures` is False. + """ + + def map_func(op: 'cirq.Operation', moment_index: int): + return protocols.decompose( + op, + intercepting_decomposer=lambda o: decomposer(o, moment_index), + keep=gateset.validate if gateset else None, + on_stuck_raise=( + None + if ignore_failures or gateset is None + else _create_on_stuck_raise_error(gateset) + ), + ) + + return transformer_primitives.map_operations_and_unroll( + circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () + ).unfreeze(copy=False) + + +@transformer_api.transformer +def convert_to_target_gateset( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + gateset: Optional['cirq.CompilationTargetGateset'] = None, + ignore_failures: bool = True, +) -> 'cirq.Circuit': + """Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`. + + 1. Run all `gateset.preprocess_transformers` + 2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`. + 3. Run all `gateset.postprocess_transformers` + + Args: + circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`. + ignore_failures: If set, operations that fail to convert are left unchanged. If not set, + conversion failures raise a TypeError. + + Returns: + An equivalent circuit containing gates accepted by `gateset`. + + Raises: + TypeError: If any input operation fails to convert and `ignore_failures` is False. + """ + if gateset is None: + return decompose_operations_to_target_gateset( + circuit, context=context, ignore_failures=ignore_failures + ) + + for transformer in gateset.preprocess_transformers: + circuit = transformer(circuit, context=context) + + circuit = decompose_operations_to_target_gateset( + circuit, + context=context, + gateset=gateset, + decomposer=gateset.decompose_to_target_gateset, + ignore_failures=ignore_failures, + ) + + for transformer in gateset.postprocess_transformers: + circuit = transformer(circuit, context=context) + + return circuit.unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset_test.py b/cirq-core/cirq/transformers/convert_to_target_gateset_test.py new file mode 100644 index 00000000000..c3f99ae389f --- /dev/null +++ b/cirq-core/cirq/transformers/convert_to_target_gateset_test.py @@ -0,0 +1,200 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file + +import cirq +from cirq.protocols.decompose_protocol import DecomposeResult +import pytest + + +def test_decompose_operations_raises_on_stuck(): + c_orig = cirq.Circuit(cirq.X(cirq.NamedQubit("q")).with_tags("ignore")) + gateset = cirq.Gateset(cirq.Y) + with pytest.raises(ValueError, match="Unable to convert"): + _ = cirq.decompose_operations_to_target_gateset( + c_orig, gateset=gateset, ignore_failures=False + ) + + # Gates marked with a no-compile tag are completely ignored. + c_new = cirq.decompose_operations_to_target_gateset( + c_orig, + context=cirq.TransformerContext(tags_to_ignore=("ignore",)), + gateset=gateset, + ignore_failures=False, + ) + cirq.testing.assert_same_circuits(c_orig, c_new) + + +def test_decompose_operations_to_target_gateset_default(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + cirq.measure(q[0], key="m"), + cirq.X(q[1]).with_classical_controls("m"), + cirq.Moment(cirq.T.on_each(*q)), + cirq.SWAP(*q), + cirq.T.on_each(*q), + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───T───×───T───×['ignore']───M───────T───×───T─── + │ │ ║ │ +1: ───────×───────×─────────────╫───X───T───×───T─── + ║ ║ +m: ═════════════════════════════@═══^═══════════════''', + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.decompose_operations_to_target_gateset(c_orig, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───T────────────@───Y^-0.5───@───Y^0.5────@───────────T───×['ignore']───M───────T────────────@───Y^-0.5───@───Y^0.5────@───────────T─── + │ │ │ │ ║ │ │ │ +1: ───────Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───────×─────────────╫───X───T───Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───T─── + ║ ║ +m: ════════════════════════════════════════════════════════════════════════@═══^══════════════════════════════════════════════════════════ +''', + ) + + +def test_decompose_operations_to_target_gateset(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + cirq.measure(q[0], key="m"), + cirq.X(q[1]).with_classical_controls("m"), + cirq.Moment(cirq.T.on_each(*q)), + cirq.SWAP(*q), + cirq.T.on_each(*q), + ) + gateset = cirq.Gateset(cirq.H, cirq.CNOT) + decomposer = ( + lambda op, _: cirq.H(op.qubits[0]) + if cirq.has_unitary(op) and cirq.num_qubits(op) == 1 + else NotImplemented + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.decompose_operations_to_target_gateset( + c_orig, gateset=gateset, decomposer=decomposer, context=context + ) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───H───@───X───@───H───×['ignore']───M───────H───@───X───@───H─── + │ │ │ │ ║ │ │ │ +1: ───────X───@───X───────×─────────────╫───X───H───X───@───X───H─── + ║ ║ +m: ═════════════════════════════════════@═══^═══════════════════════''', + ) + + with pytest.raises(ValueError, match="Unable to convert"): + _ = cirq.decompose_operations_to_target_gateset( + c_orig, gateset=gateset, decomposer=decomposer, context=context, ignore_failures=False + ) + + +class MatrixGateTargetGateset(cirq.CompilationTargetGateset): + def __init__(self): + super().__init__(cirq.MatrixGate) + + @property + def num_qubits(self) -> int: + return 2 + + def decompose_to_target_gateset(self, op: 'cirq.Operation', _) -> DecomposeResult: + if cirq.num_qubits(op) != 2 or not cirq.has_unitary(op): + return NotImplemented + return cirq.MatrixGate(cirq.unitary(op), name="M").on(*op.qubits) + + +def test_convert_to_target_gateset_default(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.convert_to_target_gateset(c_orig, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───T────────────@───Y^-0.5───@───Y^0.5────@───────────T───×['ignore']─── + │ │ │ │ +1: ───────Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───────×───────────── +''', + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) + + +def test_convert_to_target_gateset(): + q = cirq.LineQubit.range(4) + c_orig = cirq.Circuit( + cirq.QuantumFourierTransformGate(4).on(*q), + cirq.Y(q[0]).with_tags("ignore"), + cirq.Y(q[1]).with_tags("ignore"), + cirq.CNOT(*q[2:]).with_tags("ignore"), + cirq.measure(*q[:2], key="m"), + cirq.CZ(*q[2:]).with_classical_controls("m"), + cirq.inverse(cirq.QuantumFourierTransformGate(4).on(*q)), + ) + + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───qft───Y['ignore']───M───────qft^-1─── + │ ║ │ +1: ───#2────Y['ignore']───M───────#2─────── + │ ║ │ +2: ───#3────@['ignore']───╫───@───#3─────── + │ │ ║ ║ │ +3: ───#4────X─────────────╫───@───#4─────── + ║ ║ +m: ═══════════════════════@═══^════════════ +''', + ) + gateset = MatrixGateTargetGateset() + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.convert_to_target_gateset(c_orig, gateset=gateset, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' + ┌────────┐ ┌────────┐ ┌────────┐ +0: ───M[1]──────────M[1]──────────────────────M[1]────Y['ignore']───M────────M[1]───────────────────────────M[1]────M[1]───M[1]─── + │ │ │ ║ │ │ │ │ +1: ───M[2]───M[1]───┼─────────────M[1]────M[1]┼───────Y['ignore']───M────────┼───M[1]───────────M[1]────M[1]┼───────┼──────M[2]─── + │ │ │ │ │ ║ │ │ │ │ │ │ +2: ──────────M[2]───M[2]───M[1]───┼───────M[2]┼───────@['ignore']───╫───@────┼───M[2]────M[1]───┼───────M[2]┼───────M[2]────────── + │ │ │ │ ║ ║ │ │ │ │ +3: ────────────────────────M[2]───M[2]────────M[2]────X─────────────╫───@────M[2]────────M[2]───M[2]────────M[2]────────────────── + ║ ║ +m: ═════════════════════════════════════════════════════════════════@═══^═════════════════════════════════════════════════════════ + └────────┘ └────────┘ └────────┘ + ''', + ) + + with pytest.raises(ValueError, match="Unable to convert"): + # Raises an error due to CCO and Measurement gate, which are not part of the gateset. + _ = cirq.convert_to_target_gateset( + c_orig, gateset=gateset, context=context, ignore_failures=False + ) diff --git a/cirq-core/cirq/transformers/target_gatesets/__init__.py b/cirq-core/cirq/transformers/target_gatesets/__init__.py new file mode 100644 index 00000000000..567c2d7c2f5 --- /dev/null +++ b/cirq-core/cirq/transformers/target_gatesets/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gatesets which can act as compilation targets in Cirq.""" + +from cirq.transformers.target_gatesets.compilation_target_gateset import CompilationTargetGateset diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py new file mode 100644 index 00000000000..d588990f89d --- /dev/null +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py @@ -0,0 +1,98 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for creating custom target gatesets which can be used for compilation.""" + +from typing import Optional, List, Hashable, TYPE_CHECKING +import abc + +from cirq import ops, protocols, _import +from cirq.protocols.decompose_protocol import DecomposeResult +from cirq.transformers import ( + merge_k_qubit_gates, + merge_single_qubit_gates, +) + +drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers') +drop_negligible = _import.LazyLoader('drop_negligible_operations', globals(), 'cirq.transformers') +expand_composite = _import.LazyLoader('expand_composite', globals(), 'cirq.transformers') + +if TYPE_CHECKING: + import cirq + + +def _create_transformer_with_kwargs(func: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER': + """Hack to capture additional keyword arguments to transformers while preserving mypy type.""" + + def transformer( + circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None + ) -> 'cirq.AbstractCircuit': + return func(circuit, context=context, **kwargs) # type: ignore + + return transformer + + +class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta): + """Abstract base class to create gatesets that can be used as targets for compilation. + + An instance of this type can be passed to transformers like `cirq.convert_to_target_gateset`, + which can transform any given circuit to contain gates accepted by this gateset. + """ + + @property + @abc.abstractmethod + def num_qubits(self) -> int: + """Maximum number of qubits on which a gate from this gateset can act upon.""" + + @abc.abstractmethod + def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) -> DecomposeResult: + """Method to rewrite the given operation using gates from this gateset. + + Args: + op: `cirq.Operation` to be rewritten using gates from this gateset. + moment_idx: Moment index where the given operation `op` occurs in a circuit. + + Returns: + - An equivalent `cirq.OP_TREE` implementing `op` using gates from this gateset. + - `None` or `NotImplemented` if does not know how to decompose `op`. + """ + + @property + def _intermediate_result_tag(self) -> Hashable: + """A tag used to identify intermediate compilation results.""" + return "_default_merged_k_qubit_unitaries" + + @property + def preprocess_transformers(self) -> List['cirq.TRANSFORMER']: + """List of transformers which should be run before decomposing individual operations.""" + return [ + _create_transformer_with_kwargs( + expand_composite.expand_composite, + no_decomp=lambda op: protocols.num_qubits(op) <= self.num_qubits, + ), + _create_transformer_with_kwargs( + merge_k_qubit_gates.merge_k_qubit_unitaries, + k=self.num_qubits, + rewriter=lambda op: op.with_tags(self._intermediate_result_tag), + ), + ] + + @property + def postprocess_transformers(self) -> List['cirq.TRANSFORMER']: + """List of transformers which should be run after decomposing individual operations.""" + return [ + merge_single_qubit_gates.merge_single_qubit_moments_to_phxz, + drop_negligible.drop_negligible_operations, + drop_empty_moments.drop_empty_moments, + ] diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py new file mode 100644 index 00000000000..0ede8d5568b --- /dev/null +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py @@ -0,0 +1,53 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import cirq +from cirq.protocols.decompose_protocol import DecomposeResult + + +def test_compilation_target_gateset(): + class DummyTargetGateset(cirq.CompilationTargetGateset): + def __init__(self): + super().__init__(cirq.AnyUnitaryGateFamily(2)) + + @property + def num_qubits(self) -> int: + return 2 + + def decompose_to_target_gateset(self, op: 'cirq.Operation', _) -> DecomposeResult: + return op if cirq.num_qubits(op) == 2 and cirq.has_unitary(op) else NotImplemented + + @property + def preprocess_transformers(self) -> List[cirq.TRANSFORMER]: + return [] + + gateset = DummyTargetGateset() + + q = cirq.LineQubit.range(2) + assert cirq.X(q[0]) not in gateset + assert cirq.CNOT(*q) in gateset + assert cirq.measure(*q) not in gateset + + assert gateset.num_qubits == 2 + assert gateset.decompose_to_target_gateset(cirq.X(q[0]), 1) is NotImplemented + assert gateset.decompose_to_target_gateset(cirq.CNOT(*q), 2) == cirq.CNOT(*q) + assert gateset.decompose_to_target_gateset(cirq.measure(*q), 3) is NotImplemented + + assert gateset.preprocess_transformers == [] + assert gateset.postprocess_transformers == [ + cirq.merge_single_qubit_moments_to_phxz, + cirq.drop_negligible_operations, + cirq.drop_empty_moments, + ] From bd3909864e5fd7f34f03e8575ade7b994d61a5c1 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Feb 2022 23:44:43 -0800 Subject: [PATCH 2/6] Address nits --- cirq-core/cirq/transformers/convert_to_target_gateset.py | 9 +++++---- .../cirq/transformers/convert_to_target_gateset_test.py | 3 +-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset.py b/cirq-core/cirq/transformers/convert_to_target_gateset.py index ecc63ece184..8a9b2d16912 100644 --- a/cirq-core/cirq/transformers/convert_to_target_gateset.py +++ b/cirq-core/cirq/transformers/convert_to_target_gateset.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Transformers to rewrite a circuit using gates from a given target gateset.""" + from typing import Optional, Callable, TYPE_CHECKING -from cirq import protocols +from cirq.protocols import decompose_protocol as dp from cirq.transformers import transformer_api, transformer_primitives -from cirq.protocols.decompose_protocol import DecomposeResult if TYPE_CHECKING: import cirq @@ -35,7 +36,7 @@ def decompose_operations_to_target_gateset( *, context: Optional['cirq.TransformerContext'] = None, gateset: Optional['cirq.Gateset'] = None, - decomposer: Callable[['cirq.Operation', int], DecomposeResult] = lambda *_: NotImplemented, + decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented, ignore_failures=True, ) -> 'cirq.Circuit': """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`. @@ -64,7 +65,7 @@ def decompose_operations_to_target_gateset( """ def map_func(op: 'cirq.Operation', moment_index: int): - return protocols.decompose( + return dp.decompose( op, intercepting_decomposer=lambda o: decomposer(o, moment_index), keep=gateset.validate if gateset else None, diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset_test.py b/cirq-core/cirq/transformers/convert_to_target_gateset_test.py index c3f99ae389f..3b31a314f08 100644 --- a/cirq-core/cirq/transformers/convert_to_target_gateset_test.py +++ b/cirq-core/cirq/transformers/convert_to_target_gateset_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: skip-file - import cirq from cirq.protocols.decompose_protocol import DecomposeResult import pytest @@ -37,6 +35,7 @@ def test_decompose_operations_raises_on_stuck(): cirq.testing.assert_same_circuits(c_orig, c_new) +# pylint: disable=line-too-long def test_decompose_operations_to_target_gateset_default(): q = cirq.LineQubit.range(2) c_orig = cirq.Circuit( From a6596f442a5e76fa891b7aad796e00ca9504fb8a Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 18 Feb 2022 10:56:14 -0800 Subject: [PATCH 3/6] Override validation_operation to not accept intermediate results --- .../compilation_target_gateset.py | 19 +++++++++++++++++++ .../compilation_target_gateset_test.py | 3 +++ 2 files changed, 22 insertions(+) diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py index d588990f89d..88edbb20cfd 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py @@ -68,6 +68,25 @@ def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) -> - `None` or `NotImplemented` if does not know how to decompose `op`. """ + def _validate_operation(self, op: 'cirq.Operation') -> bool: + """Validates whether the given `cirq.Operation` is contained in this Gateset. + + Overrides the method on the base gateset class to ensure that operations which created + as intermediate compilation results are not accepted. + For example, if a preprocessing `merge_k_qubit_unitaries` transformer merges connected + component of 2q unitaries, it should not be accepted in the gateset so that so we can + use `decompose_to_target_gateset` to determine how to expand this component. + + Args: + op: The `cirq.Operation` instance to check containment for. + + Returns: + Whether the given operation is contained in the gateset. + """ + if self._intermediate_result_tag in op.tags: + return False + return super()._validate_operation(op) + @property def _intermediate_result_tag(self) -> Hashable: """A tag used to identify intermediate compilation results.""" diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py index 0ede8d5568b..e2c56bb9ba2 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py @@ -39,6 +39,9 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]: assert cirq.X(q[0]) not in gateset assert cirq.CNOT(*q) in gateset assert cirq.measure(*q) not in gateset + circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CZ(*q), cirq.CNOT(*q), cirq.CZ(*q))) + assert circuit_op in gateset + assert circuit_op.with_tags(gateset._intermediate_result_tag) not in gateset assert gateset.num_qubits == 2 assert gateset.decompose_to_target_gateset(cirq.X(q[0]), 1) is NotImplemented From 13786d73455172db825bec267263e7c4b0f3ca03 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 18 Feb 2022 11:54:55 -0800 Subject: [PATCH 4/6] Rename transformers --- cirq-core/cirq/__init__.py | 3 +-- cirq-core/cirq/transformers/__init__.py | 5 +--- ...eset.py => optimize_for_target_gateset.py} | 10 ++++---- ...py => optimize_for_target_gateset_test.py} | 23 +++++++++---------- 4 files changed, 18 insertions(+), 23 deletions(-) rename cirq-core/cirq/transformers/{convert_to_target_gateset.py => optimize_for_target_gateset.py} (95%) rename cirq-core/cirq/transformers/{convert_to_target_gateset_test.py => optimize_for_target_gateset_test.py} (93%) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 2c868e4454f..db8d46dbaea 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -357,12 +357,10 @@ align_right, CompilationTargetGateset, compute_cphase_exponents_for_fsim_decomposition, - convert_to_target_gateset, decompose_clifford_tableau_to_operations, decompose_cphase_into_two_fsim, decompose_multi_controlled_x, decompose_multi_controlled_rotation, - decompose_operations_to_target_gateset, decompose_two_qubit_interaction_into_four_fsim_gates, defer_measurements, dephase_measurements, @@ -383,6 +381,7 @@ merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz, merge_single_qubit_moments_to_phxz, + optimize_for_target_gateset, prepare_two_qubit_state_using_cz, prepare_two_qubit_state_using_sqrt_iswap, single_qubit_matrix_to_gates, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index cc7cf44af3e..959a1a7fa1a 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -53,10 +53,7 @@ from cirq.transformers.eject_phased_paulis import eject_phased_paulis -from cirq.transformers.convert_to_target_gateset import ( - convert_to_target_gateset, - decompose_operations_to_target_gateset, -) +from cirq.transformers.optimize_for_target_gateset import optimize_for_target_gateset from cirq.transformers.drop_empty_moments import drop_empty_moments diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset.py b/cirq-core/cirq/transformers/optimize_for_target_gateset.py similarity index 95% rename from cirq-core/cirq/transformers/convert_to_target_gateset.py rename to cirq-core/cirq/transformers/optimize_for_target_gateset.py index 8a9b2d16912..1ae40afeedc 100644 --- a/cirq-core/cirq/transformers/convert_to_target_gateset.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset.py @@ -31,13 +31,13 @@ def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: @transformer_api.transformer -def decompose_operations_to_target_gateset( +def _decompose_operations_to_target_gateset( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, gateset: Optional['cirq.Gateset'] = None, decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented, - ignore_failures=True, + ignore_failures: bool = True, ) -> 'cirq.Circuit': """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`. @@ -82,7 +82,7 @@ def map_func(op: 'cirq.Operation', moment_index: int): @transformer_api.transformer -def convert_to_target_gateset( +def optimize_for_target_gateset( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, @@ -109,14 +109,14 @@ def convert_to_target_gateset( TypeError: If any input operation fails to convert and `ignore_failures` is False. """ if gateset is None: - return decompose_operations_to_target_gateset( + return _decompose_operations_to_target_gateset( circuit, context=context, ignore_failures=ignore_failures ) for transformer in gateset.preprocess_transformers: circuit = transformer(circuit, context=context) - circuit = decompose_operations_to_target_gateset( + circuit = _decompose_operations_to_target_gateset( circuit, context=context, gateset=gateset, diff --git a/cirq-core/cirq/transformers/convert_to_target_gateset_test.py b/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py similarity index 93% rename from cirq-core/cirq/transformers/convert_to_target_gateset_test.py rename to cirq-core/cirq/transformers/optimize_for_target_gateset_test.py index 3b31a314f08..e923ceb60db 100644 --- a/cirq-core/cirq/transformers/convert_to_target_gateset_test.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset_test.py @@ -14,6 +14,7 @@ import cirq from cirq.protocols.decompose_protocol import DecomposeResult +from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset import pytest @@ -21,12 +22,10 @@ def test_decompose_operations_raises_on_stuck(): c_orig = cirq.Circuit(cirq.X(cirq.NamedQubit("q")).with_tags("ignore")) gateset = cirq.Gateset(cirq.Y) with pytest.raises(ValueError, match="Unable to convert"): - _ = cirq.decompose_operations_to_target_gateset( - c_orig, gateset=gateset, ignore_failures=False - ) + _ = _decompose_operations_to_target_gateset(c_orig, gateset=gateset, ignore_failures=False) # Gates marked with a no-compile tag are completely ignored. - c_new = cirq.decompose_operations_to_target_gateset( + c_new = _decompose_operations_to_target_gateset( c_orig, context=cirq.TransformerContext(tags_to_ignore=("ignore",)), gateset=gateset, @@ -59,7 +58,7 @@ def test_decompose_operations_to_target_gateset_default(): m: ═════════════════════════════@═══^═══════════════''', ) context = cirq.TransformerContext(tags_to_ignore=("ignore",)) - c_new = cirq.decompose_operations_to_target_gateset(c_orig, context=context) + c_new = _decompose_operations_to_target_gateset(c_orig, context=context) cirq.testing.assert_has_diagram( c_new, ''' @@ -92,7 +91,7 @@ def test_decompose_operations_to_target_gateset(): else NotImplemented ) context = cirq.TransformerContext(tags_to_ignore=("ignore",)) - c_new = cirq.decompose_operations_to_target_gateset( + c_new = _decompose_operations_to_target_gateset( c_orig, gateset=gateset, decomposer=decomposer, context=context ) cirq.testing.assert_has_diagram( @@ -106,7 +105,7 @@ def test_decompose_operations_to_target_gateset(): ) with pytest.raises(ValueError, match="Unable to convert"): - _ = cirq.decompose_operations_to_target_gateset( + _ = _decompose_operations_to_target_gateset( c_orig, gateset=gateset, decomposer=decomposer, context=context, ignore_failures=False ) @@ -125,7 +124,7 @@ def decompose_to_target_gateset(self, op: 'cirq.Operation', _) -> DecomposeResul return cirq.MatrixGate(cirq.unitary(op), name="M").on(*op.qubits) -def test_convert_to_target_gateset_default(): +def test_optimize_for_target_gateset_default(): q = cirq.LineQubit.range(2) c_orig = cirq.Circuit( cirq.T(q[0]), @@ -134,7 +133,7 @@ def test_convert_to_target_gateset_default(): cirq.SWAP(*q).with_tags("ignore"), ) context = cirq.TransformerContext(tags_to_ignore=("ignore",)) - c_new = cirq.convert_to_target_gateset(c_orig, context=context) + c_new = cirq.optimize_for_target_gateset(c_orig, context=context) cirq.testing.assert_has_diagram( c_new, ''' @@ -146,7 +145,7 @@ def test_convert_to_target_gateset_default(): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) -def test_convert_to_target_gateset(): +def test_optimize_for_target_gateset(): q = cirq.LineQubit.range(4) c_orig = cirq.Circuit( cirq.QuantumFourierTransformGate(4).on(*q), @@ -174,7 +173,7 @@ def test_convert_to_target_gateset(): ) gateset = MatrixGateTargetGateset() context = cirq.TransformerContext(tags_to_ignore=("ignore",)) - c_new = cirq.convert_to_target_gateset(c_orig, gateset=gateset, context=context) + c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context) cirq.testing.assert_has_diagram( c_new, ''' @@ -194,6 +193,6 @@ def test_convert_to_target_gateset(): with pytest.raises(ValueError, match="Unable to convert"): # Raises an error due to CCO and Measurement gate, which are not part of the gateset. - _ = cirq.convert_to_target_gateset( + _ = cirq.optimize_for_target_gateset( c_orig, gateset=gateset, context=context, ignore_failures=False ) From 9cf8908fb3ca36cccbc10dc9d4ba591f094c38e6 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 18 Feb 2022 11:57:07 -0800 Subject: [PATCH 5/6] Fix typo --- cirq-core/cirq/transformers/optimize_for_target_gateset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/optimize_for_target_gateset.py b/cirq-core/cirq/transformers/optimize_for_target_gateset.py index 1ae40afeedc..b533081c87f 100644 --- a/cirq-core/cirq/transformers/optimize_for_target_gateset.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset.py @@ -61,7 +61,7 @@ def _decompose_operations_to_target_gateset( An equivalent circuit containing gates accepted by `gateset`. Raises: - TypeError: If any input operation fails to convert and `ignore_failures` is False. + ValueError: If any input operation fails to convert and `ignore_failures` is False. """ def map_func(op: 'cirq.Operation', moment_index: int): @@ -100,13 +100,13 @@ def optimize_for_target_gateset( context: `cirq.TransformerContext` storing common configurable options for transformers. gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`. ignore_failures: If set, operations that fail to convert are left unchanged. If not set, - conversion failures raise a TypeError. + conversion failures raise a ValueError. Returns: An equivalent circuit containing gates accepted by `gateset`. Raises: - TypeError: If any input operation fails to convert and `ignore_failures` is False. + ValueError: If any input operation fails to convert and `ignore_failures` is False. """ if gateset is None: return _decompose_operations_to_target_gateset( From cd2a0da6258ec09e554d5d8f6cd9f9c303f68722 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Fri, 18 Feb 2022 11:58:17 -0800 Subject: [PATCH 6/6] Typo --- cirq-core/cirq/transformers/optimize_for_target_gateset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/transformers/optimize_for_target_gateset.py b/cirq-core/cirq/transformers/optimize_for_target_gateset.py index b533081c87f..d028366db14 100644 --- a/cirq-core/cirq/transformers/optimize_for_target_gateset.py +++ b/cirq-core/cirq/transformers/optimize_for_target_gateset.py @@ -55,7 +55,7 @@ def _decompose_operations_to_target_gateset( - An equivalent `cirq.OP_TREE` implementing `op` using gates from `gateset`. - `None` or `NotImplemented` if does not know how to decompose a given `op`. ignore_failures: If set, operations that fail to convert are left unchanged. If not set, - conversion failures raise a TypeError. + conversion failures raise a ValueError. Returns: An equivalent circuit containing gates accepted by `gateset`.