Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridDevice gateset, gate_duration, and compilation_target_gateset support #5315

Merged
117 changes: 110 additions & 7 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@

import re

from typing import Any, Set, Tuple, cast
from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast
import warnings

import cirq
from cirq_google import ops
from cirq_google import transformers
from cirq_google.api import v2
from cirq_google.experimental import ops as experimental_ops
from cirq_google.ops.fsim_gate_family import POSSIBLE_FSIM_GATES


def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
Expand Down Expand Up @@ -77,6 +83,91 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) ->
)


def _build_gateset_and_gate_durations(
proto: v2.device_pb2.DeviceSpecification,
) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private method, but I would still add a docstring since it's a little long.

gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []
fsim_gates: List[Union[Type[POSSIBLE_FSIM_GATES], POSSIBLE_FSIM_GATES]] = []
gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {}

# TODO(#5050) Describe how to add/remove gates.

for gate_spec in proto.valid_gates:
gate_name = gate_spec.WhichOneof('gate')
cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = []

if gate_name == 'syc':
cirq_gates = [ops.SYC]
fsim_gates.append(ops.SYC)
elif gate_name == 'sqrt_iswap':
cirq_gates = [cirq.SQRT_ISWAP]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not push cirq.SQRT_ISWAP to cirq_gates since gates_list.extend(cirq_gates) below will add it to gates_list; and we will again add the equivalent gates_list.append(ops.FSimGateFamily(gates_to_accept=fsim_gates)).

The gate durations should also be specified for the corresponding fsim gate family; since we are accepting all equivalents operations across types; instead of specifying it only for cirq.GateFamily(cirq.SQRT_ISWAP) which is what is being done right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different FSimGates have different durations today:

'fsim_pi_4': 32_000,
'inv_fsim_pi_4': 32_000,
'syc': 12_000,

And in general, each gate type in DeviceSpecification can have a separate duration. IIUC including the gate, while redundant, doesn't change gate validation behavior.

Since gate duration is purely informational, IMO it's OK to not have all the equivalent forms of a gate as part of the key. The alternative solution of having duration under the FSimGateFamily and taking the min/max/some other aggregate of the durations of all the different gates necessarily loses some information.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC including the gate, while redundant, doesn't change gate validation behavior.

The purpose of a gateset is description and validation. Having redundant gates makes description more confusing and verbose, which should be avoided.

Different FSimGates have different durations today:

We can insert a separate fsim gate family instance for each fsim gate type in DeviceSpecification; so we can associate gate duration with the corresponding fsim gate family. i.e.

# gate durations dict should contain this.
{cirq.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]):  32_000,
cirq.FSimGateFamily(gates_to_accept=[cg.SYC]):  12_000}
# instead of .
{cirq.GateFamily(cirq.SQRT_ISWAP):  32_000,
cirq.GateFamily(cg.SYC):  12_000}

In general, we should be consistent with the gate families that we are using for description and validation, so that inconsistencies like the following cannot occur:

$> sqrt_iswap_gate = cirq.FSimGate(-np.pi/4, 0)
# Returns True because of validation across types by FSimGateFamily.
$> assert sqrt_iswap_gate in sqrt_iswap_metadata.gateset 
# Should return True but will return False right now because `sqrt_iswap_gate in cirq.GateFamily(cirq.SQRT_ISWAP)` is False
$> assert any(sqrt_iswap_gate in gf for gf in sqrt_iswap_metadata.gate_durations) 

Copy link
Collaborator Author

@verult verult May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this makes a lot of sense, and I'll push for a change in the GridDeviceMetadata.gate_duration. Didn't end up using a function, but the proposal captures the same spirit. Will change to the approach you suggested, and leave a TODO to update it once #5427 moves forward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually your approach currently fails GridDeviceMetadata validation because all gate duration keys are expected to be in the gateset. Let's discuss in #5427 since it potentially involves a change in GridDeviceMetadata. Are you onboard with keeping the gate duration the way it is and leaving a TODO for followup?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline. Will change the gatesets to have a separate FSimGateFamily for each 2q gate instead of a single FSimGateFamily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. I considered changing the test to a more black-box approach of checking whether certain gates belong in the gateset rather than asserting via gateset equality, but decided against it because we do want to test the exact gateset elements to verify its string representation.

fsim_gates.append(cirq.SQRT_ISWAP)
elif gate_name == 'sqrt_iswap_inv':
cirq_gates = [cirq.SQRT_ISWAP_INV]
fsim_gates.append(cirq.SQRT_ISWAP_INV)
elif gate_name == 'cz':
cirq_gates = [cirq.CZ]
fsim_gates.append(cirq.CZ)
elif gate_name == 'phased_xz':
cirq_gates = [
cirq.PhasedXZGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.ZPowGate,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a cirq.ZPowGate gate here given that we have special identifiers for virtual_zpow and physical_zpow ? Note that if cirq.GateFamily(cirq.ZPowGate) is present in the gateset, it will accept all (i.e. both tagged and untagged) instances of ZPowGate; in which case the physical_zpow and virtual_zpow gate families will be irrelevant.

For example: What happens if the proto specification includes phased_xz and doesn't contain virtual_zpow ? Should a cirq.Z(q) be accepted or not? What if it contains phased_xz and physical_zpow but doesn't contain virtual_zpow ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the acceptance of ZPowGates should definitely be controlled entirely by virtual_zpow and physical_zpow. Thanks for the catch!

For the case where the spec includes phased_xz but not virtual_zpow, because PhysicalZTag() on a PhasedXZGate is ignored, a Z gate specified as a PhasedXZGate should be accepted and will be applied by the device in some way.

cirq.PhasedXPowGate,
]
elif gate_name == 'virtual_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])]
elif gate_name == 'physical_zpow':
cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])]
elif gate_name == 'coupler_pulse':
cirq_gates = [experimental_ops.CouplerPulse]
elif gate_name == 'meas':
cirq_gates = [cirq.MeasurementGate]
elif gate_name == 'wait':
cirq_gates = [cirq.WaitGate]
else:
# coverage: ignore
warnings.warn(
f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized"
" by Cirq and will be ignored. This may be due to an out-of-date Cirq version.",
UserWarning,
)
continue

gates_list.extend(cirq_gates)
for g in cirq_gates:
if not isinstance(g, cirq.GateFamily):
g = cirq.GateFamily(g)
gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos)

if fsim_gates:
gates_list.append(ops.FSimGateFamily(gates_to_accept=fsim_gates))

# TODO(#4833) Add identity gate support
# TODO(#5050) Add GlobalPhaseGate support

return cirq.Gateset(*gates_list), gate_durations


def _build_compilation_target_gatesets(
gateset: cirq.Gateset,
) -> Sequence[cirq.CompilationTargetGateset]:
"""Detects compilation target gatesets based on what gates are inside the gateset."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if we support multiple gatesets, do we have a way to decompose to a combination of CZ and sqrt-iswap for instance? Or do we have to pick one?

I think picking one is probably fine, since most grids will likely only support one type of gate, but we should make this clear in the documentation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we would have to pick one. If there's a use case to compile to both CZ and sqrt-iswap in the future we could revisit; doing so now for every target would probably create a combinatorial explosion of target gatesets.

On the other hand, if a device supports both CZ and sqrt-iswap, and someone wants to compile a circuit containing sqrt-iswap + some other arbitrary 2q gates using a CZ target gateset, the intended behavior is to leave sqrt-iswap gates untouched and compile other 2q gates to CZ + 1q gates. This isn't supported yet by existing target gatesets, but I'm planning to change that.

Will clarify this behavior in documentation.


target_gatesets: List[cirq.CompilationTargetGateset] = []
if cirq.CZ in gateset:
target_gatesets.append(cirq.CZTargetGateset())
if ops.SYC in gateset:
target_gatesets.append(transformers.SycamoreTargetGateset())
if cirq.SQRT_ISWAP in gateset:
target_gatesets.append(
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset)
)
verult marked this conversation as resolved.
Show resolved Hide resolved

return tuple(target_gatesets)


@cirq.value_equality
class GridDevice(cirq.Device):
"""Device object representing Google devices with a grid qubit layout.
Expand Down Expand Up @@ -122,7 +213,16 @@ class GridDevice(cirq.Device):
* Get a collection of approximate gate durations for every gate supported by the device.
>>> device.metadata.gate_durations

TODO(#5050) Add compilation_target_gatesets example.
* Get a collection of valid CompilationTargetGatesets for the device, which can be used to
transform a circuit which is invalid for the device to a valid one.
verult marked this conversation as resolved.
Show resolved Hide resolved
>>> device.metadata.compilation_target_gatesets

* Assuming valid CompilationTargetGatesets exist for the device, select the first one and
use it to transform a circuit to an equivalent form which is valid for the device.
>>> cirq.optimize_for_target_gateset(
circuit,
gateset=device.metadata.compilation_target_gatesets[0]
)

Notes for cirq_google internal implementation:

Expand Down Expand Up @@ -187,12 +287,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice':
if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC
]

# TODO(#5050) implement gate durations
gateset, gate_durations = _build_gateset_and_gate_durations(proto)

try:
metadata = cirq.GridDeviceMetadata(
qubit_pairs=qubit_pairs,
gateset=cirq.Gateset(), # TODO(#5050) implement
gateset=gateset,
gate_durations=gate_durations,
all_qubits=all_qubits,
compilation_target_gatesets=_build_compilation_target_gatesets(gateset),
)
except ValueError as ve: # coverage: ignore
# Spec errors should have been caught in validation above.
Expand All @@ -219,9 +322,9 @@ def validate_operation(self, operation: cirq.Operation) -> None:
Raises:
ValueError: The operation isn't valid for this device.
"""
# TODO(#5050) uncomment once gateset logic is implemented
# if operation not in self._metadata.gateset:
# raise ValueError(f'Operation {operation} is not a supported gate')

if operation not in self._metadata.gateset:
raise ValueError(f'Operation {operation} is not a supported gate')
verult marked this conversation as resolved.
Show resolved Hide resolved

for q in operation.qubits:
if q not in self._metadata.qubit_set:
Expand Down
82 changes: 73 additions & 9 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,23 @@ def _create_device_spec_with_horizontal_couplings():
# to verify GridDevice properly handles pair symmetry.
new_target = grid_targets.targets.add()
new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)])
gate = spec.valid_gates.add()
gate.syc.SetInParent()
gate.gate_duration_picos = 12000
gate.valid_targets.extend(['2_qubit_targets'])

gate_names = [
'syc',
'sqrt_iswap',
'sqrt_iswap_inv',
'cz',
'phased_xz',
'virtual_zpow',
'physical_zpow',
'coupler_pulse',
'meas',
'wait',
]
for i, g in enumerate(gate_names):
gate = spec.valid_gates.add()
getattr(gate, g).SetInParent()
gate.gate_duration_picos = i * 1000

return grid_qubits, spec

Expand Down Expand Up @@ -159,6 +172,58 @@ def test_grid_device_from_proto():
frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs
for row in range(GRID_HEIGHT)
)
assert device.metadata.gateset == cirq.Gateset(
cirq_google.SYC,
cirq.SQRT_ISWAP,
cirq.SQRT_ISWAP_INV,
cirq.CZ,
cirq.ops.phased_x_z_gate.PhasedXZGate,
cirq.ops.common_gates.XPowGate,
cirq.ops.common_gates.YPowGate,
cirq.ops.common_gates.ZPowGate,
cirq.ops.phased_x_gate.PhasedXPowGate,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
),
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
),
cirq_google.experimental.ops.coupler_pulse.CouplerPulse,
cirq.ops.measurement_gate.MeasurementGate,
cirq.ops.wait_gate.WaitGate,
cirq_google.FSimGateFamily(
gates_to_accept=[cirq_google.SYC, cirq.SQRT_ISWAP, cirq.SQRT_ISWAP_INV, cirq.CZ]
),
)
assert tuple(device.metadata.compilation_target_gatesets) == (
cirq.CZTargetGateset(),
cirq_google.SycamoreTargetGateset(),
cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True),
)

base_duration = cirq.Duration(picos=1_000)
assert device.metadata.gate_durations == {
cirq.GateFamily(cirq_google.SYC): base_duration * 0,
cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1,
cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2,
cirq.GateFamily(cirq.CZ): base_duration * 3,
cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.common_gates.ZPowGate): base_duration * 4,
cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()]
): base_duration
* 5,
cirq.GateFamily(
cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()]
): base_duration
* 6,
cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7,
cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8,
cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9,
}


def test_grid_device_validate_operations_positive():
Expand All @@ -172,23 +237,22 @@ def test_grid_device_validate_operations_positive():
for i in range(GRID_HEIGHT):
device.validate_operation(cirq.CZ(grid_qubits[2 * i], grid_qubits[2 * i + 1]))

# TODO(#5050) verify validate_operations gateset support


def test_grid_device_validate_operations_negative():
grid_qubits, spec = _create_device_spec_with_horizontal_couplings()
device = cirq_google.GridDevice.from_proto(spec)

q = cirq.GridQubit(10, 10)
bad_qubit = cirq.GridQubit(10, 10)
with pytest.raises(ValueError, match='Qubit not on device'):
device.validate_operation(cirq.X(q))
device.validate_operation(cirq.X(bad_qubit))

# vertical qubit pair
q00, q10 = grid_qubits[0], grid_qubits[2] # (0, 0), (1, 0)
with pytest.raises(ValueError, match='Qubit pair is not valid'):
device.validate_operation(cirq.CZ(q00, q10))

# TODO(#5050) verify validate_operations gateset errors
with pytest.raises(ValueError, match='not a supported gate'):
device.validate_operation(cirq.H(grid_qubits[0]))


@pytest.mark.parametrize(
Expand Down