Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add curves.Derivative #492

Merged
merged 31 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fde1b05
Add first version of MNIST notebook
gtauzin Aug 1, 2020
88f636b
Extract first running example using plottling API
gtauzin Aug 1, 2020
1e88eaf
Updated content, missing form
gtauzin Aug 1, 2020
f2e8676
ix typo
gtauzin Aug 1, 2020
78e6d08
Improve form
gtauzin Aug 1, 2020
1c4d173
Implement first batch of comment
gtauzin Aug 30, 2020
da78912
Merge branch 'master' of github.com:giotto-ai/giotto-tda
gtauzin Aug 30, 2020
6bc29de
Add a few images as it is done in other notebooks
gtauzin Aug 30, 2020
46a53dc
Add derivative
Sep 14, 2020
eec9b43
Merge branch 'master' of github.com:gtauzin/giotto-tda
gtauzin Sep 15, 2020
d6d6680
Merge branch 'master' of github.com:giotto-ai/giotto-tda
gtauzin Sep 15, 2020
3345413
Merge branch 'master' into derivative
ulupo Sep 19, 2020
94d87a4
Add doc and fix docstrings
gtauzin Sep 20, 2020
67d31c9
Add doc entry
gtauzin Sep 21, 2020
0fb3ed2
Add curves.rst to index
gtauzin Sep 23, 2020
3e7ec3c
Add ValueError in case n_bins is too small
gtauzin Sep 23, 2020
541d71d
Make Derivative a PlotterMixin, fix tests, fix imports, fix linting, …
wreise Sep 24, 2020
65e8a26
Fix linting
gtauzin Sep 29, 2020
1c1971d
Merge branch 'master' into derivative
wreise Oct 4, 2020
3f31d70
Merge with remote master
gtauzin Oct 5, 2020
b385c70
Merge branch 'master' of github.com:giotto-ai/giotto-tda into derivative
gtauzin Oct 5, 2020
c820e10
Rever changes on MNIST notebook
gtauzin Oct 5, 2020
d8999d3
Update test according to @ulupo siggestions
gtauzin Oct 5, 2020
a343814
Merge branch 'master' into derivative
ulupo Oct 5, 2020
076b8c2
Fix doc index
ulupo Oct 5, 2020
780dd7f
Fix init
ulupo Oct 5, 2020
806d79b
Rename n_samplings to n_bins in StandardFeatures
ulupo Oct 5, 2020
b6fb34e
Make Derivative docs and input checks more consistent with StandardFe…
ulupo Oct 5, 2020
409a025
Cover case of None channels in tests
ulupo Oct 5, 2020
abc0f73
Add n_channels_ attribute, make channels=None mean only channels seen…
ulupo Oct 5, 2020
9266158
Cover ValueErrors for non-3d input
ulupo Oct 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/modules/curves.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
:mod:`gtda.curves`: Curves
============================

.. automodule:: gtda.curves
:no-members:
:no-inherited-members:

Preprocessing
-------------
.. currentmodule:: gtda

.. autosummary::
:toctree: generated/curves/preprocessing/
:template: class.rst

curves.Derivative
8 changes: 8 additions & 0 deletions gtda/curves/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""The module :mod:`gtda.curves` implements transformers to postprocess
curves."""

from .preprocessing import Derivative

__all__ = [
'Derivative',
]
99 changes: 99 additions & 0 deletions gtda/curves/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Preprocessing transformers for curves."""

from numbers import Real
from types import FunctionType

import numpy as np
from joblib import Parallel, delayed, effective_n_jobs
from sklearn.utils import gen_even_slices
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_array

from ..utils._docs import adapt_fit_transform_docs
from ..utils.intervals import Interval
from ..utils.validation import validate_params


@adapt_fit_transform_docs
class Derivative(BaseEstimator, TransformerMixin):
"""Computes the derivative of multi-channel curves.

Given a multi-channel curve computes the corresponding multi-channel
derivative.

Parameters
----------
order : int, optional, default: ``1``
The number of time the multi-channels curves are derived.

n_jobs : int or None, optional, default: ``None``
The number of jobs to use for the computation. ``None`` means 1 unless
in a :obj:`joblib.parallel_backend` context. ``-1`` means using all
processors.

"""
_hyperparameters = {
'order': {'type': int, 'in': Interval(1, np.inf, closed='left')},
}

def __init__(self, order=1, n_jobs=None):
self.order = order
self.n_jobs = n_jobs

def fit(self, X, y=None):
"""Do nothing and return the estimator.

This function is here to implement the usual scikit-learn API and hence
work in pipelines.

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_bins)
Input data. Collection of multi-channel curves.

y : None
There is no need for a target in a transformer, yet the pipeline
API requires this parameter.

Returns
-------
self : object

"""
X = check_array(X, allow_nd=True)
validate_params(
self.get_params(), self._hyperparameters, exclude=['n_jobs'])

self._is_fitted = True

return self

def transform(self, X, y=None):
"""Compute the derivatives of the input multi-channel curves.

Parameters
----------
X : ndarray of shape (n_samples, n_channels, n_bins)
Input collection of multi-channel curves.

y : None
There is no need for a target in a transformer, yet the pipeline
API requires this parameter.

Returns
-------
Xt : ndarray of shape (n_samples, n_channels, n_bins-order)
Output collection of the multi-channel curves' derivative.

"""
check_is_fitted(self, '_is_fitted')
Xt = check_array(X, allow_nd=True)

Xt = Parallel(n_jobs=self.n_jobs)(
delayed(np.diff)(Xt[s], n=self.order, axis=-1)
for s in gen_even_slices(
Xt.shape[0], effective_n_jobs(self.n_jobs))
)
gtauzin marked this conversation as resolved.
Show resolved Hide resolved
Xt = np.concatenate(Xt)

return Xt
1 change: 1 addition & 0 deletions gtda/curves/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

40 changes: 40 additions & 0 deletions gtda/curves/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Testing for curves preprocessing."""

import pytest
import numpy as np
import plotly.io as pio
from numpy.testing import assert_almost_equal
from sklearn.exceptions import NotFittedError
from giotto.curves import Derivative

pio.renderers.default = 'plotly_mimetype'

np.random.seed(0)
X = np.random.rand(1, 2, 5)


def test_derivative_not_fitted():
d = Derivative()

with pytest.raises(NotFittedError):
d.transform(X)


X_res = {
1: np.array([[[ 0.16637586, -0.11242599, -0.05788019, -0.12122838],
[-0.2083069, 0.45418579, 0.07188976, -0.58022124]]]),
2: np.array([[[-0.27880185, 0.0545458, -0.06334819],
[ 0.66249269, -0.38229603, -0.652111]]]),
}

@pytest.mark.parametrize('order', [1, 2])
def test_derivative_transform(order):
d = Derivative(order)

assert_almost_equal(d.fit_transform(X), X_res[order])


def test_consistent_fit_transform_plot():
d = Derivative()
Xt = d.fit_transform(X)
d.plot(Xt)