Skip to content

Commit

Permalink
Print multi-qubit circuit with asymmetric depolarizing noise correctly (
Browse files Browse the repository at this point in the history
#5931)

Closes #5927. Fixes based on discussion in the issue.
  • Loading branch information
paaige authored Dec 19, 2022
1 parent af1267d commit 7892143
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
18 changes: 9 additions & 9 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return 'asymmetric_depolarize(' + f"error_probabilities={self._error_probabilities})"

def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> str:
def _circuit_diagram_info_(
self, args: 'protocols.CircuitDiagramInfoArgs'
) -> Union[str, Iterable[str]]:
if self._num_qubits == 1:
if args.precision is not None:
return (
Expand All @@ -154,7 +156,9 @@ def _circuit_diagram_info_(self, args: 'protocols.CircuitDiagramInfoArgs') -> st
]
else:
error_probabilities = [f"{pauli}:{p}" for pauli, p in self._error_probabilities.items()]
return f"A({', '.join(error_probabilities)})"
return [f"A({', '.join(error_probabilities)})"] + [
f'({i})' for i in range(1, self._num_qubits)
]

@property
def p_i(self) -> float:
Expand Down Expand Up @@ -193,13 +197,9 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['error_probabilities'])

def _approx_eq_(self, other: Any, atol: float) -> bool:
return (
self._num_qubits == other._num_qubits
and np.isclose(self.p_i, other.p_i, atol=atol).item()
and np.isclose(self.p_x, other.p_x, atol=atol).item()
and np.isclose(self.p_y, other.p_y, atol=atol).item()
and np.isclose(self.p_z, other.p_z, atol=atol).item()
)
self_keys, self_values = zip(*sorted(self.error_probabilities.items()))
other_keys, other_values = zip(*sorted(other.error_probabilities.items()))
return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol)


def asymmetric_depolarize(
Expand Down
40 changes: 36 additions & 4 deletions cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,38 @@ def test_multi_asymmetric_depolarizing_channel_repr():
)


def test_multi_asymmetric_depolarizing_eq():
a = cirq.asymmetric_depolarize(error_probabilities={'I': 0.8, 'X': 0.2})
b = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})

assert not cirq.approx_eq(a, b)

a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})

assert not cirq.approx_eq(a, b)

a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'ZZ': 1 / 3})
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})

assert not cirq.approx_eq(a, b)

a = cirq.asymmetric_depolarize(0.1, 0.2)
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})

assert not cirq.approx_eq(a, b)

a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333})
b = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})

assert cirq.approx_eq(a, b, atol=1e-3)

a = cirq.asymmetric_depolarize(error_probabilities={'II': 0.667, 'XX': 0.333})
b = cirq.asymmetric_depolarize(error_probabilities={'XX': 1 / 3, 'II': 2 / 3})

assert cirq.approx_eq(a, b, atol=1e-3)


def test_multi_asymmetric_depolarizing_channel_str():
assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == (
"asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})"
Expand All @@ -814,16 +846,16 @@ def test_multi_asymmetric_depolarizing_channel_str():
def test_multi_asymmetric_depolarizing_channel_text_diagram():
a = cirq.asymmetric_depolarize(error_probabilities={'II': 2 / 3, 'XX': 1 / 3})
assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo(
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',)
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)')
)
assert cirq.circuit_diagram_info(a, args=round_to_6_prec) == cirq.CircuitDiagramInfo(
wire_symbols=('A(II:0.666667, XX:0.333333)',)
wire_symbols=('A(II:0.666667, XX:0.333333)', '(1)')
)
assert cirq.circuit_diagram_info(a, args=round_to_2_prec) == cirq.CircuitDiagramInfo(
wire_symbols=('A(II:0.67, XX:0.33)',)
wire_symbols=('A(II:0.67, XX:0.33)', '(1)')
)
assert cirq.circuit_diagram_info(a, args=no_precision) == cirq.CircuitDiagramInfo(
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)',)
wire_symbols=('A(II:0.6666666666666666, XX:0.3333333333333333)', '(1)')
)


Expand Down

0 comments on commit 7892143

Please sign in to comment.