Skip to content

Commit

Permalink
Make _commutes_ consistent (quantumlib#5217)
Browse files Browse the repository at this point in the history
- Requires atol to be a named parameter.
- Also changes atol to be uniformly float around the codebase.
  (not sure why it would be int, are people using an atol=1?)
- Technically a breaking change, but it's unlikely people are using
  this widely as most commutes do not even use atol.

Fixes: quantumlib#3695
  • Loading branch information
dstrain115 authored and rht committed May 1, 2023
1 parent 4948e1c commit a5043ff
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 18 deletions.
4 changes: 1 addition & 3 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,7 @@ def cleanup_key(key: Any) -> Any:

return diagram.render()

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
"""Determines whether Moment commutes with the Operation.
Args:
Expand Down
4 changes: 1 addition & 3 deletions cirq-core/cirq/contrib/acquaintance/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def __repr__(self) -> str:
def _value_equality_values_(self) -> Any:
return (self.swap_gate,)

def _commutes_(
self, other: Any, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
if (
isinstance(other, ops.Gate)
and isinstance(other, ops.InterchangeableQubitsGate)
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/clifford_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate':

return SingleQubitCliffordGate.from_clifford_tableau(self.clifford_tableau.inverse())

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float = 1e-8) -> Union[bool, NotImplementedType]:
if isinstance(other, SingleQubitCliffordGate):
return self.commutes_with_single_qubit_gate(other)
if isinstance(other, Pauli):
Expand Down Expand Up @@ -838,7 +838,9 @@ def __pow__(self, exponent) -> 'CliffordGate':
def __repr__(self) -> str:
return f"Clifford Gate with Tableau:\n {self.clifford_tableau._str_full_()}"

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
# Note even if we assume two gates define the tabluea based on the same qubit order,
# the following approach cannot judge it:
# self.clifford_tableau.then(other.clifford_tableau) == other.clifford_tableau.then(
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def __repr__(self) -> str:
)

def _commutes_on_qids_(
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
from cirq.ops.parity_gates import ZZPowGate

Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,9 @@ def __repr__(self) -> str:
f'coefficient={proper_repr(self.coefficient)})'
)

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if isinstance(other, BaseDensePauliString):
n = min(len(self.pauli_mask), len(other.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n])
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return NotImplemented

def _commutes_(
self, other: Any, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
commutes = self.gate._commutes_on_qids_(self.qubits, other, atol=atol)
if commutes is not NotImplemented:
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/pauli_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __init__(self, index: int, name: str) -> None:
def num_qubits(self):
return 1

def _commutes_(self, other: Any, atol: float) -> Union[bool, NotImplementedType, None]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if not isinstance(other, Pauli):
return NotImplemented
return self is other
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def zip_paulis(
return (paulis for qubit, paulis in self.zip_items(other))

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
if not isinstance(other, PauliString):
return NotImplemented
Expand Down
10 changes: 6 additions & 4 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,13 @@ def _qid_shape_(self) -> Tuple[int, ...]:
"""

def _commutes_on_qids_(
self, qids: 'Sequence[cirq.Qid]', other: Any, atol: float
self, qids: 'Sequence[cirq.Qid]', other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
return NotImplemented

def _commutes_(self, other: Any, atol: float) -> Union[None, NotImplementedType, bool]:
def _commutes_(
self, other: Any, *, atol: float = 1e-8
) -> Union[None, NotImplementedType, bool]:
if not isinstance(other, Gate):
return NotImplemented
if protocols.qid_shape(self) != protocols.qid_shape(other):
Expand Down Expand Up @@ -567,7 +569,7 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
_validate_qid_shape(self, qubits)

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
"""Determine if this Operation commutes with the object"""
if not isinstance(other, Operation):
Expand Down Expand Up @@ -771,7 +773,7 @@ def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return protocols.unitary(self.sub_operation, NotImplemented)

def _commutes_(
self, other: Any, *, atol: Union[int, float] = 1e-8
self, other: Any, *, atol: float = 1e-8
) -> Union[bool, NotImplementedType, None]:
return protocols.commutes(self.sub_operation, other, atol=atol)

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/commutes_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SupportsCommutes(Protocol):
"""An object that can determine commutation relationships vs others."""

@doc_private
def _commutes_(self, other: Any, atol: float) -> Union[None, bool, NotImplementedType]:
def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImplementedType]:
r"""Determines if this object commutes with the other object.
Can return None to indicate the commutation relationship is
Expand Down

0 comments on commit a5043ff

Please sign in to comment.