Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and speed up cirq.transformers.stratify #6013

Merged
merged 40 commits into from
Apr 3, 2023

Conversation

perlinm
Copy link
Contributor

@perlinm perlinm commented Feb 20, 2023

This PR refactors cirq.transformers.stratify.stratified_circuit, making it considerably faster.

I have not contributed to cirq before, so I apologize if I am not doing things the "correct" way. Feedback and suggestions much appreciated 🙂.

Some comments:

  1. If there are no ops_to_ignore, for a given input (circuit + categories), the output of the new stratifier is identical to the old.
  2. Ignored ops are treated a bit differently, though I think the new treatment is consistent with the previous docstring. In the old stratifier, (a) each moment in the circuit is stratified into num_categories strata, and then (b) every operation is moved into the earliest stratum possible that matches its category. Ignored ops are not moved into earlier strata at step (b), which can block non-ignored ops from moving into earlier moments. The new stratifier moves ignored ops into earlier moments as well, which can result in shorter circuits. The ignored ops are still "ignored" in the sense that their relative order is unaffected by stratification. If one ignored op preceded another ignored op before stratification, that will still be true after stratification, and two ignored ops will be in the same moment after stratification if and only if they were in the same moment before stratification.
  3. I claim that thew new stratifier is faster, but I am not sure how to run a proper regression test that others can conveniently use to verify my claim. One simple test that I can run locally is
import cirq
import numpy as np
import time

np.random.seed(0)
circuit = cirq.testing.random_circuit(qubits=4, n_moments=100, op_density=1)
categories = [lambda op: len(op.qubits) == 1, cirq.CNOT]  # single-qubit ops, and CNOTs

start = time.time()
stratified_circuit = cirq.stratified_circuit(circuit, categories=categories)
print("time:", time.time() - start)

On my laptop, this test takes ~0.4 seconds with the old stratifier, and ~0.03 seconds with the new stratifier.

Edit: it looks like this PR is failing some tests, but at the time of writing all failures appear to be some issue with installing npm. I am not sure what to do about this...

@perlinm perlinm requested review from a team, vtomole and cduck as code owners February 20, 2023 03:05
@perlinm
Copy link
Contributor Author

perlinm commented Feb 20, 2023

I actually have another stratifier that, among other advantages, (1) is another ~10x faster (so overall ~100x faster than the current stratifier in Cirq) for the example above, and (2) generally yields shorter stratified circuits (with depth provably upper bounded by the current Cirq stratifier). However, this other stratifier is more complex than the one in Cirq, and the corresponding PR would be considerably larger. There is also the question of whether Cirq would want to keep both stratifiers, or only one.

My thinking was to first get through a "minor" stratifier improvement in this PR, and then open a separate, larger PR for the other stratifier. However, maybe it makes more sense to add the other stratifier to this PR as well. Thoughts?

@95-martin-orion 95-martin-orion requested review from tanujkhattar and removed request for 95-martin-orion February 21, 2023 16:09
@tanujkhattar tanujkhattar self-assigned this Feb 21, 2023
Copy link
Collaborator

@tanujkhattar tanujkhattar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for opening the PR! The changes look great overall.

This aligns well with our general roadmap of making Cirq more performant. I have left a round of comments highlighting a bug and with some suggests to restructure the code.

Overall, this is looking good. We can merge once the comments have been addressed.


# Try the algorithm with each permutation of the classifiers.
classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest))
classifiers_permutations = list(itertools.permutations(classifiers))
reversed_circuit = circuit[::-1]
solutions = []
for c in classifiers_permutations:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another potential optimization would be to only store the final "optimized" circuit and in each iteration of the loop below, update the optimized circuit with the new circuit if it's length is shorter than the optimized circuit.

This ensures we are not storing all 2 * len(classifiers)! circuits in memory till the end of the for-loop to compute the best solution.

Comment on lines 142 to 147
for key in protocols.control_keys(op) & measurement_time_index.keys():
time_index = measurement_time_index[key]
min_time_index_for_op = max(min_time_index_for_op, time_index + 1)
for key in protocols.measurement_key_objs(op) & control_time_index.keys():
time_index = control_time_index[key]
min_time_index_for_op = max(min_time_index_for_op, time_index + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bug here because you also need to check the overlap of protocols.measurement_key_objs(op) with measurement_time_index.keys(). This is to make sure that two measurement operations with the same key don't end up in the same moment.

Cirq supports repeated measurements on the same key under the assumption that these measurements occur in different moments. The following code snippet demonstrates the bug:

q = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
    cirq.measure(q[0], key="a"),
    cirq.measure(q[1], key="a"),
    cirq.X(q[2]).with_classical_controls("a")
)
print(circuit)
print(cirq.stratified_circuit(circuit))

The output is:

0: ───M───────────
      ║
1: ───╫───M───────
      ║   ║
2: ───╫───╫───X───
      ║   ║   ║
a: ═══@═══@═══^═══
      ┌──┐
0: ────M─────────
       ║
1: ────╫M────────
       ║║
2: ────╫╫────X───
       ║║    ║
a: ════@@════^═══
      └──┘

The stratified circuit should be the same as the original circuit in this case.

In general, the code here is extremely similar to our original logic of constructing a circuit with earliest strategy in the circuit constructor:

def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'):

Now that we have two uses of the same piece of code, I think it's worth abstracting out this logic into a circuit builder class, an instance of which can be used at both the places.

We have two options for this PR:

  1. We can fix the bug here, add a test and merge this PR and open a separate issue to track adding the circuit builder logic for reducing code duplication.
  2. Add the circuit builder logic as part of this PR.

I will leave it up to you to decide which option to pursue, although I'd have a slight preference towards (2) to make sure the the circuit builder abstraction doesn't get stalled.

Copy link
Contributor Author

@perlinm perlinm Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I factored out the logic for finding the earliest available/accommodating moment, and tentatively placed it in circuits.circuit.get_earliest_accommodating_moment_index. Suggestions welcome for where this function should be moved.

UPDATE: I have revised my PR since writing the text below in this comment, as well as the text in the follow-up comment. I am leaving these comments here to document my thoughts throughout making the process of changing this PR, but these comments are no longer relevant for making sense of the current code.

[OUTDATED DISCUSSION:]
There are currently two related things that I am dissatisfied with:

  1. circuits/circuit.py currently has to construct mop_qubits/mkeys/ckeys twice: (a) in _load_contents_with_earliest_strategy, and (b) in get_earliest_accommodating_moment_index.
  2. Despite using the same method/logic for finding the earliest accommodating moment, circuits/circuit.py and transformers/stratify.py have to independently update the qubit/mkey/ckey_index dictionaries (here and here).

Fundamentally, both of these issues stem from the fact that _load_contents_with_earliest_strategy has special logic for what to do with Moments rather than Operations, which gets tangled with the logic of identifying an operation's qubits/mkeys/ckeys, finding the earliest accommodating moment, and updating the qubit/mkey/ckey --> index dictionaries.

Ideas for how to fix these issues?

Copy link
Contributor Author

@perlinm perlinm Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[OUTDATED DISCUSSION:]
Two possible resolutions (though there may be better ones):

  • Make get_earliest_accommodating_moment_index accept a Union[cirq.Operation, cirq.Moment], and for moments return something like max(*qubit_indices.values(), *mkey_indices.values(), *ckey_indices.values()). This would make the code relatively "clean", but it would require to compute a maximum every time. In contrast, the current the current implementation keeps track of this maximum "manually".
  • Again accept a cirq.Moment as an input, but also accept an optional circuit_length: Optional[int], and have get_earliest_accommodating_moment_index return the given circuit_length if isinstance(mop, cirq.Moment) and circuit_length is not None. This option avoids recalculating the circuit length by maximizing over indices, but it is potentially "dangerous" in the sense that the method would happily return an invalid circuit length. On the other hand, garbage in garbage out so maybe that's okay?

Copy link
Contributor Author

@perlinm perlinm Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a bit more thinking through the issues above, I decided to move the special treatment of Moments to get_earliest_accommodating_moment_index. This method now:

  1. Optionally accepts a length argument that is "promised" to be the length of the circuit, and places new moments at the end of the circuit according to that length.
  2. Updates the qubit/mkey/ckey_indices dictionaries in-place.

(1) seemed like a reasonable cost to pay for simplifying shared logic between the circuit builder and the stratifier. (2) was primarily done so that the circuit builder doesn't have to call protocols.measurement_key_objs and protocols.control_keys twice. However, the stratifier has to later update the qubit/mkey/ckey_indices dictionaries again because it post-processes the choice of moment index for op placement. An alternative to (2) would be to leave it to the circuit builder and stratifier to independently update the qubit/mkey/ckey_indices dictionaries, which would come at the (minor) cost of having the circuit builder call protocols.measurement_key_objs and protocols.control_keys twice.

stratified_circuit.batch_remove(batch_removals)
stratified_circuit.batch_insert_into(batch_inserts)
return drop_empty_moments.drop_empty_moments(stratified_circuit)
ignored_op = any(tag in op.tags for tag in context.tags_to_ignore)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this just above the if ignored_op: check at line 150, it reduces the cognitive overhead for the reader to track how we are dealing ignored ops differently inside the loop.

# Get the earliest time index that can accomodate this op, based on its qubits.
min_time_index_for_op = max(qubit_time_index[qubit] + 1 for qubit in op.qubits)
if ignored_op:
min_time_index_for_op = max(min_time_index_for_op, min_time_index_for_ignored_op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, move inside the single if ignored_op: condition at line 150.

control_time_index: Dict['cirq.MeasurementKey', int] = {}

# The minimum time index for operations with a tag in context.tags_to_ignore.
min_time_index_for_ignored_op = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming min_time_index_for_ignored_op to last_ignored_ops_moment_index because we are essentially keeping track of the latest moment that contains only ignored ops so that we can insert the next ignores ops batch after that time.

Also, compute min_time_index_for_ignored_op as just the max(min_time_index_for_ignored_op, time_index) in the inner loop and get rid of the check min_time_index_for_op = max(min_time_index_for_op, min_time_index_for_ignored_op). Once you have min_time_index_for_ignored_op in the inner loop, you can take the max with last_ignored_ops_moment_index once outside before assigning op_time_indices to ignored_ops and then update the last_ignored_ops_moment_index with the current index.

In general, I think the current logic to take care of ignored ops is a bit convoluted and it's hard to reason about it's correctness. My proposed changes should make it easier for a reader to quickly understand how we are handling the ignored_ops.

Copy link
Contributor Author

@perlinm perlinm Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these suggestions for how to clean up the handling of ignored_ops. I'm not sure I implemented them exactly as requested, but the new version is certainly much cleaner. What do you think?

Copy link
Collaborator

@tanujkhattar tanujkhattar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was to first get through a "minor" stratifier improvement in this PR, and then open a separate, larger PR for the other stratifier

I agree with this. Let's get this PR in and then we can open a separate PR to either replace this or add a new stratifier.

@perlinm
Copy link
Contributor Author

perlinm commented Feb 28, 2023

Thanks for the detailed review and suggested changes!

I agree that it's quite difficult to keep track of how ignored_ops are handled at the moment. Dealing with ignored_ops (and measurement keys) was most of the work in this PR... Hopefully your suggestions clean this up somewhat.

Anyways, I'll try to get to your suggested changes later this week.

@CirqBot CirqBot added the size: M 50< lines changed <250 label Mar 2, 2023
@perlinm perlinm requested a review from tanujkhattar March 14, 2023 04:10
@perlinm perlinm force-pushed the refactor_stratify branch from b71468f to 1f35a4e Compare March 22, 2023 21:00
@perlinm
Copy link
Contributor Author

perlinm commented Mar 22, 2023

@tanujkhattar I believe that I have addressed all comments/requests. Please let me know of any other suggestions for what to change in this PR!

For example: I suspect that there is a better place for this new function to live.

Copy link
Collaborator

@tanujkhattar tanujkhattar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Let's merge this and open a separate issue to track what's the best place to put the free function get_earliest_accommodating_moment_index. (xref #6050)

Thank you for the contribution and apologies for the long wait! Let's merge this PR and keep the contributions coming! :)

@perlinm
Copy link
Contributor Author

perlinm commented Apr 3, 2023

Thank you! Also thanks to @vtomole for some brief consultation on how to structure some of the changes in this PR 🙂

It looks like some of the continuous integration tests are still failing, but that appears to be due to code that was not touched by this PR. What should I do about it?

@tanujkhattar tanujkhattar merged commit 6a97cca into quantumlib:master Apr 3, 2023
harry-phasecraft pushed a commit to PhaseCraft/Cirq that referenced this pull request Oct 31, 2024
* refactor cirq.transformers.stratify

* fix one failing test

* nit renaming

* fix coverage

* formatting fix

* pylint fix

* fix bug with measurements in stratification

* add missing import

* fix bug with finding time index for op

* fix test, and nit change to keeping track of time indices

* hopefully fix measurement bug

* minor fix with ignored ops

* fix test

* only store shortest circuit found

* minor bugfix

* nit typing fix

* one more silly bugfig

* store shortest stratified circuit properly

* fix bug with overlapping measurements

* clean up handling of ignored ops

* further clean up logic deciding where to put an op

* factor out logic for finding earliest accomodating moment

* fix typo

* fix typo

* remove unnecesaary use of defaultdict

* further simplify logic in get_earliest_accommodating_moment_index

* fix lint check

* fix minor bug

* nit docstring update

* separately update qubit/mkey/ckey moments in cirq.stratify

* fix bug with max only getting one argument

* fix coverage check

* fix typo

* fix lint check

---------

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size: M 50< lines changed <250
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants