From 22d720a10c85a6d3d89a5475fd9ceac8e2c82ae6 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 1 Mar 2022 00:47:15 +0530 Subject: [PATCH] Deprecate cirq.ConvertToSqrtIswapGates and cirq.MergeInteractionsToSqrtIswap (#5040) --- .../cirq/optimizers/merge_interactions.py | 4 + .../optimizers/merge_interactions_test.py | 32 ++++++-- .../merge_interactions_to_sqrt_iswap.py | 6 +- .../merge_interactions_to_sqrt_iswap_test.py | 74 +++++++++++++++---- .../optimizers/convert_to_sqrt_iswap.py | 4 + .../optimizers/convert_to_sqrt_iswap_test.py | 72 ++++++++++++++---- .../optimizers/optimize_for_sycamore.py | 32 ++++---- 7 files changed, 168 insertions(+), 56 deletions(-) diff --git a/cirq-core/cirq/optimizers/merge_interactions.py b/cirq-core/cirq/optimizers/merge_interactions.py index 37acab314eb..37c567e053f 100644 --- a/cirq-core/cirq/optimizers/merge_interactions.py +++ b/cirq-core/cirq/optimizers/merge_interactions.py @@ -26,6 +26,10 @@ import cirq +@_compat.deprecated_class( + deadline='v1.0', + fix='Use cirq.optimize_for_target_gateset and cirq.CompilationTargetGateset instead.', +) class MergeInteractionsAbc(circuits.PointOptimizer, metaclass=abc.ABCMeta): """Combines series of adjacent one- and two-qubit, non-parametrized gates operating on a pair of qubits.""" diff --git a/cirq-core/cirq/optimizers/merge_interactions_test.py b/cirq-core/cirq/optimizers/merge_interactions_test.py index 2c0db675a15..d508c50a5fd 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_test.py @@ -22,7 +22,9 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit): actual = cirq.Circuit(before) - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): opt = cirq.MergeInteractions() opt.optimize_circuit(actual) @@ -46,7 +48,9 @@ def assert_optimization_not_broken(circuit): global phase and rounding error) as the unitary matrix of the optimized circuit.""" u_before = circuit.unitary() - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): cirq.MergeInteractions().optimize_circuit(circuit) u_after = circuit.unitary() @@ -161,7 +165,9 @@ def test_optimizes_single_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.ISWAP(a, b)) assert_optimization_not_broken(c) - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): cirq.MergeInteractions().optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -170,7 +176,9 @@ def test_optimizes_tagged_partial_cz(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit((cirq.CZ ** 0.5)(a, b).with_tags('mytag')) assert_optimization_not_broken(c) - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(c) assert ( len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -182,7 +190,9 @@ def test_not_decompose_czs(): cirq.CZPowGate(exponent=1, global_shift=-0.5).on(*cirq.LineQubit.range(2)) ) circ_orig = circuit.copy() - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(circuit) assert circ_orig == circuit @@ -200,7 +210,9 @@ def test_not_decompose_czs(): ), ) def test_decompose_partial_czs(circuit): - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): optimizer = cirq.MergeInteractions(allow_partial_czs=False) optimizer.optimize_circuit(circuit) @@ -219,7 +231,9 @@ def test_not_decompose_partial_czs(): circuit = cirq.Circuit( cirq.CZPowGate(exponent=0.1, global_shift=-0.5)(*cirq.LineQubit.range(2)), ) - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): optimizer = cirq.MergeInteractions(allow_partial_czs=True) optimizer.optimize_circuit(circuit) @@ -253,7 +267,9 @@ def clean_up(operations): yield operations yield Marker()(a, b) - with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): optimizer = cirq.MergeInteractions(allow_partial_czs=False, post_clean_up=clean_up) optimizer.optimize_circuit(circuit) circuit = cirq.drop_empty_moments(circuit) diff --git a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py index 06213b1a121..f58d5ec21e8 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py +++ b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap.py @@ -19,7 +19,7 @@ import numpy as np -from cirq import ops +from cirq import ops, _compat from cirq.optimizers import merge_interactions from cirq.transformers.analytical_decompositions import two_qubit_to_sqrt_iswap @@ -27,6 +27,10 @@ import cirq +@_compat.deprecated_class( + deadline='v1.0', + fix='Use cirq.optimize_for_target_gateset and cirq.SqrtIswapTargetGateset instead.', +) class MergeInteractionsToSqrtIswap(merge_interactions.MergeInteractionsAbc): """Combines series of adjacent one- and two-qubit, non-parametrized gates operating on a pair of qubits and replaces each series with the minimum diff --git a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py index 1e6316553c7..26e2f7b60b8 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py @@ -33,8 +33,11 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs): ``MergeInteractionsToSqrtIswap`` constructor. """ actual = before.copy() - opt = cirq.MergeInteractionsToSqrtIswap(**kwargs) - opt.optimize_circuit(actual) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + opt = cirq.MergeInteractionsToSqrtIswap(**kwargs) + opt.optimize_circuit(actual) # Ignore differences that would be caught by follow-up optimizations. followup_transformers: List[cirq.TRANSFORMER] = [ @@ -57,7 +60,10 @@ def assert_optimization_not_broken(circuit: cirq.Circuit, **kwargs): circuit.""" u_before = circuit.unitary(sorted(circuit.all_qubits())) c_sqrt_iswap = circuit.copy() - cirq.MergeInteractionsToSqrtIswap(**kwargs).optimize_circuit(c_sqrt_iswap) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(**kwargs).optimize_circuit(c_sqrt_iswap) u_after = c_sqrt_iswap.unitary(sorted(circuit.all_qubits())) # Not 1e-8 because of some unaccounted accumulated error in some of Cirq's linalg functions @@ -65,7 +71,12 @@ def assert_optimization_not_broken(circuit: cirq.Circuit, **kwargs): # Also test optimization with SQRT_ISWAP_INV c_sqrt_iswap_inv = circuit.copy() - cirq.MergeInteractionsToSqrtIswap(use_sqrt_iswap_inv=True).optimize_circuit(c_sqrt_iswap_inv) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(use_sqrt_iswap_inv=True).optimize_circuit( + c_sqrt_iswap_inv + ) u_after2 = c_sqrt_iswap_inv.unitary(sorted(circuit.all_qubits())) cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after2, atol=1e-6) @@ -230,7 +241,10 @@ def test_optimizes_single_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.ISWAP(a, b)) assert_optimization_not_broken(c) - cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -238,20 +252,29 @@ def test_optimizes_single_inv_sqrt_iswap(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) assert_optimization_not_broken(c) - cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap().optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1 def test_init_raises(): with pytest.raises(ValueError, match='must be 0, 1, 2, or 3'): - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=4) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=4) def test_optimizes_single_iswap_require0(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.CNOT(a, b), cirq.CNOT(a, b)) # Minimum 0 sqrt-iSWAP assert_optimization_not_broken(c, required_sqrt_iswap_count=0) - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 0 @@ -259,14 +282,20 @@ def test_optimizes_single_iswap_require0_raises(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 0 sqrt-iSWAP gates'): - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=0).optimize_circuit(c) def test_optimizes_single_iswap_require1(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP assert_optimization_not_broken(c, required_sqrt_iswap_count=1) - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 1 @@ -274,14 +303,20 @@ def test_optimizes_single_iswap_require1_raises(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.CNOT(a, b)) # Minimum 2 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 1 sqrt-iSWAP gates'): - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=1).optimize_circuit(c) def test_optimizes_single_iswap_require2(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) # Minimum 1 sqrt-iSWAP but 2 possible assert_optimization_not_broken(c, required_sqrt_iswap_count=2) - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2 @@ -289,14 +324,20 @@ def test_optimizes_single_iswap_require2_raises(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SWAP(a, b)) # Minimum 3 sqrt-iSWAP with pytest.raises(ValueError, match='cannot be decomposed into exactly 2 sqrt-iSWAP gates'): - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=2).optimize_circuit(c) def test_optimizes_single_iswap_require3(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.ISWAP(a, b)) # Minimum 2 sqrt-iSWAP but 3 possible assert_optimization_not_broken(c, required_sqrt_iswap_count=3) - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3 @@ -304,5 +345,8 @@ def test_optimizes_single_inv_sqrt_iswap_require3(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.SQRT_ISWAP_INV(a, b)) assert_optimization_not_broken(c, required_sqrt_iswap_count=3) - cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0', count=2 + ): + cirq.MergeInteractionsToSqrtIswap(required_sqrt_iswap_count=3).optimize_circuit(c) assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 3 diff --git a/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap.py b/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap.py index b9e2feccf37..199b752d999 100644 --- a/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap.py +++ b/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap.py @@ -29,6 +29,10 @@ def _near_mod_2pi(e, t, atol=1e-8): return _near_mod_n(e, t, 2 * np.pi, atol=atol) +@cirq._compat.deprecated_class( + deadline='v1.0', + fix='Use cirq.optimize_for_target_gateset and cirq.SqrtIswapTargetGateset instead.', +) class ConvertToSqrtIswapGates(cirq.PointOptimizer): """Attempts to convert gates into ISWAP**-0.5 gates. diff --git a/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap_test.py b/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap_test.py index 763d3a8207e..ba6f69371ab 100644 --- a/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap_test.py +++ b/cirq-google/cirq_google/optimizers/convert_to_sqrt_iswap_test.py @@ -10,12 +10,10 @@ def _unitaries_allclose(circuit1, circuit2): - unitary1 = cirq.unitary(circuit1) - unitary2 = cirq.unitary(circuit2) - if unitary2.size == 1: - # Resize the unitary of empty circuits to be 4x4 for 2q gates - unitary2 = unitary2 * np.eye(unitary1.shape[0]) - return cirq.allclose_up_to_global_phase(unitary1, unitary2) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit1, circuit2, atol=1e-6 + ) + return True @pytest.mark.parametrize( @@ -45,10 +43,25 @@ def test_two_qubit_gates(gate: cirq.Gate, expected_length: int): q1 = cirq.GridQubit(5, 4) original_circuit = cirq.Circuit(gate(q0, q1)) converted_circuit = original_circuit.copy() - cgoc.ConvertToSqrtIswapGates().optimize_circuit(converted_circuit) + converted_circuit_iswap_inv = cirq.optimize_for_target_gateset( + original_circuit, gateset=cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True) + ) + converted_circuit_iswap = cirq.optimize_for_target_gateset( + original_circuit, gateset=cirq.SqrtIswapTargetGateset() + ) + with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + cgoc.ConvertToSqrtIswapGates().optimize_circuit(converted_circuit) cig.SQRT_ISWAP_GATESET.serialize(converted_circuit) + cig.SQRT_ISWAP_GATESET.serialize(converted_circuit_iswap) + cig.SQRT_ISWAP_GATESET.serialize(converted_circuit_iswap_inv) assert len(converted_circuit) <= expected_length + assert ( + len(converted_circuit_iswap) <= expected_length + or len(converted_circuit_iswap_inv) <= expected_length + ) assert _unitaries_allclose(original_circuit, converted_circuit) + assert _unitaries_allclose(original_circuit, converted_circuit_iswap) + assert _unitaries_allclose(original_circuit, converted_circuit_iswap_inv) @pytest.mark.parametrize( @@ -70,8 +83,19 @@ def test_two_qubit_gates_with_symbols(gate: cirq.Gate, expected_length: int): q1 = cirq.GridQubit(5, 4) original_circuit = cirq.Circuit(gate(q0, q1)) converted_circuit = original_circuit.copy() - cgoc.ConvertToSqrtIswapGates().optimize_circuit(converted_circuit) + with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'): + cgoc.ConvertToSqrtIswapGates().optimize_circuit(converted_circuit) + converted_circuit_iswap_inv = cirq.optimize_for_target_gateset( + original_circuit, gateset=cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True) + ) + converted_circuit_iswap = cirq.optimize_for_target_gateset( + original_circuit, gateset=cirq.SqrtIswapTargetGateset() + ) assert len(converted_circuit) <= expected_length + assert ( + len(converted_circuit_iswap) <= expected_length + or len(converted_circuit_iswap_inv) <= expected_length + ) # Check if unitaries are the same for val in np.linspace(0, 2 * np.pi, 12): @@ -79,6 +103,14 @@ def test_two_qubit_gates_with_symbols(gate: cirq.Gate, expected_length: int): cirq.resolve_parameters(original_circuit, {'t': val}), cirq.resolve_parameters(converted_circuit, {'t': val}), ) + assert _unitaries_allclose( + cirq.resolve_parameters(original_circuit, {'t': val}), + cirq.resolve_parameters(converted_circuit_iswap, {'t': val}), + ) + assert _unitaries_allclose( + cirq.resolve_parameters(original_circuit, {'t': val}), + cirq.resolve_parameters(converted_circuit_iswap_inv, {'t': val}), + ) def test_cphase(): @@ -102,11 +134,22 @@ def test_givens_rotation(): program = cirq.Circuit(cirq.givens(theta).on(qubits[0], qubits[1])) unitary = cirq.unitary(program) test_program = program.copy() - cgoc.ConvertToSqrtIswapGates().optimize_circuit(test_program) - test_unitary = cirq.unitary(test_program) - np.testing.assert_allclose( - 4, np.abs(np.trace(np.conjugate(np.transpose(test_unitary)) @ unitary)) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0' + ): + cgoc.ConvertToSqrtIswapGates().optimize_circuit(test_program) + converted_circuit_iswap_inv = cirq.optimize_for_target_gateset( + test_program, gateset=cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True) + ) + converted_circuit_iswap = cirq.optimize_for_target_gateset( + test_program, gateset=cirq.SqrtIswapTargetGateset() ) + for circuit in [test_program, converted_circuit_iswap_inv, converted_circuit_iswap]: + circuit.append(cirq.IdentityGate(2).on(*qubits)) + test_unitary = cirq.unitary(circuit) + np.testing.assert_allclose( + 4, np.abs(np.trace(np.conjugate(np.transpose(test_unitary)) @ unitary)) + ) def test_three_qubit_gate(): @@ -119,4 +162,7 @@ class ThreeQubitGate(cirq.testing.ThreeQubitGate): circuit = cirq.Circuit(ThreeQubitGate()(q0, q1, q2)) with pytest.raises(TypeError): - cgoc.ConvertToSqrtIswapGates().optimize_circuit(circuit) + with cirq.testing.assert_deprecated( + "Use cirq.optimize_for_target_gateset", deadline='v1.0' + ): + cgoc.ConvertToSqrtIswapGates().optimize_circuit(circuit) diff --git a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py index 61e4bb024c0..ca8753f75da 100644 --- a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py +++ b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py @@ -22,7 +22,6 @@ from cirq_google.optimizers import ( convert_to_xmon_gates, ConvertToSycamoreGates, - ConvertToSqrtIswapGates, ) if TYPE_CHECKING: @@ -58,22 +57,18 @@ def _get_sycamore_optimizers( return [ConvertToSycamoreGates(tabulation=tabulation).optimize_circuit] -def _get_sqrt_iswap_optimizers( - tolerance: float, tabulation: Optional[cirq.TwoQubitGateTabulation] -) -> List[Callable[[cirq.Circuit], None]]: - if tabulation is not None: - # coverage: ignore - raise ValueError("Gate tabulation not supported for sqrt_iswap") - return [ConvertToSqrtIswapGates().optimize_circuit] - - _OPTIMIZER_TYPES = { 'xmon': _get_xmon_optimizers, 'xmon_partial_cz': _get_xmon_optimizers_part_cz, - 'sqrt_iswap': _get_sqrt_iswap_optimizers, 'sycamore': _get_sycamore_optimizers, } +_TARGET_GATESETS = { + 'sqrt_iswap': lambda atol, _: cirq.SqrtIswapTargetGateset(atol=atol), + 'xmon': lambda atol, _: cirq.CZTargetGateset(atol=atol), + 'xmon_partial_cz': lambda atol, _: cirq.CZTargetGateset(atol=atol, allow_partial_czs=True), +} + @lru_cache() def _gate_product_tabulation_cached( @@ -131,7 +126,7 @@ def optimized_for_sycamore( ValueError: If the `optimizer_type` is not a supported type. """ copy = circuit.copy() - if optimizer_type not in _OPTIMIZER_TYPES: + if optimizer_type not in _OPTIMIZER_TYPES and optimizer_type not in _TARGET_GATESETS: raise ValueError( f'{optimizer_type} is not an allowed type. Allowed ' f'types are: {_OPTIMIZER_TYPES.keys()}' @@ -141,16 +136,15 @@ def optimized_for_sycamore( if tabulation_resolution is not None: tabulation = _gate_product_tabulation_cached(optimizer_type, tabulation_resolution) - opts = _OPTIMIZER_TYPES[optimizer_type](tolerance=tolerance, tabulation=tabulation) - for optimizer in opts: - optimizer(copy) - if optimizer_type.startswith('xmon'): + if optimizer_type in _TARGET_GATESETS: copy = cirq.optimize_for_target_gateset( circuit, - gateset=cirq.CZTargetGateset( - atol=tolerance, allow_partial_czs=optimizer_type.endswith('partial_cz') - ), + gateset=_TARGET_GATESETS[optimizer_type](tolerance, tabulation), ) + if optimizer_type in _OPTIMIZER_TYPES: + opts = _OPTIMIZER_TYPES[optimizer_type](tolerance=tolerance, tabulation=tabulation) + for optimizer in opts: + optimizer(copy) copy = cirq.merge_single_qubit_gates_to_phxz(copy, atol=tolerance) copy = cirq.eject_phased_paulis(copy, atol=tolerance) copy = cirq.eject_z(copy, atol=tolerance)