Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fixing symbolic pulse equating for non-unique representations #9257

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions qiskit/pulse/library/symbolic_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import functools
import warnings
from typing import Any, Dict, List, Optional, Union, Callable
from typing import Any, Dict, List, Optional, Union, Callable, Tuple

import numpy as np

Expand Down Expand Up @@ -383,6 +383,8 @@ def Sawtooth(duration, amp, freq, name):
"_envelope",
"_constraints",
"_valid_amp_conditions",
"_canonical_params",
"_excluded_params",
)

# Lambdify caches keyed on sympy expressions. Returns the corresponding callable.
Expand All @@ -400,6 +402,8 @@ def __init__(
envelope: Optional[sym.Expr] = None,
constraints: Optional[sym.Expr] = None,
valid_amp_conditions: Optional[sym.Expr] = None,
canonical_params: Optional[List[Union[ParameterExpression, complex]]] = None,
excluded_params: Optional[Tuple[str]] = None,
):
"""Create a parametric pulse.

Expand All @@ -417,6 +421,11 @@ def __init__(
will investigate the full-waveform and raise an error when the amplitude norm
of any data point exceeds 1.0. If not provided, the validation always
creates a full-waveform.
canonical_params: List of parameters for the equating operation of symbolic
pulses. When two pulses are compared, the two lists have to be identical to
yield `True`.
excluded_params: Tuple of strings matching keys in `parameters` which are to be
ignored when two symbolic pulses are ignored.

Raises:
PulseError: When not all parameters are listed in the attribute :attr:`PARAM_DEF`.
Expand All @@ -436,6 +445,13 @@ def __init__(
self._constraints = constraints
self._valid_amp_conditions = valid_amp_conditions

if canonical_params is None:
canonical_params = []
self._canonical_params = canonical_params
if excluded_params is None:
excluded_params = ()
self._excluded_params = excluded_params

def __getattr__(self, item):
# Get pulse parameters with attribute-like access.
params = object.__getattribute__(self, "_params")
Expand Down Expand Up @@ -536,6 +552,31 @@ def parameters(self) -> Dict[str, Any]:
params.update(self._params)
return params

def _equate_parameters(self, other):
"""Helper function which compares the parameters of two pulses, taking into account
_canonical_params and _excluded_params."""
if len(self._canonical_params) != len(other._canonical_params):
return False

for param1, param2 in zip(self._canonical_params, other._canonical_params):
# Because the values are calculated, we need to compare to within numerical precision,
# and can't use a simple comparison of the lists.
if isinstance(param1, ParameterExpression) or isinstance(param2, ParameterExpression):
if param1 != param2:
return False
else:
if not np.isclose(param1, param2):
return False

if self.parameters.keys() != other.parameters.keys():
return False

for key in self.parameters:
if key not in self._excluded_params and self.parameters[key] != other.parameters[key]:
return False

return True

def __eq__(self, other: "SymbolicPulse") -> bool:

if not isinstance(other, SymbolicPulse):
Expand All @@ -547,8 +588,12 @@ def __eq__(self, other: "SymbolicPulse") -> bool:
if self._envelope != other._envelope:
return False

# _canonical_params is assumed to be a function of parameters. If parameters are the same,
# we don't need to check the _canonical_params. (Also solves the edge case of a pulse with
# no parameters)
if self.parameters != other.parameters:
return False
if not self._equate_parameters(other):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nitpicky, but the logic seems to me bit inefficient. It first fully evaluates the dict equality and then compares the canonicals and rest of dict items. Maybe just return self._equate_paramters(other) without if clause enough? Probably my suggestion is wrong because evaluation of builtin dict equality might be faster.

return False

return True

Expand Down Expand Up @@ -658,6 +703,8 @@ def __new__(
angle = 0

parameters = {"amp": amp, "sigma": sigma, "angle": angle}
canonical_params = [amp * np.exp(1j * angle)]
excluded_params = ("amp", "angle")

# Prepare symbolic expressions
_t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
Expand All @@ -679,6 +726,8 @@ def __new__(
envelope=envelope_expr,
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
canonical_params=canonical_params,
excluded_params=excluded_params,
)
instance.validate_parameters()

Expand Down Expand Up @@ -787,6 +836,8 @@ def __new__(
angle = 0

parameters = {"amp": amp, "sigma": sigma, "width": width, "angle": angle}
canonical_params = [amp * np.exp(1j * angle)]
excluded_params = ("amp", "angle")

# Prepare symbolic expressions
_t, _duration, _amp, _sigma, _width, _angle = sym.symbols(
Expand Down Expand Up @@ -820,6 +871,8 @@ def __new__(
envelope=envelope_expr,
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
canonical_params=canonical_params,
excluded_params=excluded_params,
)
instance.validate_parameters()

Expand Down Expand Up @@ -911,6 +964,8 @@ def __new__(
angle = 0

parameters = {"amp": amp, "sigma": sigma, "beta": beta, "angle": angle}
canonical_params = [amp * np.exp(1j * angle)]
excluded_params = ("amp", "angle")

# Prepare symbolic expressions
_t, _duration, _amp, _sigma, _beta, _angle = sym.symbols(
Expand All @@ -935,6 +990,8 @@ def __new__(
envelope=envelope_expr,
constraints=consts_expr,
valid_amp_conditions=valid_amp_conditions_expr,
canonical_params=canonical_params,
excluded_params=excluded_params,
)
instance.validate_parameters()

Expand Down Expand Up @@ -992,6 +1049,8 @@ def __new__(
angle = 0

parameters = {"amp": amp, "angle": angle}
canonical_params = [amp * np.exp(1j * angle)]
excluded_params = ("amp", "angle")

# Prepare symbolic expressions
_t, _amp, _duration, _angle = sym.symbols("t, amp, duration, angle")
Expand Down Expand Up @@ -1019,6 +1078,8 @@ def __new__(
limit_amplitude=limit_amplitude,
envelope=envelope_expr,
valid_amp_conditions=valid_amp_conditions_expr,
canonical_params=canonical_params,
excluded_params=excluded_params,
)
instance.validate_parameters()

Expand Down
7 changes: 7 additions & 0 deletions qiskit/pulse/parameter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ def visit_SymbolicPulse(self, node: SymbolicPulse):
if isinstance(pval, ParameterExpression):
new_val = self._assign_parameter_expression(pval)
node._params[name] = new_val
# Assign canonical parameters
for i in range(len(node._canonical_params)):
pval = node._canonical_params[i]
if isinstance(pval, ParameterExpression):
new_val = self._assign_parameter_expression(pval)
node._canonical_params[i] = new_val

node.validate_parameters()

return node
Expand Down
26 changes: 26 additions & 0 deletions test/python/pulse/test_pulse_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
gaussian_square,
drag as pl_drag,
)
from qiskit.pulse import build, play, DriveChannel

from qiskit.pulse import functional_pulse, PulseError
from qiskit.test import QiskitTestCase
Expand Down Expand Up @@ -542,6 +543,31 @@ def local_gaussian(duration, amp, t0, sig):
pulse_wf_inst = local_gaussian(duration=_duration, amp=1, t0=5, sig=1)
self.assertEqual(len(pulse_wf_inst.samples), _duration)

def test_comparison_parameters(self):
"""Test equating of pulses with comparison_parameters."""
# amp,angle comparison for library pulses
gaussian_negamp = Gaussian(duration=25, sigma=4, amp=-0.5, angle=0)
gaussian_piphase = Gaussian(duration=25, sigma=4, amp=0.5, angle=np.pi)
self.assertEqual(gaussian_negamp, gaussian_piphase)

# Parameterized library pulses
amp = Parameter("amp")
gaussian1 = Gaussian(duration=25, sigma=4, amp=amp, angle=0)
gaussian2 = Gaussian(duration=25, sigma=4, amp=amp, angle=0)
self.assertEqual(gaussian1, gaussian2)

# pulses with different parameters
gaussian1._params["sigma"] = 10
self.assertNotEqual(gaussian1, gaussian2)

# Assignment of parameters (to verify computation of comparison_parameters)
angle = Parameter("angle")
with build() as sc:
play(Gaussian(duration=160, amp=amp, sigma=40, angle=angle), DriveChannel(0))
sc_piphase = sc.assign_parameters({amp: 1, angle: np.pi}, inplace=False)
sc_negamp = sc.assign_parameters({amp: -1, angle: 0}, inplace=False)
self.assertEqual(sc_piphase, sc_negamp)


if __name__ == "__main__":
unittest.main()