Skip to content

Commit

Permalink
Enforce same control order in ControlledGate equality check (quantuml…
Browse files Browse the repository at this point in the history
…ib#5131)

**Breaking Change**

Changes ControlledGate equality check to enforce gates have same order.

Fixes quantumlib#5110
  • Loading branch information
daxfohl authored and tonybruguier committed Apr 14, 2022
1 parent 0f0ab76 commit c62e4a3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_specialized_control(input_gate, specialized_output):
)
assert input_gate.controlled(control_qid_shape=(2,)).controlled(
control_qid_shape=(3,)
).controlled(control_qid_shape=(4,)) == specialized_output.controlled(
).controlled(control_qid_shape=(4,)) != specialized_output.controlled(
num_controls=2, control_qid_shape=(3, 4)
)

Expand All @@ -175,7 +175,7 @@ def test_specialized_control(input_gate, specialized_output):
)
assert input_gate.controlled(control_qid_shape=(3,)).controlled(
control_qid_shape=(2,)
).controlled(control_qid_shape=(4,)) == cirq.ControlledGate(
).controlled(control_qid_shape=(4,)) != cirq.ControlledGate(
input_gate, num_controls=3, control_qid_shape=(3, 2, 4)
)

Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def _value_equality_values_(self):
return (
self.sub_gate,
self.num_controls(),
frozenset(zip(self.control_values, self.control_qid_shape)),
self.control_values,
self.control_qid_shape,
)

def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
Expand Down
19 changes: 19 additions & 0 deletions cirq-core/cirq/ops/controlled_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,29 @@ def test_eq():
eq.add_equality_group(cirq.X)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[1, (0, 2)], control_qid_shape=[2, 3]),
cirq.ControlledGate(cirq.H, control_values=(1, [0, 2]), control_qid_shape=(2, 3)),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[(2, 0), 1], control_qid_shape=[3, 2]),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[1, 0], control_qid_shape=[2, 3]),
cirq.ControlledGate(cirq.H, control_values=(1, 0), control_qid_shape=(2, 3)),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[0, 1], control_qid_shape=[3, 2]),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[1, 0]),
cirq.ControlledGate(cirq.H, control_values=(1, 0)),
)
eq.add_equality_group(
cirq.ControlledGate(cirq.H, control_values=[0, 1]),
)
for group in eq._groups:
if isinstance(group[0], cirq.Gate):
for item in group:
np.testing.assert_allclose(cirq.unitary(item), cirq.unitary(group[0]))


def test_control():
Expand Down Expand Up @@ -266,11 +279,17 @@ def test_control():
eq.add_equality_group(
cirq.ControlledGate(g, control_values=[0, 1]),
g.controlled(control_values=[0, 1]),
g.controlled(control_values=[1]).controlled(control_values=[0]),
)
eq.add_equality_group(
g.controlled(control_values=[0]).controlled(control_values=[1]),
)
eq.add_equality_group(
cirq.ControlledGate(g, control_qid_shape=[4, 3]),
g.controlled(control_qid_shape=[4, 3]),
g.controlled(control_qid_shape=[3]).controlled(control_qid_shape=[4]),
)
eq.add_equality_group(
g.controlled(control_qid_shape=[4]).controlled(control_qid_shape=[3]),
)

Expand Down

0 comments on commit c62e4a3

Please sign in to comment.