Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProcessorSampler: route run_batch to run_sweep #6357

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions cirq-google/cirq_google/engine/processor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,27 +76,6 @@ async def run_batch_async(
params_list: Optional[Sequence[cirq.Sweepable]] = None,
repetitions: Union[int, Sequence[int]] = 1,
) -> Sequence[Sequence['cg.EngineResult']]:
"""Runs the supplied circuits.

In order to gain a speedup from using this method instead of other run
methods, the following conditions must be satisfied:
1. All circuits must measure the same set of qubits.
2. The number of circuit repetitions must be the same for all
circuits. That is, the `repetitions` argument must be an integer,
or else a list with identical values.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
if len(set(repetitions)) == 1:
# All repetitions are the same so batching can be done efficiently
job = await self._processor.run_batch_async(
programs=programs,
params_list=params_list,
repetitions=repetitions[0],
run_name=self._run_name,
device_config_name=self._device_config_name,
)
return await job.batched_results_async()
# Varying number of repetitions so no speedup
return cast(
Sequence[Sequence['cg.EngineResult']],
await super().run_batch_async(programs, params_list, repetitions),
Expand Down
64 changes: 44 additions & 20 deletions cirq-google/cirq_google/engine/processor_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,28 @@ def test_run_batch(run_name, device_config_name):
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, 5)
processor.run_batch_async.assert_called_with(
params_list=params_list,
programs=circuits,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
)

sampler.run_batch([circuit1, circuit2], [params1, params2], 5)

expected_calls = [
mock.call(
program=circuit1,
params=params1,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
mock.call(
program=circuit2,
params=params2,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
]
processor.run_sweep_async.assert_has_calls(expected_calls)


@pytest.mark.parametrize(
Expand All @@ -79,16 +91,28 @@ def test_run_batch_identical_repetitions(run_name, device_config_name):
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, [5, 5])
processor.run_batch_async.assert_called_with(
params_list=params_list,
programs=circuits,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
)

sampler.run_batch([circuit1, circuit2], [params1, params2], [5, 5])

expected_calls = [
mock.call(
program=circuit1,
params=params1,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
mock.call(
program=circuit2,
params=params2,
repetitions=5,
run_name=run_name,
device_config_name=device_config_name,
),
mock.call().results_async(),
]
processor.run_sweep_async.assert_has_calls(expected_calls)


def test_run_batch_bad_number_of_repetitions():
Expand Down
Loading