Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove incorrect assumptions in Filtering #436

Merged
merged 9 commits into from
Aug 1, 2020
83 changes: 60 additions & 23 deletions gtda/diagrams/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ def _subdiagrams(X, homology_dimensions, remove_dim=False):
list of homology dimensions. It is assumed that all diagrams in X contain
the same number of points in each homology dimension."""
n = len(X)
Xs = np.concatenate([X[X[:, :, 2] == dim].reshape(n, -1, 3)
for dim in homology_dimensions],
axis=1)
if len(homology_dimensions) == 1:
Xs = X[X[:, :, 2] == homology_dimensions[0]].reshape(n, -1, 3)
else:
Xs = np.concatenate([X[X[:, :, 2] == dim].reshape(n, -1, 3)
for dim in homology_dimensions],
axis=1)
if remove_dim:
Xs = Xs[:, :, :2]
return Xs
Expand All @@ -25,38 +28,72 @@ def _pad(X, max_diagram_sizes):
return X_padded


def _sort(Xs):
indices = np.argsort(Xs[:, :, 1] - Xs[:, :, 0], axis=1)
indices = np.stack([indices, indices, indices], axis=2)
Xs = np.flip(np.take_along_axis(Xs, indices, axis=1), axis=1)
return Xs


def _sample_image(image, sampled_diag):
# NOTE: Modifies `image` in-place
unique, counts = np.unique(sampled_diag, axis=0, return_counts=True)
unique = tuple(tuple(row) for row in unique.astype(np.int).T)
image[unique] = counts


def _filter(Xs, filtered_homology_dimensions, cutoff):
homology_dimensions = sorted(list(set(Xs[0, :, 2])))
unfiltered_homology_dimensions = sorted(list(
set(homology_dimensions) - set(filtered_homology_dimensions)))
def _multirange(counts):
"""Given a 1D array of positive integers, generate an array equal to
np.concatenate([np.arange(c) for c in counts]), but in a faster and more
memory-efficient way."""
cumsum = np.cumsum(counts)
reset_index = cumsum[:-1]
incr = np.ones(cumsum[-1], dtype=np.int32)
incr[0] = 0

# For each index in reset_index, we insert the negative value necessary
# to offset the cumsum in the last line
incr[reset_index] = 1 - counts[:-1]
incr.cumsum(out=incr)

return incr


def _filter(X, filtered_homology_dimensions, cutoff):
n = len(X)
homology_dimensions = sorted(list(set(X[0, :, 2])))
unfiltered_homology_dimensions = [dim for dim in homology_dimensions if
dim not in filtered_homology_dimensions]

if len(unfiltered_homology_dimensions) == 0:
Xf = np.empty((Xs.shape[0], 0, 3), dtype=Xs.dtype)
Xuf = np.empty((n, 0, 3), dtype=X.dtype)
else:
Xf = _subdiagrams(Xs, unfiltered_homology_dimensions)
Xuf = _subdiagrams(X, unfiltered_homology_dimensions)

# Compute a global 2D cutoff mask once
cutoff_mask = X[:, :, 1] - X[:, :, 0] > cutoff
Xf = []
for dim in filtered_homology_dimensions:
Xdim = _subdiagrams(Xs, [dim])
min_value = np.min(Xdim[:, :, 0])
mask = (Xdim[:, :, 1] - Xdim[:, :, 0]) <= cutoff
Xdim[mask, :] = [min_value, min_value, dim]
max_points = np.max(np.sum(Xs[:, :, 1] != 0, axis=1))
Xdim = Xdim[:, :max_points, :]
Xf = np.concatenate([Xf, Xdim], axis=1)
# Compute a 2D mask for persistence pairs in dimension dim
dim_mask = X[:, :, 2] == dim
# Need the indices relative to X of persistence triples in dimension
# dim surviving the cutoff
indices = np.nonzero(np.logical_and(dim_mask, cutoff_mask))
if not indices[0].size:
Xdim = np.tile([0., 0., dim], (n, 1, 1))
else:
# A unique element k is repeated N times *consecutively* in
# indices[0] iff there are exactly N valid persistence triples
# in the k-th diagram
unique, counts = np.unique(indices[0], return_counts=True)
max_n_points = np.max(counts)
# Make a global 2D array of all valid triples
X_indices = X[indices]
min_value = np.min(X_indices[:, 0]) # For padding
# Initialise the array of filtered subdiagrams in dimension m
Xdim = np.tile([min_value, min_value, dim], (n, max_n_points, 1))
# Since repeated indices in indices[0] are consecutive and we know
# the counts per unique index, we can fill the top portion of
# each 2D array entry of Xdim with the filtered triples from the
# corresponding entry of X
Xdim[indices[0], _multirange(counts)] = X_indices
Xf.append(Xdim)

Xf.append(Xuf)
Xf = np.concatenate(Xf, axis=1)
return Xf


Expand Down
30 changes: 19 additions & 11 deletions gtda/diagrams/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.utils.validation import check_is_fitted

from ._metrics import _AVAILABLE_AMPLITUDE_METRICS, _parallel_amplitude
from ._utils import _sort, _filter, _bin, _calculate_weights
from ._utils import _filter, _bin, _calculate_weights
from ..base import PlotterMixin
from ..plotting.persistence_diagrams import plot_diagram
from ..utils._docs import adapt_fit_transform_docs
Expand Down Expand Up @@ -314,18 +314,21 @@ class Filtering(BaseEstimator, TransformerMixin, PlotterMixin):
"""Filtering of persistence diagrams.

Filtering a diagram means discarding all points [b, d, q] representing
topological features whose lifetime d - b is less than or equal to a
cutoff value. Technically, discarded points are replaced by points on the
diagonal (i.e. whose birth and death values coincide), which carry no
information.
non-trivial topological features whose lifetime d - b is less than or
equal to a cutoff value. Points on the diagonal (i.e. for which b and d
are equal) may still appear in the output for padding purposes, but carry
no information.

Input collections of persistence diagrams for this transformer must
satisfy certain requirements, see e.g. :meth:`fit`.

Parameters
----------
homology_dimensions : list, tuple, or None, optional, default: ``None``
When set to ``None``, subdiagrams corresponding to all homology
dimensions seen in :meth:`fit` will be filtered.
Otherwise, it contains the homology dimensions (as non-negative
integers) at which filtering should occur.
dimensions seen in :meth:`fit` will be filtered. Otherwise, it contains
the homology dimensions (as non-negative integers) at which filtering
should occur.

epsilon : float, optional, default: ``0.01``
The cutoff value controlling the amount of filtering.
Expand Down Expand Up @@ -368,6 +371,9 @@ def fit(self, X, y=None):
Input data. Array of persistence diagrams, each a collection of
triples [b, d, q] representing persistent topological features
through their birth (b), death (d) and homology dimension (q).
It is important that, for each possible homology dimension, the
gtauzin marked this conversation as resolved.
Show resolved Hide resolved
number of triples for which q equals that homology dimension is
constants across the entries of `X`.

y : None
There is no need for a target in a transformer, yet the pipeline
Expand Down Expand Up @@ -399,23 +405,25 @@ def transform(self, X, y=None):
Input data. Array of persistence diagrams, each a collection of
triples [b, d, q] representing persistent topological features
through their birth (b), death (d) and homology dimension (q).
It is important that, for each possible homology dimension, the
number of triples for which q equals that homology dimension is
constants across the entries of X.

y : None
There is no need for a target in a transformer, yet the pipeline
API requires this parameter.

Returns
-------
Xt : ndarray of shape (n_samples, n_features, 3)
Xt : ndarray of shape (n_samples, n_features_filtered, 3)
Filtered persistence diagrams. Only the subdiagrams corresponding
to dimensions in :attr:`homology_dimensions_` are filtered.
Discarded points are replaced by points on the diagonal.
``n_features_filtered`` is less than or equal to ``n_features`.

"""
check_is_fitted(self)
X = check_diagrams(X)

X = _sort(X)
Xt = _filter(X, self.homology_dimensions_, self.epsilon)
return Xt

Expand Down
2 changes: 1 addition & 1 deletion gtda/diagrams/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_filt_transform_zero(X):
def test_filt_transform(epsilon):
filt = Filtering(epsilon=epsilon)
X_res_1 = filt.fit_transform(X_1)
assert X_res_1.shape == X_1.shape
assert X_res_1.shape[1] <= X_1.shape[1]

lifetimes_res_1 = X_res_1[:, :, 1] - X_res_1[:, :, 0]
assert not ((lifetimes_res_1 > 0.) & (lifetimes_res_1 <= epsilon)).any()