Skip to content

Change Point Interactive #988

Merged
merged 7 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from etna.analysis.plotters import plot_anomalies_interactive
from etna.analysis.plotters import plot_backtest
from etna.analysis.plotters import plot_backtest_interactive
from etna.analysis.plotters import plot_change_points_interactive
from etna.analysis.plotters import plot_clusters
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
Expand Down
139 changes: 139 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import plotly.graph_objects as go
import seaborn as sns
from matplotlib.lines import Line2D
from ruptures.base import BaseCost
from ruptures.base import BaseEstimator
from ruptures.exceptions import BadSegmentationParameters
from scipy.signal import periodogram
from typing_extensions import Literal

Expand Down Expand Up @@ -1819,3 +1822,139 @@ def metric_per_segment_distribution_plot(

plt.title("Metric per-segment distribution plot")
plt.grid()


def plot_change_points_interactive(
ts,
change_point_model: BaseEstimator,
model: BaseCost,
params_bounds: Dict[str, Tuple[Union[int, float], Union[int, float], Union[int, float]]],
model_params: List[str],
predict_params: List[str],
in_column: str = "target",
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
start: Optional[str] = None,
end: Optional[str] = None,
):
"""Plot a time series with indicated change points.
Change points are obtained using the specified method. The method parameters values
can be changed using the corresponding sliders.
Parameters
----------
ts:
TSDataset with timeseries data
change_point_model:
model to get trend change points
model:
binseg segment model, ["l1", "l2", "rbf",...]. Not used if 'custom_cost' is not None
params_bounds:
Parameters ranges of the change points detection. Bounds for the parameter are (min,max,step)
model_params:
List of iterable parameters for initialize the model
predict_params:
List of iterable parameters for predict method
in_column:
column to plot
segments:
segments to use
columns_num:
number of subplots columns
figsize:
size of the figure in inches
start:
start timestamp for plot
end:
end timestamp for plot
Notes
-----
Jupyter notebook might display the results incorrectly,
in this case try to use ``!jupyter nbextension enable --py widgetsnbextension``.
Examples
--------
>>> from etna.datasets import TSDataset
>>> from etna.datasets import generate_ar_df
>>> from etna.analysis import plot_change_points_interactive
>>> from ruptures.detection import Binseg
>>> classic_df = generate_ar_df(periods=1000, start_time="2021-08-01", n_segments=2)
>>> df = TSDataset.to_dataset(classic_df)
>>> ts = TSDataset(df, "D")
>>> params_bounds = {"n_bkps": [0, 5, 1], "min_size":[1,10,3]}
>>> plot_change_points_interactive(ts=ts, change_point_model=Binseg, model="l2", params_bounds=params_bounds, model_params=["min_size"], predict_params=["n_bkps"], figsize=(20, 10)) # doctest: +SKIP
"""
from ipywidgets import FloatSlider
from ipywidgets import IntSlider
from ipywidgets import interact

if segments is None:
segments = sorted(ts.segments)

cache = {}

sliders = dict()
style = {"description_width": "initial"}
for param, bounds in params_bounds.items():
min_, max_, step = bounds
if isinstance(min_, float) or isinstance(max_, float) or isinstance(step, float):
sliders[param] = FloatSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)
else:
sliders[param] = IntSlider(min=min_, max=max_, step=step, continuous_update=False, style=style)

def update(**kwargs):
_, ax = prepare_axes(num_plots=len(segments), columns_num=columns_num, figsize=figsize)

key = "_".join([str(val) for val in kwargs.values()])

is_fitted = False

if key not in cache:
m_params = {x: kwargs[x] for x in model_params}
p_params = {x: kwargs[x] for x in predict_params}
cache[key] = {}
else:
is_fitted = True

for i, segment in enumerate(segments):
ax[i].cla()
segment_df = ts[start:end, segment, :][segment]
timestamp = segment_df.index.values
target = segment_df[in_column].values

if not is_fitted:
try:
algo = change_point_model(model=model, **m_params).fit(signal=target)
bkps = algo.predict(**p_params)
cache[key][segment] = bkps
cache[key][segment].insert(0, 1)
except BadSegmentationParameters:
cache[key][segment] = None

segment_bkps = cache[key][segment]

if segment_bkps is not None:
for idx in range(len(segment_bkps[:-1])):
bkp = segment_bkps[idx] - 1
start_time = timestamp[bkp]
end_time = timestamp[segment_bkps[idx + 1] - 1]
selected_indices = (timestamp >= start_time) & (timestamp <= end_time)
cur_timestamp = timestamp[selected_indices]
cur_target = target[selected_indices]
ax[i].plot(cur_timestamp, cur_target)
if bkp != 0:
ax[i].axvline(timestamp[bkp], linestyle="dashed", c="grey")

else:
box = {"facecolor": "grey", "edgecolor": "red", "boxstyle": "round"}
ax[i].text(
0.5, 0.4, "Parameters\nError", bbox=box, horizontalalignment="center", color="white", fontsize=50
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
plt.show()

interact(update, **sliders)
140 changes: 128 additions & 12 deletions examples/EDA.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ We have prepared a set of tutorials for an easy introduction:
- Outliers
- Median method
- Density method
- Change Points
- Change points plot
- Interactive change points plot
#### 04. [Outliers](https://github.com/tinkoff-ai/etna/tree/master/examples/outliers.ipynb)
- Point outliers
- Median method
Expand Down