Skip to content

Commit

Permalink
add plot_timeline
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jun 13, 2023
1 parent 85ef7cb commit 140daaf
Show file tree
Hide file tree
Showing 173 changed files with 6,071 additions and 245 deletions.
169 changes: 164 additions & 5 deletions atom/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import reduce
from importlib.util import find_spec
from itertools import chain, cycle
Expand All @@ -30,6 +31,7 @@
TrigramCollocationFinder,
)
from optuna.importance import FanovaImportanceEvaluator
from optuna.trial import TrialState
from optuna.visualization._parallel_coordinate import (
_get_dims_from_info, _get_parallel_coordinate_info,
)
Expand All @@ -43,13 +45,14 @@
confusion_matrix, det_curve, precision_recall_curve, roc_curve,
)
from sklearn.utils import _safe_indexing
from sklearn.utils._bunch import Bunch
from sklearn.utils.metaestimators import available_if

from atom.utils import (
FLOAT, INT, INT_TYPES, PALETTE, SCALAR, SEQUENCE, Model, bk, check_canvas,
check_dependency, check_hyperparams, check_predict_proba, composed, crash,
divide, get_best_score, get_corpus, get_custom_scorer, has_attr, has_task,
is_binary, is_multioutput, it, lst, plot_from_model, rnd, to_rgb,
divide, flt, get_best_score, get_corpus, get_custom_scorer, has_attr,
has_task, is_binary, is_multioutput, it, lst, plot_from_model, rnd, to_rgb,
)


Expand Down Expand Up @@ -3598,8 +3601,8 @@ def plot_terminator_improvement(
See Also
--------
atom.plots:HTPlot.plot_edf
atom.plots:HTPlot.plot_pareto_front
atom.plots:HTPlot.plot_timeline
atom.plots:HTPlot.plot_trials
Examples
Expand All @@ -3623,6 +3626,8 @@ def plot_terminator_improvement(
"""
check_dependency("botorch")

models = check_hyperparams(models, "plot_terminator_improvement")

fig = self._get_figure()
xaxis, yaxis = BasePlot._fig.get_axes()
for m in models:
Expand Down Expand Up @@ -3661,6 +3666,160 @@ def plot_terminator_improvement(
display=display,
)

@composed(crash, plot_from_model)
def plot_timeline(
self,
models: INT | str | Model | slice | SEQUENCE | None = None,
*,
title: str | dict | None = None,
legend: str | dict | None = "lower right",
figsize: tuple[INT, INT] = (900, 600),
filename: str | None = None,
display: bool | None = True,
) -> go.Figure | None:
"""Plot the timeline of a study.
This plot is only available for models that ran
[hyperparameter tuning][].
Parameters
----------
models: int, str, Model, slice, sequence or None, default=None
Models to plot. If None, all models that used hyperparameter
tuning are selected.
title: str, dict or None, default=None
Title for the plot.
- If None, no title is shown.
- If str, text for the title.
- If dict, [title configuration][parameters].
legend: str, dict or None, default="lower right",
Legend for the plot. See the [user guide][parameters] for
an extended description of the choices.
- If None: No legend is shown.
- If str: Location where to show the legend.
- If dict: Legend configuration.
figsize: tuple, default=(900, 600)
Figure's size in pixels, format as (x, y)
filename: str or None, default=None
Save the plot using this name. Use "auto" for automatic
naming. The type of the file depends on the provided name
(.html, .png, .pdf, etc...). If `filename` has no file type,
the plot is saved as html. If None, the plot is not saved.
display: bool or None, default=True
Whether to render the plot. If None, it returns the figure.
Returns
-------
[go.Figure][] or None
Plot object. Only returned if `display=None`.
See Also
--------
atom.plots:HTPlot.plot_edf
atom.plots:HTPlot.plot_slice
atom.plots:HTPlot.plot_terminator_improvement
Examples
--------
```pycon
>>> from atom import ATOMClassifier
>>> from optuna.pruners import PatientPruner
>>> X = pd.read_csv("./examples/datasets/weatherAUS.csv")
>>> atom = ATOMClassifier(X, y="RainTomorrow", n_rows=1e4)
>>> atom.impute()
>>> atom.encode()
>>> atom.run(
... models="LGB",
... n_trials=15,
... ht_params={"pruner": PatientPruner(None, patience=100)},
... )
>>> atom.plot_timeline()
```
:: insert:
url: /img/plots/plot_timeline.html
"""
models = check_hyperparams(models, "plot_timeline")

fig = self._get_figure()
xaxis, yaxis = BasePlot._fig.get_axes()

_cm = {
"COMPLETE": BasePlot._fig._palette[0], # Main color
"FAIL": "rgb(255, 0, 0)", # Red
"PRUNED": "rgb(255, 165, 0)", # Orange
"RUNNING": "rgb(124, 252, 0)", # Green
"WAITING": "rgb(220, 220, 220)", # Gray
}

for m in models:
info = []
for trial in m.study.get_trials(deepcopy=False):
date_complete = trial.datetime_complete or datetime.now()
date_start = trial.datetime_start or date_complete
params = "<br>".join([f" --> {k}: {v}" for k, v in trial.params.items()])
info.append(
Bunch(
number=trial.number,
start=date_start,
duration=1000 * (date_complete - date_start).total_seconds(),
state=trial.state,
hovertext=(
f"Trial: {trial.number}<br>"
f"Value: {flt(trial.values)}<br>"
f"Parameters:<br>{params}"
)
)
)

for state in sorted(TrialState, key=lambda x: x.name):
if bars := list(filter(lambda x: x.state == state, info)):
fig.add_trace(
go.Bar(
name=state.name,
x=[b.duration for b in bars],
y=[b.number for b in bars],
base=[b.start.isoformat() for b in bars],
text=[b.hovertext for b in bars],
textposition="none",
hovertemplate=f"%{{text}}<extra>{m.name}</extra>",
orientation="h",
marker=dict(
color=f"rgba({_cm[state.name][4:-1]}, 0.2)",
line=dict(width=2, color=_cm[state.name]),
),
showlegend=BasePlot._fig.showlegend(_cm[state.name], legend),
xaxis=xaxis,
yaxis=yaxis,
)
)

fig.update_layout({f"xaxis{yaxis[1:]}_type": "date", "barmode": "group"})

BasePlot._fig.used_models.extend(models)
return self._plot(
ax=(f"xaxis{xaxis[1:]}", f"yaxis{yaxis[1:]}"),
xlabel="Datetime",
ylabel="Trial",
title=title,
legend=legend,
figsize=figsize,
plotname="plot_timeline",
filename=filename,
display=display,
)

@composed(crash, plot_from_model)
def plot_trials(
self,
Expand Down Expand Up @@ -5806,6 +5965,7 @@ def plot_permutation_importance(
y=list(np.array([[fx] * n_repeats for fx in m.features]).ravel()),
marker_color=BasePlot._fig.get_elem(m.name),
boxpoints="outliers",
orientation="h",
name=m.name,
legendgroup=m.name,
showlegend=BasePlot._fig.showlegend(m.name, legend),
Expand All @@ -5814,7 +5974,6 @@ def plot_permutation_importance(
)
)

fig.update_traces(orientation="h")
fig.update_layout(
{
f"yaxis{yaxis[1:]}": dict(categoryorder="total ascending"),
Expand Down Expand Up @@ -6760,6 +6919,7 @@ def get_std(model: Model, metric: int) -> SCALAR:
y=list(y),
marker_color=color,
boxpoints="outliers",
orientation="h",
name=name,
legendgroup=name,
showlegend=BasePlot._fig.showlegend(name, legend),
Expand Down Expand Up @@ -6790,7 +6950,6 @@ def get_std(model: Model, metric: int) -> SCALAR:
)
)

fig.update_traces(orientation="h")
fig.update_layout(
{
f"yaxis{yaxis[1:]}": dict(categoryorder="total ascending"),
Expand Down
14 changes: 14 additions & 0 deletions docs/404.html
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,20 @@



<li class="md-nav__item">
<a href="/ATOM/API/plots/plot_timeline/" class="md-nav__link">
plot_timeline
</a>
</li>









<li class="md-nav__item">
<a href="/ATOM/API/plots/plot_threshold/" class="md-nav__link">
plot_threshold
Expand Down
20 changes: 17 additions & 3 deletions docs/API/ATOM/atomclassifier/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,20 @@



<li class="md-nav__item">
<a href="../../plots/plot_timeline/" class="md-nav__link">
plot_timeline
</a>
</li>









<li class="md-nav__item">
<a href="../../plots/plot_threshold/" class="md-nav__link">
plot_threshold
Expand Down Expand Up @@ -3792,7 +3806,7 @@ <h2 id="utility-methods">Utility methods</h2>
</ul>
</table>
<p><br><br></p>
<p><a id='atomclassifier-canvas'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>canvas</strong>(rows=1, cols=2, horizontal_spacing=0.05, vertical_spacing=0.07, title=None, legend="out", figsize=None, filename=None, display=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L981>[source]</a></span></div>Create a figure with multiple plots.</p>
<p><a id='atomclassifier-canvas'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>canvas</strong>(rows=1, cols=2, horizontal_spacing=0.05, vertical_spacing=0.07, title=None, legend="out", figsize=None, filename=None, display=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L986>[source]</a></span></div>Create a figure with multiple plots.</p>
<p>This <code>@contextmanager</code> allows you to draw many plots in one
figure. The default option is to add two plots side by side.
See the <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#canvas">user guide</a> for an example.</p>
Expand Down Expand Up @@ -4121,7 +4135,7 @@ <h2 id="utility-methods">Utility methods</h2>
</div></td></tr></p>
</table>
<p><br><br></p>
<p><a id='atomclassifier-update_layout'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>update_layout</strong>(dict1=None, overwrite=False, **kwargs)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1086>[source]</a></span></div>Update the properties of the plot's layout.</p>
<p><a id='atomclassifier-update_layout'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>update_layout</strong>(dict1=None, overwrite=False, **kwargs)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1091>[source]</a></span></div>Update the properties of the plot's layout.</p>
<p>This recursively updates the structure of the original layout
with the values in the input dict / keyword arguments.</p>
<table class="table_params">
Expand All @@ -4141,7 +4155,7 @@ <h2 id="utility-methods">Utility methods</h2>
<p>Deletes all branches and models. The dataset is also reset
to its form after initialization.</p>
<p><br><br></p>
<p><a id='atomclassifier-reset_aesthetics'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>reset_aesthetics</strong>()<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1074>[source]</a></span></div>Reset the plot <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#aesthetics">aesthetics</a> to their default values.</p>
<p><a id='atomclassifier-reset_aesthetics'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>reset_aesthetics</strong>()<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1079>[source]</a></span></div>Reset the plot <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#aesthetics">aesthetics</a> to their default values.</p>
<p><br><br></p>
<p><a id='atomclassifier-save'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>save</strong>(filename="auto", save_data=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/basetransformer.py#L937>[source]</a></span></div>Save the instance to a pickle file.</p>
<table class="table_params">
Expand Down
14 changes: 14 additions & 0 deletions docs/API/ATOM/atommodel/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2523,6 +2523,20 @@



<li class="md-nav__item">
<a href="../../plots/plot_timeline/" class="md-nav__link">
plot_timeline
</a>
</li>









<li class="md-nav__item">
<a href="../../plots/plot_threshold/" class="md-nav__link">
plot_threshold
Expand Down
20 changes: 17 additions & 3 deletions docs/API/ATOM/atomregressor/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,20 @@



<li class="md-nav__item">
<a href="../../plots/plot_timeline/" class="md-nav__link">
plot_timeline
</a>
</li>









<li class="md-nav__item">
<a href="../../plots/plot_threshold/" class="md-nav__link">
plot_threshold
Expand Down Expand Up @@ -3776,7 +3790,7 @@ <h2 id="utility-methods">Utility methods</h2>
</ul>
</table>
<p><br><br></p>
<p><a id='atomregressor-canvas'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>canvas</strong>(rows=1, cols=2, horizontal_spacing=0.05, vertical_spacing=0.07, title=None, legend="out", figsize=None, filename=None, display=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L981>[source]</a></span></div>Create a figure with multiple plots.</p>
<p><a id='atomregressor-canvas'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>canvas</strong>(rows=1, cols=2, horizontal_spacing=0.05, vertical_spacing=0.07, title=None, legend="out", figsize=None, filename=None, display=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L986>[source]</a></span></div>Create a figure with multiple plots.</p>
<p>This <code>@contextmanager</code> allows you to draw many plots in one
figure. The default option is to add two plots side by side.
See the <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#canvas">user guide</a> for an example.</p>
Expand Down Expand Up @@ -4105,7 +4119,7 @@ <h2 id="utility-methods">Utility methods</h2>
</div></td></tr></p>
</table>
<p><br><br></p>
<p><a id='atomregressor-update_layout'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>update_layout</strong>(dict1=None, overwrite=False, **kwargs)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1086>[source]</a></span></div>Update the properties of the plot's layout.</p>
<p><a id='atomregressor-update_layout'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>update_layout</strong>(dict1=None, overwrite=False, **kwargs)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1091>[source]</a></span></div>Update the properties of the plot's layout.</p>
<p>This recursively updates the structure of the original layout
with the values in the input dict / keyword arguments.</p>
<table class="table_params">
Expand All @@ -4125,7 +4139,7 @@ <h2 id="utility-methods">Utility methods</h2>
<p>Deletes all branches and models. The dataset is also reset
to its form after initialization.</p>
<p><br><br></p>
<p><a id='atomregressor-reset_aesthetics'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>reset_aesthetics</strong>()<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1074>[source]</a></span></div>Reset the plot <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#aesthetics">aesthetics</a> to their default values.</p>
<p><a id='atomregressor-reset_aesthetics'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>reset_aesthetics</strong>()<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/plots.py#L1079>[source]</a></span></div>Reset the plot <a class="autorefs autorefs-internal" href="../../../user_guide/plots/#aesthetics">aesthetics</a> to their default values.</p>
<p><br><br></p>
<p><a id='atomregressor-save'></a><div class='sign'><em>method</em> <strong style='color:#008AB8'>save</strong>(filename="auto", save_data=True)<span style='float:right'><a href=https://github.com/tvdboom/ATOM/blob/master/atom/basetransformer.py#L937>[source]</a></span></div>Save the instance to a pickle file.</p>
<table class="table_params">
Expand Down
Loading

0 comments on commit 140daaf

Please sign in to comment.