diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 96af44ba7a2..3060eb4b18c 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -16,6 +16,7 @@ from cirq.testing.circuit_compare import ( assert_circuits_with_terminal_measurements_are_equivalent, + assert_circuits_have_same_unitary_given_final_permutation, assert_has_consistent_apply_unitary, assert_has_consistent_apply_unitary_for_various_exponents, assert_has_diagram, diff --git a/cirq-core/cirq/testing/circuit_compare.py b/cirq-core/cirq/testing/circuit_compare.py index 2636f334c52..55dd9e0daf4 100644 --- a/cirq-core/cirq/testing/circuit_compare.py +++ b/cirq-core/cirq/testing/circuit_compare.py @@ -220,6 +220,40 @@ def _first_differing_moment_index( return None # coverage: ignore +def assert_circuits_have_same_unitary_given_final_permutation( + actual: circuits.AbstractCircuit, + expected: circuits.AbstractCircuit, + qubit_map: Dict[ops.Qid, ops.Qid], +) -> None: + """Asserts two circuits have the same unitary up to a final permuation of qubits. + + Args: + actual: A circuit computed by some code under test. + expected: The circuit that should have been computed. + qubit_map: the permutation of qubits from the beginning to the end of the circuit. + + Raises: + ValueError: if 'qubit_map' is not a mapping from the qubits in 'actual' to themselves. + ValueError: if 'qubit_map' does not have the same set of keys and values. + """ + if set(qubit_map.keys()) != set(qubit_map.values()): + raise ValueError("'qubit_map' must have the same set of of keys and values.") + + if not set(qubit_map.keys()).issubset(actual.all_qubits()): + raise ValueError( + f"'qubit_map' must be a mapping of the qubits in the circuit 'actual' to themselves." + ) + + actual_cp = actual.unfreeze() + initial_qubits, sorted_qubits = zip(*sorted(qubit_map.items(), key=lambda x: x[1])) + inverse_permutation = [sorted_qubits.index(q) for q in initial_qubits] + actual_cp.append(ops.QubitPermutationGate(list(inverse_permutation)).on(*sorted_qubits)) + + lin_alg_utils.assert_allclose_up_to_global_phase( + expected.unitary(), actual_cp.unitary(), atol=1e-8 + ) + + def assert_has_diagram( actual: Union[circuits.AbstractCircuit, circuits.Moment], desired: str, **kwargs ) -> None: diff --git a/cirq-core/cirq/testing/circuit_compare_test.py b/cirq-core/cirq/testing/circuit_compare_test.py index 75e453f7f63..4c8ed7612fc 100644 --- a/cirq-core/cirq/testing/circuit_compare_test.py +++ b/cirq-core/cirq/testing/circuit_compare_test.py @@ -190,6 +190,35 @@ def test_assert_same_circuits(): ) +def test_assert_circuits_have_same_unitary_given_final_permutation(): + q = cirq.LineQubit.range(5) + expected = cirq.Circuit([cirq.Moment(cirq.CNOT(q[2], q[1]), cirq.CNOT(q[3], q[0]))]) + actual = cirq.Circuit( + [ + cirq.Moment(cirq.CNOT(q[2], q[1])), + cirq.Moment(cirq.SWAP(q[0], q[2])), + cirq.Moment(cirq.SWAP(q[0], q[1])), + cirq.Moment(cirq.CNOT(q[3], q[2])), + ] + ) + qubit_map = {q[0]: q[2], q[2]: q[1], q[1]: q[0]} + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + actual, expected, qubit_map + ) + + qubit_map.update({q[2]: q[3]}) + with pytest.raises(ValueError, match="'qubit_map' must have the same set"): + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + actual, expected, qubit_map=qubit_map + ) + + bad_qubit_map = {q[0]: q[2], q[2]: q[4], q[4]: q[0]} + with pytest.raises(ValueError, match="'qubit_map' must be a mapping"): + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + actual, expected, qubit_map=bad_qubit_map + ) + + def test_assert_has_diagram(): a, b = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.CNOT(a, b))