Skip to content

Commit

Permalink
Remove _sort, simplify _filter
Browse files Browse the repository at this point in the history
Arrays are no longer sorted by lifetime before filtering. Persistence pairs to be filtered out are now replaced by padding points in their original locations.
  • Loading branch information
ulupo committed Jul 17, 2020
1 parent 6b32773 commit d853062
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
27 changes: 7 additions & 20 deletions gtda/diagrams/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ def _pad(X, max_diagram_sizes):
return X_padded


def _sort(Xs):
indices = np.argsort(Xs[:, :, 1] - Xs[:, :, 0], axis=1)
indices = np.stack([indices, indices, indices], axis=2)
Xs = np.flip(np.take_along_axis(Xs, indices, axis=1), axis=1)
return Xs


def _sample_image(image, sampled_diag):
unique, counts = np.unique(sampled_diag, axis=0, return_counts=True)
unique = tuple(tuple(row) for row in unique.astype(np.int).T)
Expand All @@ -36,22 +29,16 @@ def _sample_image(image, sampled_diag):

def _filter(Xs, filtered_homology_dimensions, cutoff):
homology_dimensions = sorted(list(set(Xs[0, :, 2])))
unfiltered_homology_dimensions = sorted(list(
set(homology_dimensions) - set(filtered_homology_dimensions)))
Xf = np.empty((Xs.shape[0], 0, 3), dtype=Xs.dtype)

if len(unfiltered_homology_dimensions) == 0:
Xf = np.empty((Xs.shape[0], 0, 3), dtype=Xs.dtype)
else:
Xf = _subdiagrams(Xs, unfiltered_homology_dimensions)

for dim in filtered_homology_dimensions:
for dim in homology_dimensions:
Xdim = _subdiagrams(Xs, [dim])
min_value = np.min(Xdim[:, :, 0])
mask = (Xdim[:, :, 1] - Xdim[:, :, 0]) <= cutoff
Xdim[mask, :] = [min_value, min_value, dim]
max_points = np.max(np.sum(Xs[:, :, 1] != 0, axis=1))
Xdim = Xdim[:, :max_points, :]
if dim in filtered_homology_dimensions:
min_value = np.min(Xdim[:, :, 0])
mask = (Xdim[:, :, 1] - Xdim[:, :, 0]) <= cutoff
Xdim[mask, :] = [min_value, min_value, dim]
Xf = np.concatenate([Xf, Xdim], axis=1)

return Xf


Expand Down
3 changes: 1 addition & 2 deletions gtda/diagrams/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.utils.validation import check_is_fitted

from ._metrics import _AVAILABLE_AMPLITUDE_METRICS, _parallel_amplitude
from ._utils import _sort, _filter, _bin, _calculate_weights
from ._utils import _filter, _bin, _calculate_weights
from ..base import PlotterMixin
from ..plotting.persistence_diagrams import plot_diagram
from ..utils._docs import adapt_fit_transform_docs
Expand Down Expand Up @@ -415,7 +415,6 @@ def transform(self, X, y=None):
check_is_fitted(self)
X = check_diagrams(X)

X = _sort(X)
Xt = _filter(X, self.homology_dimensions_, self.epsilon)
return Xt

Expand Down

0 comments on commit d853062

Please sign in to comment.