From da66b1fba0cf538e387d22ec86898bbc894c0daf Mon Sep 17 00:00:00 2001 From: Julien Gacon Date: Tue, 24 Nov 2020 11:14:58 +0100 Subject: [PATCH] Support complex parameters (#5192) * allow complex parameters * fix tests * update conjugate + test * allow ParameterExpression in Hamiltonian gate * add raise_if_nan and more tests * add __complex__ Co-authored-by: Kevin Krsulich --- qiskit/circuit/gate.py | 9 +++- qiskit/circuit/parameterexpression.py | 27 ++++++---- qiskit/circuit/quantumcircuit.py | 8 ++- qiskit/extensions/hamiltonian_gate.py | 4 +- test/python/circuit/test_parameters.py | 71 ++++++++++++++++++++------ 5 files changed, 90 insertions(+), 29 deletions(-) diff --git a/qiskit/circuit/gate.py b/qiskit/circuit/gate.py index 3039e08490e4..88740f7807ff 100644 --- a/qiskit/circuit/gate.py +++ b/qiskit/circuit/gate.py @@ -227,7 +227,14 @@ def broadcast_arguments(self, qargs: List, cargs: List) -> Tuple[List, List]: def validate_parameter(self, parameter): """Gate parameters should be int, float, or ParameterExpression""" - if isinstance(parameter, (int, float, ParameterExpression)): + if isinstance(parameter, ParameterExpression): + if len(parameter.parameters) > 0: + return parameter # expression has free parameters, we cannot validate it + if not parameter._symbol_expr.is_real: + raise CircuitError("Bound parameter expression is complex in gate {}".format( + self.name)) + return parameter # per default assume parameters must be real when bound + if isinstance(parameter, (int, float)): return parameter elif isinstance(parameter, (np.integer, np.floating)): return parameter.item() diff --git a/qiskit/circuit/parameterexpression.py b/qiskit/circuit/parameterexpression.py index 78d3395c219b..86b947732d68 100644 --- a/qiskit/circuit/parameterexpression.py +++ b/qiskit/circuit/parameterexpression.py @@ -52,8 +52,9 @@ def parameters(self) -> Set: return self._parameters def conjugate(self) -> 'ParameterExpression': - """Return the conjugate, which is the ParameterExpression itself, since it is real.""" - return self + """Return the conjugate.""" + conjugated = ParameterExpression(self._parameter_symbols, self._symbol_expr.conjugate()) + return conjugated def assign(self, parameter, value: ParameterValueType) -> 'ParameterExpression': """ @@ -91,7 +92,7 @@ def bind(self, parameter_values: Dict) -> 'ParameterExpression': """ self._raise_if_passed_unknown_parameters(parameter_values.keys()) - self._raise_if_passed_non_real_value(parameter_values) + self._raise_if_passed_nan(parameter_values) symbol_values = {self._parameter_symbols[parameter]: value for parameter, value in parameter_values.items()} @@ -166,12 +167,12 @@ def _raise_if_passed_unknown_parameters(self, parameters): raise CircuitError('Cannot bind Parameters ({}) not present in ' 'expression.'.format([str(p) for p in unknown_parameters])) - def _raise_if_passed_non_real_value(self, parameter_values): - nonreal_parameter_values = {p: v for p, v in parameter_values.items() - if not isinstance(v, numbers.Real)} - if nonreal_parameter_values: - raise CircuitError('Expression cannot bind non-real or non-numeric ' - 'values ({}).'.format(nonreal_parameter_values)) + def _raise_if_passed_nan(self, parameter_values): + nan_parameter_values = {p: v for p, v in parameter_values.items() + if not isinstance(v, numbers.Number)} + if nan_parameter_values: + raise CircuitError('Expression cannot bind non-numeric values ({})'.format( + nan_parameter_values)) def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parameters=None): if outbound_parameters is None: @@ -219,7 +220,7 @@ def _apply_operation(self, operation: Callable, parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols} other_expr = other._symbol_expr - elif isinstance(other, numbers.Real) and numpy.isfinite(other): + elif isinstance(other, numbers.Number) and numpy.isfinite(other): parameter_symbols = self._parameter_symbols.copy() other_expr = other else: @@ -319,6 +320,12 @@ def __float__(self): 'cannot be cast to a float.'.format(self.parameters)) return float(self._symbol_expr) + def __complex__(self): + if self.parameters: + raise TypeError('ParameterExpression with unbound parameters ({}) ' + 'cannot be cast to a complex.'.format(self.parameters)) + return complex(self._symbol_expr) + def __int__(self): if self.parameters: raise TypeError('ParameterExpression with unbound parameters ({}) ' diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index 57c358b28374..b7c3a4c9059c 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -2031,7 +2031,13 @@ def _assign_parameter(self, parameter, value): replace instances of ``parameter``. """ for instr, param_index in self._parameter_table[parameter]: - instr.params[param_index] = instr.params[param_index].assign(parameter, value) + new_param = instr.params[param_index].assign(parameter, value) + # if fully bound, validate + if len(new_param.parameters) == 0: + instr.params[param_index] = instr.validate_parameter(new_param) + else: + instr.params[param_index] = new_param + self._rebind_definition(instr, parameter, value) if isinstance(value, ParameterExpression): diff --git a/qiskit/extensions/hamiltonian_gate.py b/qiskit/extensions/hamiltonian_gate.py index 452210903777..d30f2fe1469f 100644 --- a/qiskit/extensions/hamiltonian_gate.py +++ b/qiskit/extensions/hamiltonian_gate.py @@ -18,7 +18,7 @@ import numpy import scipy.linalg -from qiskit.circuit import Gate, QuantumCircuit, QuantumRegister +from qiskit.circuit import Gate, QuantumCircuit, QuantumRegister, ParameterExpression from qiskit.quantum_info.operators.predicates import matrix_equal from qiskit.quantum_info.operators.predicates import is_hermitian_matrix from qiskit.extensions.exceptions import ExtensionError @@ -118,6 +118,8 @@ def validate_parameter(self, parameter): """Hamiltonian parameter has to be an ndarray, operator or float.""" if isinstance(parameter, (float, int, numpy.ndarray)): return parameter + elif isinstance(parameter, ParameterExpression) and len(parameter.parameters) == 0: + return float(parameter) else: raise CircuitError("invalid param type {0} for gate " "{1}".format(type(parameter), self.name)) diff --git a/test/python/circuit/test_parameters.py b/test/python/circuit/test_parameters.py index fc1895fea1da..e778022a22d8 100644 --- a/test/python/circuit/test_parameters.py +++ b/test/python/circuit/test_parameters.py @@ -491,7 +491,7 @@ def test_parameter_name_conflicts_raises(self): qc.p(theta1, 0) - self.assertRaises(CircuitError, qc.u1, theta2, 0) + self.assertRaises(CircuitError, qc.p, theta2, 0) def test_bind_ryrz_vector(self): """Test binding a list of floats to a ParameterVector""" @@ -903,9 +903,9 @@ def test_circuit_with_ufunc(self): theta = Parameter(name='theta') qc = QuantumCircuit(2) - qc.u1(numpy.cos(phi), 0) - qc.u1(numpy.sin(phi), 0) - qc.u1(numpy.tan(phi), 0) + qc.p(numpy.cos(phi), 0) + qc.p(numpy.sin(phi), 0) + qc.p(numpy.tan(phi), 0) qc.rz(numpy.arccos(theta), 1) qc.rz(numpy.arctan(theta), 1) qc.rz(numpy.arcsin(theta), 1) @@ -914,9 +914,9 @@ def test_circuit_with_ufunc(self): inplace=True) qc_ref = QuantumCircuit(2) - qc_ref.u1(-1, 0) - qc_ref.u1(0, 0) - qc_ref.u1(0, 0) + qc_ref.p(-1, 0) + qc_ref.p(0, 0) + qc_ref.p(0, 0) qc_ref.rz(0, 1) qc_ref.rz(pi / 4, 1) qc_ref.rz(pi / 2, 1) @@ -1016,7 +1016,7 @@ def test_raise_if_subbing_in_parameter_name_conflict(self): def test_expressions_of_parameter_with_constant(self): """Verify operating on a Parameter with a constant.""" - good_constants = [2, 1.3, 0, -1, -1.0, numpy.pi] + good_constants = [2, 1.3, 0, -1, -1.0, numpy.pi, 1j] x = Parameter('x') @@ -1025,7 +1025,7 @@ def test_expressions_of_parameter_with_constant(self): expr = op(const, x) bound_expr = expr.bind({x: 2.3}) - self.assertEqual(float(bound_expr), + self.assertEqual(complex(bound_expr), op(const, 2.3)) # Division by zero will raise. Tested elsewhere. @@ -1036,23 +1036,62 @@ def test_expressions_of_parameter_with_constant(self): expr = op(x, const) bound_expr = expr.bind({x: 2.3}) - self.assertEqual(float(bound_expr), + self.assertEqual(complex(bound_expr), op(2.3, const)) + def test_complex_parameter_bound_to_real(self): + """Test a complex parameter expression can be real if bound correctly.""" + + x, y = Parameter('x'), Parameter('y') + + with self.subTest('simple 1j * x'): + qc = QuantumCircuit(1) + qc.rx(1j * x, 0) + bound = qc.bind_parameters({x: 1j}) + ref = QuantumCircuit(1) + ref.rx(-1, 0) + self.assertEqual(bound, ref) + + with self.subTest('more complex expression'): + qc = QuantumCircuit(1) + qc.rx(0.5j * x - y * y + 2 * y, 0) + bound = qc.bind_parameters({x: -4, y: 1j}) + ref = QuantumCircuit(1) + ref.rx(1, 0) + self.assertEqual(bound, ref) + + def test_complex_angle_raises_when_not_supported(self): + """Test parameters are validated when fully bound and errors are raised accordingly.""" + x = Parameter('x') + qc = QuantumCircuit(1) + qc.r(x, 1j * x, 0) + + with self.subTest('binding x to 0 yields real parameters'): + bound = qc.bind_parameters({x: 0}) + ref = QuantumCircuit(1) + ref.r(0, 0, 0) + self.assertEqual(bound, ref) + + with self.subTest('binding x to 1 yields complex parameters'): + # RGate does not support complex parameters + with self.assertRaises(CircuitError): + bound = qc.bind_parameters({x: 1}) + def test_operating_on_a_parameter_with_a_non_float_will_raise(self): """Verify operations between a Parameter and a non-float will raise.""" - bad_constants = [1j, '1', numpy.Inf, numpy.NaN, None, {}, []] + bad_constants = ['1', numpy.Inf, numpy.NaN, None, {}, []] x = Parameter('x') for op in self.supported_operations: for const in bad_constants: - with self.assertRaises(TypeError): - _ = op(const, x) + with self.subTest(op=op, const=const): + with self.assertRaises(TypeError): + _ = op(const, x) - with self.assertRaises(TypeError): - _ = op(x, const) + with self.assertRaises(TypeError): + _ = op(x, const) def test_expressions_division_by_zero(self): """Verify dividing a Parameter by 0, or binding 0 as a denominator raises.""" @@ -1386,7 +1425,7 @@ def test_substituting_compound_expression(self): def test_conjugate(self): """Test calling conjugate on a ParameterExpression.""" x = Parameter('x') - self.assertEqual(x, x.conjugate()) # Parameters are real, therefore conjugate returns self + self.assertEqual((x.conjugate() + 1j), (x - 1j).conjugate()) @data(circlib.RGate, circlib.RXGate, circlib.RYGate, circlib.RZGate, circlib.RXXGate, circlib.RYYGate, circlib.RZXGate, circlib.RZZGate, circlib.CRXGate, circlib.CRYGate,