diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index 7f28231becf..08767d86258 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -139,7 +139,9 @@ def __repr__(self) -> str: def __str__(self) -> str: return 'asymmetric_depolarize(' + f"error_probabilities={self._error_probabilities})" - def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> str: + def _circuit_diagram_info_( + self, args: 'protocols.CircuitDiagramInfoArgs' + ) -> Union[str, Iterable[str]]: if self._num_qubits == 1: if args.precision is not None: return ( @@ -154,7 +156,9 @@ def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> st ] else: error_probabilities = [f"{pauli}:{p}" for pauli, p in self._error_probabilities.items()] - return f"A({', '.join(error_probabilities)})" + return [f"A({', '.join(error_probabilities)})"] + [ + f'({i})' for i in range(1, self._num_qubits) + ] @property def p_i(self) -> float: @@ -193,13 +197,9 @@ def _json_dict_(self) -> Dict[str, Any]: return protocols.obj_to_dict_helper(self, ['error_probabilities']) def _approx_eq_(self, other: Any, atol: float) -> bool: - return ( - self._num_qubits == other._num_qubits - and np.isclose(self.p_i, other.p_i, atol=atol).item() - and np.isclose(self.p_x, other.p_x, atol=atol).item() - and np.isclose(self.p_y, other.p_y, atol=atol).item() - and np.isclose(self.p_z, other.p_z, atol=atol).item() - ) + self_keys, self_values = zip(*sorted(self.error_probabilities.items())) + other_keys, other_values = zip(*sorted(other.error_probabilities.items())) + return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol) def asymmetric_depolarize( diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index 519a455a9a3..5e9a476e31d 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -805,6 +805,38 @@ def test_multi_asymmetric_depolarizing_channel_repr(): ) +def test_multi_asymmetric_depolarizing_eq(): + a = cirq.asymmetric_depolarize(error_probabilities={'I': 0.8, 'X': 0.2}) + b = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2}) + + assert not cirq.approx_eq(a, b) + + a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2}) + b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3}) + + assert not cirq.approx_eq(a, b) + + a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'ZZ': 1 / 3}) + b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3}) + + assert not cirq.approx_eq(a, b) + + a = cirq.asymmetric_depolarize(0.1, 0.2) + b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3}) + + assert not cirq.approx_eq(a, b) + + a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333}) + b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3}) + + assert cirq.approx_eq(a, b, atol=1e-3) + + a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333}) + b = cirq.asymmetric_depolarize(error_probabilities={'XX': 1 / 3, 'II': 2 / 3}) + + assert cirq.approx_eq(a, b, atol=1e-3) + + def test_multi_asymmetric_depolarizing_channel_str(): assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == ( "asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})" @@ -814,16 +846,16 @@ def test_multi_asymmetric_depolarizing_channel_str(): def test_multi_asymmetric_depolarizing_channel_text_diagram(): a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3}) assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo( - wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',) + wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)') ) assert cirq.circuit_diagram_info(a, args=round_to_6_prec) == cirq.CircuitDiagramInfo( - wire_symbols=('A(II:0.666667, XX:0.333333)',) + wire_symbols=('A(II:0.666667, XX:0.333333)', '(1)') ) assert cirq.circuit_diagram_info(a, args=round_to_2_prec) == cirq.CircuitDiagramInfo( - wire_symbols=('A(II:0.67, XX:0.33)',) + wire_symbols=('A(II:0.67, XX:0.33)', '(1)') ) assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo( - wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',) + wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)') )