diff --git a/gtda/homology/simplicial.py b/gtda/homology/simplicial.py index 18702913..81cd2199 100644 --- a/gtda/homology/simplicial.py +++ b/gtda/homology/simplicial.py @@ -432,7 +432,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): is a parameter (see `weight_params`). If a callable, it must return non-negative 1D arrays. - weight_params : dict, optional, default: ``None`` + weight_params : dict, optional, default: ``{}`` Additional parameters for the weighted filtration. ``"p"`` determines the power to be used in computing edge weights from vertex weights. It can be one of ``1``, ``2`` or ``np.inf`` and defaults to ``1``. If @@ -525,7 +525,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): "of": {"type": int, "in": Interval(0, np.inf, closed="left")} }, "weights": {"type": (str, FunctionType)}, - "weight_params": {"type": (dict, type(None))}, + "weight_params": {"type": dict}, "collapse_edges": {"type": bool}, "coeff": {"type": int, "in": Interval(2, np.inf, closed="left")}, "max_edge_weight": {"type": Real}, @@ -534,7 +534,7 @@ class WeightedRipsPersistence(BaseEstimator, TransformerMixin, PlotterMixin): } def __init__(self, metric="euclidean", metric_params={}, - homology_dimensions=(0, 1), weights="DTM", weight_params=None, + homology_dimensions=(0, 1), weights="DTM", weight_params={}, collapse_edges=False, coeff=2, max_edge_weight=np.inf, infinity_values=None, reduced_homology=True, n_jobs=None): self.metric = metric @@ -616,7 +616,7 @@ def fit(self, X, y=None): self.effective_weight_params_.update({"n_neighbors": 3, "r": 2}) else: key = "general" - if self.weight_params is not None: + if self.weight_params: self.effective_weight_params_.update(self.weight_params) validate_params(self.effective_weight_params_, _AVAILABLE_RIPS_WEIGHTS[key])