diff --git a/gtda/homology/tests/test_simplicial.py b/gtda/homology/tests/test_simplicial.py index fbc2baa8c..c9d2926aa 100644 --- a/gtda/homology/tests/test_simplicial.py +++ b/gtda/homology/tests/test_simplicial.py @@ -131,6 +131,15 @@ def test_wrp_params(): wrp.fit_transform(X_pc) +def test_wrp_metric_params(): + def metric(x, y, **kwargs): + return np.linalg.norm(x - y) + + metric_params = {"parameter": 0.} + wrp = WeightedRipsPersistence(metric=metric, metric_params=metric_params) + wrp.fit_transform(X_pc) + + def test_wrp_not_fitted(): wrp = WeightedRipsPersistence() diff --git a/gtda/utils/validation.py b/gtda/utils/validation.py index 32cfa6872..37929b140 100644 --- a/gtda/utils/validation.py +++ b/gtda/utils/validation.py @@ -138,7 +138,10 @@ def _validate_params_single(_parameter, _reference, _name): ref_type = _validate_params_single(parameter, reference, name) if ref_type: ref_of = reference.get('of', None) - if ref_type == dict: + if ref_of is None: + # if ref_of is None, the elements are not to be validated + continue + elif ref_type == dict: _validate_params(parameter, ref_of, rec_name=name) else: # List, tuple or ndarray type for i, parameter_elem in enumerate(parameter):