Skip to content

Commit

Permalink
Validating Sampler (#4609)
Browse files Browse the repository at this point in the history
* Validating Sampler

- Wrapper around sampler to do device related validation.
- This will be used in AbstractEngine in order to centralize
and simplify validation.
  • Loading branch information
dstrain115 authored Nov 3, 2021
1 parent 2f7dfde commit d137a7d
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 0 deletions.
1 change: 1 addition & 0 deletions cirq-google/cirq_google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
EngineProcessor,
ProtoVersion,
QuantumEngineSampler,
ValidatingSampler,
get_engine,
get_engine_calibration,
get_engine_device,
Expand Down
4 changes: 4 additions & 0 deletions cirq-google/cirq_google/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@
get_engine_sampler,
QuantumEngineSampler,
)

from cirq_google.engine.validating_sampler import (
ValidatingSampler,
)
79 changes: 79 additions & 0 deletions cirq-google/cirq_google/engine/validating_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Sequence, Union

import cirq

VALIDATOR_TYPE = Callable[
[Sequence[cirq.AbstractCircuit], Sequence[cirq.Sweepable], Union[int, List[int]]], None
]


class ValidatingSampler(cirq.Sampler):
def __init__(
self,
*,
device: Optional[cirq.Device] = None,
validator: Optional[VALIDATOR_TYPE] = None,
sampler: cirq.Sampler = cirq.Simulator(),
):
"""Wrapper around `cirq.Sampler` that performs device validation.
This sampler will delegate to the wrapping sampler after
performing validation on the circuit(s) given to the sampler.
Args:
device: `cirq.Device` that will validate_circuit before sampling.
validator: A callable that will do any additional validation
beyond the device. For instance, this can perform serialization
checks. Note that this function takes a list of circuits and
sweeps so that batch functionality can also be tested.
sampler: sampler wrapped by this object. After validating,
samples will be returned by this enclosed `cirq.Sampler`.
"""
self._device = device
self._validator = validator
self._sampler = sampler

def _validate_circuit(
self,
circuits: Sequence[cirq.AbstractCircuit],
sweeps: List[cirq.Sweepable],
repetitions: Union[int, List[int]],
):
if self._device:
for circuit in circuits:
self._device.validate_circuit(circuit)
if self._validator:
self._validator(circuits, sweeps, repetitions)

def run_sweep(
self,
program: cirq.AbstractCircuit,
params: cirq.Sweepable,
repetitions: int = 1,
) -> List['cirq.Result']:
self._validate_circuit([program], [params], repetitions)
return self._sampler.run_sweep(program, params, repetitions)

def run_batch(
self,
programs: Sequence['cirq.AbstractCircuit'],
params_list: Optional[List['cirq.Sweepable']] = None,
repetitions: Union[int, List[int]] = 1,
) -> List[List['cirq.Result']]:
if params_list is None:
params_list = [None] * len(programs)
self._validate_circuit(programs, params_list, repetitions)
return self._sampler.run_batch(programs, params_list, repetitions)
106 changes: 106 additions & 0 deletions cirq-google/cirq_google/engine/validating_sampler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import pytest
import sympy
import numpy as np

import cirq
import cirq_google as cg


def test_device_validation():
sampler = cg.ValidatingSampler(
device=cg.Sycamore23, validator=lambda c, s, r: True, sampler=cirq.Simulator()
)

# Good qubit
q = cirq.GridQubit(5, 2)
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m'))
sweep = cirq.Points(key='t', points=[1, 0])
results = sampler.run_sweep(circuit, sweep, repetitions=100)
assert np.all(results[0].measurements['m'] == 1)
assert np.all(results[1].measurements['m'] == 0)

# Bad qubit
q = cirq.GridQubit(2, 2)
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m'))
with pytest.raises(ValueError, match='Qubit not on device'):
results = sampler.run_sweep(circuit, sweep, repetitions=100)


def _batch_size_less_than_two(
circuits: List[cirq.Circuit], sweeps: List[cirq.Sweepable], repetitions: int
):
if len(circuits) > 2:
raise ValueError('Too many batches')


def test_batch_validation():
sampler = cg.ValidatingSampler(
device=cirq.UNCONSTRAINED_DEVICE,
validator=_batch_size_less_than_two,
sampler=cirq.Simulator(),
)
q = cirq.GridQubit(2, 2)
circuits = [
cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')),
cirq.Circuit(cirq.X(q) ** sympy.Symbol('x'), cirq.measure(q, key='m2')),
]
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1])]
results = sampler.run_batch(circuits, sweeps, repetitions=100)

assert np.all(results[0][0].measurements['m'] == 1)
assert np.all(results[0][1].measurements['m'] == 0)
assert np.all(results[1][0].measurements['m2'] == 0)
assert np.all(results[1][1].measurements['m2'] == 1)

circuits = [
cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m')),
cirq.Circuit(cirq.X(q) ** sympy.Symbol('x'), cirq.measure(q, key='m2')),
cirq.Circuit(cirq.measure(q, key='m3')),
]
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1]), {}]
with pytest.raises(ValueError, match='Too many batches'):
results = sampler.run_batch(circuits, sweeps, repetitions=100)


def _too_many_reps(circuits: List[cirq.Circuit], sweeps: List[cirq.Sweepable], repetitions: int):
if repetitions > 10000:
raise ValueError('Too many repetitions')


def test_sweeps_validation():
sampler = cg.ValidatingSampler(
device=cirq.UNCONSTRAINED_DEVICE,
validator=_too_many_reps,
sampler=cirq.Simulator(),
)
q = cirq.GridQubit(2, 2)
circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q, key='m'))
sweeps = [cirq.Points(key='t', points=[1, 0]), cirq.Points(key='x', points=[0, 1])]
with pytest.raises(ValueError, match='Too many repetitions'):
_ = sampler.run_sweep(circuit, sweeps, repetitions=20000)


def test_batch_default_sweeps():
sampler = cg.ValidatingSampler()
q = cirq.GridQubit(2, 2)
circuits = [
cirq.Circuit(cirq.X(q), cirq.measure(q, key='m')),
cirq.Circuit(cirq.measure(q, key='m2')),
]
results = sampler.run_batch(circuits, None, repetitions=100)
assert np.all(results[0][0].measurements['m'] == 1)
assert np.all(results[1][0].measurements['m2'] == 0)
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'SerializingArg',
'THETA_ZETA_GAMMA_FLOQUET_PHASED_FSIM_CHARACTERIZATION',
'QuantumEngineSampler',
'ValidatingSampler',
# Abstract:
'ExecutableSpec',
],
Expand Down

0 comments on commit d137a7d

Please sign in to comment.