Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and speed up cirq.transformers.stratify #6013

Merged
merged 40 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f302aaa
refactor cirq.transformers.stratify
perlinm Feb 20, 2023
6e9c08d
fix one failing test
perlinm Feb 20, 2023
b459b71
nit renaming
perlinm Feb 20, 2023
e2f58b1
fix coverage
perlinm Feb 20, 2023
41d2c48
formatting fix
perlinm Feb 20, 2023
e3030fa
pylint fix
perlinm Feb 20, 2023
8f934ff
fix bug with measurements in stratification
perlinm Feb 20, 2023
e4fc85e
add missing import
perlinm Feb 20, 2023
4eaa489
fix bug with finding time index for op
perlinm Feb 20, 2023
435e606
fix test, and nit change to keeping track of time indices
perlinm Feb 20, 2023
c7394a8
hopefully fix measurement bug
perlinm Feb 20, 2023
8b2ff60
minor fix with ignored ops
perlinm Feb 20, 2023
4c65ac4
fix test
perlinm Feb 20, 2023
605df23
Merge branch 'master' into refactor_stratify
tanujkhattar Feb 28, 2023
6d5bb90
only store shortest circuit found
perlinm Mar 13, 2023
bf65cbc
minor bugfix
perlinm Mar 13, 2023
40fe0ca
nit typing fix
perlinm Mar 13, 2023
d2b2737
one more silly bugfig
perlinm Mar 13, 2023
2e7db1c
store shortest stratified circuit properly
perlinm Mar 13, 2023
4842091
fix bug with overlapping measurements
perlinm Mar 13, 2023
5b5996b
clean up handling of ignored ops
perlinm Mar 14, 2023
784a067
further clean up logic deciding where to put an op
perlinm Mar 14, 2023
090beb6
factor out logic for finding earliest accomodating moment
perlinm Mar 14, 2023
e132a50
fix typo
perlinm Mar 14, 2023
f34d081
fix typo
perlinm Mar 14, 2023
49c286e
remove unnecesaary use of defaultdict
perlinm Mar 14, 2023
3f8e966
Merge branch 'master' into refactor_stratify
perlinm Mar 20, 2023
b3ae21c
Merge branch 'master' into refactor_stratify
perlinm Mar 22, 2023
d567650
further simplify logic in get_earliest_accommodating_moment_index
perlinm Mar 22, 2023
f6de11c
fix lint check
perlinm Mar 22, 2023
695360e
fix minor bug
perlinm Mar 22, 2023
71b969b
nit docstring update
perlinm Mar 22, 2023
c9674d9
separately update qubit/mkey/ckey moments in cirq.stratify
perlinm Mar 22, 2023
3da970c
fix bug with max only getting one argument
perlinm Mar 22, 2023
4f1cfdc
fix coverage check
perlinm Mar 22, 2023
1f35a4e
fix typo
perlinm Mar 22, 2023
0d84391
fix lint check
perlinm Mar 22, 2023
b58e053
Merge branch 'master' into refactor_stratify
tanujkhattar Apr 3, 2023
6a08fed
Merge branch 'master' into refactor_stratify
tanujkhattar Apr 3, 2023
c46cdf5
Merge branch 'master' into refactor_stratify
tanujkhattar Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 86 additions & 43 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,12 +1776,10 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):
Non-moment entries will be inserted according to the EARLIEST
insertion strategy.
"""
# These are dicts from the qubit/key to the greatest moment index that has it. It is safe
# to default to `-1`, as that is interpreted as meaning the zeroth index onward does not
# have this value.
qubit_indexes: Dict['cirq.Qid', int] = defaultdict(lambda: -1)
mkey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
ckey_indexes: Dict['cirq.MeasurementKey', int] = defaultdict(lambda: -1)
# These are dicts from the qubit/key to the greatest moment index that has it.
qubit_indices: Dict['cirq.Qid', int] = {}
mkey_indices: Dict['cirq.MeasurementKey', int] = {}
ckey_indices: Dict['cirq.MeasurementKey', int] = {}

# We also maintain the dict from moment index to moments/ops that go into it, for use when
# building the actual moments at the end.
Expand All @@ -1793,46 +1791,17 @@ def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):

# "mop" means current moment-or-operation
for mop in ops.flatten_to_ops_or_moments(contents):
mop_qubits = mop.qubits
mop_mkeys = protocols.measurement_key_objs(mop)
mop_ckeys = protocols.control_keys(mop)

# Both branches define `i`, the moment index at which to place the mop.
# Identify the index of the moment to place this `mop` into.
placement_index = get_earliest_accommodating_moment_index(
mop, qubit_indices, mkey_indices, ckey_indices, length
)
length = max(length, placement_index + 1) # update the length of the circuit thus far

if isinstance(mop, Moment):
# We always append moment to the end, to be consistent with `self.append`
i = length
moments_by_index[i] = mop
moments_by_index[placement_index] = mop
else:
# Initially we define `i` as the greatest moment index that has a conflict. `-1` is
# the initial conflict, and we search for larger ones. Once we get the largest one,
# we increment i by 1 to set the placement index.
i = -1

# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of
# this op's qubits, that has a measurement or control key the same as one of this
# op's measurement keys, or that has a measurement key the same as one of this op's
# control keys. (Control keys alone can commute past each other). The `ifs` are
# logically unnecessary but seem to make this slightly faster.
if mop_qubits:
i = max(i, *[qubit_indexes[q] for q in mop_qubits])
if mop_mkeys:
i = max(i, *[mkey_indexes[k] for k in mop_mkeys])
i = max(i, *[ckey_indexes[k] for k in mop_mkeys])
if mop_ckeys:
i = max(i, *[mkey_indexes[k] for k in mop_ckeys])
i += 1
op_lists_by_index[i].append(mop)

# Update our dicts with data from the latest mop placement. Note `i` will always be
# greater than the existing value for all of these, by construction, so there is no
# need to do a `max(i, existing)`.
for q in mop_qubits:
qubit_indexes[q] = i
for k in mop_mkeys:
mkey_indexes[k] = i
for k in mop_ckeys:
ckey_indexes[k] = i
length = max(length, i + 1)
op_lists_by_index[placement_index].append(mop)

# Finally, once everything is placed, we can construct and append the actual moments for
# each index.
Expand Down Expand Up @@ -2753,3 +2722,77 @@ def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey],
Tuples containing the group key and item values.
"""
return ((k, [val(i) for i in v]) for (k, v) in itertools.groupby(items, key))


def get_earliest_accommodating_moment_index(
moment_or_operation: Union['cirq.Moment', 'cirq.Operation'],
qubit_indices: Dict['cirq.Qid', int],
mkey_indices: Dict['cirq.MeasurementKey', int],
ckey_indices: Dict['cirq.MeasurementKey', int],
length: Optional[int] = None,
) -> int:
"""Get the index of the earliest moment that can accomodate the given moment or operation.

Updates the dictionaries keeping track of the last moment index addressing a given qubit,
measurement key, and control key.

Args:
moment_or_operation: The moment operation in question.
qubit_indices: A dictionary mapping qubits to the latest moments that address them.
mkey_indices: A dictionary mapping measureent keys to the latest moments that address them.
ckey_indices: A dictionary mapping control keys to the latest moments that address them.
length: The length of the circuit that we are trying to insert a moment or operation into.
Should probably be equal to the maximum of the values in `qubit_indices`,
`mkey_indices`, and `ckey_indices`.

Returns:
The integer index of the earliest moment that can accomodate the given moment or operation.
"""
mop_qubits = moment_or_operation.qubits
mop_mkeys = protocols.measurement_key_objs(moment_or_operation)
mop_ckeys = protocols.control_keys(moment_or_operation)

if isinstance(moment_or_operation, Moment):
# For consistency with `Circuit.append`, moments always get placed at the end of a circuit.
if length is not None:
last_conflict = length - 1
else:
last_conflict = max(
[*qubit_indices.values(), *mkey_indices.values(), *ckey_indices.values(), -1]
)

else:
# We start by searching for the `latest_conflict` moment index, which we will increment by
# `1` to identify the earliest moment that *does not* conflict with the given operation.
# The `latest_conflict` is initialized to `-1` before searching for later conflicting
# moments.
last_conflict = -1

# Look for the maximum conflict; i.e. a moment that has a qubit the same as one of this op's
# qubits, that has a measurement or control key the same as one of this op's measurement
# keys, or that has a measurement key the same as one of this op's control keys. (Control
# keys alone can commute past each other). The `ifs` are logically unnecessary but seem to
# make this slightly faster.
if mop_qubits:
last_conflict = max(
last_conflict, *[qubit_indices.get(qubit, -1) for qubit in mop_qubits]
)
if mop_mkeys:
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_mkeys])
last_conflict = max(last_conflict, *[ckey_indices.get(key, -1) for key in mop_mkeys])
if mop_ckeys:
last_conflict = max(last_conflict, *[mkey_indices.get(key, -1) for key in mop_ckeys])

# The index of the moment to place this moment or operaton ("mop") into.
mop_index = last_conflict + 1

# Update our dicts with data from this `mop` placement. Note `mop_index` will always be greater
# than the existing value for all of these, by construction.
for qubit in mop_qubits:
qubit_indices[qubit] = mop_index
for key in mop_mkeys:
mkey_indices[key] = mop_index
for key in mop_ckeys:
ckey_indices[key] = mop_index

return mop_index
11 changes: 11 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,17 @@ def test_insert_moment():
assert c.operation_at(qubit, actual_index) == operation[0]


def test_circuit_length_inference():
# tests that `get_earliest_accommodating_moment_index` properly computes circuit length
circuit = cirq.Circuit(cirq.X(cirq.q(0)))
qubit_indices = {cirq.q(0): 0}
mkey_indices = {}
ckey_indices = {}
assert circuits.circuit.get_earliest_accommodating_moment_index(
cirq.Moment(), qubit_indices, mkey_indices, ckey_indices
) == len(circuit)


def test_insert_inline_near_start():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
Expand Down
186 changes: 124 additions & 62 deletions cirq-core/cirq/transformers/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""Transformer pass to repack circuits avoiding simultaneous operations with different classes."""

import itertools
from typing import TYPE_CHECKING, Type, Callable, Optional, Union, Iterable, Sequence, List, Tuple
from typing import TYPE_CHECKING, Type, Callable, Dict, Optional, Union, Iterable, Sequence, List

from cirq import ops, circuits, _import
from cirq.transformers import transformer_api, transformer_primitives
from cirq import ops, circuits, protocols, _import
from cirq.transformers import transformer_api

drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers')

Expand Down Expand Up @@ -61,38 +61,36 @@ def stratified_circuit(
Returns:
A copy of the original circuit, but with re-arranged operations.
"""

# Normalize categories into classifier functions.
classifiers = [_category_to_classifier(category) for category in categories]
# Make the classifiers exhaustive by adding an "everything else" bucket.
and_the_rest = lambda op: all(not classifier(op) for classifier in classifiers)
classifiers_and_the_rest = [*classifiers, and_the_rest]
classifiers = _get_classifiers(circuit, categories)

# Try the algorithm with each permutation of the classifiers.
classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest))
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
shortest_stratified_circuit = circuits.Circuit()
reversed_circuit = circuit[::-1]
solutions = []
for c in classifiers_permutations:
solutions.append(
_stratify_circuit(
circuit,
classifiers=list(c),
context=context or transformer_api.TransformerContext(),
)
for ordered_classifiers in itertools.permutations(classifiers):
solution = _stratify_circuit(
circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)

# Do the same thing, except this time in reverse. This helps for some
# circuits because it inserts operations at the end instead of at the
# beginning.
solutions.append(
_stratify_circuit(
reversed_circuit,
classifiers=list(c),
context=context or transformer_api.TransformerContext(),
)[::-1]
)
solution = _stratify_circuit(
reversed_circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)[::-1]
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)

# Return the shortest circuit.
return min(solutions, key=lambda c: len(c))
return shortest_stratified_circuit


def _stratify_circuit(
Expand All @@ -116,43 +114,88 @@ def _stratify_circuit(
Returns:
The stratified circuit.
"""
num_categories = len(classifiers) + 1

def map_func(m: 'cirq.Moment', _) -> Sequence['cirq.Moment']:
stratified_ops: List[List['cirq.Operation']] = [[] for _ in range(num_categories)]
for op in m:
if set(op.tags) & set(context.tags_to_ignore):
stratified_ops[0].append(op)
continue
for i, classifier in enumerate(classifiers):
if classifier(op):
stratified_ops[i + 1].append(op)
break
return [circuits.Moment(op_list) for op_list in stratified_ops]

stratified_circuit = transformer_primitives.map_moments(circuit, map_func).unfreeze(copy=False)
assert len(stratified_circuit) == len(circuit) * num_categories

# Try to move operations to the left to reduce circuit depth, preserving stratification.
for curr_idx, moment in enumerate(stratified_circuit):
curr_category = curr_idx % num_categories
if curr_category == 0:
# Moment containing tagged operations to be ignored.
continue
batch_removals: List[Tuple[int, 'cirq.Operation']] = []
batch_inserts: List[Tuple[int, 'cirq.Operation']] = []
num_classes = len(classifiers) + 1 # include one "extra" category for ignored operations
new_moments: List[List['cirq.Operation']] = []

# Keep track of the the latest time index for each qubit, measurement key, and control key.
qubit_time_index: Dict['cirq.Qid', int] = {}
measurement_time_index: Dict['cirq.MeasurementKey', int] = {}
control_time_index: Dict['cirq.MeasurementKey', int] = {}

# The minimum time index for operations with a tag in context.tags_to_ignore.
last_ignored_ops_time_index = 0

for moment in circuit:
# Identify the new time indices that operations should be moved into.
ignored_ops = []
op_time_indices = {}
for op in moment:
prv_idx = stratified_circuit.earliest_available_moment(op, end_moment_index=curr_idx)
prv_category = prv_idx % num_categories
should_move_to_next_batch = curr_category < prv_category
prv_idx += curr_category - prv_category + num_categories * should_move_to_next_batch
assert prv_idx <= curr_idx and prv_idx % num_categories == curr_idx % num_categories
if prv_idx < curr_idx:
batch_inserts.append((prv_idx, op))
batch_removals.append((curr_idx, op))
stratified_circuit.batch_remove(batch_removals)
stratified_circuit.batch_insert_into(batch_inserts)
return drop_empty_moments.drop_empty_moments(stratified_circuit)

# Identify the earliest moment that can accommodate this op.
min_time_index_for_op = circuits.circuit.get_earliest_accommodating_moment_index(
op, qubit_time_index, measurement_time_index, control_time_index
)

# Identify the "class" of this operation (by index).
ignored_op = any(tag in op.tags for tag in context.tags_to_ignore)
if not ignored_op:
op_class = _get_op_class(op, classifiers)
else:
op_class = len(classifiers)
ignored_ops.append(op)
min_time_index_for_op = max(min_time_index_for_op, last_ignored_ops_time_index + 1)

# Identify the time index to place this operation into.
time_index = (min_time_index_for_op // num_classes) * num_classes + op_class
if time_index < min_time_index_for_op:
time_index += num_classes
op_time_indices[op] = time_index

# Assign ignored operations to the same moment.
if ignored_ops:
last_ignored_ops_time_index = max(op_time_indices[op] for op in ignored_ops)
for op in ignored_ops:
op_time_indices[op] = last_ignored_ops_time_index

# Move the operations into their assigned moments.
for op, time_index in op_time_indices.items():
if time_index >= len(new_moments):
new_moments += [[] for _ in range(num_classes)]
new_moments[time_index].append(op)

# Update qubit, measurment key, and control key moments.
for qubit in op.qubits:
qubit_time_index[qubit] = time_index
for key in protocols.measurement_key_objs(op):
measurement_time_index[key] = time_index
for key in protocols.control_keys(op):
control_time_index[key] = time_index

return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment)


def _get_classifiers(
circuit: circuits.AbstractCircuit, categories: Iterable[Category]
) -> List[Classifier]:
"""Convert a collection of categories into a list of classifiers.

The returned list of classifiers is:
- Exhaustive, meaning every operation in the circuit is classified by at least one classifier.
- Minimal, meaning unused classifiers are forgotten.
"""
# Convert all categories into classifiers, and make the list exhaustive by adding a dummy
# classifier for otherwise unclassified ops.
classifiers = [_category_to_classifier(cat) for cat in categories] + [_dummy_classifier]

# Figure out which classes are actually used in the circuit.
class_is_used = [False for _ in classifiers]
for op in circuit.all_operations():
class_is_used[_get_op_class(op, classifiers)] = True
if all(class_is_used):
break

# Return only the classifiers that are used.
return [classifier for classifier, is_used in zip(classifiers, class_is_used) if is_used]


# No type for `category` because mypy does not keep the return type when
Expand All @@ -177,3 +220,22 @@ def _category_to_classifier(category) -> Classifier:
f'Type[cirq.Gate], Type[cirq.Operation], '
f'or Callable[[cirq.Operation], bool].'
)


def _dummy_classifier(op: 'cirq.Operation') -> bool:
"""Dummy classifier, used to "complete" a collection of classifiers and make it exhaustive."""


def _get_op_class(op: 'cirq.Operation', classifiers: Sequence[Classifier]) -> int:
"""Get the "class" of an operator, by index."""
for class_index, classifier in enumerate(classifiers):
if classifier is _dummy_classifier:
dummy_classifier_index = class_index
elif classifier(op):
return class_index
# If we got this far, the operation did not match any "actual" classifier,
# so return the index of the dummy classifer.
try:
return dummy_classifier_index
except NameError:
raise ValueError(f"Operation {op} not identified by any classifier")
Loading