Skip to content

Commit

Permalink
ENH add get_qoi and get_qoi_names (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Feb 29, 2024
1 parent adbace6 commit aae93e2
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 9 deletions.
7 changes: 7 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ This is the full API documentation of the `sharp` package.
qoi.RankScoreQoI
qoi.TopKQoI

.. autosummary::
:toctree: _generated/
:template: function.rst

qoi.get_qoi
qoi.get_qoi_names


:mod:`sharp.visualization`
--------------------------
Expand Down
6 changes: 4 additions & 2 deletions sharp/qoi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
RankQoI,
RankScoreQoI,
TopKQoI,
QOI_OBJECTS,
get_qoi,
get_qoi_names,
)

__all__ = [
Expand All @@ -19,5 +20,6 @@
"RankQoI",
"RankScoreQoI",
"TopKQoI",
"QOI_OBJECTS",
"get_qoi",
"get_qoi_names",
]
65 changes: 64 additions & 1 deletion sharp/qoi/_qoi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from .base import BaseQoI, BaseRankQoI


Expand Down Expand Up @@ -138,11 +139,73 @@ def _calculate(self, rows1, rows2):
return (self.estimate(rows1) - self.estimate(rows2)).mean()


QOI_OBJECTS = {
_QOI_OBJECTS = {
"diff": DiffQoI,
"flip": FlipQoI,
"likelihood": LikelihoodQoI,
"rank": RankQoI,
"rank_score": RankScoreQoI,
"top_k": TopKQoI,
}


def get_qoi_names():
"""Get the names of all available quantities of interest.
These names can be passed to :func:`~sharp.qoi.get_qoi` to
retrieve the QoI object.
Returns
-------
list of str
Names of all available quantities of interest.
Examples
--------
>>> from sharp.qoi import get_qoi_names
>>> all_qois = get_qoi_names()
>>> type(all_qois)
<class 'list'>
>>> all_qois[:3]
['diff', 'flip', 'likelihood']
>>> "ranking" in all_qois
True
"""
return sorted(_QOI_OBJECTS.keys())


def get_qoi(qoi):
"""Get a quantity of interest from string.
:func:`~sharp.qoi.get_qoi_names` can be used to retrieve the names
of all available quantities of interest.
Parameters
----------
qoi : str, callable or None
Quantity of interest as string. If callable it is returned as is.
If None, returns None.
Returns
-------
quantity : callable
The quantity of interest.
Notes
-----
When passed a string, this function always returns a copy of the scorer
object. Calling `get_qoi` twice for the same scorer results in two
separate QoI objects.
"""
if isinstance(qoi, str):
try:
quantity = copy.deepcopy(_QOI_OBJECTS[qoi])
except KeyError:
raise ValueError(
"%r is not a valid scoring value. "
"Use sklearn.metrics.get_scorer_names() "
"to get valid options." % qoi
)
else:
quantity = qoi
return quantity
6 changes: 3 additions & 3 deletions sharp/tests/test_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sklearn.utils import check_random_state
from sharp import ShaRP
from sharp.qoi import QOI_OBJECTS
from sharp.qoi import get_qoi
from sharp._measures import MEASURES

# Set up some envrionment variables
Expand All @@ -12,10 +12,10 @@
rng = check_random_state(RNG_SEED)

rank_qois_str = ["rank", "rank_score", "top_k"]
rank_qois_obj = [QOI_OBJECTS[qoi] for qoi in rank_qois_str]
rank_qois_obj = [get_qoi(qoi) for qoi in rank_qois_str]

clf_qois_str = ["diff", "flip", "likelihood"]
clf_qois_obj = [QOI_OBJECTS[qoi] for qoi in clf_qois_str]
clf_qois_obj = [get_qoi(qoi) for qoi in clf_qois_str]

measures = list(MEASURES.keys())

Expand Down
6 changes: 3 additions & 3 deletions sharp/utils/_checks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from sklearn.utils.validation import check_array, _get_feature_names

from sharp.qoi import QOI_OBJECTS
from sharp.qoi import get_qoi
from sharp._measures import MEASURES


Expand Down Expand Up @@ -53,7 +53,7 @@ def check_qoi(qoi, target_function=None, X=None):
msg = "If `qoi` is of type `str`, `target_function` cannot be None."
raise TypeError(msg)

if QOI_OBJECTS[qoi]._qoi_type == "rank":
if get_qoi(qoi)._qoi_type == "rank":
# Add dataset to list of parameters if QoI is rank-based
params["X"] = X

Expand All @@ -72,5 +72,5 @@ def check_qoi(qoi, target_function=None, X=None):
else:
return qoi

qoi = QOI_OBJECTS[qoi](**params)
qoi = get_qoi(qoi)(**params)
return qoi

0 comments on commit aae93e2

Please sign in to comment.