From f7c5c5675addfa638b443d377550ea14b8b4c962 Mon Sep 17 00:00:00 2001 From: wreise Date: Mon, 5 Jul 2021 11:30:14 +0200 Subject: [PATCH 1/5] Treat the case when ref_of is None --- gtda/utils/validation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtda/utils/validation.py b/gtda/utils/validation.py index 32cfa6872..ba656e8d1 100644 --- a/gtda/utils/validation.py +++ b/gtda/utils/validation.py @@ -138,7 +138,9 @@ def _validate_params_single(_parameter, _reference, _name): ref_type = _validate_params_single(parameter, reference, name) if ref_type: ref_of = reference.get('of', None) - if ref_type == dict: + if ref_of is None: # if ref_of is None, the elements are not to be validated + continue + elif ref_type == dict: _validate_params(parameter, ref_of, rec_name=name) else: # List, tuple or ndarray type for i, parameter_elem in enumerate(parameter): From 7357c2bdc620b504eadb6d0a53d9ec98614e6a2d Mon Sep 17 00:00:00 2001 From: wreise Date: Mon, 5 Jul 2021 11:31:30 +0200 Subject: [PATCH 2/5] Linting --- gtda/utils/validation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtda/utils/validation.py b/gtda/utils/validation.py index ba656e8d1..37929b140 100644 --- a/gtda/utils/validation.py +++ b/gtda/utils/validation.py @@ -138,7 +138,8 @@ def _validate_params_single(_parameter, _reference, _name): ref_type = _validate_params_single(parameter, reference, name) if ref_type: ref_of = reference.get('of', None) - if ref_of is None: # if ref_of is None, the elements are not to be validated + if ref_of is None: + # if ref_of is None, the elements are not to be validated continue elif ref_type == dict: _validate_params(parameter, ref_of, rec_name=name) From 270e8f13b000aceca4068103d2f7f2f479f805af Mon Sep 17 00:00:00 2001 From: wreise Date: Wed, 7 Jul 2021 21:43:00 +0200 Subject: [PATCH 3/5] Add test for the metric_params --- gtda/homology/tests/test_simplicial.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gtda/homology/tests/test_simplicial.py b/gtda/homology/tests/test_simplicial.py index fbc2baa8c..4c8c10e00 100644 --- a/gtda/homology/tests/test_simplicial.py +++ b/gtda/homology/tests/test_simplicial.py @@ -131,6 +131,15 @@ def test_wrp_params(): wrp.fit_transform(X_pc) +def test_wrp_metric_params(): + def metric(x, y, parameter): + return np.linalg.norm(x-y) + + metric_params = {"parameter": 0.} + wrp = WeightedRipsPersistence(metric=metric, metric_params=metric_params) + wrp.fit_transform(X_pc) + + def test_wrp_not_fitted(): wrp = WeightedRipsPersistence() From b8b22e6433e6d25f9b8d63d721f6012453c122cf Mon Sep 17 00:00:00 2001 From: Umberto Lupo <46537483+ulupo@users.noreply.github.com> Date: Thu, 8 Jul 2021 15:43:16 +0200 Subject: [PATCH 4/5] Fix linting --- gtda/homology/tests/test_simplicial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtda/homology/tests/test_simplicial.py b/gtda/homology/tests/test_simplicial.py index 4c8c10e00..c7eb2d835 100644 --- a/gtda/homology/tests/test_simplicial.py +++ b/gtda/homology/tests/test_simplicial.py @@ -134,7 +134,7 @@ def test_wrp_params(): def test_wrp_metric_params(): def metric(x, y, parameter): return np.linalg.norm(x-y) - + metric_params = {"parameter": 0.} wrp = WeightedRipsPersistence(metric=metric, metric_params=metric_params) wrp.fit_transform(X_pc) From 75e74ce4f93c2d8f114c090797ad33138ef6d9a4 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Thu, 8 Jul 2021 17:53:15 +0200 Subject: [PATCH 5/5] Small fix in test_wrp_metric_params --- gtda/homology/tests/test_simplicial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtda/homology/tests/test_simplicial.py b/gtda/homology/tests/test_simplicial.py index c7eb2d835..c9d2926aa 100644 --- a/gtda/homology/tests/test_simplicial.py +++ b/gtda/homology/tests/test_simplicial.py @@ -132,8 +132,8 @@ def test_wrp_params(): def test_wrp_metric_params(): - def metric(x, y, parameter): - return np.linalg.norm(x-y) + def metric(x, y, **kwargs): + return np.linalg.norm(x - y) metric_params = {"parameter": 0.} wrp = WeightedRipsPersistence(metric=metric, metric_params=metric_params)