Skip to content

Commit

Permalink
Use CIRCUIT_SERIALIZER as default serializer for quantum engine (#4983)
Browse files Browse the repository at this point in the history
`CIRCUIT_SERIALIZER` can serialize gates from any other gateset, uses a more compact proto serialization that doesn't embed redundant string literals, and can serialize more complicated constructs such as `CircuitOperation`. It is meant to provide serialization separate from the notion of which gates are supported by a particular device, two concepts which were previously conflated. Having this as a default is a big UX improvement when using quantum engine, as the `gate_set` argument can be omitted on most calls. The `gate_set` argument could be removed in the future, as specifying the serializer once on `EngineContext`, along with the protobuf version, is sufficient to customize the serialization when needed.
  • Loading branch information
maffoo authored Feb 15, 2022
1 parent 3c89de1 commit 72f9a5b
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 70 deletions.
1 change: 1 addition & 0 deletions cirq-google/cirq_google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@

from cirq_google.serialization import (
arg_from_proto,
CIRCUIT_SERIALIZER,
CircuitSerializer,
CircuitOpDeserializer,
DeserializingArg,
Expand Down
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/engine/abstract_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_processor(self, processor_id: str) -> abstract_processor.AbstractProcess

@abc.abstractmethod
def get_sampler(
self, processor_id: Union[str, List[str]], gate_set: Serializer
self, processor_id: Union[str, List[str]], gate_set: Optional[Serializer] = None
) -> cirq.Sampler:
"""Returns a sampler backed by the engine.
Expand Down
32 changes: 16 additions & 16 deletions cirq-google/cirq_google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,19 @@
import cirq
from cirq._compat import deprecated
from cirq_google.api import v2
from cirq_google.engine import abstract_engine, abstract_program
from cirq_google.engine.client import quantum
from cirq_google.engine.result_type import ResultType
from cirq_google.serialization import SerializableGateSet, Serializer
from cirq_google.serialization.arg_func_langs import arg_to_proto
from cirq_google.engine import (
abstract_engine,
abstract_program,
engine_client,
engine_program,
engine_job,
engine_processor,
engine_program,
engine_sampler,
)
from cirq_google.engine.client import quantum
from cirq_google.engine.result_type import ResultType
from cirq_google.serialization import CIRCUIT_SERIALIZER, SerializableGateSet, Serializer
from cirq_google.serialization.arg_func_langs import arg_to_proto

if TYPE_CHECKING:
import cirq_google
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
verbose: Optional[bool] = None,
client: 'Optional[engine_client.EngineClient]' = None,
timeout: Optional[int] = None,
serializer: Serializer = CIRCUIT_SERIALIZER,
) -> None:
"""Context and client for using Quantum Engine.
Expand All @@ -99,6 +101,7 @@ def __init__(
created.
timeout: Timeout for polling for results, in seconds. Default is
to never timeout.
serializer: Used to serialize circuits when running jobs.
Raises:
ValueError: If either `service_args` and `verbose` were supplied
Expand All @@ -110,6 +113,7 @@ def __init__(
self.proto_version = proto_version or ProtoVersion.V2
if self.proto_version == ProtoVersion.V1:
raise ValueError('ProtoVersion V1 no longer supported')
self.serializer = serializer

if not client:
client = engine_client.EngineClient(service_args=service_args, verbose=verbose)
Expand Down Expand Up @@ -195,7 +199,7 @@ def run(
param_resolver: cirq.ParamResolver = cirq.ParamResolver({}),
repetitions: int = 1,
processor_ids: Sequence[str] = ('xmonsim',),
gate_set: Optional[Serializer] = None,
gate_set: Serializer = None,
program_description: Optional[str] = None,
program_labels: Optional[Dict[str, str]] = None,
job_description: Optional[str] = None,
Expand Down Expand Up @@ -233,8 +237,6 @@ def run(
Raises:
ValueError: If no gate set is provided.
"""
if not gate_set:
raise ValueError('No gate set provided')
return list(
self.run_sweep(
program=program,
Expand Down Expand Up @@ -301,8 +303,6 @@ def run_sweep(
Raises:
ValueError: If no gate set is provided.
"""
if not gate_set:
raise ValueError('No gate set provided')
engine_program = self.create_program(
program, program_id, gate_set, program_description, program_labels
)
Expand Down Expand Up @@ -502,7 +502,7 @@ def create_program(
ValueError: If no gate set is provided.
"""
if not gate_set:
raise ValueError('No gate set provided')
gate_set = self.context.serializer

if not program_id:
program_id = _make_random_id('prog-')
Expand Down Expand Up @@ -548,7 +548,7 @@ def create_batch_program(
ValueError: If no gate set is provided.
"""
if not gate_set:
raise ValueError('Gate set must be specified.')
gate_set = self.context.serializer
if not program_id:
program_id = _make_random_id('prog-')

Expand Down Expand Up @@ -601,7 +601,7 @@ def create_calibration_program(
ValueError: If not gate set is given.
"""
if not gate_set:
raise ValueError('Gate set must be specified.')
gate_set = self.context.serializer
if not program_id:
program_id = _make_random_id('calibration-')

Expand Down Expand Up @@ -784,7 +784,7 @@ def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor:

@deprecated(deadline="v1.0", fix="Use get_sampler instead.")
def sampler(
self, processor_id: Union[str, List[str]], gate_set: Serializer
self, processor_id: Union[str, List[str]], gate_set: Optional[Serializer] = None
) -> engine_sampler.QuantumEngineSampler:
"""Returns a sampler backed by the engine.
Expand All @@ -802,7 +802,7 @@ def sampler(
return self.get_sampler(processor_id, gate_set)

def get_sampler(
self, processor_id: Union[str, List[str]], gate_set: Serializer
self, processor_id: Union[str, List[str]], gate_set: Optional[Serializer] = None
) -> engine_sampler.QuantumEngineSampler:
"""Returns a sampler backed by the engine.
Expand Down
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/engine/engine_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
calibration_layer,
engine_sampler,
)
from cirq_google.serialization import circuit_serializer, serializable_gate_set, serializer
from cirq_google.serialization import serializable_gate_set, serializer
from cirq_google.serialization import gate_sets as gs

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_sampler(
return engine_sampler.QuantumEngineSampler(
engine=self.engine(),
processor_id=self.processor_id,
gate_set=gate_set or circuit_serializer.CIRCUIT_SERIALIZER,
gate_set=gate_set,
)

def run_batch(
Expand Down
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/engine/engine_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
*,
engine: 'cirq_google.Engine',
processor_id: Union[str, List[str]],
gate_set: 'cirq_google.serialization.Serializer',
gate_set: Optional['cirq_google.serialization.Serializer'] = None,
):
"""Inits QuantumEngineSampler.
Expand Down
66 changes: 17 additions & 49 deletions cirq-google/cirq_google/engine/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def test_run_circuit(client):
program_id='prog',
job_id='job-id',
processor_ids=['mysim'],
gate_set=cg.XMON,
)

assert result.repetitions == 1
Expand Down Expand Up @@ -392,20 +391,14 @@ def test_run_circuit(client):


def test_no_gate_set():
circuit = cirq.Circuit()
engine = cg.Engine(project_id='project-id')
with pytest.raises(ValueError, match='No gate set'):
engine.run(program=circuit)
with pytest.raises(ValueError, match='No gate set'):
engine.run_sweep(program=circuit)
with pytest.raises(ValueError, match='No gate set'):
engine.create_program(program=circuit)
assert engine.context.serializer == cg.CIRCUIT_SERIALIZER


def test_unsupported_program_type():
engine = cg.Engine(project_id='project-id')
with pytest.raises(TypeError, match='program'):
engine.run(program="this isn't even the right type of thing!", gate_set=cg.XMON)
engine.run(program="this isn't even the right type of thing!")


@mock.patch('cirq_google.engine.engine_client.EngineClient')
Expand Down Expand Up @@ -435,7 +428,7 @@ def test_run_circuit_failed(client):
match='Job projects/proj/programs/prog/jobs/job-id on processor'
' myqc failed. SYSTEM_ERROR: Not good',
):
engine.run(program=_CIRCUIT, gate_set=cg.XMON)
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient')
Expand Down Expand Up @@ -464,7 +457,7 @@ def test_run_circuit_failed_missing_processor_name(client):
match='Job projects/proj/programs/prog/jobs/job-id on processor'
' UNKNOWN failed. SYSTEM_ERROR: Not good',
):
engine.run(program=_CIRCUIT, gate_set=cg.XMON)
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient')
Expand All @@ -491,7 +484,7 @@ def test_run_circuit_cancelled(client):
RuntimeError,
match='Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.',
):
engine.run(program=_CIRCUIT, gate_set=cg.XMON)
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient')
Expand All @@ -516,7 +509,7 @@ def test_run_circuit_timeout(patched_time_sleep, client):

engine = cg.Engine(project_id='project-id', timeout=600)
with pytest.raises(RuntimeError, match='Timed out'):
engine.run(program=_CIRCUIT, gate_set=cg.XMON)
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient')
Expand All @@ -527,7 +520,6 @@ def test_run_sweep_params(client):
job = engine.run_sweep(
program=_CIRCUIT,
params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})],
gate_set=cg.XMON,
)
results = job.results()
assert len(results) == 2
Expand Down Expand Up @@ -555,7 +547,7 @@ def test_run_multiple_times(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2)
program = engine.create_program(program=_CIRCUIT, gate_set=cg.XMON)
program = engine.create_program(program=_CIRCUIT)
program.run(param_resolver=cirq.ParamResolver({'a': 1}))
run_context = v2.run_context_pb2.RunContext()
client().create_job.call_args[1]['run_context'].Unpack(run_context)
Expand Down Expand Up @@ -589,9 +581,7 @@ def test_run_sweep_v2(client):
project_id='proj',
proto_version=cg.engine.engine.ProtoVersion.V2,
)
job = engine.run_sweep(
program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]), gate_set=cg.XMON
)
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
results = job.results()
assert len(results) == 2
for i, v in enumerate([1, 2]):
Expand Down Expand Up @@ -619,7 +609,6 @@ def test_run_batch(client):
proto_version=cg.engine.engine.ProtoVersion.V2,
)
job = engine.run_batch(
gate_set=cg.XMON,
programs=[_CIRCUIT, _CIRCUIT2],
job_id='job-id',
params_list=[cirq.Points('a', [1, 2]), cirq.Points('a', [3, 4])],
Expand Down Expand Up @@ -657,9 +646,7 @@ def test_run_batch_no_params(client):
project_id='proj',
proto_version=cg.engine.engine.ProtoVersion.V2,
)
engine.run_batch(
programs=[_CIRCUIT, _CIRCUIT2], gate_set=cg.XMON, job_id='job-id', processor_ids=['mysim']
)
engine.run_batch(programs=[_CIRCUIT, _CIRCUIT2], job_id='job-id', processor_ids=['mysim'])
# Validate correct number of params have been created and that they
# are empty sweeps.
run_context = v2.batch_pb2.BatchRunContext()
Expand All @@ -681,7 +668,6 @@ def test_batch_size_validation_fails():
with pytest.raises(ValueError, match='Number of circuits and sweeps'):
_ = engine.run_batch(
programs=[_CIRCUIT, _CIRCUIT2],
gate_set=cg.XMON,
job_id='job-id',
params_list=[
cirq.Points('a', [1, 2]),
Expand All @@ -694,19 +680,10 @@ def test_batch_size_validation_fails():
with pytest.raises(ValueError, match='Processor id must be specified'):
_ = engine.run_batch(
programs=[_CIRCUIT, _CIRCUIT2],
gate_set=cg.XMON,
job_id='job-id',
params_list=[cirq.Points('a', [1, 2]), cirq.Points('a', [3, 4])],
)

with pytest.raises(ValueError, match='Gate set must be specified'):
_ = engine.run_batch(
programs=[_CIRCUIT, _CIRCUIT2],
job_id='job-id',
params_list=[cirq.Points('a', [1, 2]), cirq.Points('a', [3, 4])],
processor_ids=['mysim'],
)


def test_bad_sweep_proto():
engine = cg.Engine(project_id='project-id', proto_version=cg.ProtoVersion.UNDEFINED)
Expand All @@ -729,9 +706,7 @@ def test_run_calibration(client):
layer2 = cg.CalibrationLayer(
'readout', cirq.Circuit(cirq.measure(q1, q2)), {'num_samples': 4242}
)
job = engine.run_calibration(
gate_set=cg.FSIM_GATESET, layers=[layer1, layer2], job_id='job-id', processor_id='mysim'
)
job = engine.run_calibration(layers=[layer1, layer2], job_id='job-id', processor_id='mysim')
results = job.calibration_results()
assert len(results) == 2
assert results[0].code == v2.calibration_pb2.SUCCESS
Expand Down Expand Up @@ -768,18 +743,13 @@ def test_run_calibration_validation_fails():
)

with pytest.raises(ValueError, match='Processor id must be specified'):
_ = engine.run_calibration(layers=[layer1, layer2], gate_set=cg.XMON, job_id='job-id')
_ = engine.run_calibration(layers=[layer1, layer2], job_id='job-id')

with pytest.raises(ValueError, match='Gate set must be specified'):
_ = engine.run_calibration(
layers=[layer1, layer2], processor_ids=['mysim'], job_id='job-id'
)
with pytest.raises(ValueError, match='processor_id and processor_ids'):
_ = engine.run_calibration(
layers=[layer1, layer2],
processor_ids=['mysim'],
processor_id='mysim',
gate_set=cg.XMON,
job_id='job-id',
)

Expand All @@ -792,9 +762,7 @@ def test_bad_result_proto(client):
setup_run_circuit_with_result_(client, result)

engine = cg.Engine(project_id='project-id', proto_version=cg.engine.engine.ProtoVersion.V2)
job = engine.run_sweep(
program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]), gate_set=cg.XMON
)
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
with pytest.raises(ValueError, match='invalid result proto version'):
job.results()

Expand All @@ -804,9 +772,9 @@ def test_bad_program_proto():
project_id='project-id', proto_version=cg.engine.engine.ProtoVersion.UNDEFINED
)
with pytest.raises(ValueError, match='invalid program proto version'):
engine.run_sweep(program=_CIRCUIT, gate_set=cg.XMON)
engine.run_sweep(program=_CIRCUIT)
with pytest.raises(ValueError, match='invalid program proto version'):
engine.create_program(_CIRCUIT, gate_set=cg.XMON)
engine.create_program(_CIRCUIT)


def test_get_program():
Expand All @@ -832,7 +800,7 @@ def test_list_programs(list_programs):
@mock.patch('cirq_google.engine.engine_client.EngineClient')
def test_create_program(client):
client().create_program.return_value = ('prog', qtypes.QuantumProgram())
result = cg.Engine(project_id='proj').create_program(_CIRCUIT, 'prog', gate_set=cg.XMON)
result = cg.Engine(project_id='proj').create_program(_CIRCUIT, 'prog')
client().create_program.assert_called_once()
assert result.program_id == 'prog'

Expand Down Expand Up @@ -879,7 +847,7 @@ def test_sampler(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj')
sampler = engine.get_sampler(processor_id='tmp', gate_set=cg.XMON)
sampler = engine.get_sampler(processor_id='tmp')
results = sampler.run_sweep(
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
)
Expand All @@ -891,7 +859,7 @@ def test_sampler(client):
assert client().create_program.call_args[0][0] == 'proj'

with cirq.testing.assert_deprecated('sampler', deadline='1.0'):
_ = engine.sampler(processor_id='tmp', gate_set=cg.XMON)
_ = engine.sampler(processor_id='tmp')


@mock.patch('cirq_google.engine.client.quantum.QuantumEngineServiceClient')
Expand Down
1 change: 1 addition & 0 deletions cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
'CircuitOpDeserializer',
'CircuitOpSerializer',
'CircuitSerializer',
'CIRCUIT_SERIALIZER',
'CircuitWithCalibration',
'ConvertToSqrtIswapGates',
'ConvertToSycamoreGates',
Expand Down
Loading

0 comments on commit 72f9a5b

Please sign in to comment.