Skip to content

Commit

Permalink
Move cirq/ops/moment.py to cirq/circuits/moment.py (#4932)
Browse files Browse the repository at this point in the history
* Move cirq/ops/moment.py to cirq/circuits/moment.py

* Deprecate module cirq.ops.moment and cirq.ops.Moment

* Ignore mypy error on setting submodule attribute

mypy seems to interpret left-hand-side `ops.Moment` as attribute access
and complains the attribute is not set.

* Add test for `Moment._with_key_path_prefix_`

* Decouple `with_key_path_prefix` test from `with_key_path`

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
pavoljuhas and tanujkhattar authored Feb 5, 2022
1 parent 6806f6a commit 8b64834
Show file tree
Hide file tree
Showing 40 changed files with 225 additions and 163 deletions.
22 changes: 21 additions & 1 deletion cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
CircuitOperation,
FrozenCircuit,
InsertStrategy,
Moment,
PointOptimizationSummary,
PointOptimizer,
QasmOutput,
Expand Down Expand Up @@ -250,7 +251,6 @@
measure_paulistring_terms,
measure_single_paulistring,
MeasurementGate,
Moment,
MutableDensePauliString,
MutablePauliString,
NamedQubit,
Expand Down Expand Up @@ -686,4 +686,24 @@
contrib,
)

# deprecate cirq.ops.moment and related attributes

from cirq import _compat

_compat.deprecated_submodule(
new_module_name='cirq.circuits.moment',
old_parent='cirq.ops',
old_child='moment',
deadline='v0.16',
create_attribute=True,
)

ops.Moment = Moment # type: ignore
_compat.deprecate_attributes(
'cirq.ops',
{
'Moment': ('v0.16', 'Use cirq.circuits.Moment instead'),
},
)

# pylint: enable=wrong-import-position
4 changes: 4 additions & 0 deletions cirq-core/cirq/circuits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
InsertStrategy,
)

from cirq.circuits.moment import (
Moment,
)

from cirq.circuits.optimization_pass import (
PointOptimizer,
PointOptimizationSummary,
Expand Down
35 changes: 18 additions & 17 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from cirq.circuits.qasm_output import QasmOutput
from cirq.circuits.quil_output import QuilOutput
from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
from cirq.circuits.moment import Moment
from cirq.protocols import circuit_diagram_info_protocol
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -225,7 +226,7 @@ def __getitem__(self, key):
moment_idx, qubit_idx = key
# moment_idx - int or slice; qubit_idx - Qid or Iterable[Qid].
selected_moments = self.moments[moment_idx]
if isinstance(selected_moments, ops.Moment):
if isinstance(selected_moments, Moment):
return selected_moments[qubit_idx]
if isinstance(qubit_idx, ops.Qid):
qubit_idx = [qubit_idx]
Expand Down Expand Up @@ -1606,7 +1607,7 @@ def _tetris_concat_helper(
shift = _overlap_collision_time(buf[c1_offset : c1_offset + n1], c2, align)
c2_offset = c1_offset + n1 - shift
for k in range(n2):
buf[k + c2_offset] = (buf[k + c2_offset] or ops.Moment()) + c2[k]
buf[k + c2_offset] = (buf[k + c2_offset] or Moment()) + c2[k]
return min(c1_offset, c2_offset), max(n1, n2, n1 + n2 - shift)


Expand Down Expand Up @@ -1768,13 +1769,13 @@ def __setitem__(self, key: slice, value: Iterable['cirq.Moment']):

def __setitem__(self, key, value):
if isinstance(key, int):
if not isinstance(value, ops.Moment):
if not isinstance(value, Moment):
raise TypeError('Can only assign Moments into Circuits.')
self._device.validate_moment(value)

if isinstance(key, slice):
value = list(value)
if any(not isinstance(v, ops.Moment) for v in value):
if any(not isinstance(v, Moment) for v in value):
raise TypeError('Can only assign Moments into Circuits.')
for moment in value:
self._device.validate_moment(moment)
Expand Down Expand Up @@ -1881,7 +1882,7 @@ def with_device(
with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)):
return Circuit(
[
ops.Moment(
Moment(
operation.transform_qubits(qubit_mapping) for operation in moment.operations
)
for moment in self._moments
Expand Down Expand Up @@ -1939,7 +1940,7 @@ def transform_qubits(
raise TypeError('qubit_map must be a function or dict mapping qubits to qubits.')

op_list = [
ops.Moment(operation.transform_qubits(transform) for operation in moment.operations)
Moment(operation.transform_qubits(transform) for operation in moment.operations)
for moment in self._moments
]

Expand Down Expand Up @@ -1989,7 +1990,7 @@ def _pick_or_create_inserted_op_moment_index(
"""

if strategy is InsertStrategy.NEW or strategy is InsertStrategy.NEW_THEN_INLINE:
self._moments.insert(splitter_index, ops.Moment())
self._moments.insert(splitter_index, Moment())
return splitter_index

if strategy is InsertStrategy.INLINE:
Expand Down Expand Up @@ -2074,22 +2075,22 @@ def insert(
)

for moment_or_op in moments_and_operations:
if isinstance(moment_or_op, ops.Moment):
self._device.validate_moment(cast(ops.Moment, moment_or_op))
if isinstance(moment_or_op, Moment):
self._device.validate_moment(cast(Moment, moment_or_op))
else:
self._device.validate_operation(cast(ops.Operation, moment_or_op))

# limit index to 0..len(self._moments), also deal with indices smaller 0
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
for moment_or_op in moments_and_operations:
if isinstance(moment_or_op, ops.Moment):
if isinstance(moment_or_op, Moment):
self._moments.insert(k, moment_or_op)
k += 1
else:
op = cast(ops.Operation, moment_or_op)
p = self._pick_or_create_inserted_op_moment_index(k, op, strategy)
while p >= len(self._moments):
self._moments.append(ops.Moment())
self._moments.append(Moment())
self._moments[p] = self._moments[p].with_operation(op)
self._device.validate_moment(self._moments[p])
k = max(k, p + 1)
Expand Down Expand Up @@ -2186,7 +2187,7 @@ def _push_frontier(
)
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [ops.Moment()] * n_new_moments
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2212,12 +2213,12 @@ def _insert_operations(
"""
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [ops.Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
moment_to_ops: Dict[int, List['cirq.Operation']] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = ops.Moment(
self._moments[moment_index] = Moment(
self._moments[moment_index].operations + tuple(new_ops)
)

Expand Down Expand Up @@ -2274,7 +2275,7 @@ def batch_remove(self, removals: Iterable[Tuple[int, 'cirq.Operation']]) -> None
for i, op in removals:
if op not in copy._moments[i].operations:
raise ValueError(f"Can't remove {op} @ {i} because it doesn't exist.")
copy._moments[i] = ops.Moment(
copy._moments[i] = Moment(
old_op for old_op in copy._moments[i].operations if op != old_op
)
self._device.validate_circuit(copy)
Expand All @@ -2299,7 +2300,7 @@ def batch_replace(
for i, op, new_op in replacements:
if op not in copy._moments[i].operations:
raise ValueError(f"Can't replace {op} @ {i} because it doesn't exist.")
copy._moments[i] = ops.Moment(
copy._moments[i] = Moment(
old_op if old_op != op else new_op for old_op in copy._moments[i].operations
)
self._device.validate_circuit(copy)
Expand Down Expand Up @@ -2397,7 +2398,7 @@ def _resolve_parameters_(
resolved_moments = []
for moment in self:
resolved_operations = _resolve_operations(moment.operations, resolver, recursive)
new_moment = ops.Moment(resolved_operations)
new_moment = Moment(resolved_operations)
resolved_moments.append(new_moment)
if self._device == devices.UNCONSTRAINED_DEVICE:
return Circuit(resolved_moments)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class _Foxy(ValidatingTestDevice):
def can_add_operation_into_moment(
self, operation: 'ops.Operation', moment: 'ops.Moment'
self, operation: 'cirq.Operation', moment: 'cirq.Moment'
) -> bool:
if not super().can_add_operation_into_moment(operation, moment):
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from cirq import protocols, ops, qis
from cirq._import import LazyLoader
from cirq.ops import raw_types
from cirq.ops import raw_types, op_tree
from cirq.protocols import circuit_diagram_info_protocol
from cirq.type_workarounds import NotImplementedType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
import cirq
import cirq.testing

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'


def test_deprecated_submodule():
with cirq.testing.assert_deprecated(
"Use cirq.circuits.moment instead",
deadline="v0.16",
):
_ = cirq.ops.moment.Moment


def test_deprecated_attribute_in_cirq_ops():
with cirq.testing.assert_deprecated(
"Use cirq.circuits.Moment instead",
deadline="v0.16",
):
_ = cirq.ops.Moment


def test_validation():
a = cirq.NamedQubit('a')
Expand Down Expand Up @@ -305,6 +323,16 @@ def test_with_key_path():
)


def test_with_key_path_prefix():
a, b, c = cirq.LineQubit.range(3)
m = cirq.Moment(cirq.measure(a, key='m1'), cirq.measure(b, key='m2'), cirq.X(c))
mb = cirq.with_key_path_prefix(m, ('b',))
mab = cirq.with_key_path_prefix(mb, ('a',))
assert mab.operations[0] == cirq.measure(a, key=cirq.MeasurementKey.parse_serialized('a:b:m1'))
assert mab.operations[1] == cirq.measure(b, key=cirq.MeasurementKey.parse_serialized('a:b:m2'))
assert mab.operations[2] is m.operations[2]


def test_copy():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/acquaintance/mutation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import cast, Dict, List, Optional, Sequence, Union, TYPE_CHECKING

from cirq import ops, transformers
from cirq import circuits, ops, transformers

from cirq.contrib.acquaintance.gates import SwapNetworkGate, AcquaintanceOpportunityGate
from cirq.contrib.acquaintance.devices import get_acquaintance_size
Expand Down Expand Up @@ -52,7 +52,7 @@ def rectify_acquaintance_strategy(circuit: 'cirq.Circuit', acquaint_first: bool
rectified_moments.append(moment)
continue
for acquaint_first in sorted(gate_type_to_ops.keys(), reverse=acquaint_first):
rectified_moments.append(ops.Moment(gate_type_to_ops[acquaint_first]))
rectified_moments.append(circuits.Moment(gate_type_to_ops[acquaint_first]))
circuit._moments = rectified_moments


Expand Down Expand Up @@ -92,7 +92,7 @@ def replace_acquaintance_with_swap_network(
qubit_order, moment.operations, acquaintance_size, swap_gate
)
swap_network_op = swap_network_gate(*qubit_order)
moment = ops.Moment([swap_network_op])
moment = circuits.Moment([swap_network_op])
reflected = not reflected
circuit._moments[moment_index] = moment
return reflected
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/contrib/acquaintance/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import cast, FrozenSet, List, Sequence, Set, TYPE_CHECKING

from cirq import ops
from cirq import circuits

from cirq.contrib.acquaintance.gates import acquaint
from cirq.contrib.acquaintance.executor import AcquaintanceOperation
Expand Down Expand Up @@ -50,6 +50,6 @@ def remove_redundant_acquaintance_opportunities(strategy: 'cirq.Circuit') -> int
n_removed += 1
else:
new_moment.append(op)
new_moments.append(ops.Moment(new_moment))
new_moments.append(circuits.Moment(new_moment))
strategy._moments = new_moments
return n_removed
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/acquaintance/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
TYPE_CHECKING,
)

from cirq import ops, protocols, transformers, value
from cirq import circuits, ops, protocols, transformers, value
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,12 +131,12 @@ def display_mapping(circuit: 'cirq.Circuit', initial_mapping: LogicalMapping) ->

old_moments = circuit._moments
gate = MappingDisplayGate(mapping.get(q) for q in qubits)
new_moments = [ops.Moment([gate(*qubits)])]
new_moments = [circuits.Moment([gate(*qubits)])]
for moment in old_moments:
new_moments.append(moment)
update_mapping(mapping, moment)
gate = MappingDisplayGate(mapping.get(q) for q in qubits)
new_moments.append(ops.Moment([gate(*qubits)]))
new_moments.append(circuits.Moment([gate(*qubits)]))

circuit._moments = new_moments

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/acquaintance/strategies/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def cubic_acquaintance_strategy(
new_index_order = skip_and_wrap_around(stepped_indices_concatenated)
permutation = {i: new_index_order.index(j) for i, j in enumerate(index_order)}
permutation_gate = LinearPermutationGate(n_qubits, permutation, swap_gate)
moments.append(ops.Moment([permutation_gate(*qubits)]))
moments.append(circuits.Moment([permutation_gate(*qubits)]))
for i in range(n_qubits + 1):
for offset in range(3):
moment = ops.Moment(
moment = circuits.Moment(
acquaint(*qubits[j : j + 3]) for j in range(offset, n_qubits - 2, 3)
)
moments.append(moment)
if i < n_qubits:
moment = ops.Moment(
moment = circuits.Moment(
swap_gate(*qubits[j : j + 2]) for j in range(i % 2, n_qubits - 1, 2)
)
moments.append(moment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def quartic_paired_acquaintance_strategy(
strategy = circuits.Circuit(swap_network)
expose_acquaintance_gates(strategy)
for i in reversed(range(0, n_qubits, 2)):
moment = ops.Moment([acquaint(*qubits[j : j + 4]) for j in range(i % 4, n_qubits - 3, 4)])
moment = circuits.Moment(
[acquaint(*qubits[j : j + 4]) for j in range(i % 4, n_qubits - 3, 4)]
)
strategy.insert(2 * i, moment)
return strategy, qubits
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/graph_device/graph_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def validate_crosstalk(
):
validator(operation, *crosstalk_operations)

def validate_moment(self, moment: ops.Moment):
def validate_moment(self, moment: 'cirq.Moment'):
super().validate_moment(moment)
ops = moment.operations
for i, op in enumerate(ops):
Expand Down
Loading

0 comments on commit 8b64834

Please sign in to comment.