From ff8aebcceb35ef5c6fe9d1988782f97fc72d1a77 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Thu, 28 Apr 2022 00:12:52 +0000 Subject: [PATCH 1/7] GridDevice gateset, gate_duration, and compilation_target_gateset support * Displays a warning about old Cirq version if DeviceSpecification contains gates not recognized by the client. * Added compilation target gateset example in GridDevice docstring --- .../cirq_google/devices/grid_device.py | 117 ++++++++++++++++-- .../cirq_google/devices/grid_device_test.py | 82 ++++++++++-- 2 files changed, 183 insertions(+), 16 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index c57f8e38013..257abc56b34 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -16,9 +16,15 @@ import re -from typing import Any, Set, Tuple, cast +from typing import Any, Dict, List, Sequence, Set, Tuple, Type, Union, cast +import warnings + import cirq +from cirq_google import ops +from cirq_google import transformers from cirq_google.api import v2 +from cirq_google.experimental import ops as experimental_ops +from cirq_google.ops.fsim_gate_family import POSSIBLE_FSIM_GATES def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: @@ -77,6 +83,91 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> ) +def _build_gateset_and_gate_durations( + proto: v2.device_pb2.DeviceSpecification, +) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]: + gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + fsim_gates: List[Union[Type[POSSIBLE_FSIM_GATES], POSSIBLE_FSIM_GATES]] = [] + gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} + + # TODO(#5050) Describe how to add/remove gates. + + for gate_spec in proto.valid_gates: + gate_name = gate_spec.WhichOneof('gate') + cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] + + if gate_name == 'syc': + cirq_gates = [ops.SYC] + fsim_gates.append(ops.SYC) + elif gate_name == 'sqrt_iswap': + cirq_gates = [cirq.SQRT_ISWAP] + fsim_gates.append(cirq.SQRT_ISWAP) + elif gate_name == 'sqrt_iswap_inv': + cirq_gates = [cirq.SQRT_ISWAP_INV] + fsim_gates.append(cirq.SQRT_ISWAP_INV) + elif gate_name == 'cz': + cirq_gates = [cirq.CZ] + fsim_gates.append(cirq.CZ) + elif gate_name == 'phased_xz': + cirq_gates = [ + cirq.PhasedXZGate, + cirq.XPowGate, + cirq.YPowGate, + cirq.ZPowGate, + cirq.PhasedXPowGate, + ] + elif gate_name == 'virtual_zpow': + cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])] + elif gate_name == 'physical_zpow': + cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[ops.PhysicalZTag()])] + elif gate_name == 'coupler_pulse': + cirq_gates = [experimental_ops.CouplerPulse] + elif gate_name == 'meas': + cirq_gates = [cirq.MeasurementGate] + elif gate_name == 'wait': + cirq_gates = [cirq.WaitGate] + else: + # coverage: ignore + warnings.warn( + f"The DeviceSpecification contains the gate '{gate_name}' which is not recognized" + " by Cirq and will be ignored. This may be due to an out-of-date Cirq version.", + UserWarning, + ) + continue + + gates_list.extend(cirq_gates) + for g in cirq_gates: + if not isinstance(g, cirq.GateFamily): + g = cirq.GateFamily(g) + gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) + + if fsim_gates: + gates_list.append(ops.FSimGateFamily(gates_to_accept=fsim_gates)) + + # TODO(#4833) Add identity gate support + # TODO(#5050) Add GlobalPhaseGate support + + return cirq.Gateset(*gates_list), gate_durations + + +def _build_compilation_target_gatesets( + gateset: cirq.Gateset, +) -> Sequence[cirq.CompilationTargetGateset]: + """Detects compilation target gatesets based on what gates are inside the gateset.""" + + target_gatesets: List[cirq.CompilationTargetGateset] = [] + if cirq.CZ in gateset: + target_gatesets.append(cirq.CZTargetGateset()) + if ops.SYC in gateset: + target_gatesets.append(transformers.SycamoreTargetGateset()) + if cirq.SQRT_ISWAP in gateset: + target_gatesets.append( + cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=cirq.SQRT_ISWAP_INV in gateset) + ) + + return tuple(target_gatesets) + + @cirq.value_equality class GridDevice(cirq.Device): """Device object representing Google devices with a grid qubit layout. @@ -122,7 +213,16 @@ class GridDevice(cirq.Device): * Get a collection of approximate gate durations for every gate supported by the device. >>> device.metadata.gate_durations - TODO(#5050) Add compilation_target_gatesets example. + * Get a collection of valid CompilationTargetGatesets for the device, which can be used to + transform a circuit which is invalid for the device to a valid one. + >>> device.metadata.compilation_target_gatesets + + * Assuming valid CompilationTargetGatesets exist for the device, select the first one and + use it to transform a circuit to an equivalent form which is valid for the device. + >>> cirq.optimize_for_target_gateset( + circuit, + gateset=device.metadata.compilation_target_gatesets[0] + ) Notes for cirq_google internal implementation: @@ -187,12 +287,15 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': if len(target.ids) == 2 and ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC ] - # TODO(#5050) implement gate durations + gateset, gate_durations = _build_gateset_and_gate_durations(proto) + try: metadata = cirq.GridDeviceMetadata( qubit_pairs=qubit_pairs, - gateset=cirq.Gateset(), # TODO(#5050) implement + gateset=gateset, + gate_durations=gate_durations, all_qubits=all_qubits, + compilation_target_gatesets=_build_compilation_target_gatesets(gateset), ) except ValueError as ve: # coverage: ignore # Spec errors should have been caught in validation above. @@ -219,9 +322,9 @@ def validate_operation(self, operation: cirq.Operation) -> None: Raises: ValueError: The operation isn't valid for this device. """ - # TODO(#5050) uncomment once gateset logic is implemented - # if operation not in self._metadata.gateset: - # raise ValueError(f'Operation {operation} is not a supported gate') + + if operation not in self._metadata.gateset: + raise ValueError(f'Operation {operation} is not a supported gate') for q in operation.qubits: if q not in self._metadata.qubit_set: diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index dfc0a69b465..30811d2b4bc 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -47,10 +47,23 @@ def _create_device_spec_with_horizontal_couplings(): # to verify GridDevice properly handles pair symmetry. new_target = grid_targets.targets.add() new_target.ids.extend([v2.qubit_to_proto_id(cirq.GridQubit(row, 1 - j)) for j in range(2)]) - gate = spec.valid_gates.add() - gate.syc.SetInParent() - gate.gate_duration_picos = 12000 - gate.valid_targets.extend(['2_qubit_targets']) + + gate_names = [ + 'syc', + 'sqrt_iswap', + 'sqrt_iswap_inv', + 'cz', + 'phased_xz', + 'virtual_zpow', + 'physical_zpow', + 'coupler_pulse', + 'meas', + 'wait', + ] + for i, g in enumerate(gate_names): + gate = spec.valid_gates.add() + getattr(gate, g).SetInParent() + gate.gate_duration_picos = i * 1000 return grid_qubits, spec @@ -159,6 +172,58 @@ def test_grid_device_from_proto(): frozenset((cirq.GridQubit(row, 0), cirq.GridQubit(row, 1))) in device.metadata.qubit_pairs for row in range(GRID_HEIGHT) ) + assert device.metadata.gateset == cirq.Gateset( + cirq_google.SYC, + cirq.SQRT_ISWAP, + cirq.SQRT_ISWAP_INV, + cirq.CZ, + cirq.ops.phased_x_z_gate.PhasedXZGate, + cirq.ops.common_gates.XPowGate, + cirq.ops.common_gates.YPowGate, + cirq.ops.common_gates.ZPowGate, + cirq.ops.phased_x_gate.PhasedXPowGate, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ), + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ), + cirq_google.experimental.ops.coupler_pulse.CouplerPulse, + cirq.ops.measurement_gate.MeasurementGate, + cirq.ops.wait_gate.WaitGate, + cirq_google.FSimGateFamily( + gates_to_accept=[cirq_google.SYC, cirq.SQRT_ISWAP, cirq.SQRT_ISWAP_INV, cirq.CZ] + ), + ) + assert tuple(device.metadata.compilation_target_gatesets) == ( + cirq.CZTargetGateset(), + cirq_google.SycamoreTargetGateset(), + cirq.SqrtIswapTargetGateset(use_sqrt_iswap_inv=True), + ) + + base_duration = cirq.Duration(picos=1_000) + assert device.metadata.gate_durations == { + cirq.GateFamily(cirq_google.SYC): base_duration * 0, + cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1, + cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2, + cirq.GateFamily(cirq.CZ): base_duration * 3, + cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, + cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4, + cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4, + cirq.GateFamily(cirq.ops.common_gates.ZPowGate): base_duration * 4, + cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] + ): base_duration + * 5, + cirq.GateFamily( + cirq.ops.common_gates.ZPowGate, tags_to_accept=[cirq_google.PhysicalZTag()] + ): base_duration + * 6, + cirq.GateFamily(cirq_google.experimental.ops.coupler_pulse.CouplerPulse): base_duration * 7, + cirq.GateFamily(cirq.ops.measurement_gate.MeasurementGate): base_duration * 8, + cirq.GateFamily(cirq.ops.wait_gate.WaitGate): base_duration * 9, + } def test_grid_device_validate_operations_positive(): @@ -172,23 +237,22 @@ def test_grid_device_validate_operations_positive(): for i in range(GRID_HEIGHT): device.validate_operation(cirq.CZ(grid_qubits[2 * i], grid_qubits[2 * i + 1])) - # TODO(#5050) verify validate_operations gateset support - def test_grid_device_validate_operations_negative(): grid_qubits, spec = _create_device_spec_with_horizontal_couplings() device = cirq_google.GridDevice.from_proto(spec) - q = cirq.GridQubit(10, 10) + bad_qubit = cirq.GridQubit(10, 10) with pytest.raises(ValueError, match='Qubit not on device'): - device.validate_operation(cirq.X(q)) + device.validate_operation(cirq.X(bad_qubit)) # vertical qubit pair q00, q10 = grid_qubits[0], grid_qubits[2] # (0, 0), (1, 0) with pytest.raises(ValueError, match='Qubit pair is not valid'): device.validate_operation(cirq.CZ(q00, q10)) - # TODO(#5050) verify validate_operations gateset errors + with pytest.raises(ValueError, match='not a supported gate'): + device.validate_operation(cirq.H(grid_qubits[0])) @pytest.mark.parametrize( From 486376e845f685887a673ae16bd864b1c830fd49 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 18 May 2022 01:15:31 +0000 Subject: [PATCH 2/7] Addresses Tanuj's comments --- .../cirq_google/devices/grid_device.py | 20 ++++++++----------- .../cirq_google/devices/grid_device_test.py | 4 +--- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 257abc56b34..80bd0ef5464 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -109,13 +109,7 @@ def _build_gateset_and_gate_durations( cirq_gates = [cirq.CZ] fsim_gates.append(cirq.CZ) elif gate_name == 'phased_xz': - cirq_gates = [ - cirq.PhasedXZGate, - cirq.XPowGate, - cirq.YPowGate, - cirq.ZPowGate, - cirq.PhasedXPowGate, - ] + cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] elif gate_name == 'virtual_zpow': cirq_gates = [cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=[ops.PhysicalZTag()])] elif gate_name == 'physical_zpow': @@ -214,11 +208,13 @@ class GridDevice(cirq.Device): >>> device.metadata.gate_durations * Get a collection of valid CompilationTargetGatesets for the device, which can be used to - transform a circuit which is invalid for the device to a valid one. + transform a circuit to one which only contains gates from a native target gateset + supported by the device. >>> device.metadata.compilation_target_gatesets * Assuming valid CompilationTargetGatesets exist for the device, select the first one and - use it to transform a circuit to an equivalent form which is valid for the device. + use it to transform a circuit to one which only contains gates from a native target + gateset supported by the device. >>> cirq.optimize_for_target_gateset( circuit, gateset=device.metadata.compilation_target_gatesets[0] @@ -324,17 +320,17 @@ def validate_operation(self, operation: cirq.Operation) -> None: """ if operation not in self._metadata.gateset: - raise ValueError(f'Operation {operation} is not a supported gate') + raise ValueError(f'Operation {operation} contains a gate which is not supported.') for q in operation.qubits: if q not in self._metadata.qubit_set: - raise ValueError(f'Qubit not on device: {q!r}') + raise ValueError(f'Qubit not on device: {q!r}.') if ( len(operation.qubits) == 2 and frozenset(operation.qubits) not in self._metadata.qubit_pairs ): - raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}') + raise ValueError(f'Qubit pair is not valid on device: {operation.qubits!r}.') def __str__(self) -> str: diagram = cirq.TextDiagramDrawer() diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 30811d2b4bc..83c08a0cca1 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -180,7 +180,6 @@ def test_grid_device_from_proto(): cirq.ops.phased_x_z_gate.PhasedXZGate, cirq.ops.common_gates.XPowGate, cirq.ops.common_gates.YPowGate, - cirq.ops.common_gates.ZPowGate, cirq.ops.phased_x_gate.PhasedXPowGate, cirq.GateFamily( cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] @@ -210,7 +209,6 @@ def test_grid_device_from_proto(): cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4, cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4, - cirq.GateFamily(cirq.ops.common_gates.ZPowGate): base_duration * 4, cirq.GateFamily(cirq.ops.phased_x_gate.PhasedXPowGate): base_duration * 4, cirq.GateFamily( cirq.ops.common_gates.ZPowGate, tags_to_ignore=[cirq_google.PhysicalZTag()] @@ -251,7 +249,7 @@ def test_grid_device_validate_operations_negative(): with pytest.raises(ValueError, match='Qubit pair is not valid'): device.validate_operation(cirq.CZ(q00, q10)) - with pytest.raises(ValueError, match='not a supported gate'): + with pytest.raises(ValueError, match='gate which is not supported'): device.validate_operation(cirq.H(grid_qubits[0])) From 84bdfe5d812f3972aad35fdbe291ab0a7fdfb4d9 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 18 May 2022 19:04:10 +0000 Subject: [PATCH 3/7] Added GridDevice docstring comment about WaitGates. I avoided mentioning the alternative of inserting no_compile tags because transforming a circuit which contains WaitGates is the wrong approach as all delays would be scrambled by the transformation. --- cirq-google/cirq_google/devices/grid_device.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 80bd0ef5464..3560ee907de 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -220,6 +220,12 @@ class GridDevice(cirq.Device): gateset=device.metadata.compilation_target_gatesets[0] ) + A note about CompilationTargetGatesets: + + A circuit which contains `cirq.WaitGate`s will be dropped if it is transformed using + CompilationTargetGatesets generated by GridDevice. To better control circuit timing, insert + WaitGates after the circuit has been transformed. + Notes for cirq_google internal implementation: For Google devices, the From 17262f0fb6d31e9ca3a8f16b9567461d9cbd0046 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 18 May 2022 22:01:06 +0000 Subject: [PATCH 4/7] Set gate_duration to None if the spec doesn't have duration info --- cirq-google/cirq_google/devices/grid_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 3560ee907de..0115a595ad0 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -295,7 +295,7 @@ def from_proto(cls, proto: v2.device_pb2.DeviceSpecification) -> 'GridDevice': metadata = cirq.GridDeviceMetadata( qubit_pairs=qubit_pairs, gateset=gateset, - gate_durations=gate_durations, + gate_durations=gate_durations if len(gate_durations) > 0 else None, all_qubits=all_qubits, compilation_target_gatesets=_build_compilation_target_gatesets(gateset), ) From 288a4dcd05171feaf4a25b9917b00de85577fac3 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 31 May 2022 21:11:23 +0000 Subject: [PATCH 5/7] Addressed Doug's comments --- cirq-google/cirq_google/devices/grid_device.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 0115a595ad0..139b873a403 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -86,6 +86,8 @@ def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> def _build_gateset_and_gate_durations( proto: v2.device_pb2.DeviceSpecification, ) -> Tuple[cirq.Gateset, Dict[cirq.GateFamily, cirq.Duration]]: + """Extracts gate set and gate duration information from the given DeviceSpecification proto.""" + gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] fsim_gates: List[Union[Type[POSSIBLE_FSIM_GATES], POSSIBLE_FSIM_GATES]] = [] gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} @@ -147,7 +149,16 @@ def _build_gateset_and_gate_durations( def _build_compilation_target_gatesets( gateset: cirq.Gateset, ) -> Sequence[cirq.CompilationTargetGateset]: - """Detects compilation target gatesets based on what gates are inside the gateset.""" + """Detects compilation target gatesets based on what gates are inside the gateset. + + If a device contains gates which yield multiple compilation target gatesets, the user can only + choose one target gateset to compile to. For example, a device may contain both SYC and + SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to + either SYC or SQRT_ISWAP for its two-qubit gates, not both. + + TODO(verult) when implemented, mention that gates which are part of the gateset but not the + compilation target gateset are untouched when compiled. + """ target_gatesets: List[cirq.CompilationTargetGateset] = [] if cirq.CZ in gateset: From 8135a8cc5b85b87d2dbafd98b43340abfd863036 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 31 May 2022 21:53:15 +0000 Subject: [PATCH 6/7] Addressed Tanuj's 2nd round of comments --- cirq-google/cirq_google/devices/grid_device.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 139b873a403..65bbbe1f78d 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -132,6 +132,9 @@ def _build_gateset_and_gate_durations( continue gates_list.extend(cirq_gates) + + # TODO(#5050) Allow different gate representations of the same gate to be looked up in + # gate_durations. for g in cirq_gates: if not isinstance(g, cirq.GateFamily): g = cirq.GateFamily(g) @@ -156,10 +159,13 @@ def _build_compilation_target_gatesets( SQRT_ISWAP gates which yield two separate target gatesets, but a circuit can only be compiled to either SYC or SQRT_ISWAP for its two-qubit gates, not both. - TODO(verult) when implemented, mention that gates which are part of the gateset but not the - compilation target gateset are untouched when compiled. + TODO(#5050) when cirq-google CompilationTargetGateset subclasses are implemented, mention that + gates which are part of the gateset but not the compilation target gateset are untouched when + compiled. """ + # TODO(#5050) Subclass core CompilationTargetGatesets in cirq-google. + target_gatesets: List[cirq.CompilationTargetGateset] = [] if cirq.CZ in gateset: target_gatesets.append(cirq.CZTargetGateset()) From 7a1bcd975b449929b6bd188309d03c51bb7fd46f Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 1 Jun 2022 20:10:22 +0000 Subject: [PATCH 7/7] Break FSimGateFamily into one for each 2q gate type --- .../cirq_google/devices/grid_device.py | 17 ++++------------- .../cirq_google/devices/grid_device_test.py | 19 ++++++++----------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/cirq-google/cirq_google/devices/grid_device.py b/cirq-google/cirq_google/devices/grid_device.py index 65bbbe1f78d..aa92a4566f9 100644 --- a/cirq-google/cirq_google/devices/grid_device.py +++ b/cirq-google/cirq_google/devices/grid_device.py @@ -24,7 +24,6 @@ from cirq_google import transformers from cirq_google.api import v2 from cirq_google.experimental import ops as experimental_ops -from cirq_google.ops.fsim_gate_family import POSSIBLE_FSIM_GATES def _validate_device_specification(proto: v2.device_pb2.DeviceSpecification) -> None: @@ -89,7 +88,6 @@ def _build_gateset_and_gate_durations( """Extracts gate set and gate duration information from the given DeviceSpecification proto.""" gates_list: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] - fsim_gates: List[Union[Type[POSSIBLE_FSIM_GATES], POSSIBLE_FSIM_GATES]] = [] gate_durations: Dict[cirq.GateFamily, cirq.Duration] = {} # TODO(#5050) Describe how to add/remove gates. @@ -99,17 +97,13 @@ def _build_gateset_and_gate_durations( cirq_gates: List[Union[Type[cirq.Gate], cirq.Gate, cirq.GateFamily]] = [] if gate_name == 'syc': - cirq_gates = [ops.SYC] - fsim_gates.append(ops.SYC) + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[ops.SYC])] elif gate_name == 'sqrt_iswap': - cirq_gates = [cirq.SQRT_ISWAP] - fsim_gates.append(cirq.SQRT_ISWAP) + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP])] elif gate_name == 'sqrt_iswap_inv': - cirq_gates = [cirq.SQRT_ISWAP_INV] - fsim_gates.append(cirq.SQRT_ISWAP_INV) + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV])] elif gate_name == 'cz': - cirq_gates = [cirq.CZ] - fsim_gates.append(cirq.CZ) + cirq_gates = [ops.FSimGateFamily(gates_to_accept=[cirq.CZ])] elif gate_name == 'phased_xz': cirq_gates = [cirq.PhasedXZGate, cirq.XPowGate, cirq.YPowGate, cirq.PhasedXPowGate] elif gate_name == 'virtual_zpow': @@ -140,9 +134,6 @@ def _build_gateset_and_gate_durations( g = cirq.GateFamily(g) gate_durations[g] = cirq.Duration(picos=gate_spec.gate_duration_picos) - if fsim_gates: - gates_list.append(ops.FSimGateFamily(gates_to_accept=fsim_gates)) - # TODO(#4833) Add identity gate support # TODO(#5050) Add GlobalPhaseGate support diff --git a/cirq-google/cirq_google/devices/grid_device_test.py b/cirq-google/cirq_google/devices/grid_device_test.py index 83c08a0cca1..bdbbacb3ffe 100644 --- a/cirq-google/cirq_google/devices/grid_device_test.py +++ b/cirq-google/cirq_google/devices/grid_device_test.py @@ -173,10 +173,10 @@ def test_grid_device_from_proto(): for row in range(GRID_HEIGHT) ) assert device.metadata.gateset == cirq.Gateset( - cirq_google.SYC, - cirq.SQRT_ISWAP, - cirq.SQRT_ISWAP_INV, - cirq.CZ, + cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]), + cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]), cirq.ops.phased_x_z_gate.PhasedXZGate, cirq.ops.common_gates.XPowGate, cirq.ops.common_gates.YPowGate, @@ -190,9 +190,6 @@ def test_grid_device_from_proto(): cirq_google.experimental.ops.coupler_pulse.CouplerPulse, cirq.ops.measurement_gate.MeasurementGate, cirq.ops.wait_gate.WaitGate, - cirq_google.FSimGateFamily( - gates_to_accept=[cirq_google.SYC, cirq.SQRT_ISWAP, cirq.SQRT_ISWAP_INV, cirq.CZ] - ), ) assert tuple(device.metadata.compilation_target_gatesets) == ( cirq.CZTargetGateset(), @@ -202,10 +199,10 @@ def test_grid_device_from_proto(): base_duration = cirq.Duration(picos=1_000) assert device.metadata.gate_durations == { - cirq.GateFamily(cirq_google.SYC): base_duration * 0, - cirq.GateFamily(cirq.SQRT_ISWAP): base_duration * 1, - cirq.GateFamily(cirq.SQRT_ISWAP_INV): base_duration * 2, - cirq.GateFamily(cirq.CZ): base_duration * 3, + cirq_google.FSimGateFamily(gates_to_accept=[cirq_google.SYC]): base_duration * 0, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP]): base_duration * 1, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.SQRT_ISWAP_INV]): base_duration * 2, + cirq_google.FSimGateFamily(gates_to_accept=[cirq.CZ]): base_duration * 3, cirq.GateFamily(cirq.ops.phased_x_z_gate.PhasedXZGate): base_duration * 4, cirq.GateFamily(cirq.ops.common_gates.XPowGate): base_duration * 4, cirq.GateFamily(cirq.ops.common_gates.YPowGate): base_duration * 4,