Dynamically updating plots in Jupyter notebooks, e.g. for visualizing training progress. Inspired by livelossplot, and aims to be easier to use with better jupyter notebook support.
pip install trainplot
This is a simple example (example notebook):
from trainplot import plot
from time import sleep
for i in range(50):
plot(loss = 1/(i+1), acc = 1-1/(.01*i**2+1))
sleep(.1)
Example for the tf/keras callback (example notebook):
from trainplot import TrainPlotKeras
model = ...
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks=[TrainPlotKeras()])
It also works together with e.g. tqdm.notebook
and printing (example notebook):
from trainplot import plot
from tqdm.notebook import trange
from time import sleep
for i in trange(50):
plot(i=i, root=i**.5)
if i % 10 == 0:
print(f'currently at {i} iterations')
sleep(0.1)
You can make use of a TrainPlot object to add a bunch of custumizations (example notebook):
from trainplot import TrainPlot
from time import sleep
tp = TrainPlot(
update_period=.2,
fig_args=dict(nrows=2, ncols=2, figsize=(10, 8), gridspec_kw={'height_ratios': [1, 1], 'width_ratios': [1, 1]}),
plot_pos={'loss': (0, 0, 0), 'accuracy': (0, 1, 0), 'val_loss': (1, 0, 0), 'val_accuracy': (1, 1, 0)},
plot_args={'loss': {'color': 'orange'}, 'accuracy': {'color': 'green'}, 'val_loss': {'color': 'orange', 'label': 'validation loss'}, 'val_accuracy': {'color': 'green', 'label': 'validation accuracy'}},
)
for i in range(100, 200):
tp(step=i, loss=(i/100-2)**4, accuracy=i/2, val_loss=(i/100-2.1)**4, val_accuracy=i/2.1)
sleep(0.1)
More:
- When using a Trainplot object, you can also put the plot into a separate cell than the training loop: example notebook
- Experimental plotly support (
from trainplot.trainplot import TrainPlotPlotlyExperimental
): example notebook
Trainplot outputs the matplotlib figure to an ipywidgets.Output
widget, so it doesn't interfere with other outputs like tqdm
or print statements. To avoid wasting resources and flickering, the figure is only updated with a given update_period
.
A post_run_cell
callback is added to the IPython
instance, so that all updated TrainPlot figures include all new data when a cell execution is finished.
When using trainplot.plot
, a TrainPlot object is created for the current cell and cell-execution-count.