Skip to content

Commit

Permalink
GridDevice gateset, gate_duration, and compilation_target_gateset sup…
Browse files Browse the repository at this point in the history
…port (quantumlib#5315)

* GridDevice gateset, gate_duration, and compilation_target_gateset support

* Displays a warning about old Cirq version if DeviceSpecification contains gates not recognized by the client.
* Added compilation target gateset example in GridDevice docstring

* Addresses Tanuj's comments

* Added GridDevice docstring comment about WaitGates.

I avoided mentioning the alternative of inserting no_compile tags because transforming a circuit which contains WaitGates is the wrong approach as all delays would be scrambled by the transformation.

* Set gate_duration to None if the spec doesn't have duration info

* Addressed Doug's comments

* Addressed Tanuj's 2nd round of comments

* Break FSimGateFamily into one for each 2q gate type
  • Loading branch information
verult authored and rht committed May 1, 2023
1 parent 468e667 commit 85309b8
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 17 deletions.
131 changes: 122 additions & 9 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@

import re

from typing import Any, Set, Tuple, cast
from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast
import warnings

import cirq
from cirq_google import ops
from cirq_google import transformers
from cirq_google.api import v2
from cirq_google.experimental import ops as experimental_ops


def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
Expand Down Expand Up @@ -67,6 +72,94 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) ->
raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.")


def _build_gateset_and_gate_durations(
proto: v2.device_pb2.DeviceSpecification,
) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]:
"""Extracts gate set and gate duration information from the given DeviceSpecification proto."""

gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []
gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {}

# TODO(#5050) Describe how to add/remove gates.

for gate_spec in proto.valid_gates:
gate_name = gate_spec.WhichOneof('gate')
cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []

if gate_name == 'syc':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])]
elif gate_name == 'sqrt_iswap':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])]
elif gate_name == 'sqrt_iswap_inv':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])]
elif gate_name == 'cz':
cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])]
elif gate_name == 'phased_xz':
cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate]
elif gate_name == 'virtual_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])]
elif gate_name == 'physical_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])]
elif gate_name == 'coupler_pulse':
cirq_gates = [experimental_ops.CouplerPulse]
elif gate_name == 'meas':
cirq_gates = [cirq.MeasurementGate]
elif gate_name == 'wait':
cirq_gates = [cirq.WaitGate]
else:
# coverage: ignore
warnings.warn(
f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized"
" by Cirq and will be ignored. This may be due to an out-of-date Cirq version.",
UserWarning,
)
continue

gates_list.extend(cirq_gates)

# TODO(#5050) Allow different gate representations of the same gate to be looked up in
# gate_durations.
for g in cirq_gates:
if not isinstance(g, cirq.GateFamily):
g = cirq.GateFamily(g)
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos)

# TODO(#4833) Add identity gate support
# TODO(#5050) Add GlobalPhaseGate support

return cirq.Gateset(*gates_list), gate_durations


def _build_compilation_target_gatesets(
gateset: cirq.Gateset,
) -> Sequence[cirq.CompilationTargetGateset]:
"""Detects compilation target gatesets based on what gates are inside the gateset.
If a device contains gates which yield multiple compilation target gatesets, the user can only
choose one target gateset to compile to. For example, a device may contain both SYC and
SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to
either SYC or SQRT_ISWAP for its two-qubit gates, not both.
TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that
gates which are part of the gateset but not the compilation target gateset are untouched when
compiled.
"""

# TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google.

target_gatesets: List[cirq.CompilationTargetGateset] = []
if cirq.CZ in gateset:
target_gatesets.append(cirq.CZTargetGateset())
if ops.SYC in gateset:
target_gatesets.append(transformers.SycamoreTargetGateset())
if cirq.SQRT_ISWAP in gateset:
target_gatesets.append(
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset)
)

return tuple(target_gatesets)


@cirq.value_equality
class GridDevice(cirq.Device):
"""Device object representing Google devices with a grid qubit layout.
Expand Down Expand Up @@ -112,7 +205,24 @@ class GridDevice(cirq.Device):
* Get a collection of approximate gate durations for every gate supported by the device.
>>> device.metadata.gate_durations
TODO(#5050) Add compilation_target_gatesets example.
* Get a collection of valid CompilationTargetGatesets for the device, which can be used to
transform a circuit to one which only contains gates from a native target gateset
supported by the device.
>>> device.metadata.compilation_target_gatesets
* Assuming valid CompilationTargetGatesets exist for the device, select the first one and
use it to transform a circuit to one which only contains gates from a native target
gateset supported by the device.
>>> cirq.optimize_for_target_gateset(
circuit,
gateset=device.metadata.compilation_target_gatesets[0]
)
A note about CompilationTargetGatesets:
A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using
CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert
WaitGates after the circuit has been transformed.
Notes for cirq_google internal implementation:
Expand Down Expand Up @@ -162,12 +272,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice':
if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC
]

# TODO(#5050) implement gate durations
gateset, gate_durations = _build_gateset_and_gate_durations(proto)

try:
metadata = cirq.GridDeviceMetadata(
qubit_pairs=qubit_pairs,
gateset=cirq.Gateset(), # TODO(#5050) implement
gateset=gateset,
gate_durations=gate_durations if len(gate_durations) > 0 else None,
all_qubits=all_qubits,
compilation_target_gatesets=_build_compilation_target_gatesets(gateset),
)
except ValueError as ve: # coverage: ignore
# Spec errors should have been caught in validation above.
Expand All @@ -194,19 +307,19 @@ def validate_operation(self, operation: cirq.Operation) -> None:
Raises:
ValueError: The operation isn't valid for this device.
"""
# TODO(#5050) uncomment once gateset logic is implemented
# if operation not in self._metadata.gateset:
# raise ValueError(f'Operation {operation} is not a supported gate')

if operation not in self._metadata.gateset:
raise ValueError(f'Operation {operation} contains a gate which is not supported.')

for q in operation.qubits:
if q not in self._metadata.qubit_set:
raise ValueError(f'Qubit not on device: {q!r}')
raise ValueError(f'Qubit not on device: {q!r}.')

if (
len(operation.qubits) == 2
and frozenset(operation.qubits) not in self._metadata.qubit_pairs
):
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}')
raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.')

def __str__(self) -> str:
diagram = cirq.TextDiagramDrawer()
Expand Down
76 changes: 68 additions & 8 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,23 @@ def _create_device_spec_with_horizontal_couplings():
# to verify GridDevice properly handles pair symmetry.
new_target = grid_targets.targets.add()
new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)])
gate = spec.valid_gates.add()
gate.syc.SetInParent()
gate.gate_duration_picos = 12000

gate_names = [
'syc',
'sqrt_iswap',
'sqrt_iswap_inv',
'cz',
'phased_xz',
'virtual_zpow',
'physical_zpow',
'coupler_pulse',
'meas',
'wait',
]
for i, g in enumerate(gate_names):
gate = spec.valid_gates.add()
getattr(gate, g).SetInParent()
gate.gate_duration_picos = i * 1000

return grid_qubits, spec

Expand Down Expand Up @@ -153,6 +167,53 @@ def test_grid_device_from_proto():
frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs
for row in range(GRID_HEIGHT)
)
assert device.metadata.gateset == cirq.Gateset(
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]),
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]),
cirq.ops.phased_x_z_gate.PhasedXZGate,
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.phased_x_gate.PhasedXPowGate,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
),
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
cirq.ops.measurement_gate.MeasurementGate,
cirq.ops.wait_gate.WaitGate,
)
assert tuple(device.metadata.compilation_target_gatesets) == (
cirq.CZTargetGateset(),
cirq_google.SycamoreTargetGateset(),
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True),
)

base_duration = cirq.Duration(picos=1_000)
assert device.metadata.gate_durations == {
cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]): base_duration * 0,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2,
cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3,
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
): base_duration
* 5,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
): base_duration
* 6,
cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7,
cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8,
cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9,
}


def test_grid_device_validate_operations_positive():
Expand All @@ -166,23 +227,22 @@ def test_grid_device_validate_operations_positive():
for i in range(GRID_HEIGHT):
device.validate_operation(cirq.CZ(grid_qubits[2 * i], grid_qubits[2 * i + 1]))

# TODO(#5050) verify validate_operations gateset support


def test_grid_device_validate_operations_negative():
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
device = cirq_google.GridDevice.from_proto(spec)

q = cirq.GridQubit(10, 10)
bad_qubit = cirq.GridQubit(10, 10)
with pytest.raises(ValueError, match='Qubit not on device'):
device.validate_operation(cirq.X(q))
device.validate_operation(cirq.X(bad_qubit))

# vertical qubit pair
q00, q10 = grid_qubits[0], grid_qubits[2] # (0, 0), (1, 0)
with pytest.raises(ValueError, match='Qubit pair is not valid'):
device.validate_operation(cirq.CZ(q00, q10))

# TODO(#5050) verify validate_operations gateset errors
with pytest.raises(ValueError, match='gate which is not supported'):
device.validate_operation(cirq.H(grid_qubits[0]))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 85309b8

Please sign in to comment.