From 65d945ecb760d0d424b7140a095e309a9301dc8d Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 22 Nov 2023 01:33:53 +0000 Subject: [PATCH] ProcessorSampler: route run_batch to run_sweep --- .../cirq_google/engine/processor_sampler.py | 21 ------ .../engine/processor_sampler_test.py | 64 +++++++++++++------ 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/cirq-google/cirq_google/engine/processor_sampler.py b/cirq-google/cirq_google/engine/processor_sampler.py index 305fafb7c19..ecf18230d56 100644 --- a/cirq-google/cirq_google/engine/processor_sampler.py +++ b/cirq-google/cirq_google/engine/processor_sampler.py @@ -76,27 +76,6 @@ async def run_batch_async( params_list: Optional[Sequence[cirq.Sweepable]] = None, repetitions: Union[int, Sequence[int]] = 1, ) -> Sequence[Sequence['cg.EngineResult']]: - """Runs the supplied circuits. - - In order to gain a speedup from using this method instead of other run - methods, the following conditions must be satisfied: - 1. All circuits must measure the same set of qubits. - 2. The number of circuit repetitions must be the same for all - circuits. That is, the `repetitions` argument must be an integer, - or else a list with identical values. - """ - params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions) - if len(set(repetitions)) == 1: - # All repetitions are the same so batching can be done efficiently - job = await self._processor.run_batch_async( - programs=programs, - params_list=params_list, - repetitions=repetitions[0], - run_name=self._run_name, - device_config_name=self._device_config_name, - ) - return await job.batched_results_async() - # Varying number of repetitions so no speedup return cast( Sequence[Sequence['cg.EngineResult']], await super().run_batch_async(programs, params_list, repetitions), diff --git a/cirq-google/cirq_google/engine/processor_sampler_test.py b/cirq-google/cirq_google/engine/processor_sampler_test.py index b3a64d5ca5a..d7f8232e786 100644 --- a/cirq-google/cirq_google/engine/processor_sampler_test.py +++ b/cirq-google/cirq_google/engine/processor_sampler_test.py @@ -54,16 +54,28 @@ def test_run_batch(run_name, device_config_name): circuit2 = cirq.Circuit(cirq.Y(a)) params1 = [cirq.ParamResolver({'t': 1})] params2 = [cirq.ParamResolver({'t': 2})] - circuits = [circuit1, circuit2] - params_list = [params1, params2] - sampler.run_batch(circuits, params_list, 5) - processor.run_batch_async.assert_called_with( - params_list=params_list, - programs=circuits, - repetitions=5, - run_name=run_name, - device_config_name=device_config_name, - ) + + sampler.run_batch([circuit1, circuit2], [params1, params2], 5) + + expected_calls = [ + mock.call( + program=circuit1, + params=params1, + repetitions=5, + run_name=run_name, + device_config_name=device_config_name, + ), + mock.call().results_async(), + mock.call( + program=circuit2, + params=params2, + repetitions=5, + run_name=run_name, + device_config_name=device_config_name, + ), + mock.call().results_async(), + ] + processor.run_sweep_async.assert_has_calls(expected_calls) @pytest.mark.parametrize( @@ -79,16 +91,28 @@ def test_run_batch_identical_repetitions(run_name, device_config_name): circuit2 = cirq.Circuit(cirq.Y(a)) params1 = [cirq.ParamResolver({'t': 1})] params2 = [cirq.ParamResolver({'t': 2})] - circuits = [circuit1, circuit2] - params_list = [params1, params2] - sampler.run_batch(circuits, params_list, [5, 5]) - processor.run_batch_async.assert_called_with( - params_list=params_list, - programs=circuits, - repetitions=5, - run_name=run_name, - device_config_name=device_config_name, - ) + + sampler.run_batch([circuit1, circuit2], [params1, params2], [5, 5]) + + expected_calls = [ + mock.call( + program=circuit1, + params=params1, + repetitions=5, + run_name=run_name, + device_config_name=device_config_name, + ), + mock.call().results_async(), + mock.call( + program=circuit2, + params=params2, + repetitions=5, + run_name=run_name, + device_config_name=device_config_name, + ), + mock.call().results_async(), + ] + processor.run_sweep_async.assert_has_calls(expected_calls) def test_run_batch_bad_number_of_repetitions():