Skip to content

Commit

Permalink
Support complex parameters (#5192)
Browse files Browse the repository at this point in the history
* allow complex parameters

* fix tests

* update conjugate + test

* allow ParameterExpression in Hamiltonian gate

* add raise_if_nan and more tests

* add __complex__

Co-authored-by: Kevin Krsulich <kevin.krsulich@ibm.com>
  • Loading branch information
Cryoris and kdk authored Nov 24, 2020
1 parent d26f273 commit da66b1f
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 29 deletions.
9 changes: 8 additions & 1 deletion qiskit/circuit/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,14 @@ def broadcast_arguments(self, qargs: List, cargs: List) -> Tuple[List, List]:

def validate_parameter(self, parameter):
"""Gate parameters should be int, float, or ParameterExpression"""
if isinstance(parameter, (int, float, ParameterExpression)):
if isinstance(parameter, ParameterExpression):
if len(parameter.parameters) > 0:
return parameter # expression has free parameters, we cannot validate it
if not parameter._symbol_expr.is_real:
raise CircuitError("Bound parameter expression is complex in gate {}".format(
self.name))
return parameter # per default assume parameters must be real when bound
if isinstance(parameter, (int, float)):
return parameter
elif isinstance(parameter, (np.integer, np.floating)):
return parameter.item()
Expand Down
27 changes: 17 additions & 10 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def parameters(self) -> Set:
return self._parameters

def conjugate(self) -> 'ParameterExpression':
"""Return the conjugate, which is the ParameterExpression itself, since it is real."""
return self
"""Return the conjugate."""
conjugated = ParameterExpression(self._parameter_symbols, self._symbol_expr.conjugate())
return conjugated

def assign(self, parameter, value: ParameterValueType) -> 'ParameterExpression':
"""
Expand Down Expand Up @@ -91,7 +92,7 @@ def bind(self, parameter_values: Dict) -> 'ParameterExpression':
"""

self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_non_real_value(parameter_values)
self._raise_if_passed_nan(parameter_values)

symbol_values = {self._parameter_symbols[parameter]: value
for parameter, value in parameter_values.items()}
Expand Down Expand Up @@ -166,12 +167,12 @@ def _raise_if_passed_unknown_parameters(self, parameters):
raise CircuitError('Cannot bind Parameters ({}) not present in '
'expression.'.format([str(p) for p in unknown_parameters]))

def _raise_if_passed_non_real_value(self, parameter_values):
nonreal_parameter_values = {p: v for p, v in parameter_values.items()
if not isinstance(v, numbers.Real)}
if nonreal_parameter_values:
raise CircuitError('Expression cannot bind non-real or non-numeric '
'values ({}).'.format(nonreal_parameter_values))
def _raise_if_passed_nan(self, parameter_values):
nan_parameter_values = {p: v for p, v in parameter_values.items()
if not isinstance(v, numbers.Number)}
if nan_parameter_values:
raise CircuitError('Expression cannot bind non-numeric values ({})'.format(
nan_parameter_values))

def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parameters=None):
if outbound_parameters is None:
Expand Down Expand Up @@ -219,7 +220,7 @@ def _apply_operation(self, operation: Callable,

parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
other_expr = other._symbol_expr
elif isinstance(other, numbers.Real) and numpy.isfinite(other):
elif isinstance(other, numbers.Number) and numpy.isfinite(other):
parameter_symbols = self._parameter_symbols.copy()
other_expr = other
else:
Expand Down Expand Up @@ -319,6 +320,12 @@ def __float__(self):
'cannot be cast to a float.'.format(self.parameters))
return float(self._symbol_expr)

def __complex__(self):
if self.parameters:
raise TypeError('ParameterExpression with unbound parameters ({}) '
'cannot be cast to a complex.'.format(self.parameters))
return complex(self._symbol_expr)

def __int__(self):
if self.parameters:
raise TypeError('ParameterExpression with unbound parameters ({}) '
Expand Down
8 changes: 7 additions & 1 deletion qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,7 +2031,13 @@ def _assign_parameter(self, parameter, value):
replace instances of ``parameter``.
"""
for instr, param_index in self._parameter_table[parameter]:
instr.params[param_index] = instr.params[param_index].assign(parameter, value)
new_param = instr.params[param_index].assign(parameter, value)
# if fully bound, validate
if len(new_param.parameters) == 0:
instr.params[param_index] = instr.validate_parameter(new_param)
else:
instr.params[param_index] = new_param

self._rebind_definition(instr, parameter, value)

if isinstance(value, ParameterExpression):
Expand Down
4 changes: 3 additions & 1 deletion qiskit/extensions/hamiltonian_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy
import scipy.linalg

from qiskit.circuit import Gate, QuantumCircuit, QuantumRegister
from qiskit.circuit import Gate, QuantumCircuit, QuantumRegister, ParameterExpression
from qiskit.quantum_info.operators.predicates import matrix_equal
from qiskit.quantum_info.operators.predicates import is_hermitian_matrix
from qiskit.extensions.exceptions import ExtensionError
Expand Down Expand Up @@ -118,6 +118,8 @@ def validate_parameter(self, parameter):
"""Hamiltonian parameter has to be an ndarray, operator or float."""
if isinstance(parameter, (float, int, numpy.ndarray)):
return parameter
elif isinstance(parameter, ParameterExpression) and len(parameter.parameters) == 0:
return float(parameter)
else:
raise CircuitError("invalid param type {0} for gate "
"{1}".format(type(parameter), self.name))
Expand Down
71 changes: 55 additions & 16 deletions test/python/circuit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_parameter_name_conflicts_raises(self):

qc.p(theta1, 0)

self.assertRaises(CircuitError, qc.u1, theta2, 0)
self.assertRaises(CircuitError, qc.p, theta2, 0)

def test_bind_ryrz_vector(self):
"""Test binding a list of floats to a ParameterVector"""
Expand Down Expand Up @@ -903,9 +903,9 @@ def test_circuit_with_ufunc(self):
theta = Parameter(name='theta')

qc = QuantumCircuit(2)
qc.u1(numpy.cos(phi), 0)
qc.u1(numpy.sin(phi), 0)
qc.u1(numpy.tan(phi), 0)
qc.p(numpy.cos(phi), 0)
qc.p(numpy.sin(phi), 0)
qc.p(numpy.tan(phi), 0)
qc.rz(numpy.arccos(theta), 1)
qc.rz(numpy.arctan(theta), 1)
qc.rz(numpy.arcsin(theta), 1)
Expand All @@ -914,9 +914,9 @@ def test_circuit_with_ufunc(self):
inplace=True)

qc_ref = QuantumCircuit(2)
qc_ref.u1(-1, 0)
qc_ref.u1(0, 0)
qc_ref.u1(0, 0)
qc_ref.p(-1, 0)
qc_ref.p(0, 0)
qc_ref.p(0, 0)
qc_ref.rz(0, 1)
qc_ref.rz(pi / 4, 1)
qc_ref.rz(pi / 2, 1)
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def test_raise_if_subbing_in_parameter_name_conflict(self):
def test_expressions_of_parameter_with_constant(self):
"""Verify operating on a Parameter with a constant."""

good_constants = [2, 1.3, 0, -1, -1.0, numpy.pi]
good_constants = [2, 1.3, 0, -1, -1.0, numpy.pi, 1j]

x = Parameter('x')

Expand All @@ -1025,7 +1025,7 @@ def test_expressions_of_parameter_with_constant(self):
expr = op(const, x)
bound_expr = expr.bind({x: 2.3})

self.assertEqual(float(bound_expr),
self.assertEqual(complex(bound_expr),
op(const, 2.3))

# Division by zero will raise. Tested elsewhere.
Expand All @@ -1036,23 +1036,62 @@ def test_expressions_of_parameter_with_constant(self):
expr = op(x, const)
bound_expr = expr.bind({x: 2.3})

self.assertEqual(float(bound_expr),
self.assertEqual(complex(bound_expr),
op(2.3, const))

def test_complex_parameter_bound_to_real(self):
"""Test a complex parameter expression can be real if bound correctly."""

x, y = Parameter('x'), Parameter('y')

with self.subTest('simple 1j * x'):
qc = QuantumCircuit(1)
qc.rx(1j * x, 0)
bound = qc.bind_parameters({x: 1j})
ref = QuantumCircuit(1)
ref.rx(-1, 0)
self.assertEqual(bound, ref)

with self.subTest('more complex expression'):
qc = QuantumCircuit(1)
qc.rx(0.5j * x - y * y + 2 * y, 0)
bound = qc.bind_parameters({x: -4, y: 1j})
ref = QuantumCircuit(1)
ref.rx(1, 0)
self.assertEqual(bound, ref)

def test_complex_angle_raises_when_not_supported(self):
"""Test parameters are validated when fully bound and errors are raised accordingly."""
x = Parameter('x')
qc = QuantumCircuit(1)
qc.r(x, 1j * x, 0)

with self.subTest('binding x to 0 yields real parameters'):
bound = qc.bind_parameters({x: 0})
ref = QuantumCircuit(1)
ref.r(0, 0, 0)
self.assertEqual(bound, ref)

with self.subTest('binding x to 1 yields complex parameters'):
# RGate does not support complex parameters
with self.assertRaises(CircuitError):
bound = qc.bind_parameters({x: 1})

def test_operating_on_a_parameter_with_a_non_float_will_raise(self):
"""Verify operations between a Parameter and a non-float will raise."""

bad_constants = [1j, '1', numpy.Inf, numpy.NaN, None, {}, []]
bad_constants = ['1', numpy.Inf, numpy.NaN, None, {}, []]

x = Parameter('x')

for op in self.supported_operations:
for const in bad_constants:
with self.assertRaises(TypeError):
_ = op(const, x)
with self.subTest(op=op, const=const):
with self.assertRaises(TypeError):
_ = op(const, x)

with self.assertRaises(TypeError):
_ = op(x, const)
with self.assertRaises(TypeError):
_ = op(x, const)

def test_expressions_division_by_zero(self):
"""Verify dividing a Parameter by 0, or binding 0 as a denominator raises."""
Expand Down Expand Up @@ -1386,7 +1425,7 @@ def test_substituting_compound_expression(self):
def test_conjugate(self):
"""Test calling conjugate on a ParameterExpression."""
x = Parameter('x')
self.assertEqual(x, x.conjugate()) # Parameters are real, therefore conjugate returns self
self.assertEqual((x.conjugate() + 1j), (x - 1j).conjugate())

@data(circlib.RGate, circlib.RXGate, circlib.RYGate, circlib.RZGate, circlib.RXXGate,
circlib.RYYGate, circlib.RZXGate, circlib.RZZGate, circlib.CRXGate, circlib.CRYGate,
Expand Down

0 comments on commit da66b1f

Please sign in to comment.