From f60608e5428d621f45628841b94a6305db64c1fb Mon Sep 17 00:00:00 2001 From: wreise Date: Thu, 8 Jul 2021 20:38:35 +0200 Subject: [PATCH 1/2] Change the default weight_params in WeightedRips --- gtda/homology/simplicial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtda/homology/simplicial.py b/gtda/homology/simplicial.py index 18702913d..2cb624a7e 100644 --- a/gtda/homology/simplicial.py +++ b/gtda/homology/simplicial.py @@ -432,7 +432,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): is a parameter (see `weight_params`). If a callable, it must return non-negative 1D arrays. - weight_params : dict, optional, default: ``None`` + weight_params : dict, optional, default: ``{}`` Additional parameters for the weighted filtration. ``"p"`` determines the power to be used in computing edge weights from vertex weights. It can be one of ``1``, ``2`` or ``np.inf`` and defaults to ``1``. If @@ -525,7 +525,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): "of": {"type": int, "in": Interval(0, np.inf, closed="left")} }, "weights": {"type": (str, FunctionType)}, - "weight_params": {"type": (dict, type(None))}, + "weight_params": {"type": dict}, "collapse_edges": {"type": bool}, "coeff": {"type": int, "in": Interval(2, np.inf, closed="left")}, "max_edge_weight": {"type": Real}, @@ -534,7 +534,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): } def __init__(self, metric="euclidean", metric_params={}, - homology_dimensions=(0, 1), weights="DTM", weight_params=None, + homology_dimensions=(0, 1), weights="DTM", weight_params={}, collapse_edges=False, coeff=2, max_edge_weight=np.inf, infinity_values=None, reduced_homology=True, n_jobs=None): self.metric = metric From d76a6d0232b4e7f2dd0febc857d02001c1ac4bb3 Mon Sep 17 00:00:00 2001 From: wreise Date: Thu, 8 Jul 2021 21:13:36 +0200 Subject: [PATCH 2/2] Fix non-mepty check --- gtda/homology/simplicial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtda/homology/simplicial.py b/gtda/homology/simplicial.py index 2cb624a7e..81cd2199f 100644 --- a/gtda/homology/simplicial.py +++ b/gtda/homology/simplicial.py @@ -616,7 +616,7 @@ def fit(self, X, y=None): self.effective_weight_params_.update({"n_neighbors": 3, "r": 2}) else: key = "general" - if self.weight_params is not None: + if self.weight_params: self.effective_weight_params_.update(self.weight_params) validate_params(self.effective_weight_params_, _AVAILABLE_RIPS_WEIGHTS[key])