Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #6

Merged
merged 14 commits into from
Nov 17, 2023
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"myst-parser" # parse md and rst files
]

dev_packages = [
"pytest"
]

setuptools.setup(
name=NAME,
version=version["__version__"],
Expand All @@ -61,6 +65,7 @@
install_requires=base_packages + plotting_packages,
extras_require={
"docs": base_packages + plotting_packages + doc_packages,
"dev": base_packages + plotting_packages + doc_packages + dev_packages
},
include_package_data=True,
license="MIT",
Expand Down
2 changes: 1 addition & 1 deletion shapiq/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2"
__version__ = "0.0.3"
7 changes: 7 additions & 0 deletions shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""This module contains the approximators to estimate the Shapley interaction values."""
from .permutation.sii import PermutationSamplingSII
from .permutation.sti import PermutationSamplingSTI

__all__ = [
"PermutationSamplingSII",
]
182 changes: 182 additions & 0 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""This module contains the base approximator classes for the shapiq package."""
import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Union, Optional

import numpy as np


AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"}


@dataclass
class InteractionValues:
""" This class contains the interaction values as estimated by an approximator.

Attributes:
values: The interaction values of the model. Mapping from order to the interaction values.
index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and
'FSI'.
order: The order of the approximation.
"""
values: dict[int, np.ndarray]
index: str
order: int

def __post_init__(self) -> None:
"""Checks if the index is valid."""
if self.index not in ["SII", "nSII", "STI", "FSI"]:
raise ValueError(
f"Index {self.index} is not valid. "
f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'."
)
if self.order < 1 or self.order != max(self.values.keys()):
raise ValueError(
f"Order {self.order} is not valid. "
f"Order should be a positive integer equal to the maximum key of the values."
)


class Approximator(ABC):
"""This class is the base class for all approximators.

Approximators are used to estimate the interaction values of a model or any value function.
Different approximators can be used to estimate different interaction indices. Some can be used
to estimate all indices.

Args:
n: The number of players.
max_order: The interaction order of the approximation.
index: The interaction index to be estimated. Available indices are 'SII', 'nSII', 'STI',
and 'FSI'.
top_order: If True, the approximation is performed only for the top order interactions. If
False, the approximation is performed for all orders up to the specified order.
random_state: The random state to use for the approximation. Defaults to None.

Attributes:
n: The number of players.
N: The set of players (starting from 0 to n - 1).
max_order: The interaction order of the approximation.
index: The interaction index to be estimated.
top_order: If True, the approximation is performed only for the top order interactions. If
False, the approximation is performed for all orders up to the specified order.
min_order: The minimum order of the approximation. If top_order is True, min_order is equal
to max_order. Otherwise, min_order is equal to 1.
"""

def __init__(
self,
n: int,
max_order: int,
index: str,
top_order: bool,
random_state: Optional[int] = None
) -> None:
"""Initializes the approximator."""
self.index: str = index
if self.index not in AVAILABLE_INDICES:
raise ValueError(
f"Index {self.index} is not valid. "
f"Available indices are {AVAILABLE_INDICES}."
)
self.n: int = n
self.N: set = set(range(self.n))
self.max_order: int = max_order
self.top_order: bool = top_order
self.min_order: int = self.max_order if self.top_order else 1
self._random_state: Optional[int] = random_state
self._rng: Optional[np.random.Generator] = np.random.default_rng(seed=self._random_state)

@abstractmethod
def approximate(
self,
budget: int,
game: Callable[[Union[set, tuple]], float],
*args, **kwargs
) -> InteractionValues:
"""Approximates the interaction values. Abstract method that needs to be implemented for
each approximator.

Args:
budget: The budget for the approximation.
game: The game function.

Returns:
The interaction values.

Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError

def _init_result(self, dtype=float) -> dict[int, np.ndarray]:
"""Initializes the result dictionary mapping from order to the interaction values.
For order 1 the interaction values are of shape (n,) for order 2 of shape (n, n) and so on.

Args:
dtype: The data type of the result dictionary values. Defaults to float.

Returns:
The result dictionary.
"""
result = {s: self._get_empty_array(self.n, s, dtype=dtype)
for s in self._order_iterator}
return result

@staticmethod
def _get_empty_array(n: int, order: int, dtype=float) -> np.ndarray:
"""Returns an empty array of the appropriate shape for the given order.

Args:
n: The number of players.
order: The order of the array.
dtype: The data type of the array. Defaults to float.

Returns:
The empty array.
"""
return np.zeros(n ** order, dtype=dtype).reshape((n,) * order)

@property
def _order_iterator(self) -> range:
"""Returns an iterator over the orders of the approximation.

Returns:
The iterator.
"""
return range(self.min_order, self.max_order + 1)

def _finalize_result(self, result) -> InteractionValues:
"""Finalizes the result dictionary.

Args:
result: The result dictionary.

Returns:
The interaction values.
"""
return InteractionValues(result, self.index, self.max_order)

@staticmethod
def _smooth_with_epsilon(
interaction_results: Union[dict, np.ndarray],
eps=0.00001
) -> Union[dict, np.ndarray]:
"""Smooth the interaction results with a small epsilon to avoid numerical issues.

Args:
interaction_results: Interaction results.
eps: Small epsilon. Defaults to 0.00001.

Returns:
Union[dict, np.ndarray]: Smoothed interaction results.
"""
if not isinstance(interaction_results, dict):
interaction_results[np.abs(interaction_results) < eps] = 0
return copy.deepcopy(interaction_results)
interactions = {}
for interaction_order, interaction_values in interaction_results.items():
interaction_values[np.abs(interaction_values) < eps] = 0
interactions[interaction_order] = interaction_values
return copy.deepcopy(interactions)
10 changes: 10 additions & 0 deletions shapiq/approximator/permutation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""This module contains all permutation-based sampling algorithms to estimate SII/nSII and STI."""
from ._base import PermutationSampling
from .sii import PermutationSamplingSII
from .sti import PermutationSamplingSTI

__all__ = [
"PermutationSampling",
"PermutationSamplingSII",
"PermutationSamplingSTI",
]
78 changes: 78 additions & 0 deletions shapiq/approximator/permutation/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""This module contains the base permutation sampling algorithms to estimate SII/nSII and STI."""
from typing import Optional, Callable, Union

import numpy as np

from approximator._base import Approximator, InteractionValues


AVAILABLE_INDICES_PERMUTATION = {"SII", "nSII", "STI"}


class PermutationSampling(Approximator):
"""Permutation sampling approximator. This class contains the permutation sampling algorithm to
estimate SII/nSII and STI values.

Args:
n: The number of players.
max_order: The interaction order of the approximation.
index: The interaction index to be estimated. Available indices are 'SII', 'nSII', and
'STI'.
top_order: Whether to approximate only the top order interactions (`True`) or all orders up
to the specified order (`False`).

Attributes:
n (int): The number of players.
N (set): The set of players (starting from 0 to n - 1).
max_order (int): The interaction order of the approximation.
index (str): The interaction index to be estimated.
top_order (bool): Whether to approximate only the top order interactions or all orders up to
the specified order.
min_order (int): The minimum order of the approximation. If top_order is True, min_order is
equal to order. Otherwise, min_order is equal to 1.

"""

def __init__(
self,
n: int,
max_order: int,
index: str,
top_order: bool,
random_state: Optional[int] = None
) -> None:
if index not in AVAILABLE_INDICES_PERMUTATION:
raise ValueError(
f"Index {index} is not valid. "
f"Available indices are {AVAILABLE_INDICES_PERMUTATION}."
)
super().__init__(n, max_order, index, top_order, random_state)

def approximate(
self,
budget: int,
game: Callable[[Union[set, tuple]], float]
) -> InteractionValues:
"""Approximates the interaction values."""
raise NotImplementedError

@staticmethod
def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]:
"""Computes the number of iterations and the size of the last batch given the batch size and
the budget.

Args:
budget: The budget for the approximation.
batch_size: The size of the batch.
iteration_cost: The cost of a single iteration.

Returns:
int, int: The number of iterations and the size of the last batch.
"""
n_iterations = budget // (iteration_cost * batch_size)
last_batch_size = batch_size
remaining_budget = budget - n_iterations * iteration_cost * batch_size
if remaining_budget > 0 and remaining_budget // iteration_cost > 0:
last_batch_size = remaining_budget // iteration_cost
n_iterations += 1
return n_iterations, last_batch_size
Loading