Skip to content

Commit

Permalink
Addressed Doug's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
verult committed Jun 3, 2022
1 parent 74f14fc commit 6a2402e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 27 deletions.
40 changes: 26 additions & 14 deletions cirq-google/cirq_google/devices/grid_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
from cirq_google.experimental import ops as experimental_ops


SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC)
SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP)
SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV)
CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ)
PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate)
VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])
PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])
COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse)
MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate)
WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate)


def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None:
"""Raises a ValueError if the `DeviceSpecification` proto is invalid."""

Expand Down Expand Up @@ -459,31 +471,32 @@ def _value_equality_values_(self):
def _set_gate_in_gate_spec(
gate_spec: v2.device_pb2.GateSpecification, gate_family: cirq.GateFamily
) -> None:
if gate_family == cirq.GateFamily(ops.SYC):
if gate_family == SYC_GATE_FAMILY:
gate_spec.syc.SetInParent()
elif gate_family == cirq.GateFamily(cirq.SQRT_ISWAP):
elif gate_family == SQRT_ISWAP_GATE_FAMILY:
gate_spec.sqrt_iswap.SetInParent()
elif gate_family == cirq.GateFamily(cirq.SQRT_ISWAP_INV):
elif gate_family == SQRT_ISWAP_INV_GATE_FAMILY:
gate_spec.sqrt_iswap_inv.SetInParent()
elif gate_family == cirq.GateFamily(cirq.CZ):
elif gate_family == CZ_GATE_FAMILY:
gate_spec.cz.SetInParent()
elif gate_family == cirq.GateFamily(cirq.PhasedXZGate):
elif gate_family == PHASED_XZ_GATE_FAMILY:
gate_spec.phased_xz.SetInParent()
elif gate_family == cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()]):
elif gate_family == VIRTUAL_ZPOW_GATE_FAMILY:
gate_spec.virtual_zpow.SetInParent()
elif gate_family == cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()]):
elif gate_family == PHYSICAL_ZPOW_GATE_FAMILY:
gate_spec.physical_zpow.SetInParent()
elif gate_family == cirq.GateFamily(experimental_ops.CouplerPulse):
elif gate_family == COUPLER_PULSE_GATE_FAMILY:
gate_spec.coupler_pulse.SetInParent()
elif gate_family == cirq.GateFamily(cirq.MeasurementGate):
elif gate_family == MEASUREMENT_GATE_FAMILY:
gate_spec.meas.SetInParent()
elif gate_family == cirq.GateFamily(cirq.WaitGate):
elif gate_family == WAIT_GATE_FAMILY:
gate_spec.wait.SetInParent()
else:
raise ValueError(f'Unrecognized gate {gate_family}.')


def to_proto(
def create_device_specification_proto(
*,
qubits: Collection[cirq.GridQubit],
pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]],
gateset: cirq.Gateset,
Expand Down Expand Up @@ -519,9 +532,8 @@ def to_proto(
out = v2.device_pb2.DeviceSpecification()

# If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them
# as is.
# Fields populated in the new format do not conflict with how they were populated in the old
# format.
# as is. Fields populated in the new format do not conflict with how they were populated in the
# old format.
# TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are
# removed.

Expand Down
28 changes: 15 additions & 13 deletions cirq-google/cirq_google/devices/grid_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,11 @@ def test_to_proto():
cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9,
}

spec = grid_device.to_proto(
device_info.grid_qubits,
device_info.qubit_pairs,
cirq.Gateset(*gate_durations.keys()),
gate_durations,
spec = grid_device.create_device_specification_proto(
qubits=device_info.grid_qubits,
pairs=device_info.qubit_pairs,
gateset=cirq.Gateset(*gate_durations.keys()),
gate_durations=gate_durations,
)

assert text_format.MessageToString(spec) == text_format.MessageToString(expected_spec)
Expand Down Expand Up @@ -412,7 +412,9 @@ def test_to_proto():
)
def test_to_proto_invalid_input(error_match, qubits, qubit_pairs, gateset, gate_durations):
with pytest.raises(ValueError, match=error_match):
grid_device.to_proto(qubits, qubit_pairs, gateset, gate_durations)
grid_device.create_device_specification_proto(
qubits=qubits, pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations
)


def test_to_proto_backward_compatibility():
Expand Down Expand Up @@ -449,12 +451,12 @@ def test_to_proto_backward_compatibility():
)

# Serialize the new way
grid_device.to_proto(
device_info.grid_qubits,
device_info.qubit_pairs,
cirq.Gateset(*gate_durations.keys()),
gate_durations,
spec,
grid_device.create_device_specification_proto(
qubits=device_info.grid_qubits,
pairs=device_info.qubit_pairs,
gateset=cirq.Gateset(*gate_durations.keys()),
gate_durations=gate_durations,
out=spec,
)

# Deserialize both ways
Expand Down Expand Up @@ -487,7 +489,7 @@ def test_to_proto_backward_compatibility():


def test_to_proto_empty():
spec = grid_device.to_proto(
spec = grid_device.create_device_specification_proto(
# Qubits are always expected to be set
qubits=[cirq.GridQubit(0, i) for i in range(5)],
pairs=[],
Expand Down
12 changes: 12 additions & 0 deletions cirq-google/cirq_google/devices/known_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,24 @@ def create_device_proto_for_qubits(
def populate_qubits_in_device_proto(
qubits: Collection[cirq.Qid], out: device_pb2.DeviceSpecification
) -> None:
"""Populates `DeviceSpecification.valid_qubits` with the device's qubits.
Args:
qubits: The collection of the device's qubits.
out: The `DeviceSpecification` to be populated.
"""
out.valid_qubits.extend(v2.qubit_to_proto_id(q) for q in qubits)


def populate_qubit_pairs_in_device_proto(
pairs: Collection[Tuple[cirq.Qid, cirq.Qid]], out: device_pb2.DeviceSpecification
) -> None:
"""Populates `DeviceSpecification.valid_targets` with the device's qubit pairs.
Args:
pairs: The collection of the device's bi-directional qubit pairs.
out: The `DeviceSpecification` to be populated.
"""
grid_targets = out.valid_targets.add()
grid_targets.name = _2_QUBIT_TARGET_SET
grid_targets.target_ordering = device_pb2.TargetSet.SYMMETRIC
Expand Down

0 comments on commit 6a2402e

Please sign in to comment.