Skip to content

Commit

Permalink
Check for complex arguments for symbolic parameters (#439)
Browse files Browse the repository at this point in the history
* Moving the complex check to _apply such that symbolic arguments are evaluated; adding test for MeasureHD

* Adjust docstring

* Updating previous test case so that the program is run on an engine

* Adding TF checks by using the numpy attribute

* Fromatting

Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
antalszava and josh146 authored Aug 6, 2020
1 parent 5449093 commit e38de6e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
33 changes: 22 additions & 11 deletions strawberryfields/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,17 @@ class Coherent(Preparation):
"""

def __init__(self, r=0.0, phi=0.0):
if (np.iscomplex([r, phi])).any():
raise ValueError("The arguments of Coherent(r, phi) cannot be complex")

super().__init__([r, phi])

def _apply(self, reg, backend, **kwargs):
r = par_evaluate(self.p[0])
phi = par_evaluate(self.p[1])

tf_complex = any(hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [r, phi])

if (np.iscomplex([r, phi])).any() or tf_complex:
raise ValueError("The arguments of Coherent(r, phi) cannot be complex")

backend.prepare_coherent_state(r, phi, *reg)


Expand Down Expand Up @@ -714,15 +717,20 @@ class DisplacedSqueezed(Preparation):
"""

def __init__(self, r_d=0.0, phi_d=0.0, r_s=0.0, phi_s=0.0):
if (np.iscomplex([r_d, phi_d, r_s, phi_s])).any():
raise ValueError(
"The arguments of DisplacedSqueezed(r_d, phi_d, r_s, phi_s) cannot be complex"
)

super().__init__([r_d, phi_d, r_s, phi_s])

def _apply(self, reg, backend, **kwargs):
p = par_evaluate(self.p)

tf_complex = any(
hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [p[0], p[1], p[2], p[3]]
)

if (np.iscomplex([p[0], p[1], p[2], p[3]])).any() or tf_complex:
raise ValueError(
"The arguments of DisplacedSqueezed(r_d, phi_d, r_s, phi_s) cannot be complex"
)

# prepare the displaced squeezed state directly
backend.prepare_displaced_squeezed_state(p[0], p[1], p[2], p[3], *reg)

Expand Down Expand Up @@ -1324,13 +1332,16 @@ class Dgate(Gate):
"""

def __init__(self, r, phi=0.0):
if (np.iscomplex([r, phi])).any():
raise ValueError("The arguments of Dgate(r, phi) cannot be complex")

super().__init__([r, phi])

def _apply(self, reg, backend, **kwargs):
r, phi = par_evaluate(self.p)

tf_complex = any(hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [r, phi])

if (np.iscomplex([r, phi])).any() or tf_complex:
raise ValueError("The arguments of Dgate(r, phi) cannot be complex")

backend.displacement(r, phi, *reg)


Expand Down
20 changes: 20 additions & 0 deletions tests/frontend/test_ops_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import strawberryfields.program_utils as pu

from strawberryfields import Engine
from strawberryfields import ops
from strawberryfields.program import Program
from strawberryfields.program_utils import MergeFailure, RegRefError
Expand Down Expand Up @@ -176,6 +177,25 @@ def test_complex_first_argument_error(self, gate):
with prog.context as q:
gate(0.2+1j) | q

eng = Engine("gaussian")
res = eng.run(prog)

def test_complex_symbolic(self, gate):
"""Test that passing a complex value to symbolic parameter of a gate
that previously accepted complex parameters raises an error.
An example here is testing heterodyne measurements.
"""
with pytest.raises(ValueError, match="cannot be complex"):

prog = Program(1)

with prog.context as q:
ops.MeasureHD | q[0]
gate(q[0].par) | q

eng = Engine("gaussian")
res = eng.run(prog)

def test_merge_measured_pars():
"""Test merging two gates with measured parameters."""
Expand Down
18 changes: 18 additions & 0 deletions tests/frontend/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,21 @@ def test_zgate_decompose(self, backend, hbar, applied_cmds):
assert isinstance(applied_cmds[0].op, sf.ops.Dgate)
assert par_evaluate(applied_cmds[0].op.p[0]) == mapping["p"] / np.sqrt(2 * hbar)
assert applied_cmds[0].op.p[1] == np.pi/2

@pytest.mark.parametrize("gate", [sf.ops.Dgate, sf.ops.Coherent, sf.ops.DisplacedSqueezed])
def test_complex_symbolic_tf(self, gate):
"""Test that passing a TF Variable to gates that previously accepted
complex parameters raises an error when using the TF backend."""
import tensorflow as tf
with pytest.raises(ValueError, match="cannot be complex"):

prog = sf.Program(1)
alpha = prog.params("alpha")

with prog.context as q:
gate(alpha) | q[0]

eng = sf.Engine("tf", backend_options={"cutoff_dim":5})

with tf.GradientTape() as tape:
res = eng.run(prog, args={"alpha": tf.Variable(0.5+1j)})

0 comments on commit e38de6e

Please sign in to comment.