Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add curves.Derivative #492

Merged
merged 31 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fde1b05
Add first version of MNIST notebook
gtauzin Aug 1, 2020
88f636b
Extract first running example using plottling API
gtauzin Aug 1, 2020
1e88eaf
Updated content, missing form
gtauzin Aug 1, 2020
f2e8676
ix typo
gtauzin Aug 1, 2020
78e6d08
Improve form
gtauzin Aug 1, 2020
1c4d173
Implement first batch of comment
gtauzin Aug 30, 2020
da78912
Merge branch 'master' of github.com:giotto-ai/giotto-tda
gtauzin Aug 30, 2020
6bc29de
Add a few images as it is done in other notebooks
gtauzin Aug 30, 2020
46a53dc
Add derivative
Sep 14, 2020
eec9b43
Merge branch 'master' of github.com:gtauzin/giotto-tda
gtauzin Sep 15, 2020
d6d6680
Merge branch 'master' of github.com:giotto-ai/giotto-tda
gtauzin Sep 15, 2020
3345413
Merge branch 'master' into derivative
ulupo Sep 19, 2020
94d87a4
Add doc and fix docstrings
gtauzin Sep 20, 2020
67d31c9
Add doc entry
gtauzin Sep 21, 2020
0fb3ed2
Add curves.rst to index
gtauzin Sep 23, 2020
3e7ec3c
Add ValueError in case n_bins is too small
gtauzin Sep 23, 2020
541d71d
Make Derivative a PlotterMixin, fix tests, fix imports, fix linting, …
wreise Sep 24, 2020
65e8a26
Fix linting
gtauzin Sep 29, 2020
1c1971d
Merge branch 'master' into derivative
wreise Oct 4, 2020
3f31d70
Merge with remote master
gtauzin Oct 5, 2020
b385c70
Merge branch 'master' of github.com:giotto-ai/giotto-tda into derivative
gtauzin Oct 5, 2020
c820e10
Rever changes on MNIST notebook
gtauzin Oct 5, 2020
d8999d3
Update test according to @ulupo siggestions
gtauzin Oct 5, 2020
a343814
Merge branch 'master' into derivative
ulupo Oct 5, 2020
076b8c2
Fix doc index
ulupo Oct 5, 2020
780dd7f
Fix init
ulupo Oct 5, 2020
806d79b
Rename n_samplings to n_bins in StandardFeatures
ulupo Oct 5, 2020
b6fb34e
Make Derivative docs and input checks more consistent with StandardFe…
ulupo Oct 5, 2020
409a025
Cover case of None channels in tests
ulupo Oct 5, 2020
abc0f73
Add n_channels_ attribute, make channels=None mean only channels seen…
ulupo Oct 5, 2020
9266158
Cover ValueErrors for non-3d input
ulupo Oct 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/modules/curves.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
:no-members:
:no-inherited-members:

Preprocessing
-------------
.. currentmodule:: gtda

.. autosummary::
:toctree: generated/curves/preprocessing/
:template: class.rst

curves.Derivative

Feature extraction
------------------
.. currentmodule:: gtda
Expand Down
4 changes: 3 additions & 1 deletion gtda/curves/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""The module :mod:`gtda.curves` implements transformers to postprocess
curves."""

from .preprocessing import Derivative
from .features import StandardFeatures

__all__ = [
'StandardFeatures'
"Derivative",
"StandardFeatures"
]
19 changes: 9 additions & 10 deletions gtda/curves/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
class StandardFeatures(BaseEstimator, TransformerMixin):
"""Standard features from multi-channel curves.

A multi-channel (sampled) curve is a 2D array of shape
``(n_channels, n_samplings)``, where each row represents the y-values in
one of channels. This transformer applies scalar or vector-valued functions
channel-wise to extract features from a collection of multi-channel curves,
of shape ``(n_samples, n_channels, n_samplings)``. The output is always a
2D array such that row ``i`` is the concatenation of the outputs of the
chosen functions on the channels in the ``i``-th (multi-)curve in the
collection.
A multi-channel (integer sampled) curve is a 2D array of shape
``(n_channels, n_bins)``, where each row represents the y-values in one of
the channels. This transformer applies scalar or vector-valued functions
channel-wise to extract features from each multi-channel curve in a
collection. The output is always a 2D array such that row ``i`` is the
concatenation of the outputs of the chosen functions on the channels in the
``i``-th (multi-)curve in the collection.

Parameters
----------
Expand Down Expand Up @@ -133,7 +132,7 @@ def fit(self, X, y=None):

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_samplings)
X : ndarray of shape (n_samples, n_channels, n_bins)
Input data. Collection of multi-channel curves.

y : None
Expand Down Expand Up @@ -213,7 +212,7 @@ def transform(self, X, y=None):

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_samplings)
X : ndarray of shape (n_samples, n_channels, n_bins)
Input collection of multi-channel curves.

y : None
Expand Down
199 changes: 199 additions & 0 deletions gtda/curves/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Preprocessing transformers for curves."""
# License: GNU AGPLv3

import numpy as np
from joblib import Parallel, delayed, effective_n_jobs
from plotly.graph_objs import Figure, Scatter
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import gen_even_slices
from sklearn.utils.validation import check_is_fitted, check_array

from ..base import PlotterMixin
from ..utils._docs import adapt_fit_transform_docs
from ..utils.intervals import Interval
from ..utils.validation import validate_params


@adapt_fit_transform_docs
class Derivative(BaseEstimator, TransformerMixin, PlotterMixin):
"""Derivatives of multi-channel curves.

A multi-channel (integer sampled) curve is a 2D array of shape
``(n_channels, n_bins)``, where each row represents the y-values in one of
the channels. This transformer computes the n-th order derivative of each
channel in each multi-channel curve in a collection, by discrete
differences. The output is another collection of multi-channel curves.

Parameters
----------
order : int, optional, default: ``1``
Order of the derivative to be taken.

n_jobs : int or None, optional, default: ``None``
The number of jobs to use for the computation. ``None`` means 1 unless
in a :obj:`joblib.parallel_backend` context. ``-1`` means using all
processors.

Attributes
----------
n_channels_ : int
Number of channels present in the 3D array passed to :meth:`fit`.

"""
_hyperparameters = {
'order': {'type': int, 'in': Interval(1, np.inf, closed='left')},
}

def __init__(self, order=1, n_jobs=None):
self.order = order
self.n_jobs = n_jobs

def fit(self, X, y=None):
"""Compute :attr:`n_channels_`. Then, return the estimator.

This function is here to implement the usual scikit-learn API and hence
work in pipelines.

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_bins)
Input data. Collection of multi-channel curves.

y : None
There is no need for a target in a transformer, yet the pipeline
API requires this parameter.

Returns
-------
self : object

"""
check_array(X, ensure_2d=False, allow_nd=True)
if X.ndim != 3:
raise ValueError("Input must be 3-dimensional.")
validate_params(
self.get_params(), self._hyperparameters, exclude=['n_jobs'])

n_bins = X.shape[2]
if self.order >= n_bins:
raise ValueError(
f"Input channels have length {n_bins} but they must have at "
f"least length {self.order + 1} to calculate derivatives of "
f"order {self.order}."
)

self.n_channels_ = X.shape[1]

return self

def transform(self, X, y=None):
"""Compute derivatives of multi-channel curves.

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_bins)
Input collection of multi-channel curves.

y : None
There is no need for a target in a transformer, yet the pipeline
API requires this parameter.

Returns
-------
Xt : ndarray of shape (n_samples, n_channels, n_bins - order)
Output collection of multi-channel curves given by taking discrete
differences of order `order` in each channel in the curves in `X`.

"""
check_is_fitted(self)
Xt = check_array(X, ensure_2d=False, allow_nd=True)
if Xt.ndim != 3:
raise ValueError("Input must be 3-dimensional.")

Xt = Parallel(n_jobs=self.n_jobs)(
delayed(np.diff)(Xt[s], n=self.order, axis=-1)
for s in gen_even_slices(len(Xt), effective_n_jobs(self.n_jobs))
)
Xt = np.concatenate(Xt)

return Xt

def plot(self, Xt, sample=0, channels=None, plotly_params=None):
"""Plot a sample from a collection of derivatives of multi-channel
curves arranged as in the output of :meth:`transform`.

Parameters
----------
Xt : ndarray of shape (n_samples, n_channels, n_bins)
Collection of multi-channel curves, such as returned by
:meth:`transform`.

sample : int, optional, default: ``0``
Index of the sample in `Xt` to be plotted.

channels : list, tuple or None, optional, default: ``None``
Which channels to include in the plot. ``None`` means plotting the
first :attr:`n_channels_` channels.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)

layout_axes_common = {
"type": "linear",
"ticks": "outside",
"showline": True,
"zeroline": True,
"linewidth": 1,
"linecolor": "black",
"mirror": False,
"showexponent": "all",
"exponentformat": "e"
}
layout = {
"xaxis1": {
"title": "Sample",
"side": "bottom",
"anchor": "y1",
**layout_axes_common
},
"yaxis1": {
"title": "Derivative",
"side": "left",
"anchor": "x1",
**layout_axes_common
},
"plot_bgcolor": "white",
"title": f"Derivative of sample {sample}"
}

fig = Figure(layout=layout)

if channels is None:
channels = range(self.n_channels_)

samplings = np.arange(Xt[sample].shape[0])
for ix, channel in enumerate(channels):
fig.add_trace(Scatter(x=samplings,
y=Xt[sample][ix],
mode="lines",
showlegend=True,
name=f"Channel {channel}"))

# Update traces and layout according to user input
if plotly_params:
fig.update_traces(plotly_params.get("traces", None))
fig.update_layout(plotly_params.get("layout", None))

return fig
65 changes: 65 additions & 0 deletions gtda/curves/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Testing for curves preprocessing."""

import pytest
import numpy as np
import plotly.io as pio
from numpy.testing import assert_almost_equal
from sklearn.exceptions import NotFittedError
from gtda.curves import Derivative

pio.renderers.default = 'plotly_mimetype'
line_plots_traces_params = {"mode": "lines+markers"}
layout_params = {"title": "New title"}
plotly_params = \
{"traces": line_plots_traces_params, "layout": layout_params}


np.random.seed(0)
X = np.random.rand(1, 2, 5)


def test_derivative_not_fitted():
d = Derivative()

with pytest.raises(NotFittedError):
d.transform(X)


def test_derivative_big_order():
d = Derivative(order=5)

with pytest.raises(ValueError):
d.fit(X)


@pytest.mark.parametrize("shape", [(2,), (2, 3), (2, 3, 4, 5)])
def test_standard_invalid_shape(shape):
sf = Derivative()

with pytest.raises(ValueError, match="Input must be 3-dimensional."):
sf.fit(np.ones(shape))

with pytest.raises(ValueError, match="Input must be 3-dimensional."):
sf.fit(X).transform(np.ones(shape))


X_res = {
1: np.array([[[0.16637586, -0.11242599, -0.05788019, -0.12122838],
[-0.2083069, 0.45418579, 0.07188976, -0.58022124]]]),
2: np.array([[[-0.27880185, 0.0545458, -0.06334819],
[0.66249269, -0.38229603, -0.652111]]]),
}


@pytest.mark.parametrize('order', [1, 2])
def test_derivative_transform(order):
d = Derivative(order)

assert_almost_equal(d.fit_transform(X), X_res[order])


@pytest.mark.parametrize("channels", [None, [1], [0, 1]])
def test_consistent_fit_transform_plot(channels):
d = Derivative()
Xt = d.fit_transform(X)
d.plot(Xt, channels=channels, plotly_params=plotly_params)