Skip to content

Commit

Permalink
Merge pull request #45 from qutech/feature/small_plotting_improvements
Browse files Browse the repository at this point in the history
Small plotting improvements
  • Loading branch information
thangleiter authored Jan 4, 2021
2 parents 54a40c5 + b88a14e commit 3a3bdbc
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 19 deletions.
3 changes: 1 addition & 2 deletions filter_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

from . import analytic, basis, numeric, pulse_sequence, superoperator, util
from .basis import Basis

from .gradient import infidelity_derivative
from .numeric import error_transfer_matrix, infidelity
from .pulse_sequence import PulseSequence, concatenate, concatenate_periodic, extend, remap
from .superoperator import liouville_representation
from .gradient import infidelity_derivative

__all__ = ['Basis', 'PulseSequence', 'analytic', 'basis', 'concatenate', 'concatenate_periodic',
'error_transfer_matrix', 'extend', 'infidelity', 'liouville_representation', 'numeric',
Expand Down
76 changes: 62 additions & 14 deletions filter_functions/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
from numpy import ndarray

from . import numeric, util
from .types import (Axes, Coefficients, Colormap, Figure, FigureAxes,
FigureAxesLegend, FigureGrid, Grid, Operator, State)
from .types import (Axes, Coefficients, Colormap, Figure, FigureAxes, FigureAxesLegend, FigureGrid,
Grid, Operator, State)

__all__ = ['plot_cumulant_function', 'plot_infidelity_convergence', 'plot_filter_function',
'plot_pulse_correlation_filter_function', 'plot_pulse_train']
Expand All @@ -68,6 +68,24 @@
qt = mock.Mock()


def _make_str_tex_compatible(s: str) -> str:
"""Escape incompatible characters in strings passed to TeX."""
if not plt.rcParams['text.usetex']:
return s

s = str(s)
incompatible = ('_',)
for char in incompatible:
locs = [i for i, c in enumerate(s) if c == char]
# Loop backwards so as not to change locs when modifying s
for loc in locs[::-1]:
# Check if math environment, if not add escape character
if not s.count('$', loc) % 2:
s = s[:loc] + '\\' + s[loc:]

return s


def get_bloch_vector(states: Sequence[State]) -> ndarray:
r"""
Get the Bloch vector from quantum states.
Expand Down Expand Up @@ -247,6 +265,7 @@ def plot_pulse_train(
c_oper_identifiers: Optional[Sequence[int]] = None,
fig: Optional[Figure] = None,
axes: Optional[Axes] = None,
cycler: Optional['cycler.Cycler'] = None,
plot_kw: Optional[dict] = {},
subplot_kw: Optional[dict] = None,
gridspec_kw: Optional[dict] = None,
Expand All @@ -267,6 +286,9 @@ def plot_pulse_train(
A matplotlib figure instance to plot in
axes: matplotlib axes, optional
A matplotlib axes instance to use for plotting.
cycler: cycler.Cycler, optional
A Cycler instance used to set the style cycle if multiple lines
are to be drawn
plot_kw: dict, optional
Dictionary with keyword arguments passed to the plot function
subplot_kw: dict, optional
Expand Down Expand Up @@ -307,10 +329,14 @@ def plot_pulse_train(
elif fig is None and axes is not None:
fig = axes.figure

if cycler is not None:
axes.set_prop_cycle(cycler)

handles = []
for i, c_coeffs in enumerate(pulse.c_coeffs[tuple(c_oper_inds), ...]):
coeffs = np.insert(c_coeffs, 0, c_coeffs[0])
handles += axes.step(pulse.t, coeffs, label=c_oper_identifiers[i], **plot_kw)
handles += axes.step(pulse.t, coeffs,
label=_make_str_tex_compatible(c_oper_identifiers[i]), **plot_kw)

axes.set_xlim(pulse.t[0], pulse.tau)
axes.set_xlabel(r'$t$ / a.u.')
Expand All @@ -330,6 +356,7 @@ def plot_filter_function(
xscale: str = 'log',
yscale: str = 'linear',
omega_in_units_of_tau: bool = True,
cycler: Optional['cycler.Cycler'] = None,
plot_kw: dict = {},
subplot_kw: Optional[dict] = None,
gridspec_kw: Optional[dict] = None,
Expand Down Expand Up @@ -363,6 +390,9 @@ def plot_filter_function(
y-axis scaling. One of ('linear', 'log').
omega_in_units_of_tau: bool, optional
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
cycler: cycler.Cycler, optional
A Cycler instance used to set the style cycle if multiple lines
are to be drawn
plot_kw: dict, optional
Dictionary with keyword arguments passed to the plot function
subplot_kw: dict, optional
Expand Down Expand Up @@ -409,6 +439,9 @@ def plot_filter_function(
elif fig is None and axes is not None:
fig = axes.figure

if cycler is not None:
axes.set_prop_cycle(cycler)

if omega_in_units_of_tau:
tau = np.ptp(pulse.t)
z = omega*tau
Expand All @@ -423,7 +456,8 @@ def plot_filter_function(
handles = []
for i, ind in enumerate(n_oper_inds):
handles += axes.plot(z, filter_function[ind],
label=n_oper_identifiers[i], **plot_kw)
label=_make_str_tex_compatible(n_oper_identifiers[i]),
**plot_kw)

# Set the axis scales
axes.set_xscale(xscale)
Expand Down Expand Up @@ -452,6 +486,7 @@ def plot_pulse_correlation_filter_function(
xscale: str = 'log',
yscale: str = 'linear',
omega_in_units_of_tau: bool = True,
cycler: Optional['cycler.Cycler'] = None,
plot_kw: dict = {},
subplot_kw: Optional[dict] = None,
gridspec_kw: Optional[dict] = None,
Expand Down Expand Up @@ -483,6 +518,9 @@ def plot_pulse_correlation_filter_function(
y-axis scaling. One of ('linear', 'log').
omega_in_units_of_tau: bool, optional
Plot :math:`\omega\tau` or just :math:`\omega` on x-axis.
cycler: cycler.Cycler, optional
A Cycler instance used to set the style cycle if multiple lines
are to be drawn in one subplot. Used for all subplots.
plot_kw: dict, optional
Dictionary with keyword arguments passed to the plot function
subplot_kw: dict, optional
Expand Down Expand Up @@ -546,10 +584,13 @@ def plot_pulse_correlation_filter_function(
dashed_line = lines.Line2D([], [], color='gray', linestyle='--')
for i in range(n):
for j in range(n):
if cycler is not None:
axes[i, j].set_prop_cycle(cycler)

handles = []
for k, ind in enumerate(n_oper_inds):
handles += axes[i, j].plot(z, F_pc[i, j, ind].real,
label=n_oper_identifiers[k],
label=_make_str_tex_compatible(n_oper_identifiers[k]),
**plot_kw)
if i != j:
axes[i, j].plot(z, F_pc[i, j, ind].imag, linestyle='--',
Expand All @@ -566,7 +607,8 @@ def plot_pulse_correlation_filter_function(

if i == 0 and j == n-1:
handles += [transparent_line, solid_line, dashed_line]
labels = n_oper_identifiers.tolist() + ['', r'$Re$', r'$Im$']
labels = ([_make_str_tex_compatible(n) for n in n_oper_identifiers]
+ ['', r'$Re$', r'$Im$'])
legend = axes[i, j].legend(handles=handles, labels=labels,
bbox_to_anchor=(1.05, 1), loc=2,
borderaxespad=0., frameon=False)
Expand Down Expand Up @@ -628,11 +670,12 @@ def plot_cumulant_function(
omega: Optional[Coefficients] = None,
cumulant_function: Optional[ndarray] = None,
n_oper_identifiers: Optional[Sequence[int]] = None,
basis_labels: Optional[Sequence[str]] = None,
colorscale: str = 'linear',
linthresh: Optional[float] = None,
cbar_label: str = 'Cumulant Function',
basis_labels: Optional[Sequence[str]] = None,
basis_labelsize: Optional[int] = None,
cbar_label: str = 'Cumulant Function',
cbar_labelsize: Optional[int] = None,
fig: Optional[Figure] = None,
grid: Optional[Grid] = None,
cmap: Optional[Colormap] = None,
Expand Down Expand Up @@ -669,18 +712,20 @@ def plot_cumulant_function(
The identifiers of the noise operators for which the cumulant
function should be plotted. All identifiers can be accessed via
``pulse.n_oper_identifiers``. Defaults to all.
basis_labels: array_like (str), optional
Labels for the elements of the cumulant function (the basis
elements).
colorscale: str, optional
The scale of the color code ('linear' or 'log' (default))
linthresh: float, optional
The threshold below which the colorscale will be linear (only
for 'log') colorscale
cbar_label: str, optional
The label for the colorbar. Default: 'Cumulant Function'.
basis_labels: array_like (str), optional
Labels for the elements of the cumulant function (the basis
elements).
basis_labelsize: int, optional
The size in points for the basis labels.
cbar_label: str, optional
The label for the colorbar. Default: 'Cumulant Function'.
cbar_labelsize: int, optional
The size in points for the colorbar label.
fig: matplotlib figure, optional
A matplotlib figure instance to plot in
grid: matplotlib ImageGrid, optional
Expand Down Expand Up @@ -752,6 +797,8 @@ def plot_cumulant_function(
if len(basis_labels) != K.shape[-1]:
raise ValueError('Invalid number of basis_labels given')

basis_labels = [_make_str_tex_compatible(bl) for bl in basis_labels]

if grid is None:
aspect_ratio = 2/3
n_rows = int(np.round(np.sqrt(aspect_ratio*len(n_oper_inds))))
Expand Down Expand Up @@ -799,6 +846,7 @@ def plot_cumulant_function(
imshow_kw.setdefault('norm', norm)

basis_labelsize = basis_labelsize or 8
cbar_labelsize = cbar_labelsize or plt.rcParams['axes.labelsize']

# Draw the images
for i, n_oper_identifier in enumerate(n_oper_identifiers):
Expand All @@ -818,6 +866,6 @@ def plot_cumulant_function(
cbar_kw = cbar_kw or {}
cbar_kw.setdefault('orientation', 'vertical')
cbar = fig.colorbar(im, cax=grid.cbar_axes[0], **cbar_kw)
cbar.set_label(cbar_label)
cbar.set_label(_make_str_tex_compatible(cbar_label), fontsize=cbar_labelsize)

return fig, grid
2 changes: 1 addition & 1 deletion filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from numpy import linalg as nla
from numpy import ndarray

from . import numeric, util, gradient
from . import gradient, numeric, util
from .basis import Basis, equivalent_pauli_basis_elements, remap_pauli_basis_elements
from .superoperator import liouville_representation
from .types import Coefficients, Hamiltonian, Operator, PulseMapping
Expand Down
1 change: 1 addition & 0 deletions tests/gradient_testutil.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np

import filter_functions as ff

sigma_x = np.asarray([[0, 1], [1, 0]]) / 2
Expand Down
3 changes: 2 additions & 1 deletion tests/test_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class MissingExtrasTest(testutil.TestCase):
'fancy_progressbar' in os.environ.get('INSTALL_EXTRAS', all_extras),
reason='Skipping tests for missing fancy progressbar extra in build with requests') # noqa
def test_fancy_progressbar_not_available(self):
from filter_functions import util
from tqdm import tqdm

from filter_functions import util
self.assertEqual(util._NOTEBOOK_NAME, '')
self.assertIs(tqdm, util._tqdm)

Expand Down
1 change: 0 additions & 1 deletion tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tests.gradient_testutil as grad_util
from tests import testutil


np.random.seed(0)
initial_pulse = np.random.rand(grad_util.n_time_steps)
initial_pulse = np.expand_dims(initial_pulse, 0)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
This module tests the plotting functionality of the package.
"""
import string
from copy import copy
from random import sample

import numpy as np
Expand All @@ -38,6 +39,7 @@
reason='Skipping plotting tests for build without matplotlib')
if plotting is not None:
import matplotlib.pyplot as plt
from matplotlib import cycler

simple_pulse = testutil.rand_pulse_sequence(2, 1, 1, 1, btype='Pauli')
complicated_pulse = testutil.rand_pulse_sequence(2, 100, 3, 3)
Expand Down Expand Up @@ -67,6 +69,10 @@ def test_plot_pulse_train(self):
c_oper_identifiers,
fig=fig, axes=ax)

# Test cycler arg
cycle = cycler(color=['r', 'g', 'b'])
fig, ax, leg = plotting.plot_pulse_train(simple_pulse, cycler=cycle)

# invalid identifier
with self.assertRaises(ValueError):
plotting.plot_pulse_train(complicated_pulse,
Expand Down Expand Up @@ -113,6 +119,10 @@ def test_plot_filter_function(self):
fig=fig, axes=ax, omega_in_units_of_tau=False
)

# Test cycler arg
cycle = cycler(color=['r', 'g', 'b'])
fig, ax, leg = plotting.plot_filter_function(simple_pulse, cycler=cycle)

# invalid identifier
with self.assertRaises(ValueError):
plotting.plot_filter_function(complicated_pulse,
Expand Down Expand Up @@ -177,6 +187,11 @@ def test_plot_pulse_correlation_filter_function(self):
omega_in_units_of_tau=False
)

# Test cycler arg
cycle = cycler(color=['r', 'g', 'b'])
fig, ax, leg = plotting.plot_pulse_correlation_filter_function(concatenated_simple_pulse,
cycler=cycle)

# invalid identifiers
with self.assertRaises(ValueError):
plotting.plot_pulse_correlation_filter_function(
Expand Down Expand Up @@ -299,6 +314,17 @@ def spectrum(omega):
fig, ax = plotting.plot_infidelity_convergence(n, infids)


class LaTeXRenderingTest(testutil.TestCase):

def test_plot_filter_function(self):
pulse = copy(simple_pulse)
pulse.c_oper_identifiers = np.array([f'B_{i}' for i in range(len(pulse.c_opers))])
pulse.n_oper_identifiers = np.array([f'B_{i}' for i in range(len(pulse.n_opers))])
with plt.rc_context(rc={'text.usetex': True}):
_ = plotting.plot_pulse_train(pulse)
_ = plotting.plot_filter_function(pulse)


@pytest.mark.skipif(
qutip is None,
reason='Skipping bloch sphere visualization tests for build without qutip')
Expand Down

0 comments on commit 3a3bdbc

Please sign in to comment.