From 736d70163619ceec942b4e1375f37b27a9e39094 Mon Sep 17 00:00:00 2001 From: Daniel Puzzuoli Date: Fri, 14 Jul 2023 10:05:22 -0400 Subject: [PATCH] Add warning if digital carrier exceeds Nyquist frequency in pulse -> signal conversion (#242) --- qiskit_dynamics/pulse/pulse_to_signals.py | 32 ++++++++++++++++++-- test/dynamics/pulse/test_pulse_to_signals.py | 13 ++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/qiskit_dynamics/pulse/pulse_to_signals.py b/qiskit_dynamics/pulse/pulse_to_signals.py index 1377d5941..9ac451295 100644 --- a/qiskit_dynamics/pulse/pulse_to_signals.py +++ b/qiskit_dynamics/pulse/pulse_to_signals.py @@ -16,6 +16,7 @@ from typing import Callable, Dict, List, Optional import functools +from warnings import warn import numpy as np import sympy as sym @@ -40,6 +41,12 @@ from qiskit_dynamics.array import Array from qiskit_dynamics.signals import DiscreteSignal +try: + import jax + import jax.numpy as jnp +except ImportError: + pass + class InstructionToSignals: """Converts pulse instructions to signals to be used in models. @@ -133,6 +140,9 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: Similarly to ``ShiftFrequency``, the shift rule for :math:`\phi_a` is defined to maintain carrier wave continuity. + If, at any sample point :math:`k`, :math:`\Delta\nu(k)` is larger than the Nyquist sampling + rate given by ``dt``, a warning will be raised. + Args: schedule: The schedule to represent in terms of signals. Instances of :class:`~qiskit.pulse.ScheduleBlock` must first be converted to @@ -188,14 +198,15 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: if isinstance(inst, ShiftPhase): phases[chan] += inst.phase + if isinstance(inst, SetPhase): + phases[chan] = inst.phase + if isinstance(inst, ShiftFrequency): frequency_shifts[chan] = frequency_shifts[chan] + Array(inst.frequency) phase_accumulations[chan] = ( phase_accumulations[chan] - inst.frequency * start_sample * self._dt ) - - if isinstance(inst, SetPhase): - phases[chan] = inst.phase + _nyquist_warn(frequency_shifts[chan], self._dt, chan) if isinstance(inst, SetFrequency): phase_accumulations[chan] = phase_accumulations[chan] - ( @@ -204,6 +215,7 @@ def get_signals(self, schedule: Schedule) -> List[DiscreteSignal]: * self._dt ) frequency_shifts[chan] = inst.frequency - signals[chan].carrier_freq + _nyquist_warn(frequency_shifts[chan], self._dt, chan) # ensure all signals have the same number of samples max_duration = 0 @@ -367,3 +379,17 @@ def _lru_cache_expr(expr: sym.Expr, backend) -> Callable: continue params.append(param) return sym.lambdify(params, expr, modules=backend) + + +def _nyquist_warn(frequency_shift: Array, dt: float, channel: str): + """Raise a warning if the frequency shift is above the Nyquist frequency given by ``dt``.""" + + if ( + Array(frequency_shift).backend != "jax" or not isinstance(jnp.array(0), jax.core.Tracer) + ) and np.abs(frequency_shift) > 0.5 / dt: + warn( + "Due to SetFrequency and ShiftFrequency instructions, the digital carrier frequency " + f"of channel {channel} is larger than the Nyquist frequency of the envelope sample " + "size dt. As shifts of the frequency from the analog frequency are handled digitally, " + "this will result in aliasing effects." + ) diff --git a/test/dynamics/pulse/test_pulse_to_signals.py b/test/dynamics/pulse/test_pulse_to_signals.py index 8652ac60a..47822d68b 100644 --- a/test/dynamics/pulse/test_pulse_to_signals.py +++ b/test/dynamics/pulse/test_pulse_to_signals.py @@ -39,6 +39,19 @@ def setUp(self): # Typical length of samples in units of dt in IBM real backends is 1/4.5. self._dt = 1 / 4.5 + def test_nyquist_warning(self): + """Test Nyquist warning is raised.""" + converter = InstructionToSignals(dt=1, carriers={"d0": 0.0}) + + sched = Schedule(name="Schedule") + sched += pulse.SetFrequency(1.0, pulse.DriveChannel(0)) + sched += pulse.Play( + pulse.Drag(duration=20, amp=0.5, sigma=4, beta=0.5), pulse.DriveChannel(0) + ) + + with self.assertWarnsRegex(Warning, "Due to SetFrequency and ShiftFrequency"): + converter.get_signals(sched) + def test_pulse_to_signals(self): """Generic test."""