Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridDevice gateset, gate_duration, and compilation_target_gateset support #5315

Merged
131 changes: 122 additions & 9 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@

import re

from typing import Any, Set, Tuple, cast
from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast
import warnings

import cirq
from cirq_google import ops
from cirq_google import transformers
from cirq_google.api import v2
from cirq_google.experimental import ops as experimental_ops


def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
Expand Down Expand Up @@ -67,6 +72,94 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) ->
raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.")


def _build_gateset_and_gate_durations(
proto: v2.device_pb2.DeviceSpecification,
) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private method, but I would still add a docstring since it's a little long.

"""Extracts gate set and gate duration information from the given DeviceSpecification proto."""

gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []
gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {}

# TODO(#5050) Describe how to add/remove gates.

for gate_spec in proto.valid_gates:
gate_name = gate_spec.WhichOneof('gate')
cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []

if gate_name == 'syc':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])]
elif gate_name == 'sqrt_iswap':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])]
elif gate_name == 'sqrt_iswap_inv':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])]
elif gate_name == 'cz':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])]
elif gate_name == 'phased_xz':
cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate]
elif gate_name == 'virtual_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])]
elif gate_name == 'physical_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])]
elif gate_name == 'coupler_pulse':
cirq_gates = [experimental_ops.CouplerPulse]
elif gate_name == 'meas':
cirq_gates = [cirq.MeasurementGate]
elif gate_name == 'wait':
cirq_gates = [cirq.WaitGate]
else:
# coverage: ignore
warnings.warn(
f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized"
" by Cirq and will be ignored. This may be due to an out-of-date Cirq version.",
UserWarning,
)
continue

gates_list.extend(cirq_gates)

# TODO(#5050) Allow different gate representations of the same gate to be looked up in
# gate_durations.
for g in cirq_gates:
if not isinstance(g, cirq.GateFamily):
g = cirq.GateFamily(g)
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos)

# TODO(#4833) Add identity gate support
# TODO(#5050) Add GlobalPhaseGate support

return cirq.Gateset(*gates_list), gate_durations


def _build_compilation_target_gatesets(
gateset: cirq.Gateset,
) -> Sequence[cirq.CompilationTargetGateset]:
"""Detects compilation target gatesets based on what gates are inside the gateset.

If a device contains gates which yield multiple compilation target gatesets, the user can only
choose one target gateset to compile to. For example, a device may contain both SYC and
SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to
either SYC or SQRT_ISWAP for its two-qubit gates, not both.

TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that
gates which are part of the gateset but not the compilation target gateset are untouched when
compiled.
"""

# TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google.

target_gatesets: List[cirq.CompilationTargetGateset] = []
if cirq.CZ in gateset:
target_gatesets.append(cirq.CZTargetGateset())
if ops.SYC in gateset:
target_gatesets.append(transformers.SycamoreTargetGateset())
if cirq.SQRT_ISWAP in gateset:
target_gatesets.append(
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset)
)
verult marked this conversation as resolved.
Show resolved Hide resolved

return tuple(target_gatesets)


@cirq.value_equality
class GridDevice(cirq.Device):
"""Device object representing Google devices with a grid qubit layout.
Expand Down Expand Up @@ -112,7 +205,24 @@ class GridDevice(cirq.Device):
* Get a collection of approximate gate durations for every gate supported by the device.
>>> device.metadata.gate_durations

TODO(#5050) Add compilation_target_gatesets example.
* Get a collection of valid CompilationTargetGatesets for the device, which can be used to
transform a circuit to one which only contains gates from a native target gateset
supported by the device.
>>> device.metadata.compilation_target_gatesets

* Assuming valid CompilationTargetGatesets exist for the device, select the first one and
use it to transform a circuit to one which only contains gates from a native target
gateset supported by the device.
>>> cirq.optimize_for_target_gateset(
circuit,
gateset=device.metadata.compilation_target_gatesets[0]
)

A note about CompilationTargetGatesets:

A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using
CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert
WaitGates after the circuit has been transformed.
Comment on lines +221 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this note should go to the docstring of compilation_target_gatesets property. This is a specific detail of compilation targets we are using today, and should not be tied to the description of device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we want to call out WaitGate because Google devices all support this gate. Is this concern general enough to all devices to be kept in the compilation_target_gateset property in cirq.GridDeviceMetadata? If so, I would suggest to keep it in CompilationTargetGateset instead, since all CompilationTargetGatesets would have this issue. If not, I would keep it where it is now.

Also, if we were to keep this comment in cirq-core, are there other identity-like gates that may not be negligible that we should call out?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is mainly used by the Google devices and we don't need to add it to the CompilationTargetGateset class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will keep the docstring here then since GridDevice doesn't have a compilation_target_gatesets property.


Notes for cirq_google internal implementation:

Expand Down Expand Up @@ -162,12 +272,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice':
if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC
]

# TODO(#5050) implement gate durations
gateset, gate_durations = _build_gateset_and_gate_durations(proto)

try:
metadata = cirq.GridDeviceMetadata(
qubit_pairs=qubit_pairs,
gateset=cirq.Gateset(), # TODO(#5050) implement
gateset=gateset,
gate_durations=gate_durations if len(gate_durations) > 0 else None,
all_qubits=all_qubits,
compilation_target_gatesets=_build_compilation_target_gatesets(gateset),
)
except ValueError as ve: # coverage: ignore
# Spec errors should have been caught in validation above.
Expand All @@ -194,19 +307,19 @@ def validate_operation(self, operation: cirq.Operation) -> None:
Raises:
ValueError: The operation isn't valid for this device.
"""
# TODO(#5050) uncomment once gateset logic is implemented
# if operation not in self._metadata.gateset:
# raise ValueError(f'Operation {operation} is not a supported gate')

if operation not in self._metadata.gateset:
raise ValueError(f'Operation {operation} contains a gate which is not supported.')

for q in operation.qubits:
if q not in self._metadata.qubit_set:
raise ValueError(f'Qubit not on device: {q!r}')
raise ValueError(f'Qubit not on device: {q!r}.')

if (
len(operation.qubits) == 2
and frozenset(operation.qubits) not in self._metadata.qubit_pairs
):
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}')
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.')

def __str__(self) -> str:
diagram = cirq.TextDiagramDrawer()
Expand Down
76 changes: 68 additions & 8 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,23 @@ def _create_device_spec_with_horizontal_couplings():
# to verify GridDevice properly handles pair symmetry.
new_target = grid_targets.targets.add()
new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)])
gate = spec.valid_gates.add()
gate.syc.SetInParent()
gate.gate_duration_picos = 12000

gate_names = [
'syc',
'sqrt_iswap',
'sqrt_iswap_inv',
'cz',
'phased_xz',
'virtual_zpow',
'physical_zpow',
'coupler_pulse',
'meas',
'wait',
]
for i, g in enumerate(gate_names):
gate = spec.valid_gates.add()
getattr(gate, g).SetInParent()
gate.gate_duration_picos = i * 1000

return grid_qubits, spec

Expand Down Expand Up @@ -153,6 +167,53 @@ def test_grid_device_from_proto():
frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs
for row in range(GRID_HEIGHT)
)
assert device.metadata.gateset == cirq.Gateset(
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.ops.phased_x_z_gate.PhasedXZGate,
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.phased_x_gate.PhasedXPowGate,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
),
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
cirq.ops.measurement_gate.MeasurementGate,
cirq.ops.wait_gate.WaitGate,
)
assert tuple(device.metadata.compilation_target_gatesets) == (
cirq.CZTargetGateset(),
cirq_google.SycamoreTargetGateset(),
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True),
)

base_duration = cirq.Duration(picos=1_000)
assert device.metadata.gate_durations == {
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]): base_duration * 0,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3,
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
): base_duration
* 5,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
): base_duration
* 6,
cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7,
cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8,
cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9,
}


def test_grid_device_validate_operations_positive():
Expand All @@ -166,23 +227,22 @@ def test_grid_device_validate_operations_positive():
for i in range(GRID_HEIGHT):
device.validate_operation(cirq.CZ(grid_qubits[2 * i], grid_qubits[2 * i + 1]))

# TODO(#5050) verify validate_operations gateset support


def test_grid_device_validate_operations_negative():
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
device = cirq_google.GridDevice.from_proto(spec)

q = cirq.GridQubit(10, 10)
bad_qubit = cirq.GridQubit(10, 10)
with pytest.raises(ValueError, match='Qubit not on device'):
device.validate_operation(cirq.X(q))
device.validate_operation(cirq.X(bad_qubit))

# vertical qubit pair
q00, q10 = grid_qubits[0], grid_qubits[2] # (0, 0), (1, 0)
with pytest.raises(ValueError, match='Qubit pair is not valid'):
device.validate_operation(cirq.CZ(q00, q10))

# TODO(#5050) verify validate_operations gateset errors
with pytest.raises(ValueError, match='gate which is not supported'):
device.validate_operation(cirq.H(grid_qubits[0]))


@pytest.mark.parametrize(
Expand Down