Skip to content

Commit

Permalink
Update Sampler.run_async signature to match Sampler.run (#5036)
Browse files Browse the repository at this point in the history
Also change `run_async` to delegate to `run_sweep_async` by default instead of `run`. This simplifies implementing an async sampler since only `run_sweep_async` needs to be overridden, instead of both `run_sweep_async` and `run_async`.

@verult
  • Loading branch information
maffoo authored Feb 28, 2022
1 parent 8bc6915 commit 8d65806
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 30 deletions.
55 changes: 25 additions & 30 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,23 @@ def run(
) -> 'cirq.Result':
"""Samples from the given Circuit.
By default, the `run_async` method invokes this method on another
thread. So this method is supposed to be thread safe.
Args:
program: The circuit to sample from.
param_resolver: Parameters to run with the program.
repetitions: The number of times to sample.
Returns:
Result for a run.
"""
return self.run_sweep(program, param_resolver, repetitions)[0]

async def run_async(
self,
program: 'cirq.AbstractCircuit',
param_resolver: 'cirq.ParamResolverOrSimilarType' = None,
repetitions: int = 1,
) -> 'cirq.Result':
"""Asynchronously samples from the given Circuit.
Args:
program: The circuit to sample from.
Expand All @@ -53,7 +68,8 @@ def run(
Returns:
Result for a run.
"""
return self.run_sweep(program, study.ParamResolver(param_resolver), repetitions)[0]
results = await self.run_sweep_async(program, param_resolver, repetitions)
return results[0]

def sample(
self,
Expand Down Expand Up @@ -152,58 +168,37 @@ def run_sweep(
) -> Sequence['cirq.Result']:
"""Samples from the given Circuit.
In contrast to run, this allows for sweeping over different parameter
values.
This allows for sweeping over different parameter values,
unlike the `run` method.
Args:
program: The circuit to sample from.
params: Parameters to run with the program.
repetitions: The number of times to sample.
Returns:
Result list for this run; one for each possible parameter
resolver.
Result list for this run; one for each possible parameter resolver.
"""

async def run_async(
self, program: 'cirq.AbstractCircuit', *, repetitions: int
) -> 'cirq.Result':
"""Asynchronously samples from the given Circuit.
By default, this method invokes `run` synchronously and simply exposes
its result is an awaitable. Child classes that are capable of true
asynchronous sampling should override it to use other strategies.
Args:
program: The circuit to sample from.
repetitions: The number of times to sample.
Returns:
An awaitable Result.
"""
return self.run(program, repetitions=repetitions)

async def run_sweep_async(
self,
program: 'cirq.AbstractCircuit',
params: 'cirq.Sweepable',
repetitions: int = 1,
) -> Sequence['cirq.Result']:
"""Asynchronously sweeps and samples from the given Circuit.
"""Asynchronously samples from the given Circuit.
By default, this method invokes `run_sweep` synchronously and simply
exposes its result is an awaitable. Child classes that are capable of
true asynchronous sampling should override it to use other strategies.
Args:
program: The circuit to sample from.
params: One or more mappings from parameter keys to parameter values
to use. For each parameter assignment, `repetitions` samples
will be taken.
params: Parameters to run with the program.
repetitions: The number of times to sample.
Returns:
An awaitable Result.
Result list for this run; one for each possible parameter resolver.
"""
return self.run_sweep(program, params=params, repetitions=repetitions)

Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@
import cirq


@duet.sync
async def test_run_async():
sim = cirq.Simulator()
result = await sim.run_async(
cirq.Circuit(cirq.measure(cirq.GridQubit(0, 0), key='m')), repetitions=10
)
np.testing.assert_equal(result.records['m'], np.zeros((10, 1, 1)))


@duet.sync
async def test_run_sweep_async():
sim = cirq.Simulator()
results = await sim.run_sweep_async(
cirq.Circuit(cirq.measure(cirq.GridQubit(0, 0), key='m')),
cirq.Linspace('foo', 0, 1, 10),
repetitions=10,
)
assert len(results) == 10
for result in results:
np.testing.assert_equal(result.records['m'], np.zeros((10, 1, 1)))


@duet.sync
async def test_sampler_async_fail():
class FailingSampler(cirq.Sampler):
Expand Down

0 comments on commit 8d65806

Please sign in to comment.