Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX make sure that FunctionSampler will bypass validation in fit #790

Merged
merged 2 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ Bug fixes
the targeted class.
:pr:`769` by :user:`Guillaume Lemaitre <glemaitre>`.

- Fix a bug in :class:`imblearn.FunctionSampler` where validation was performed
even with `validate=False` when calling `fit`.
:pr:`790` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

Expand Down
32 changes: 32 additions & 0 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,38 @@ def __init__(self, *, func=None, accept_sparse=True, kw_args=None,
self.kw_args = kw_args
self.validate = validate

def fit(self, X, y):
"""Check inputs and statistics of the sampler.

You should use ``fit_resample`` in all cases.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Data array.

y : array-like of shape (n_samples,)
Target array.

Returns
-------
self : object
Return the instance itself.
"""
# we need to overwrite SamplerMixin.fit to bypass the validation
if self.validate:
check_classification_targets(y)
X, y, _ = self._check_X_y(
X, y, accept_sparse=self.accept_sparse
)

self.sampling_strategy_ = check_sampling_strategy(
self.sampling_strategy, y, self._sampling_type
)

return self

def fit_resample(self, X, y):
"""Resample the dataset.

Expand Down
15 changes: 15 additions & 0 deletions imblearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ def dummy_sampler(X, y):
y_pred = pipeline.fit(X, y).predict(X)

assert type_of_target(y_pred) == 'continuous'


def test_function_resampler_fit():
# Check that the validation is bypass when calling `fit`
# Non-regression test for:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/782
X = np.array([[1, np.nan], [2, 3], [np.inf, 4]])
y = np.array([0, 1, 1])

def func(X, y):
return X[:1], y[:1]

sampler = FunctionSampler(func=func, validate=False)
sampler.fit(X, y)
sampler.fit_resample(X, y)