Skip to content

Commit

Permalink
Allow using a relative score_threshold$
Browse files Browse the repository at this point in the history
This allows to stop selection early when identical points are fed to
FPS or another selector by using something like score_threshold=1e-12
and score_threshold_kind="relative".
  • Loading branch information
Luthaf committed May 16, 2022
1 parent 18bccdf commit 7b9b775
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 6 deletions.
41 changes: 35 additions & 6 deletions skcosmo/_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class GreedySelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -86,6 +92,7 @@ def __init__(
selection_type,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -94,6 +101,12 @@ def __init__(
self.selection_type = selection_type
self.n_to_select = n_to_select
self.score_threshold = score_threshold
self.score_threshold_kind = score_threshold_kind
self._first_score = None
if self.score_threshold_kind not in ["relative", "absolute"]:
raise ValueError(
"invalid score_threshold_kind, expected one of 'relative' or 'absolute'"
)

self.full = full
self.progress_bar = progress_bar
Expand Down Expand Up @@ -352,14 +365,22 @@ def _continue_greedy_search(self, X, y, n_to_select):
self.selected_idx_[: self.n_selected_] = old_idx

def _get_best_new_selection(self, scorer, X, y):

scores = scorer(X, y)

amax = np.argmax(scores)
if self.score_threshold is not None and scores[amax] < self.score_threshold:
return None
else:
return amax
max_score_idx = np.argmax(scores)
if self._first_score is None:
self._first_score = scores[max_score_idx]

if self.score_threshold is not None:
if self.score_threshold_kind == "absolute":
if scores[max_score_idx] < self.score_threshold:
return None

if self.score_threshold_kind == "relative":
if scores[max_score_idx] / self._first_score < self.score_threshold:
return None

return max_score_idx

def _update_post_selection(self, X, y, last_selected):
"""
Expand Down Expand Up @@ -448,6 +469,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -461,6 +483,7 @@ def __init__(
selection_type=selection_type,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -639,6 +662,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -653,6 +677,7 @@ def __init__(
selection_type=selection_type,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -848,6 +873,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -858,6 +884,7 @@ def __init__(
selection_type=selection_type,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -1009,6 +1036,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -1026,6 +1054,7 @@ def __init__(
selection_type=selection_type,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down
32 changes: 32 additions & 0 deletions skcosmo/feature_selection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class FPS(_FPS):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand All @@ -57,6 +63,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -66,6 +73,7 @@ def __init__(
initialize=initialize,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -98,6 +106,12 @@ class PCovFPS(_PCovFPS):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand All @@ -124,6 +138,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -134,6 +149,7 @@ def __init__(
initialize=initialize,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -168,6 +184,12 @@ class CUR(_CUR):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -198,6 +220,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -209,6 +232,7 @@ def __init__(
tolerance=tolerance,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -247,6 +271,12 @@ class PCovCUR(_PCovCUR):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -282,6 +312,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -294,6 +325,7 @@ def __init__(
tolerance=tolerance,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down
32 changes: 32 additions & 0 deletions skcosmo/sample_selection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class FPS(_FPS):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -59,6 +65,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -68,6 +75,7 @@ def __init__(
initialize=initialize,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -100,6 +108,12 @@ class PCovFPS(_PCovFPS):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -129,6 +143,7 @@ def __init__(
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -139,6 +154,7 @@ def __init__(
initialize=initialize,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -172,6 +188,12 @@ class CUR(_CUR):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -204,6 +226,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -215,6 +238,7 @@ def __init__(
tolerance=tolerance,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down Expand Up @@ -254,6 +278,12 @@ class PCovCUR(_PCovCUR):
n_to_select is chosen. Otherwise will stop when the score falls below the threshold.
Stored in :py:attr:`self.score_threshold`.
score_threshold_kind : str, default="absolute"
How to interpret the ``score_threshold``. "absolute" means the score is compared
directly to the threshold, and "relative" means the ratio of the current score
to the first score is compared to the threshold.
Stored in :py:attr:`self.score_threshold_kind`.
progress_bar: bool, default=False
option to use `tqdm <https://tqdm.github.io/>`_
progress bar to monitor selections. Stored in :py:attr:`self.report_progress`.
Expand Down Expand Up @@ -291,6 +321,7 @@ def __init__(
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_kind="absolute",
progress_bar=False,
full=False,
random_state=0,
Expand All @@ -303,6 +334,7 @@ def __init__(
tolerance=tolerance,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_kind=score_threshold_kind,
progress_bar=progress_bar,
full=full,
random_state=random_state,
Expand Down

0 comments on commit 7b9b775

Please sign in to comment.