-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Validating Sampler - Wrapper around sampler to do device related validation. - This will be used in AbstractEngine in order to centralize and simplify validation.
- Loading branch information
1 parent
2f7dfde
commit d137a7d
Showing
5 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, List, Optional, Sequence, Union | ||
|
||
import cirq | ||
|
||
VALIDATOR_TYPE = Callable[ | ||
[Sequence[cirq.AbstractCircuit], Sequence[cirq.Sweepable], Union[int, List[int]]], None | ||
] | ||
|
||
|
||
class ValidatingSampler(cirq.Sampler): | ||
def __init__( | ||
self, | ||
*, | ||
device: Optional[cirq.Device] = None, | ||
validator: Optional[VALIDATOR_TYPE] = None, | ||
sampler: cirq.Sampler = cirq.Simulator(), | ||
): | ||
"""Wrapper around `cirq.Sampler` that performs device validation. | ||
This sampler will delegate to the wrapping sampler after | ||
performing validation on the circuit(s) given to the sampler. | ||
Args: | ||
device: `cirq.Device` that will validate_circuit before sampling. | ||
validator: A callable that will do any additional validation | ||
beyond the device. For instance, this can perform serialization | ||
checks. Note that this function takes a list of circuits and | ||
sweeps so that batch functionality can also be tested. | ||
sampler: sampler wrapped by this object. After validating, | ||
samples will be returned by this enclosed `cirq.Sampler`. | ||
""" | ||
self._device = device | ||
self._validator = validator | ||
self._sampler = sampler | ||
|
||
def _validate_circuit( | ||
self, | ||
circuits: Sequence[cirq.AbstractCircuit], | ||
sweeps: List[cirq.Sweepable], | ||
repetitions: Union[int, List[int]], | ||
): | ||
if self._device: | ||
for circuit in circuits: | ||
self._device.validate_circuit(circuit) | ||
if self._validator: | ||
self._validator(circuits, sweeps, repetitions) | ||
|
||
def run_sweep( | ||
self, | ||
program: cirq.AbstractCircuit, | ||
params: cirq.Sweepable, | ||
repetitions: int = 1, | ||
) -> List['cirq.Result']: | ||
self._validate_circuit([program], [params], repetitions) | ||
return self._sampler.run_sweep(program, params, repetitions) | ||
|
||
def run_batch( | ||
self, | ||
programs: Sequence['cirq.AbstractCircuit'], | ||
params_list: Optional[List['cirq.Sweepable']] = None, | ||
repetitions: Union[int, List[int]] = 1, | ||
) -> List[List['cirq.Result']]: | ||
if params_list is None: | ||
params_list = [None] * len(programs) | ||
self._validate_circuit(programs, params_list, repetitions) | ||
return self._sampler.run_batch(programs, params_list, repetitions) |
106 changes: 106 additions & 0 deletions
106
cirq-google/cirq_google/engine/validating_sampler_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List | ||
import pytest | ||
import sympy | ||
import numpy as np | ||
|
||
import cirq | ||
import cirq_google as cg | ||
|
||
|
||
def test_device_validation(): | ||
sampler = cg.ValidatingSampler( | ||
device=cg.Sycamore23, validator=lambda c, s, r: True, sampler=cirq.Simulator() | ||
) | ||
|
||
# Good qubit | ||
q = cirq.GridQubit(5, 2) | ||
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')) | ||
sweep = cirq.Points(key='t', points=[1, 0]) | ||
results = sampler.run_sweep(circuit, sweep, repetitions=100) | ||
assert np.all(results[0].measurements['m'] == 1) | ||
assert np.all(results[1].measurements['m'] == 0) | ||
|
||
# Bad qubit | ||
q = cirq.GridQubit(2, 2) | ||
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')) | ||
with pytest.raises(ValueError, match='Qubit not on device'): | ||
results = sampler.run_sweep(circuit, sweep, repetitions=100) | ||
|
||
|
||
def _batch_size_less_than_two( | ||
circuits: List[cirq.Circuit], sweeps: List[cirq.Sweepable], repetitions: int | ||
): | ||
if len(circuits) > 2: | ||
raise ValueError('Too many batches') | ||
|
||
|
||
def test_batch_validation(): | ||
sampler = cg.ValidatingSampler( | ||
device=cirq.UNCONSTRAINED_DEVICE, | ||
validator=_batch_size_less_than_two, | ||
sampler=cirq.Simulator(), | ||
) | ||
q = cirq.GridQubit(2, 2) | ||
circuits = [ | ||
cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')), | ||
cirq.Circuit(cirq.X(q) ** sympy.Symbol('x'), cirq.measure(q, key='m2')), | ||
] | ||
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1])] | ||
results = sampler.run_batch(circuits, sweeps, repetitions=100) | ||
|
||
assert np.all(results[0][0].measurements['m'] == 1) | ||
assert np.all(results[0][1].measurements['m'] == 0) | ||
assert np.all(results[1][0].measurements['m2'] == 0) | ||
assert np.all(results[1][1].measurements['m2'] == 1) | ||
|
||
circuits = [ | ||
cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')), | ||
cirq.Circuit(cirq.X(q) ** sympy.Symbol('x'), cirq.measure(q, key='m2')), | ||
cirq.Circuit(cirq.measure(q, key='m3')), | ||
] | ||
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1]), {}] | ||
with pytest.raises(ValueError, match='Too many batches'): | ||
results = sampler.run_batch(circuits, sweeps, repetitions=100) | ||
|
||
|
||
def _too_many_reps(circuits: List[cirq.Circuit], sweeps: List[cirq.Sweepable], repetitions: int): | ||
if repetitions > 10000: | ||
raise ValueError('Too many repetitions') | ||
|
||
|
||
def test_sweeps_validation(): | ||
sampler = cg.ValidatingSampler( | ||
device=cirq.UNCONSTRAINED_DEVICE, | ||
validator=_too_many_reps, | ||
sampler=cirq.Simulator(), | ||
) | ||
q = cirq.GridQubit(2, 2) | ||
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')) | ||
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1])] | ||
with pytest.raises(ValueError, match='Too many repetitions'): | ||
_ = sampler.run_sweep(circuit, sweeps, repetitions=20000) | ||
|
||
|
||
def test_batch_default_sweeps(): | ||
sampler = cg.ValidatingSampler() | ||
q = cirq.GridQubit(2, 2) | ||
circuits = [ | ||
cirq.Circuit(cirq.X(q), cirq.measure(q, key='m')), | ||
cirq.Circuit(cirq.measure(q, key='m2')), | ||
] | ||
results = sampler.run_batch(circuits, None, repetitions=100) | ||
assert np.all(results[0][0].measurements['m'] == 1) | ||
assert np.all(results[1][0].measurements['m2'] == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters