-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GridDevice gateset, gate_duration, and compilation_target_gateset support #5315
Changes from all commits
ff8aebc
486376e
84bdfe5
17262f0
288a4dc
8135a8c
7a1bcd9
e9de7ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,14 @@ | |
|
||
import re | ||
|
||
from typing import Any, Set, Tuple, cast | ||
from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast | ||
import warnings | ||
|
||
import cirq | ||
from cirq_google import ops | ||
from cirq_google import transformers | ||
from cirq_google.api import v2 | ||
from cirq_google.experimental import ops as experimental_ops | ||
|
||
|
||
def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: | ||
|
@@ -67,6 +72,94 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> | |
raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.") | ||
|
||
|
||
def _build_gateset_and_gate_durations( | ||
proto: v2.device_pb2.DeviceSpecification, | ||
) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]: | ||
"""Extracts gate set and gate duration information from the given DeviceSpecification proto.""" | ||
|
||
gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] | ||
gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} | ||
|
||
# TODO(#5050) Describe how to add/remove gates. | ||
|
||
for gate_spec in proto.valid_gates: | ||
gate_name = gate_spec.WhichOneof('gate') | ||
cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] | ||
|
||
if gate_name == 'syc': | ||
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])] | ||
elif gate_name == 'sqrt_iswap': | ||
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])] | ||
elif gate_name == 'sqrt_iswap_inv': | ||
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])] | ||
elif gate_name == 'cz': | ||
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])] | ||
elif gate_name == 'phased_xz': | ||
cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] | ||
elif gate_name == 'virtual_zpow': | ||
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])] | ||
elif gate_name == 'physical_zpow': | ||
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])] | ||
elif gate_name == 'coupler_pulse': | ||
cirq_gates = [experimental_ops.CouplerPulse] | ||
elif gate_name == 'meas': | ||
cirq_gates = [cirq.MeasurementGate] | ||
elif gate_name == 'wait': | ||
cirq_gates = [cirq.WaitGate] | ||
else: | ||
# coverage: ignore | ||
warnings.warn( | ||
f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized" | ||
" by Cirq and will be ignored. This may be due to an out-of-date Cirq version.", | ||
UserWarning, | ||
) | ||
continue | ||
|
||
gates_list.extend(cirq_gates) | ||
|
||
# TODO(#5050) Allow different gate representations of the same gate to be looked up in | ||
# gate_durations. | ||
for g in cirq_gates: | ||
if not isinstance(g, cirq.GateFamily): | ||
g = cirq.GateFamily(g) | ||
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) | ||
|
||
# TODO(#4833) Add identity gate support | ||
# TODO(#5050) Add GlobalPhaseGate support | ||
|
||
return cirq.Gateset(*gates_list), gate_durations | ||
|
||
|
||
def _build_compilation_target_gatesets( | ||
gateset: cirq.Gateset, | ||
) -> Sequence[cirq.CompilationTargetGateset]: | ||
"""Detects compilation target gatesets based on what gates are inside the gateset. | ||
|
||
If a device contains gates which yield multiple compilation target gatesets, the user can only | ||
choose one target gateset to compile to. For example, a device may contain both SYC and | ||
SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to | ||
either SYC or SQRT_ISWAP for its two-qubit gates, not both. | ||
|
||
TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that | ||
gates which are part of the gateset but not the compilation target gateset are untouched when | ||
compiled. | ||
""" | ||
|
||
# TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google. | ||
|
||
target_gatesets: List[cirq.CompilationTargetGateset] = [] | ||
if cirq.CZ in gateset: | ||
target_gatesets.append(cirq.CZTargetGateset()) | ||
if ops.SYC in gateset: | ||
target_gatesets.append(transformers.SycamoreTargetGateset()) | ||
if cirq.SQRT_ISWAP in gateset: | ||
target_gatesets.append( | ||
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset) | ||
) | ||
verult marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return tuple(target_gatesets) | ||
|
||
|
||
@cirq.value_equality | ||
class GridDevice(cirq.Device): | ||
"""Device object representing Google devices with a grid qubit layout. | ||
|
@@ -112,7 +205,24 @@ class GridDevice(cirq.Device): | |
* Get a collection of approximate gate durations for every gate supported by the device. | ||
>>> device.metadata.gate_durations | ||
|
||
TODO(#5050) Add compilation_target_gatesets example. | ||
* Get a collection of valid CompilationTargetGatesets for the device, which can be used to | ||
transform a circuit to one which only contains gates from a native target gateset | ||
supported by the device. | ||
>>> device.metadata.compilation_target_gatesets | ||
|
||
* Assuming valid CompilationTargetGatesets exist for the device, select the first one and | ||
use it to transform a circuit to one which only contains gates from a native target | ||
gateset supported by the device. | ||
>>> cirq.optimize_for_target_gateset( | ||
circuit, | ||
gateset=device.metadata.compilation_target_gatesets[0] | ||
) | ||
|
||
A note about CompilationTargetGatesets: | ||
|
||
A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using | ||
CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert | ||
WaitGates after the circuit has been transformed. | ||
Comment on lines
+221
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this note should go to the docstring of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that we want to call out Also, if we were to keep this comment in cirq-core, are there other identity-like gates that may not be negligible that we should call out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is mainly used by the Google devices and we don't need to add it to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, will keep the docstring here then since GridDevice doesn't have a |
||
|
||
Notes for cirq_google internal implementation: | ||
|
||
|
@@ -162,12 +272,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': | |
if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC | ||
] | ||
|
||
# TODO(#5050) implement gate durations | ||
gateset, gate_durations = _build_gateset_and_gate_durations(proto) | ||
|
||
try: | ||
metadata = cirq.GridDeviceMetadata( | ||
qubit_pairs=qubit_pairs, | ||
gateset=cirq.Gateset(), # TODO(#5050) implement | ||
gateset=gateset, | ||
gate_durations=gate_durations if len(gate_durations) > 0 else None, | ||
all_qubits=all_qubits, | ||
compilation_target_gatesets=_build_compilation_target_gatesets(gateset), | ||
) | ||
except ValueError as ve: # coverage: ignore | ||
# Spec errors should have been caught in validation above. | ||
|
@@ -194,19 +307,19 @@ def validate_operation(self, operation: cirq.Operation) -> None: | |
Raises: | ||
ValueError: The operation isn't valid for this device. | ||
""" | ||
# TODO(#5050) uncomment once gateset logic is implemented | ||
# if operation not in self._metadata.gateset: | ||
# raise ValueError(f'Operation {operation} is not a supported gate') | ||
|
||
if operation not in self._metadata.gateset: | ||
raise ValueError(f'Operation {operation} contains a gate which is not supported.') | ||
|
||
for q in operation.qubits: | ||
if q not in self._metadata.qubit_set: | ||
raise ValueError(f'Qubit not on device: {q!r}') | ||
raise ValueError(f'Qubit not on device: {q!r}.') | ||
|
||
if ( | ||
len(operation.qubits) == 2 | ||
and frozenset(operation.qubits) not in self._metadata.qubit_pairs | ||
): | ||
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}') | ||
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.') | ||
|
||
def __str__(self) -> str: | ||
diagram = cirq.TextDiagramDrawer() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a private method, but I would still add a docstring since it's a little long.