Skip to content

Commit

Permalink
Check value equality first when comparing circuits (#6375)
Browse files Browse the repository at this point in the history
Review: @dstrain115
  • Loading branch information
maffoo authored Dec 19, 2023
1 parent 1961207 commit 181d7aa
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
AbstractSet,
Any,
Callable,
Mapping,
MutableSequence,
cast,
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Mapping,
MutableSequence,
Optional,
overload,
Sequence,
Expand All @@ -58,9 +58,9 @@
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.circuits.moment import Moment
from cirq.circuits.qasm_output import QasmOutput
from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
from cirq.circuits.moment import Moment
from cirq.protocols import circuit_diagram_info_protocol
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -203,19 +203,24 @@ def unfreeze(self, copy: bool = True) -> 'cirq.Circuit':
copy: If True and 'self' is a Circuit, returns a copy that circuit.
"""

def __bool__(self):
def __bool__(self) -> bool:
return bool(self.moments)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if not isinstance(other, AbstractCircuit):
return NotImplemented
return tuple(self.moments) == tuple(other.moments)
return other is self or (
len(self.moments) == len(other.moments)
and all(m0 == m1 for m0, m1 in zip(self.moments, other.moments))
)

def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
"""See `cirq.protocols.SupportsApproximateEquality`."""
if not isinstance(other, AbstractCircuit):
return NotImplemented
return cirq.protocols.approx_eq(tuple(self.moments), tuple(other.moments), atol=atol)
return other is self or cirq.protocols.approx_eq(
tuple(self.moments), tuple(other.moments), atol=atol
)

def __ne__(self, other) -> bool:
return not self == other
Expand Down

0 comments on commit 181d7aa

Please sign in to comment.