Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH duck-typing scikit-learn estimator instead of inheritance #858

Merged
merged 53 commits into from
Jan 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d17b6b5
add duck-type check for KNeighbors-likeness
Sep 2, 2021
379ea7e
removal ofKNeighborsMixin type check
Sep 2, 2021
8790628
Added _is_neighbors_object() private validation function
Sep 9, 2021
e997e23
Addded pep8lank lines
Sep 9, 2021
94b0725
change isinstance check for SVM estimator to simply clone the estimat…
Sep 10, 2021
9fbf360
remove explicit class-check for KMeans estimator
Sep 13, 2021
f736879
remove explicit class check for KNeighborsClassifier
Sep 13, 2021
fcb118e
remove explicit class check for KNeighborsClassifier in CondensedNear…
Sep 13, 2021
a4e959c
remove explicit class check for ClassifierMixin in InstanceHardnessTh…
Sep 13, 2021
65ae4fd
PEP 8 issue fix
Sep 13, 2021
5b76d49
PEP 8 issue fix - line break before operator
Sep 13, 2021
8284b70
PEP 8 issue fix - no more line break before operator
Sep 13, 2021
e97ae36
Undo changes to _instance_hardness_threshold
Sep 15, 2021
495ec27
revert OneSidedSelection changes
Sep 16, 2021
10456f5
Undo changes to CondensedNearestNeighbour
Sep 24, 2021
93200e1
example NearestNeighbors test
Sep 29, 2021
f104057
Use sklearn.base.clone to validate NN object and throw error
Sep 29, 2021
b82e4d9
undo last commit, and raise nn_object TypeError
Sep 29, 2021
70b6778
remove unused imports
Sep 29, 2021
c67c775
Add test for cuml ADASYN
Oct 4, 2021
010f4d5
Updated check_neighbors_object docstring and error type
Oct 4, 2021
178d0f0
Updated tests
Oct 5, 2021
9868d0f
Merge branch 'master' into ducktype-check_neighbors
NV-jpt Oct 29, 2021
2e1ee17
Merge remote-tracking branch 'origin/master' into pr/NV-jpt/858
glemaitre Dec 7, 2021
8889cfd
duck-typing svm
glemaitre Dec 7, 2021
5e875a0
TST add couple of tests
glemaitre Dec 7, 2021
9545172
better error message with duck-typing
glemaitre Dec 7, 2021
29a414b
iter
glemaitre Dec 7, 2021
12991ba
CI let's try a run on CircleCI with cuML
glemaitre Dec 7, 2021
e24ee06
iter
glemaitre Dec 7, 2021
525002f
iter
glemaitre Dec 7, 2021
189f0e9
iter
glemaitre Dec 7, 2021
2cbe273
iter
glemaitre Dec 7, 2021
cc7fae9
iter
glemaitre Dec 7, 2021
29e4619
ITER
glemaitre Dec 7, 2021
a098e84
iter
glemaitre Dec 7, 2021
0aa328e
iter
glemaitre Dec 7, 2021
8cce474
dbg
glemaitre Dec 7, 2021
8d4ff31
dbg
glemaitre Dec 8, 2021
0ceacfb
MNT move to circleci
glemaitre Dec 8, 2021
ee6b7b0
iter
glemaitre Dec 8, 2021
d089b7b
iter
glemaitre Jan 15, 2022
ac7e00a
Merge remote-tracking branch 'origin/master' into pr/NV-jpt/858
glemaitre Jan 15, 2022
d815e2d
create custom NN class
glemaitre Jan 15, 2022
964d082
add test no dependent on cupy
glemaitre Jan 16, 2022
99d5206
update documentation
glemaitre Jan 16, 2022
48d1fd5
iter
glemaitre Jan 16, 2022
18b6057
iter
glemaitre Jan 16, 2022
76fbd59
revert redirector
glemaitre Jan 16, 2022
8fa97ed
add changelog
glemaitre Jan 16, 2022
615a2bf
remove duplicated test
glemaitre Jan 16, 2022
b75b77d
make testing function private
glemaitre Jan 16, 2022
b627cf1
iter
glemaitre Jan 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/azure/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set -x
UNAMESTR=`uname`

make_conda() {
conda update -yq conda
TO_INSTALL="$@"
if [[ "$DISTRIB" == *"mamba"* ]]; then
mamba create -n $VIRTUALENV --yes $TO_INSTALL
Expand Down
3 changes: 2 additions & 1 deletion doc/developers_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ which accepts arrays, matrices, or sparse matrices as arguments, the following
should be used when applicable.

- :func:`check_neighbors_object`: Check the objects is consistent to be a NN.
- :func:`check_target_type`: Check the target types to be conform to the current sam plers.
- :func:`check_target_type`: Check the target types to be conform to the current
samplers.
- :func:`check_sampling_strategy`: Checks that sampling target is onsistent with
the type and return a dictionary containing each targeted class with its
corresponding number of pixel.
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v0.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ Version 0.10.0 (ongoing)

Changelog
---------

Enhancements
............

- Add support to accept compatible `NearestNeighbors` objects by only
duck-typing. For instance, it allows to accept cuML instances.
:pr:`858` by :user:`NV-jpt <NV-jpt>` and
:user:`Guillaume Lemaitre <glemaitre>`.
16 changes: 11 additions & 5 deletions imblearn/over_sampling/_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,17 @@ class ADASYN(BaseOverSampler):
{random_state}

n_neighbors : int or estimator object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

Expand Down Expand Up @@ -124,7 +131,6 @@ def _validate_estimator(self):
self.nn_ = check_neighbors_object(
"n_neighbors", self.n_neighbors, additional_neighbor=1
)
self.nn_.set_params(**{"n_jobs": self.n_jobs})

def _fit_resample(self, X, y):
self._validate_estimator()
Expand Down
45 changes: 33 additions & 12 deletions imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,17 @@ class SMOTE(BaseSMOTE):
{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

Expand Down Expand Up @@ -367,10 +374,17 @@ class SMOTENC(SMOTE):
{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

Expand Down Expand Up @@ -636,10 +650,17 @@ class SMOTEN(SMOTE):
{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

Expand Down
15 changes: 11 additions & 4 deletions imblearn/over_sampling/_smote/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,17 @@ class KMeansSMOTE(BaseSMOTE):
{random_state}

k_neighbors : int or object, default=2
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

Expand Down
75 changes: 53 additions & 22 deletions imblearn/over_sampling/_smote/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from sklearn.utils import _safe_indexing

from ..base import BaseOverSampler
from ...exceptions import raise_isinstance_error
from ...utils import check_neighbors_object
from ...utils import Substitution
from ...utils._docstring import _n_jobs_docstring
Expand Down Expand Up @@ -48,18 +47,32 @@ class BorderlineSMOTE(BaseSMOTE):
{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

m_neighbors : int or object, default=10
If int, number of nearest neighbours to use to determine if a minority
sample is in danger. If object, an estimator that inherits
from :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used
to find the m_neighbors.
The nearest neighbors used to determine if a minority sample is in
"danger". You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

kind : {{"borderline-1", "borderline-2"}}, default='borderline-1'
The type of SMOTE algorithm to use one of the following options:
Expand Down Expand Up @@ -155,7 +168,6 @@ def _validate_estimator(self):
self.nn_m_ = check_neighbors_object(
"m_neighbors", self.m_neighbors, additional_neighbor=1
)
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
if self.kind not in ("borderline-1", "borderline-2"):
raise ValueError(
f'The possible "kind" of algorithm are '
Expand Down Expand Up @@ -263,21 +275,37 @@ class SVMSMOTE(BaseSMOTE):
{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.
The nearest neighbors used to define the neighborhood of samples to use
to generate the synthetic samples. You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

{n_jobs}

m_neighbors : int or object, default=10
If int, number of nearest neighbours to use to determine if a minority
sample is in danger. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the m_neighbors.
The nearest neighbors used to determine if a minority sample is in
"danger". You can pass:

- an `int` corresponding to the number of neighbors to use. A
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
case.
- an instance of a compatible nearest neighbors algorithm that should
implement both methods `kneighbors` and `kneighbors_graph`. For
instance, it could correspond to a
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
any compatible class.

svm_estimator : estimator object, default=SVC()
A parametrized :class:`~sklearn.svm.SVC` classifier can be passed.
A scikit-learn compatible estimator can be passed but it is required
to expose a `support_` fitted attribute.

out_step : float, default=0.5
Step size when extrapolating.
Expand Down Expand Up @@ -381,14 +409,11 @@ def _validate_estimator(self):
self.nn_m_ = check_neighbors_object(
"m_neighbors", self.m_neighbors, additional_neighbor=1
)
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})

if self.svm_estimator is None:
self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state)
elif isinstance(self.svm_estimator, SVC):
self.svm_estimator_ = clone(self.svm_estimator)
else:
raise_isinstance_error("svm_estimator", [SVC], self.svm_estimator)
self.svm_estimator_ = clone(self.svm_estimator)
Copy link
Contributor Author

@NV-jpt NV-jpt Sep 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change removes the explicit isinstance check for validating the SVC estimator in SVMSMOTE's _validate_estimator method; the estimator is instead validated by way of sklearn.base.clone(), similar to that of KMeansSMOTE.

This will enable the integration of SVM estimators that enforce the same API contract as sklearn instead of requiring the explicit class check (isinstance(svm_estimator, sklearn.svm.SVC))

As a motivating example, the integration of a GPU-accelerated SVC from cuML can offer significant performance gains when working with large datasets.

image

Hardware Specs for the Loose Benchmark:
Intel Xeon E5-2698, 2.2 GHz, 16-cores & NVIDIA V100 32 GB GPU

Benchmarking gist:
https://gist.github.com/NV-jpt/039a8d9c7d37365379faa1d7c7aafc5e


def _fit_resample(self, X, y):
self._validate_estimator()
Expand All @@ -403,6 +428,12 @@ def _fit_resample(self, X, y):
X_class = _safe_indexing(X, target_class_indices)

self.svm_estimator_.fit(X, y)
if not hasattr(self.svm_estimator_, "support_"):
raise RuntimeError(
"`svm_estimator` is required to exposed a `support_` fitted "
"attribute. Such estimator belongs to the familly of Support "
"Vector Machine."
)
support_index = self.svm_estimator_.support_[
y[self.svm_estimator_.support_] == class_sample
]
Expand Down
14 changes: 0 additions & 14 deletions imblearn/over_sampling/_smote/tests/test_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
# License: MIT

import numpy as np
import pytest

from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_array_equal
from sklearn.neighbors import NearestNeighbors

from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import SVMSMOTE
from imblearn.over_sampling import BorderlineSMOTE


RND_SEED = 0
Expand Down Expand Up @@ -153,14 +150,3 @@ def test_sample_regular_with_nn():
)
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
assert_array_equal(y_resampled, y_gt)


@pytest.mark.parametrize(
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
)
def test_smote_m_neighbors(smote):
# check that m_neighbors is properly set. Regression test for:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
_ = smote.fit_resample(X, Y)
assert smote.nn_k_.n_neighbors == 6
assert smote.nn_m_.n_neighbors == 11
10 changes: 10 additions & 0 deletions imblearn/over_sampling/_smote/tests/test_svm_smote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors
from sklearn.svm import SVC

Expand Down Expand Up @@ -54,3 +55,12 @@ def test_svm_smote(data):

assert_allclose(X_res_1, X_res_2)
assert_array_equal(y_res_1, y_res_2)


def test_svm_smote_not_svm(data):
"""Check that we raise a proper error if passing an estimator that does not
expose a `support_` fitted attribute."""

err_msg = "`svm_estimator` is required to exposed a `support_` fitted attribute."
with pytest.raises(RuntimeError, match=err_msg):
SVMSMOTE(svm_estimator=LogisticRegression()).fit_resample(*data)
6 changes: 5 additions & 1 deletion imblearn/over_sampling/tests/test_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ def test_ada_fit_resample_nn_obj():
{"sampling_strategy": {0: 9, 1: 12}},
"No samples will be generated.",
),
({"n_neighbors": "rnd"}, "has to be one of"),
(
{"n_neighbors": "rnd"},
"n_neighbors must be an interger or an object compatible with the "
"KNeighborsMixin API of scikit-learn",
),
],
)
def test_adasyn_error(adasyn_params, err_msg):
Expand Down
Loading