From 97979a02a2f90aadfa3454346584e7f62a12b653 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 5 Dec 2023 16:10:54 -0800 Subject: [PATCH] Check value equality first when comparing circuits --- cirq-core/cirq/circuits/circuit.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 97777edcd24..f0329266bdb 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -29,14 +29,14 @@ AbstractSet, Any, Callable, - Mapping, - MutableSequence, cast, Dict, FrozenSet, Iterable, Iterator, List, + Mapping, + MutableSequence, Optional, overload, Sequence, @@ -58,9 +58,9 @@ from cirq.circuits._bucket_priority_queue import BucketPriorityQueue from cirq.circuits.circuit_operation import CircuitOperation from cirq.circuits.insert_strategy import InsertStrategy +from cirq.circuits.moment import Moment from cirq.circuits.qasm_output import QasmOutput from cirq.circuits.text_diagram_drawer import TextDiagramDrawer -from cirq.circuits.moment import Moment from cirq.protocols import circuit_diagram_info_protocol from cirq.type_workarounds import NotImplementedType @@ -203,19 +203,24 @@ def unfreeze(self, copy: bool = True) -> 'cirq.Circuit': copy: If True and 'self' is a Circuit, returns a copy that circuit. """ - def __bool__(self): + def __bool__(self) -> bool: return bool(self.moments) - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, AbstractCircuit): return NotImplemented - return tuple(self.moments) == tuple(other.moments) + return other is self or ( + len(self.moments) == len(other.moments) + and all(m0 == m1 for m0, m1 in zip(self.moments, other.moments)) + ) def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool: """See `cirq.protocols.SupportsApproximateEquality`.""" if not isinstance(other, AbstractCircuit): return NotImplemented - return cirq.protocols.approx_eq(tuple(self.moments), tuple(other.moments), atol=atol) + return other is self or cirq.protocols.approx_eq( + tuple(self.moments), tuple(other.moments), atol=atol + ) def __ne__(self, other) -> bool: return not self == other