diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 70215ad2d..96f118850 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -55,6 +55,9 @@ information about the results. [(#493)](https://github.com/XanaduAI/strawberryfields/pull/493) +* `TDMProgram.run_options` is now correctly used when running a TDM program. + [(#500)](https://github.com/XanaduAI/strawberryfields/pull/500) + * Fixes a bug where a single parameter list passed to the `TDMProgram` context results in an error. [(#503)](https://github.com/XanaduAI/strawberryfields/pull/503) diff --git a/strawberryfields/engine.py b/strawberryfields/engine.py index 82a7343ea..eb349fd4e 100644 --- a/strawberryfields/engine.py +++ b/strawberryfields/engine.py @@ -424,8 +424,9 @@ def run(self, program, *, args=None, compile_options=None, **kwargs): not isinstance(program, collections.abc.Sequence) and program.type == "tdm" ) if valid_tdm_program: + # priority order for the shots value should be kwargs > run_options > 1 + shots = kwargs.get("shots", program.run_options.get("shots", 1)) - shots = kwargs.get("shots", 1) program.unroll(shots=shots) # Shots >1 for a TDM program simply corresponds to creating # multiple copies of the program, and appending them to run sequentially. diff --git a/tests/frontend/test_engine.py b/tests/frontend/test_engine.py index ae3e6792c..c80a6fc9a 100644 --- a/tests/frontend/test_engine.py +++ b/tests/frontend/test_engine.py @@ -69,6 +69,46 @@ def test_bad_backend(self): class TestEngineProgramInteraction: """Test the Engine class and its interaction with Program instances.""" + def test_shots_default(self): + """Test that default shots (1) is used""" + prog = sf.Program(1) + eng = sf.Engine("gaussian") + + with prog.context as q: + ops.Sgate(0.5) | q[0] + ops.MeasureFock() | q + + results = eng.run(prog) + assert results.samples.shape[0] == 1 + + def test_shots_run_options(self): + """Test that run_options takes precedence over default""" + prog = sf.Program(1) + eng = sf.Engine("gaussian") + + with prog.context as q: + ops.Sgate(0.5) | q[0] + ops.MeasureFock() | q + + prog.run_options = {"shots": 5} + results = eng.run(prog) + assert results.samples.shape[0] == 5 + + def test_shots_passed(self): + """Test that shots supplied via eng.run takes precedence over + run_options and that run_options isn't changed""" + prog = sf.Program(1) + eng = sf.Engine("gaussian") + + with prog.context as q: + ops.Sgate(0.5) | q[0] + ops.MeasureFock() | q + + prog.run_options = {"shots": 5} + results = eng.run(prog, shots=2) + assert results.samples.shape[0] == 2 + assert prog.run_options["shots"] == 5 + def test_history(self, eng, prog): """Engine history.""" # no programs have been run diff --git a/tests/frontend/test_tdmprogram.py b/tests/frontend/test_tdmprogram.py index 2bf21dcfe..22c2cf62f 100644 --- a/tests/frontend/test_tdmprogram.py +++ b/tests/frontend/test_tdmprogram.py @@ -724,3 +724,46 @@ def test_move_vac_modes(self, N, crop, expected): res = move_vac_modes(samples, N, crop=crop) assert np.all(res == expected) + +class TestEngineTDMProgramInteraction: + """Test the Engine class and its interaction with TDMProgram instances.""" + + def test_shots_default(self): + """Test that default shots (1) is used""" + prog = sf.TDMProgram(2) + eng = sf.Engine("gaussian") + + with prog.context([1,2], [3,4]) as (p, q): + ops.Sgate(p[0]) | q[0] + ops.MeasureHomodyne(p[1]) | q[0] + + results = eng.run(prog) + assert results.samples.shape[0] == 1 + + def test_shots_run_options(self): + """Test that run_options takes precedence over default""" + prog = sf.TDMProgram(2) + eng = sf.Engine("gaussian") + + with prog.context([1,2], [3,4]) as (p, q): + ops.Sgate(p[0]) | q[0] + ops.MeasureHomodyne(p[1]) | q[0] + + prog.run_options = {"shots": 5} + results = eng.run(prog) + assert results.samples.shape[0] == 5 + + def test_shots_passed(self): + """Test that shots supplied via eng.run takes precedence over + run_options and that run_options isn't changed""" + prog = sf.TDMProgram(2) + eng = sf.Engine("gaussian") + + with prog.context([1,2], [3,4]) as (p, q): + ops.Sgate(p[0]) | q[0] + ops.MeasureHomodyne(p[1]) | q[0] + + prog.run_options = {"shots": 5} + results = eng.run(prog, shots=2) + assert results.samples.shape[0] == 2 + assert prog.run_options["shots"] == 5