diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 07bb8d401d10..27e06df3979d 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -26,6 +26,18 @@ from cirq_google.experimental import ops as experimental_ops +SYC_GATE_FAMILY = cirq.GateFamily(ops.SYC) +SQRT_ISWAP_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP) +SQRT_ISWAP_INV_GATE_FAMILY = cirq.GateFamily(cirq.SQRT_ISWAP_INV) +CZ_GATE_FAMILY = cirq.GateFamily(cirq.CZ) +PHASED_XZ_GATE_FAMILY = cirq.GateFamily(cirq.PhasedXZGate) +VIRTUAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()]) +PHYSICAL_ZPOW_GATE_FAMILY = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()]) +COUPLER_PULSE_GATE_FAMILY = cirq.GateFamily(experimental_ops.CouplerPulse) +MEASUREMENT_GATE_FAMILY = cirq.GateFamily(cirq.MeasurementGate) +WAIT_GATE_FAMILY = cirq.GateFamily(cirq.WaitGate) + + def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: """Raises a ValueError if the `DeviceSpecification` proto is invalid.""" @@ -459,31 +471,32 @@ def _value_equality_values_(self): def _set_gate_in_gate_spec( gate_spec: v2.device_pb2.GateSpecification, gate_family: cirq.GateFamily ) -> None: - if gate_family == cirq.GateFamily(ops.SYC): + if gate_family == SYC_GATE_FAMILY: gate_spec.syc.SetInParent() - elif gate_family == cirq.GateFamily(cirq.SQRT_ISWAP): + elif gate_family == SQRT_ISWAP_GATE_FAMILY: gate_spec.sqrt_iswap.SetInParent() - elif gate_family == cirq.GateFamily(cirq.SQRT_ISWAP_INV): + elif gate_family == SQRT_ISWAP_INV_GATE_FAMILY: gate_spec.sqrt_iswap_inv.SetInParent() - elif gate_family == cirq.GateFamily(cirq.CZ): + elif gate_family == CZ_GATE_FAMILY: gate_spec.cz.SetInParent() - elif gate_family == cirq.GateFamily(cirq.PhasedXZGate): + elif gate_family == PHASED_XZ_GATE_FAMILY: gate_spec.phased_xz.SetInParent() - elif gate_family == cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()]): + elif gate_family == VIRTUAL_ZPOW_GATE_FAMILY: gate_spec.virtual_zpow.SetInParent() - elif gate_family == cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()]): + elif gate_family == PHYSICAL_ZPOW_GATE_FAMILY: gate_spec.physical_zpow.SetInParent() - elif gate_family == cirq.GateFamily(experimental_ops.CouplerPulse): + elif gate_family == COUPLER_PULSE_GATE_FAMILY: gate_spec.coupler_pulse.SetInParent() - elif gate_family == cirq.GateFamily(cirq.MeasurementGate): + elif gate_family == MEASUREMENT_GATE_FAMILY: gate_spec.meas.SetInParent() - elif gate_family == cirq.GateFamily(cirq.WaitGate): + elif gate_family == WAIT_GATE_FAMILY: gate_spec.wait.SetInParent() else: raise ValueError(f'Unrecognized gate {gate_family}.') -def to_proto( +def create_device_specification_proto( + *, qubits: Collection[cirq.GridQubit], pairs: Collection[Tuple[cirq.GridQubit, cirq.GridQubit]], gateset: cirq.Gateset, @@ -519,9 +532,8 @@ def to_proto( out = v2.device_pb2.DeviceSpecification() # If fields are already filled (i.e. as part of the old DeviceSpecification format), leave them - # as is. - # Fields populated in the new format do not conflict with how they were populated in the old - # format. + # as is. Fields populated in the new format do not conflict with how they were populated in the + # old format. # TODO(#5050) remove empty checks below once deprecated fields in DeviceSpecification are # removed. diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 2b15fe40837f..bc0678447c26 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -372,11 +372,11 @@ def test_to_proto(): cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, } - spec = grid_device.to_proto( - device_info.grid_qubits, - device_info.qubit_pairs, - cirq.Gateset(*gate_durations.keys()), - gate_durations, + spec = grid_device.create_device_specification_proto( + qubits=device_info.grid_qubits, + pairs=device_info.qubit_pairs, + gateset=cirq.Gateset(*gate_durations.keys()), + gate_durations=gate_durations, ) assert text_format.MessageToString(spec) == text_format.MessageToString(expected_spec) @@ -412,7 +412,9 @@ def test_to_proto(): ) def test_to_proto_invalid_input(error_match, qubits, qubit_pairs, gateset, gate_durations): with pytest.raises(ValueError, match=error_match): - grid_device.to_proto(qubits, qubit_pairs, gateset, gate_durations) + grid_device.create_device_specification_proto( + qubits=qubits, pairs=qubit_pairs, gateset=gateset, gate_durations=gate_durations + ) def test_to_proto_backward_compatibility(): @@ -449,12 +451,12 @@ def test_to_proto_backward_compatibility(): ) # Serialize the new way - grid_device.to_proto( - device_info.grid_qubits, - device_info.qubit_pairs, - cirq.Gateset(*gate_durations.keys()), - gate_durations, - spec, + grid_device.create_device_specification_proto( + qubits=device_info.grid_qubits, + pairs=device_info.qubit_pairs, + gateset=cirq.Gateset(*gate_durations.keys()), + gate_durations=gate_durations, + out=spec, ) # Deserialize both ways @@ -487,7 +489,7 @@ def test_to_proto_backward_compatibility(): def test_to_proto_empty(): - spec = grid_device.to_proto( + spec = grid_device.create_device_specification_proto( # Qubits are always expected to be set qubits=[cirq.GridQubit(0, i) for i in range(5)], pairs=[], diff --git a/cirq-google/cirq_google/devices/known_devices.py b/cirq-google/cirq_google/devices/known_devices.py index 2f7e127e1cca..ebd40f992f74 100644 --- a/cirq-google/cirq_google/devices/known_devices.py +++ b/cirq-google/cirq_google/devices/known_devices.py @@ -184,12 +184,24 @@ def create_device_proto_for_qubits( def populate_qubits_in_device_proto( qubits: Collection[cirq.Qid], out: device_pb2.DeviceSpecification ) -> None: + """Populates `DeviceSpecification.valid_qubits` with the device's qubits. + + Args: + qubits: The collection of the device's qubits. + out: The `DeviceSpecification` to be populated. + """ out.valid_qubits.extend(v2.qubit_to_proto_id(q) for q in qubits) def populate_qubit_pairs_in_device_proto( pairs: Collection[Tuple[cirq.Qid, cirq.Qid]], out: device_pb2.DeviceSpecification ) -> None: + """Populates `DeviceSpecification.valid_targets` with the device's qubit pairs. + + Args: + pairs: The collection of the device's bi-directional qubit pairs. + out: The `DeviceSpecification` to be populated. + """ grid_targets = out.valid_targets.add() grid_targets.name = _2_QUBIT_TARGET_SET grid_targets.target_ordering = device_pb2.TargetSet.SYMMETRIC