-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and speed up cirq.transformers.stratify
#6013
Conversation
I actually have another stratifier that, among other advantages, (1) is another ~10x faster (so overall ~100x faster than the current stratifier in Cirq) for the example above, and (2) generally yields shorter stratified circuits (with depth provably upper bounded by the current Cirq stratifier). However, this other stratifier is more complex than the one in Cirq, and the corresponding PR would be considerably larger. There is also the question of whether Cirq would want to keep both stratifiers, or only one. My thinking was to first get through a "minor" stratifier improvement in this PR, and then open a separate, larger PR for the other stratifier. However, maybe it makes more sense to add the other stratifier to this PR as well. Thoughts? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for opening the PR! The changes look great overall.
This aligns well with our general roadmap of making Cirq more performant. I have left a round of comments highlighting a bug and with some suggests to restructure the code.
Overall, this is looking good. We can merge once the comments have been addressed.
|
||
# Try the algorithm with each permutation of the classifiers. | ||
classifiers_permutations = list(itertools.permutations(classifiers_and_the_rest)) | ||
classifiers_permutations = list(itertools.permutations(classifiers)) | ||
reversed_circuit = circuit[::-1] | ||
solutions = [] | ||
for c in classifiers_permutations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another potential optimization would be to only store the final "optimized" circuit and in each iteration of the loop below, update the optimized circuit with the new circuit if it's length is shorter than the optimized circuit.
This ensures we are not storing all 2 * len(classifiers)!
circuits in memory till the end of the for-loop to compute the best solution.
for key in protocols.control_keys(op) & measurement_time_index.keys(): | ||
time_index = measurement_time_index[key] | ||
min_time_index_for_op = max(min_time_index_for_op, time_index + 1) | ||
for key in protocols.measurement_key_objs(op) & control_time_index.keys(): | ||
time_index = control_time_index[key] | ||
min_time_index_for_op = max(min_time_index_for_op, time_index + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug here because you also need to check the overlap of protocols.measurement_key_objs(op)
with measurement_time_index.keys()
. This is to make sure that two measurement operations with the same key don't end up in the same moment.
Cirq supports repeated measurements on the same key under the assumption that these measurements occur in different moments. The following code snippet demonstrates the bug:
q = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.measure(q[0], key="a"),
cirq.measure(q[1], key="a"),
cirq.X(q[2]).with_classical_controls("a")
)
print(circuit)
print(cirq.stratified_circuit(circuit))
The output is:
0: ───M───────────
║
1: ───╫───M───────
║ ║
2: ───╫───╫───X───
║ ║ ║
a: ═══@═══@═══^═══
┌──┐
0: ────M─────────
║
1: ────╫M────────
║║
2: ────╫╫────X───
║║ ║
a: ════@@════^═══
└──┘
The stratified circuit should be the same as the original circuit in this case.
In general, the code here is extremely similar to our original logic of constructing a circuit with earliest strategy in the circuit constructor:
Cirq/cirq-core/cirq/circuits/circuit.py
Line 1760 in af6624d
def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'): |
Now that we have two uses of the same piece of code, I think it's worth abstracting out this logic into a circuit builder class, an instance of which can be used at both the places.
We have two options for this PR:
- We can fix the bug here, add a test and merge this PR and open a separate issue to track adding the circuit builder logic for reducing code duplication.
- Add the circuit builder logic as part of this PR.
I will leave it up to you to decide which option to pursue, although I'd have a slight preference towards (2) to make sure the the circuit builder abstraction doesn't get stalled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I factored out the logic for finding the earliest available/accommodating moment, and tentatively placed it in circuits.circuit.get_earliest_accommodating_moment_index
. Suggestions welcome for where this function should be moved.
UPDATE: I have revised my PR since writing the text below in this comment, as well as the text in the follow-up comment. I am leaving these comments here to document my thoughts throughout making the process of changing this PR, but these comments are no longer relevant for making sense of the current code.
[OUTDATED DISCUSSION:]
There are currently two related things that I am dissatisfied with:
circuits/circuit.py
currently has to constructmop_qubits/mkeys/ckeys
twice: (a) in_load_contents_with_earliest_strategy
, and (b) inget_earliest_accommodating_moment_index
.- Despite using the same method/logic for finding the earliest accommodating moment,
circuits/circuit.py
andtransformers/stratify.py
have to independently update thequbit/mkey/ckey_index
dictionaries (here and here).
Fundamentally, both of these issues stem from the fact that _load_contents_with_earliest_strategy
has special logic for what to do with Moment
s rather than Operation
s, which gets tangled with the logic of identifying an operation's qubits/mkeys/ckeys
, finding the earliest accommodating moment, and updating the qubit/mkey/ckey --> index
dictionaries.
Ideas for how to fix these issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[OUTDATED DISCUSSION:]
Two possible resolutions (though there may be better ones):
- Make
get_earliest_accommodating_moment_index
accept aUnion[cirq.Operation, cirq.Moment]
, and for moments return something likemax(*qubit_indices.values(), *mkey_indices.values(), *ckey_indices.values())
. This would make the code relatively "clean", but it would require to compute a maximum every time. In contrast, the current the current implementation keeps track of this maximum "manually". - Again accept a
cirq.Moment
as an input, but also accept an optionalcircuit_length: Optional[int]
, and haveget_earliest_accommodating_moment_index
return the givencircuit_length
ifisinstance(mop, cirq.Moment) and circuit_length is not None
. This option avoids recalculating the circuit length by maximizing over indices, but it is potentially "dangerous" in the sense that the method would happily return an invalid circuit length. On the other hand, garbage in garbage out so maybe that's okay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After a bit more thinking through the issues above, I decided to move the special treatment of Moment
s to get_earliest_accommodating_moment_index. This method now:
- Optionally accepts a
length
argument that is "promised" to be the length of the circuit, and places new moments at the end of the circuit according to that length. - Updates the
qubit/mkey/ckey_indices
dictionaries in-place.
(1) seemed like a reasonable cost to pay for simplifying shared logic between the circuit builder and the stratifier. (2) was primarily done so that the circuit builder doesn't have to call protocols.measurement_key_objs
and protocols.control_keys
twice. However, the stratifier has to later update the qubit/mkey/ckey_indices
dictionaries again because it post-processes the choice of moment index for op
placement. An alternative to (2) would be to leave it to the circuit builder and stratifier to independently update the qubit/mkey/ckey_indices
dictionaries, which would come at the (minor) cost of having the circuit builder call protocols.measurement_key_objs
and protocols.control_keys
twice.
stratified_circuit.batch_remove(batch_removals) | ||
stratified_circuit.batch_insert_into(batch_inserts) | ||
return drop_empty_moments.drop_empty_moments(stratified_circuit) | ||
ignored_op = any(tag in op.tags for tag in context.tags_to_ignore) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this just above the if ignored_op:
check at line 150, it reduces the cognitive overhead for the reader to track how we are dealing ignored ops differently inside the loop.
# Get the earliest time index that can accomodate this op, based on its qubits. | ||
min_time_index_for_op = max(qubit_time_index[qubit] + 1 for qubit in op.qubits) | ||
if ignored_op: | ||
min_time_index_for_op = max(min_time_index_for_op, min_time_index_for_ignored_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, move inside the single if ignored_op:
condition at line 150.
control_time_index: Dict['cirq.MeasurementKey', int] = {} | ||
|
||
# The minimum time index for operations with a tag in context.tags_to_ignore. | ||
min_time_index_for_ignored_op = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider renaming min_time_index_for_ignored_op
to last_ignored_ops_moment_index
because we are essentially keeping track of the latest moment that contains only ignored ops so that we can insert the next ignores ops batch after that time.
Also, compute min_time_index_for_ignored_op
as just the max(min_time_index_for_ignored_op, time_index)
in the inner loop and get rid of the check min_time_index_for_op = max(min_time_index_for_op, min_time_index_for_ignored_op)
. Once you have min_time_index_for_ignored_op
in the inner loop, you can take the max with last_ignored_ops_moment_index
once outside before assigning op_time_indices
to ignored_ops
and then update the last_ignored_ops_moment_index
with the current index.
In general, I think the current logic to take care of ignored ops is a bit convoluted and it's hard to reason about it's correctness. My proposed changes should make it easier for a reader to quickly understand how we are handling the ignored_ops
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these suggestions for how to clean up the handling of ignored_ops
. I'm not sure I implemented them exactly as requested, but the new version is certainly much cleaner. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was to first get through a "minor" stratifier improvement in this PR, and then open a separate, larger PR for the other stratifier
I agree with this. Let's get this PR in and then we can open a separate PR to either replace this or add a new stratifier.
Thanks for the detailed review and suggested changes! I agree that it's quite difficult to keep track of how Anyways, I'll try to get to your suggested changes later this week. |
b71468f
to
1f35a4e
Compare
@tanujkhattar I believe that I have addressed all comments/requests. Please let me know of any other suggestions for what to change in this PR! For example: I suspect that there is a better place for this new function to live. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Let's merge this and open a separate issue to track what's the best place to put the free function get_earliest_accommodating_moment_index
. (xref #6050)
Thank you for the contribution and apologies for the long wait! Let's merge this PR and keep the contributions coming! :)
Thank you! Also thanks to @vtomole for some brief consultation on how to structure some of the changes in this PR 🙂 It looks like some of the continuous integration tests are still failing, but that appears to be due to code that was not touched by this PR. What should I do about it? |
* refactor cirq.transformers.stratify * fix one failing test * nit renaming * fix coverage * formatting fix * pylint fix * fix bug with measurements in stratification * add missing import * fix bug with finding time index for op * fix test, and nit change to keeping track of time indices * hopefully fix measurement bug * minor fix with ignored ops * fix test * only store shortest circuit found * minor bugfix * nit typing fix * one more silly bugfig * store shortest stratified circuit properly * fix bug with overlapping measurements * clean up handling of ignored ops * further clean up logic deciding where to put an op * factor out logic for finding earliest accomodating moment * fix typo * fix typo * remove unnecesaary use of defaultdict * further simplify logic in get_earliest_accommodating_moment_index * fix lint check * fix minor bug * nit docstring update * separately update qubit/mkey/ckey moments in cirq.stratify * fix bug with max only getting one argument * fix coverage check * fix typo * fix lint check --------- Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
This PR refactors
cirq.transformers.stratify.stratified_circuit
, making it considerably faster.I have not contributed to
cirq
before, so I apologize if I am not doing things the "correct" way. Feedback and suggestions much appreciated 🙂.Some comments:
ops_to_ignore
, for a given input (circuit + categories), the output of the new stratifier is identical to the old.num_categories
strata, and then (b) every operation is moved into the earliest stratum possible that matches its category. Ignored ops are not moved into earlier strata at step (b), which can block non-ignored ops from moving into earlier moments. The new stratifier moves ignored ops into earlier moments as well, which can result in shorter circuits. The ignored ops are still "ignored" in the sense that their relative order is unaffected by stratification. If one ignored op preceded another ignored op before stratification, that will still be true after stratification, and two ignored ops will be in the same moment after stratification if and only if they were in the same moment before stratification.On my laptop, this test takes ~0.4 seconds with the old stratifier, and ~0.03 seconds with the new stratifier.
Edit: it looks like this PR is failing some tests, but at the time of writing all failures appear to be some issue with installing
npm
. I am not sure what to do about this...