Skip to content

Commit

Permalink
Support FrozenCircuit incg.Engine (#4731)
Browse files Browse the repository at this point in the history
* Support FrozenCircuit in QuantumEngineSampler.run_sweep

* Support FrozenCircuit cg.Engine
  • Loading branch information
tanujkhattar authored Dec 7, 2021
1 parent a0497a4 commit 5ee7ff6
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
12 changes: 7 additions & 5 deletions cirq-google/cirq_google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __str__(self) -> str:

def run(
self,
program: cirq.Circuit,
program: cirq.AbstractCircuit,
program_id: Optional[str] = None,
job_id: Optional[str] = None,
param_resolver: cirq.ParamResolver = cirq.ParamResolver({}),
Expand Down Expand Up @@ -253,7 +253,7 @@ def run(

def run_sweep(
self,
program: cirq.Circuit,
program: cirq.AbstractCircuit,
program_id: Optional[str] = None,
job_id: Optional[str] = None,
params: cirq.Sweepable = None,
Expand Down Expand Up @@ -475,7 +475,7 @@ def run_calibration(

def create_program(
self,
program: cirq.Circuit,
program: cirq.AbstractCircuit,
program_id: Optional[str] = None,
gate_set: Optional[Serializer] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -629,8 +629,10 @@ def create_calibration_program(
result_type=ResultType.Calibration,
)

def _serialize_program(self, program: cirq.Circuit, gate_set: Serializer) -> any_pb2.Any:
if not isinstance(program, cirq.Circuit):
def _serialize_program(
self, program: cirq.AbstractCircuit, gate_set: Serializer
) -> any_pb2.Any:
if not isinstance(program, cirq.AbstractCircuit):
raise TypeError(f'Unrecognized program type: {type(program)}')
program.device.validate_circuit(program)

Expand Down
4 changes: 2 additions & 2 deletions cirq-google/cirq_google/engine/engine_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast, List, Optional, Sequence, TYPE_CHECKING, Union
from typing import List, Optional, Sequence, TYPE_CHECKING, Union

import cirq
from cirq_google import engine
Expand Down Expand Up @@ -60,7 +60,7 @@ def run_sweep(
)
else:
job = self._engine.run_sweep(
program=cast(cirq.Circuit, program),
program=program,
params=params,
repetitions=repetitions,
processor_ids=self._processor_ids,
Expand Down
10 changes: 7 additions & 3 deletions cirq-google/cirq_google/engine/engine_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import cirq_google.engine.client.quantum


def test_run_circuit():
@pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()])
def test_run_circuit(circuit):
engine = mock.Mock()
sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp', gate_set=cg.XMON)
circuit = cirq.Circuit()
params = [cirq.ParamResolver({'a': 1})]
sampler.run_sweep(circuit, params, 5)
engine.run_sweep.assert_called_with(
gate_set=cg.XMON, params=params, processor_ids=['tmp'], program=circuit, repetitions=5
gate_set=cg.XMON,
params=params,
processor_ids=['tmp'],
program=circuit,
repetitions=5,
)


Expand Down
2 changes: 1 addition & 1 deletion cirq-google/cirq_google/engine/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)


_CIRCUIT2 = cirq.Circuit(
_CIRCUIT2 = cirq.FrozenCircuit(
cirq.Y(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result')
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def serialize(
if msg is None:
msg = v2.program_pb2.Program()
msg.language.gate_set = self.name
if isinstance(program, cirq.Circuit):
if isinstance(program, cirq.AbstractCircuit):
constants: Optional[List[v2.program_pb2.Constant]] = [] if use_constants else None
raw_constants: Optional[Dict[Any, int]] = {} if use_constants else None
self._serialize_circuit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def test_is_supported_operation_can_serialize_predicate():
def test_serialize_deserialize_circuit():
q0 = cirq.GridQubit(1, 1)
q1 = cirq.GridQubit(1, 2)
circuit = cirq.Circuit(cirq.X(q0), cirq.X(q1), cirq.X(q0))

circuit_base = cirq.Circuit(cirq.X(q0), cirq.X(q1), cirq.X(q0))
proto = v2.program_pb2.Program(
language=v2.program_pb2.Language(arg_function_language='', gate_set='my_gate_set'),
circuit=v2.program_pb2.Circuit(
Expand All @@ -174,8 +173,9 @@ def test_serialize_deserialize_circuit():
],
),
)
assert proto == MY_GATE_SET.serialize(circuit)
assert MY_GATE_SET.deserialize(proto) == circuit
for circuit in [circuit_base, circuit_base.freeze()]:
assert proto == MY_GATE_SET.serialize(circuit)
assert MY_GATE_SET.deserialize(proto) == circuit


def test_serialize_deserialize_circuit_with_tokens():
Expand Down

0 comments on commit 5ee7ff6

Please sign in to comment.