diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index bfb511ce012..fdffd31233f 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -16,9 +16,14 @@ import re -from typing import Any, Set, Tuple, cast +from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast +import warnings + import cirq +from cirq_google import ops +from cirq_google import transformers from cirq_google.api import v2 +from cirq_google.experimental import ops as experimental_ops def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: @@ -67,6 +72,94 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> raise ValueError("Invalid DeviceSpecification: target_ordering cannot be ASYMMETRIC.") +def _build_gateset_and_gate_durations( + proto: v2.device_pb2.DeviceSpecification, +) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]: + """Extracts gate set and gate duration information from the given DeviceSpecification proto.""" + + gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} + + # TODO(#5050) Describe how to add/remove gates. + + for gate_spec in proto.valid_gates: + gate_name = gate_spec.WhichOneof('gate') + cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + + if gate_name == 'syc': + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])] + elif gate_name == 'sqrt_iswap': + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])] + elif gate_name == 'sqrt_iswap_inv': + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])] + elif gate_name == 'cz': + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])] + elif gate_name == 'phased_xz': + cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] + elif gate_name == 'virtual_zpow': + cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])] + elif gate_name == 'physical_zpow': + cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])] + elif gate_name == 'coupler_pulse': + cirq_gates = [experimental_ops.CouplerPulse] + elif gate_name == 'meas': + cirq_gates = [cirq.MeasurementGate] + elif gate_name == 'wait': + cirq_gates = [cirq.WaitGate] + else: + # coverage: ignore + warnings.warn( + f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized" + " by Cirq and will be ignored. This may be due to an out-of-date Cirq version.", + UserWarning, + ) + continue + + gates_list.extend(cirq_gates) + + # TODO(#5050) Allow different gate representations of the same gate to be looked up in + # gate_durations. + for g in cirq_gates: + if not isinstance(g, cirq.GateFamily): + g = cirq.GateFamily(g) + gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) + + # TODO(#4833) Add identity gate support + # TODO(#5050) Add GlobalPhaseGate support + + return cirq.Gateset(*gates_list), gate_durations + + +def _build_compilation_target_gatesets( + gateset: cirq.Gateset, +) -> Sequence[cirq.CompilationTargetGateset]: + """Detects compilation target gatesets based on what gates are inside the gateset. + + If a device contains gates which yield multiple compilation target gatesets, the user can only + choose one target gateset to compile to. For example, a device may contain both SYC and + SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to + either SYC or SQRT_ISWAP for its two-qubit gates, not both. + + TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that + gates which are part of the gateset but not the compilation target gateset are untouched when + compiled. + """ + + # TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google. + + target_gatesets: List[cirq.CompilationTargetGateset] = [] + if cirq.CZ in gateset: + target_gatesets.append(cirq.CZTargetGateset()) + if ops.SYC in gateset: + target_gatesets.append(transformers.SycamoreTargetGateset()) + if cirq.SQRT_ISWAP in gateset: + target_gatesets.append( + cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset) + ) + + return tuple(target_gatesets) + + @cirq.value_equality class GridDevice(cirq.Device): """Device object representing Google devices with a grid qubit layout. @@ -112,7 +205,24 @@ class GridDevice(cirq.Device): * Get a collection of approximate gate durations for every gate supported by the device. >>> device.metadata.gate_durations - TODO(#5050) Add compilation_target_gatesets example. + * Get a collection of valid CompilationTargetGatesets for the device, which can be used to + transform a circuit to one which only contains gates from a native target gateset + supported by the device. + >>> device.metadata.compilation_target_gatesets + + * Assuming valid CompilationTargetGatesets exist for the device, select the first one and + use it to transform a circuit to one which only contains gates from a native target + gateset supported by the device. + >>> cirq.optimize_for_target_gateset( + circuit, + gateset=device.metadata.compilation_target_gatesets[0] + ) + + A note about CompilationTargetGatesets: + + A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using + CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert + WaitGates after the circuit has been transformed. Notes for cirq_google internal implementation: @@ -162,12 +272,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC ] - # TODO(#5050) implement gate durations + gateset, gate_durations = _build_gateset_and_gate_durations(proto) + try: metadata = cirq.GridDeviceMetadata( qubit_pairs=qubit_pairs, - gateset=cirq.Gateset(), # TODO(#5050) implement + gateset=gateset, + gate_durations=gate_durations if len(gate_durations) > 0 else None, all_qubits=all_qubits, + compilation_target_gatesets=_build_compilation_target_gatesets(gateset), ) except ValueError as ve: # coverage: ignore # Spec errors should have been caught in validation above. @@ -194,19 +307,19 @@ def validate_operation(self, operation: cirq.Operation) -> None: Raises: ValueError: The operation isn't valid for this device. """ - # TODO(#5050) uncomment once gateset logic is implemented - # if operation not in self._metadata.gateset: - # raise ValueError(f'Operation {operation} is not a supported gate') + + if operation not in self._metadata.gateset: + raise ValueError(f'Operation {operation} contains a gate which is not supported.') for q in operation.qubits: if q not in self._metadata.qubit_set: - raise ValueError(f'Qubit not on device: {q!r}') + raise ValueError(f'Qubit not on device: {q!r}.') if ( len(operation.qubits) == 2 and frozenset(operation.qubits) not in self._metadata.qubit_pairs ): - raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}') + raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.') def __str__(self) -> str: diagram = cirq.TextDiagramDrawer() diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 0176ba8847c..5c7c89a97fc 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -47,9 +47,23 @@ def _create_device_spec_with_horizontal_couplings(): # to verify GridDevice properly handles pair symmetry. new_target = grid_targets.targets.add() new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)]) - gate = spec.valid_gates.add() - gate.syc.SetInParent() - gate.gate_duration_picos = 12000 + + gate_names = [ + 'syc', + 'sqrt_iswap', + 'sqrt_iswap_inv', + 'cz', + 'phased_xz', + 'virtual_zpow', + 'physical_zpow', + 'coupler_pulse', + 'meas', + 'wait', + ] + for i, g in enumerate(gate_names): + gate = spec.valid_gates.add() + getattr(gate, g).SetInParent() + gate.gate_duration_picos = i * 1000 return grid_qubits, spec @@ -153,6 +167,53 @@ def test_grid_device_from_proto(): frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs for row in range(GRID_HEIGHT) ) + assert device.metadata.gateset == cirq.Gateset( + cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]), + cirq.ops.phased_x_z_gate.PhasedXZGate, + cirq.ops.common_gates.XPowGate, + cirq.ops.common_gates.YPowGate, + cirq.ops.phased_x_gate.PhasedXPowGate, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ), + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ), + cirq_google.experimental.ops.coupler_pulse.CouplerPulse, + cirq.ops.measurement_gate.MeasurementGate, + cirq.ops.wait_gate.WaitGate, + ) + assert tuple(device.metadata.compilation_target_gatesets) == ( + cirq.CZTargetGateset(), + cirq_google.SycamoreTargetGateset(), + cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True), + ) + + base_duration = cirq.Duration(picos=1_000) + assert device.metadata.gate_durations == { + cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]): base_duration * 0, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3, + cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, + cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4, + cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4, + cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ): base_duration + * 5, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ): base_duration + * 6, + cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7, + cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8, + cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, + } def test_grid_device_validate_operations_positive(): @@ -166,23 +227,22 @@ def test_grid_device_validate_operations_positive(): for i in range(GRID_HEIGHT): device.validate_operation(cirq.CZ(grid_qubits[2 * i], grid_qubits[2 * i + 1])) - # TODO(#5050) verify validate_operations gateset support - def test_grid_device_validate_operations_negative(): grid_qubits, spec = _create_device_spec_with_horizontal_couplings() device = cirq_google.GridDevice.from_proto(spec) - q = cirq.GridQubit(10, 10) + bad_qubit = cirq.GridQubit(10, 10) with pytest.raises(ValueError, match='Qubit not on device'): - device.validate_operation(cirq.X(q)) + device.validate_operation(cirq.X(bad_qubit)) # vertical qubit pair q00, q10 = grid_qubits[0], grid_qubits[2] # (0, 0), (1, 0) with pytest.raises(ValueError, match='Qubit pair is not valid'): device.validate_operation(cirq.CZ(q00, q10)) - # TODO(#5050) verify validate_operations gateset errors + with pytest.raises(ValueError, match='gate which is not supported'): + device.validate_operation(cirq.H(grid_qubits[0])) @pytest.mark.parametrize(