Skip to content

Commit

Permalink
updated plot_traces()
Browse files Browse the repository at this point in the history
more accurate and flexible legend
  • Loading branch information
CommonClimate committed Sep 11, 2024
1 parent 990b8b0 commit 5a428ae
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 10 deletions.
72 changes: 64 additions & 8 deletions pens/ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#from matplotlib import gridspec
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import matplotlib.collections as mcol
from matplotlib.legend_handler import HandlerLineCollection
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import xarray as xr
Expand All @@ -23,6 +26,48 @@
import properscoring as ps
from more_itertools import distinct_combinations

class HandlerDashedLines(HandlerLineCollection): # helper class for plot_traces legend
"""
Custom Handler for LineCollection instances.
Credit: https://matplotlib.org/stable/gallery/text_labels_and_annotations/legend_demo.html
"""
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
# figure out how many lines there are
numlines = len(orig_handle.get_segments())
xdata, xdata_marker = self.get_xdata(legend, xdescent, ydescent,
width, height, fontsize)
leglines = []
# divide the vertical space where the lines will go
# into equal parts based on the number of lines
ydata = np.full_like(xdata, height / (numlines + 1))
# for each line, create the line at the proper location
# and set the dash pattern
for i in range(numlines):
legline = Line2D(xdata, ydata * (numlines - i) - ydescent)
self.update_prop(legline, orig_handle, legend)
# set color, dash pattern, and linewidth to that
# of the lines in linecollection
try:
color = orig_handle.get_colors()[i]
except IndexError:
color = orig_handle.get_colors()[0]
try:
dashes = orig_handle.get_dashes()[i]
except IndexError:
dashes = orig_handle.get_dashes()[0]
try:
lw = orig_handle.get_linewidths()[i]
except IndexError:
lw = orig_handle.get_linewidths()[0]
if dashes[1] is not None:
legline.set_dashes(dashes[1])
legline.set_color(color)
legline.set_transform(trans)
legline.set_linewidth(lw)
leglines.append(legline)
return leglines


class EnsembleTS:
''' Ensemble Timeseries
Expand Down Expand Up @@ -1170,7 +1215,7 @@ def plot(self, figsize=[12, 4],
def plot_traces(self, num_traces = 5, figsize=[10, 4], title=None, label = None,
seed = None, indices = None, xlim=None, ylim=None,
linestyle='-', ax=None, plot_legend=True, lgd_kwargs=None,
xlabel=None, ylabel=None, color='grey', lw=0.5, alpha=0.1):
xlabel=None, ylabel=None, lw=0.5, alpha=0.1):
'''Plot EnsembleTS as a subset of traces.
Parameters
Expand Down Expand Up @@ -1302,14 +1347,23 @@ def plot_traces(self, num_traces = 5, figsize=[10, 4], title=None, label = None,
trace_idx = range(nts_max)
trace_lbl = label if label is not None else f'sample paths (n={num_traces})'


# define colors
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:num_traces]

# plot the traces
for idx in trace_idx:
for i, idx in enumerate(trace_idx):
ax.plot(self.time, self.value[:,idx], zorder=99, linewidth=lw,
color=color, alpha=alpha, linestyle='-')
# dummy plot for trace labels
ax.plot(np.nan, np.nan, color=color, alpha=alpha, linestyle='-',
label=trace_lbl)

color=colors[i], alpha=alpha, linestyle='-')

# make proxy artists
# make list of one line -- doesn't matter what the coordinates are
line = [[(0, 0)]]
# set up the proxy artist
nlines = np.min([num_traces, 5])
lc = mcol.LineCollection(nlines * line, colors=colors[:nlines],
alpha=alpha, linewidth=1.5*lw) # make slightly thicker to increase visibility

if xlabel is not None:
ax.set_xlabel(xlabel)
else:
Expand All @@ -1332,7 +1386,9 @@ def plot_traces(self, num_traces = 5, figsize=[10, 4], title=None, label = None,
if plot_legend:
lgd_args = {'frameon': False}
lgd_args.update(lgd_kwargs)
ax.legend(**lgd_args)
# create the legend
ax.legend([lc], [trace_lbl], handler_map={type(lc): HandlerDashedLines()},
handlelength=2.5, handleheight=3, **lgd_args)

if 'fig' in locals():
return fig, ax
Expand Down
2 changes: 0 additions & 2 deletions pens/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
from datetime import datetime, timedelta
import cftime
from termcolor import cprint
from scipy.special import rel_entr
import statsmodels.api as sm
Expand Down

0 comments on commit 5a428ae

Please sign in to comment.