diff --git a/gtda/diagrams/_utils.py b/gtda/diagrams/_utils.py index c02557d04..476e2d1f4 100644 --- a/gtda/diagrams/_utils.py +++ b/gtda/diagrams/_utils.py @@ -12,9 +12,12 @@ def _subdiagrams(X, homology_dimensions, remove_dim=False): list of homology dimensions. It is assumed that all diagrams in X contain the same number of points in each homology dimension.""" n = len(X) - Xs = np.concatenate([X[X[:, :, 2] == dim].reshape(n, -1, 3) - for dim in homology_dimensions], - axis=1) + if len(homology_dimensions) == 1: + Xs = X[X[:, :, 2] == homology_dimensions[0]].reshape(n, -1, 3) + else: + Xs = np.concatenate([X[X[:, :, 2] == dim].reshape(n, -1, 3) + for dim in homology_dimensions], + axis=1) if remove_dim: Xs = Xs[:, :, :2] return Xs diff --git a/gtda/diagrams/preprocessing.py b/gtda/diagrams/preprocessing.py index 1d50be726..368dedbc0 100644 --- a/gtda/diagrams/preprocessing.py +++ b/gtda/diagrams/preprocessing.py @@ -314,18 +314,21 @@ class Filtering(BaseEstimator, TransformerMixin, PlotterMixin): """Filtering of persistence diagrams. Filtering a diagram means discarding all points [b, d, q] representing - topological features whose lifetime d - b is less than or equal to a - cutoff value. Technically, discarded points are replaced by points on the - diagonal (i.e. whose birth and death values coincide), which carry no - information. + non-trivial topological features whose lifetime d - b is less than or + equal to a cutoff value. Points on the diagonal (i.e. for which b and d + are equal) may still appear in the output for padding purposes, but carry + no information. + + Input collections of persistence diagrams for this transformer must + satisfy certain requirements, see e.g. :meth:`fit`. Parameters ---------- homology_dimensions : list, tuple, or None, optional, default: ``None`` When set to ``None``, subdiagrams corresponding to all homology - dimensions seen in :meth:`fit` will be filtered. - Otherwise, it contains the homology dimensions (as non-negative - integers) at which filtering should occur. + dimensions seen in :meth:`fit` will be filtered. Otherwise, it contains + the homology dimensions (as non-negative integers) at which filtering + should occur. epsilon : float, optional, default: ``0.01`` The cutoff value controlling the amount of filtering. @@ -368,6 +371,9 @@ def fit(self, X, y=None): Input data. Array of persistence diagrams, each a collection of triples [b, d, q] representing persistent topological features through their birth (b), death (d) and homology dimension (q). + It is important that, for each possible homology dimension, the + number of triples for which q equals that homology dimension is + constants across the entries of `X`. y : None There is no need for a target in a transformer, yet the pipeline @@ -399,6 +405,9 @@ def transform(self, X, y=None): Input data. Array of persistence diagrams, each a collection of triples [b, d, q] representing persistent topological features through their birth (b), death (d) and homology dimension (q). + It is important that, for each possible homology dimension, the + number of triples for which q equals that homology dimension is + constants across the entries of X. y : None There is no need for a target in a transformer, yet the pipeline @@ -406,10 +415,10 @@ def transform(self, X, y=None): Returns ------- - Xt : ndarray of shape (n_samples, n_features, 3) + Xt : ndarray of shape (n_samples, n_features_filtered, 3) Filtered persistence diagrams. Only the subdiagrams corresponding to dimensions in :attr:`homology_dimensions_` are filtered. - Discarded points are replaced by points on the diagonal. + ``n_features_filtered`` is less than or equal to ``n_features`. """ check_is_fitted(self)