Skip to content

Commit

Permalink
Make homology_dimensions_ attributes tuples instead of lists, with in…
Browse files Browse the repository at this point in the history
…tegers when possible
  • Loading branch information
ulupo committed Aug 12, 2020
1 parent a532cc8 commit 17b906e
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 32 deletions.
4 changes: 2 additions & 2 deletions gtda/diagrams/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _multirange(counts):

def _filter(X, filtered_homology_dimensions, cutoff):
n = len(X)
homology_dimensions = sorted(list(set(X[0, :, 2])))
homology_dimensions = sorted(set(X[0, :, 2]))
unfiltered_homology_dimensions = [dim for dim in homology_dimensions if
dim not in filtered_homology_dimensions]

Expand Down Expand Up @@ -91,7 +91,7 @@ def _filter(X, filtered_homology_dimensions, cutoff):


def _bin(X, metric, n_bins=100, **kw_args):
homology_dimensions = sorted(list(set(X[0, :, 2])))
homology_dimensions = sorted(set(X[0, :, 2]))
# For some vectorizations, we force the values to be the same + widest
sub_diags = {dim: _subdiagrams(X, [dim], remove_dim=True)
for dim in homology_dimensions}
Expand Down
7 changes: 5 additions & 2 deletions gtda/diagrams/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class PairwiseDistance(BaseEstimator, TransformerMixin):
Dictionary containing all information present in `metric_params` as
well as relevant quantities computed in :meth:`fit`.
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
See also
Expand Down Expand Up @@ -178,7 +178,10 @@ def fit(self, X, y=None):
validate_params(
self.effective_metric_params_, _AVAILABLE_METRICS[self.metric])

self.homology_dimensions_ = sorted(set(X[0, :, 2]))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)

self.effective_metric_params_['samplings'], \
self.effective_metric_params_['step_sizes'] = \
Expand Down
14 changes: 10 additions & 4 deletions gtda/diagrams/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class PersistenceEntropy(BaseEstimator, TransformerMixin):
Attributes
----------
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
See also
Expand Down Expand Up @@ -127,7 +127,10 @@ def fit(self, X, y=None):
validate_params(
self.get_params(), self._hyperparameters, exclude=['n_jobs'])

self.homology_dimensions_ = sorted(set(X[0, :, 2]))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)

return self
Expand Down Expand Up @@ -257,7 +260,7 @@ class Amplitude(BaseEstimator, TransformerMixin):
Dictionary containing all information present in `metric_params` as
well as relevant quantities computed in :meth:`fit`.
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
See also
Expand Down Expand Up @@ -327,7 +330,10 @@ def fit(self, X, y=None):
validate_params(self.effective_metric_params_,
_AVAILABLE_AMPLITUDE_METRICS[self.metric])

self.homology_dimensions_ = sorted(set(X[0, :, 2]))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)

self.effective_metric_params_['samplings'], \
self.effective_metric_params_['step_sizes'] = \
Expand Down
33 changes: 20 additions & 13 deletions gtda/diagrams/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class Scaler(BaseEstimator, TransformerMixin, PlotterMixin):
Dictionary containing all information present in `metric_params` as
well as relevant quantities computed in :meth:`fit`.
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
scale_ : float
Expand Down Expand Up @@ -239,7 +239,10 @@ def fit(self, X, y=None):
validate_params(self.effective_metric_params_,
_AVAILABLE_AMPLITUDE_METRICS[self.metric])

self.homology_dimensions_ = sorted(set(X[0, :, 2]))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)

self.effective_metric_params_['samplings'], \
self.effective_metric_params_['step_sizes'] = \
Expand Down Expand Up @@ -356,10 +359,10 @@ class Filtering(BaseEstimator, TransformerMixin, PlotterMixin):
"""Filtering of persistence diagrams.
Filtering a diagram means discarding all points [b, d, q] representing
non-trivial topological features whose lifetime d - b is less than or
equal to a cutoff value. Points on the diagonal (i.e. for which b and d
are equal) may still appear in the output for padding purposes, but carry
no information.
non-trivial topological features whose lifetime d - b is less than or equal
to a cutoff value. Points on the diagonal (i.e. for which b and d are
equal) may still appear in the output for padding purposes, but carry no
information.
**Important note**:
Expand All @@ -379,11 +382,10 @@ class Filtering(BaseEstimator, TransformerMixin, PlotterMixin):
Attributes
----------
homology_dimensions_ : list
If `homology_dimensions` is set to ``None``, then this is the list
of homology dimensions seen in :meth:`fit`, sorted in ascending
order. Otherwise, it is a similarly sorted version of
`homology_dimensions`.
homology_dimensions_ : tuple
If `homology_dimensions` is set to ``None``, contains the homology
dimensions seen in :meth:`fit`, sorted in ascending order. Otherwise,
it is a similarly sorted version of `homology_dimensions`.
See also
--------
Expand Down Expand Up @@ -434,10 +436,15 @@ def fit(self, X, y=None):
self.get_params(), self._hyperparameters)

if self.homology_dimensions is None:
self.homology_dimensions_ = [int(dim) for dim in set(X[0, :, 2])]
self.homology_dimensions_ = [
int(dim) if dim != np.inf else dim for dim in set(X[0, :, 2])
]
else:
self.homology_dimensions_ = self.homology_dimensions
self.homology_dimensions_ = sorted(self.homology_dimensions_)
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)

return self

Expand Down
35 changes: 25 additions & 10 deletions gtda/diagrams/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BettiCurve(BaseEstimator, TransformerMixin, PlotterMixin):
Attributes
----------
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
samplings_ : dict
Expand Down Expand Up @@ -112,7 +112,10 @@ def fit(self, X, y=None):
validate_params(
self.get_params(), self._hyperparameters, exclude=["n_jobs"])

self.homology_dimensions_ = sorted(list(set(X[0, :, 2])))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)
self._samplings, _ = _bin(X, "betti", n_bins=self.n_bins)
self.samplings_ = {dim: s.flatten()
Expand Down Expand Up @@ -284,7 +287,7 @@ class PersistenceLandscape(BaseEstimator, TransformerMixin, PlotterMixin):
Attributes
----------
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`.
samplings_ : dict
Expand Down Expand Up @@ -350,7 +353,10 @@ def fit(self, X, y=None):
validate_params(
self.get_params(), self._hyperparameters, exclude=["n_jobs"])

self.homology_dimensions_ = sorted(list(set(X[0, :, 2])))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)
self._samplings, _ = _bin(X, "landscape", n_bins=self.n_bins)
self.samplings_ = {dim: s.flatten()
Expand Down Expand Up @@ -546,7 +552,7 @@ class HeatKernel(BaseEstimator, TransformerMixin, PlotterMixin):
Attributes
----------
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`.
samplings_ : dict
Expand Down Expand Up @@ -620,7 +626,10 @@ def fit(self, X, y=None):
validate_params(
self.get_params(), self._hyperparameters, exclude=["n_jobs"])

self.homology_dimensions_ = sorted(list(set(X[0, :, 2])))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)
self._samplings, self._step_size = _bin(X, "heat", n_bins=self.n_bins)
self.samplings_ = {dim: s.flatten()
Expand Down Expand Up @@ -767,7 +776,7 @@ class PersistenceImage(BaseEstimator, TransformerMixin, PlotterMixin):
Effective function corresponding to `weight_function`. Set in
:meth:`fit`.
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`.
samplings_ : dict
Expand Down Expand Up @@ -855,7 +864,10 @@ def fit(self, X, y=None):
else:
self.effective_weight_function_ = self.weight_function

self.homology_dimensions_ = sorted(list(set(X[0, :, 2])))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)
self._samplings, self._step_size = _bin(X, "persistence_image",
n_bins=self.n_bins)
Expand Down Expand Up @@ -1003,7 +1015,7 @@ class Silhouette(BaseEstimator, TransformerMixin, PlotterMixin):
Attributes
----------
homology_dimensions_ : list
homology_dimensions_ : tuple
Homology dimensions seen in :meth:`fit`, sorted in ascending order.
samplings_ : dict
Expand Down Expand Up @@ -1076,7 +1088,10 @@ def fit(self, X, y=None):
validate_params(
self.get_params(), self._hyperparameters, exclude=["n_jobs"])

self.homology_dimensions_ = sorted(list(set(X[0, :, 2])))
self.homology_dimensions_ = tuple(
sorted([int(dim) if dim != np.inf else dim
for dim in set(X[0, :, 2])])
)
self._n_dimensions = len(self.homology_dimensions_)
self._samplings, _ = _bin(X, "silhouette", n_bins=self.n_bins)
self.samplings_ = {dim: s.flatten()
Expand Down
2 changes: 1 addition & 1 deletion gtda/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check_diagrams(X, copy=False):
f"components, but there are {X_array.shape[2]} components.")

X_array = X_array.astype(float, copy=False)
homology_dimensions = sorted(list(set(X_array[0, :, 2])))
homology_dimensions = sorted(set(X_array[0, :, 2]))
for dim in homology_dimensions:
if dim == np.inf:
if len(homology_dimensions) != 1:
Expand Down

0 comments on commit 17b906e

Please sign in to comment.