diff --git a/setup.py b/setup.py index a91a8c09..21e3da8c 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,10 @@ "myst-parser" # parse md and rst files ] +dev_packages = [ + "pytest" +] + setuptools.setup( name=NAME, version=version["__version__"], @@ -61,6 +65,7 @@ install_requires=base_packages + plotting_packages, extras_require={ "docs": base_packages + plotting_packages + doc_packages, + "dev": base_packages + plotting_packages + doc_packages + dev_packages }, include_package_data=True, license="MIT", diff --git a/shapiq/__version__.py b/shapiq/__version__.py index 3b93d0be..27fdca49 100644 --- a/shapiq/__version__.py +++ b/shapiq/__version__.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/shapiq/approximator/__init__.py b/shapiq/approximator/__init__.py new file mode 100644 index 00000000..7e6e221d --- /dev/null +++ b/shapiq/approximator/__init__.py @@ -0,0 +1,7 @@ +"""This module contains the approximators to estimate the Shapley interaction values.""" +from .permutation.sii import PermutationSamplingSII +from .permutation.sti import PermutationSamplingSTI + +__all__ = [ + "PermutationSamplingSII", +] diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py new file mode 100644 index 00000000..49f4850e --- /dev/null +++ b/shapiq/approximator/_base.py @@ -0,0 +1,182 @@ +"""This module contains the base approximator classes for the shapiq package.""" +import copy +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Union, Optional + +import numpy as np + + +AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"} + + +@dataclass +class InteractionValues: + """ This class contains the interaction values as estimated by an approximator. + + Attributes: + values: The interaction values of the model. Mapping from order to the interaction values. + index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and + 'FSI'. + order: The order of the approximation. + """ + values: dict[int, np.ndarray] + index: str + order: int + + def __post_init__(self) -> None: + """Checks if the index is valid.""" + if self.index not in ["SII", "nSII", "STI", "FSI"]: + raise ValueError( + f"Index {self.index} is not valid. " + f"Available indices are 'SII', 'nSII', 'STI', and 'FSI'." + ) + if self.order < 1 or self.order != max(self.values.keys()): + raise ValueError( + f"Order {self.order} is not valid. " + f"Order should be a positive integer equal to the maximum key of the values." + ) + + +class Approximator(ABC): + """This class is the base class for all approximators. + + Approximators are used to estimate the interaction values of a model or any value function. + Different approximators can be used to estimate different interaction indices. Some can be used + to estimate all indices. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + index: The interaction index to be estimated. Available indices are 'SII', 'nSII', 'STI', + and 'FSI'. + top_order: If True, the approximation is performed only for the top order interactions. If + False, the approximation is performed for all orders up to the specified order. + random_state: The random state to use for the approximation. Defaults to None. + + Attributes: + n: The number of players. + N: The set of players (starting from 0 to n - 1). + max_order: The interaction order of the approximation. + index: The interaction index to be estimated. + top_order: If True, the approximation is performed only for the top order interactions. If + False, the approximation is performed for all orders up to the specified order. + min_order: The minimum order of the approximation. If top_order is True, min_order is equal + to max_order. Otherwise, min_order is equal to 1. + """ + + def __init__( + self, + n: int, + max_order: int, + index: str, + top_order: bool, + random_state: Optional[int] = None + ) -> None: + """Initializes the approximator.""" + self.index: str = index + if self.index not in AVAILABLE_INDICES: + raise ValueError( + f"Index {self.index} is not valid. " + f"Available indices are {AVAILABLE_INDICES}." + ) + self.n: int = n + self.N: set = set(range(self.n)) + self.max_order: int = max_order + self.top_order: bool = top_order + self.min_order: int = self.max_order if self.top_order else 1 + self._random_state: Optional[int] = random_state + self._rng: Optional[np.random.Generator] = np.random.default_rng(seed=self._random_state) + + @abstractmethod + def approximate( + self, + budget: int, + game: Callable[[Union[set, tuple]], float], + *args, **kwargs + ) -> InteractionValues: + """Approximates the interaction values. Abstract method that needs to be implemented for + each approximator. + + Args: + budget: The budget for the approximation. + game: The game function. + + Returns: + The interaction values. + + Raises: + NotImplementedError: If the method is not implemented. + """ + raise NotImplementedError + + def _init_result(self, dtype=float) -> dict[int, np.ndarray]: + """Initializes the result dictionary mapping from order to the interaction values. + For order 1 the interaction values are of shape (n,) for order 2 of shape (n, n) and so on. + + Args: + dtype: The data type of the result dictionary values. Defaults to float. + + Returns: + The result dictionary. + """ + result = {s: self._get_empty_array(self.n, s, dtype=dtype) + for s in self._order_iterator} + return result + + @staticmethod + def _get_empty_array(n: int, order: int, dtype=float) -> np.ndarray: + """Returns an empty array of the appropriate shape for the given order. + + Args: + n: The number of players. + order: The order of the array. + dtype: The data type of the array. Defaults to float. + + Returns: + The empty array. + """ + return np.zeros(n ** order, dtype=dtype).reshape((n,) * order) + + @property + def _order_iterator(self) -> range: + """Returns an iterator over the orders of the approximation. + + Returns: + The iterator. + """ + return range(self.min_order, self.max_order + 1) + + def _finalize_result(self, result) -> InteractionValues: + """Finalizes the result dictionary. + + Args: + result: The result dictionary. + + Returns: + The interaction values. + """ + return InteractionValues(result, self.index, self.max_order) + + @staticmethod + def _smooth_with_epsilon( + interaction_results: Union[dict, np.ndarray], + eps=0.00001 + ) -> Union[dict, np.ndarray]: + """Smooth the interaction results with a small epsilon to avoid numerical issues. + + Args: + interaction_results: Interaction results. + eps: Small epsilon. Defaults to 0.00001. + + Returns: + Union[dict, np.ndarray]: Smoothed interaction results. + """ + if not isinstance(interaction_results, dict): + interaction_results[np.abs(interaction_results) < eps] = 0 + return copy.deepcopy(interaction_results) + interactions = {} + for interaction_order, interaction_values in interaction_results.items(): + interaction_values[np.abs(interaction_values) < eps] = 0 + interactions[interaction_order] = interaction_values + return copy.deepcopy(interactions) diff --git a/shapiq/approximator/permutation/__init__.py b/shapiq/approximator/permutation/__init__.py new file mode 100644 index 00000000..c7067199 --- /dev/null +++ b/shapiq/approximator/permutation/__init__.py @@ -0,0 +1,10 @@ +"""This module contains all permutation-based sampling algorithms to estimate SII/nSII and STI.""" +from ._base import PermutationSampling +from .sii import PermutationSamplingSII +from .sti import PermutationSamplingSTI + +__all__ = [ + "PermutationSampling", + "PermutationSamplingSII", + "PermutationSamplingSTI", +] diff --git a/shapiq/approximator/permutation/_base.py b/shapiq/approximator/permutation/_base.py new file mode 100644 index 00000000..f16a8ae3 --- /dev/null +++ b/shapiq/approximator/permutation/_base.py @@ -0,0 +1,78 @@ +"""This module contains the base permutation sampling algorithms to estimate SII/nSII and STI.""" +from typing import Optional, Callable, Union + +import numpy as np + +from approximator._base import Approximator, InteractionValues + + +AVAILABLE_INDICES_PERMUTATION = {"SII", "nSII", "STI"} + + +class PermutationSampling(Approximator): + """Permutation sampling approximator. This class contains the permutation sampling algorithm to + estimate SII/nSII and STI values. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + index: The interaction index to be estimated. Available indices are 'SII', 'nSII', and + 'STI'. + top_order: Whether to approximate only the top order interactions (`True`) or all orders up + to the specified order (`False`). + + Attributes: + n (int): The number of players. + N (set): The set of players (starting from 0 to n - 1). + max_order (int): The interaction order of the approximation. + index (str): The interaction index to be estimated. + top_order (bool): Whether to approximate only the top order interactions or all orders up to + the specified order. + min_order (int): The minimum order of the approximation. If top_order is True, min_order is + equal to order. Otherwise, min_order is equal to 1. + + """ + + def __init__( + self, + n: int, + max_order: int, + index: str, + top_order: bool, + random_state: Optional[int] = None + ) -> None: + if index not in AVAILABLE_INDICES_PERMUTATION: + raise ValueError( + f"Index {index} is not valid. " + f"Available indices are {AVAILABLE_INDICES_PERMUTATION}." + ) + super().__init__(n, max_order, index, top_order, random_state) + + def approximate( + self, + budget: int, + game: Callable[[Union[set, tuple]], float] + ) -> InteractionValues: + """Approximates the interaction values.""" + raise NotImplementedError + + @staticmethod + def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]: + """Computes the number of iterations and the size of the last batch given the batch size and + the budget. + + Args: + budget: The budget for the approximation. + batch_size: The size of the batch. + iteration_cost: The cost of a single iteration. + + Returns: + int, int: The number of iterations and the size of the last batch. + """ + n_iterations = budget // (iteration_cost * batch_size) + last_batch_size = batch_size + remaining_budget = budget - n_iterations * iteration_cost * batch_size + if remaining_budget > 0 and remaining_budget // iteration_cost > 0: + last_batch_size = remaining_budget // iteration_cost + n_iterations += 1 + return n_iterations, last_batch_size diff --git a/shapiq/approximator/permutation/sii.py b/shapiq/approximator/permutation/sii.py new file mode 100644 index 00000000..27eaf743 --- /dev/null +++ b/shapiq/approximator/permutation/sii.py @@ -0,0 +1,119 @@ +"""This module implements the Permutation Sampling approximator for the SII (and nSII) index. + +# TODO add docstring +""" +from typing import Optional, Callable, Union + +import numpy as np + +from approximator._base import InteractionValues +from approximator.permutation import PermutationSampling +from utils import powerset + + +class PermutationSamplingSII(PermutationSampling): + """ Permutation Sampling approximator for the SII (and nSII) index. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + top_order: Whether to approximate only the top order interactions (`True`) or all orders up + to the specified order (`False`). + random_state: The random state to use for the permutation sampling. Defaults to `None`. + """ + + def __init__( + self, + n: int, + max_order: int, + top_order: bool, + random_state: Optional[int] = None + ) -> None: + super().__init__(n, max_order, 'SII', top_order, random_state) + self._iteration_cost: int = self._compute_iteration_cost() + + def _compute_iteration_cost(self) -> int: + """Computes the cost of performing a single iteration of the permutation sampling given + the order, the number of players, and the SII index. + + Returns: + int: The cost of a single iteration. + """ + iteration_cost: int = 0 + for s in self._order_iterator: + iteration_cost += (self.n - s + 1) * 2 ** s + return iteration_cost + + def approximate( + self, + budget: int, + game: Callable[[np.ndarray], np.ndarray], + batch_size: Optional[int] = None, + ) -> InteractionValues: + """Approximates the interaction values. + + Args: + budget: The budget for the approximation. + game: The game function as a callable that takes a set of players and returns the value. + batch_size: The size of the batch. If None, the batch size is set to 1. + + Returns: + InteractionValues: The estimated interaction values. + """ + + batch_size = 1 if batch_size is None else batch_size + + result = self._init_result() + counts = self._init_result(dtype=int) + + # compute the number of iterations and size of the last batch (can be smaller than original) + n_iterations, last_batch_size = self._get_n_iterations( + budget, batch_size, self._iteration_cost) + + # main permutation sampling loop + for iteration in range(1, n_iterations + 1): + + batch_size = batch_size if iteration != n_iterations else last_batch_size + + # create the permutations: a 2d matrix of shape (batch_size, n) where each row is a + # permutation of the players + permutations = np.tile(np.arange(self.n), (batch_size, 1)) + self._rng.permuted(permutations, axis=1, out=permutations) + n_permutations = permutations.shape[0] + n_subsets = n_permutations * self._iteration_cost + + # get all subsets to evaluate per iteration + subsets = np.zeros(shape=(n_subsets, self.n), dtype=bool) + subset_index = 0 + for permutation_id in range(n_permutations): + for order in self._order_iterator: + for k in range(self.n - order + 1): + subset = permutations[permutation_id, k:k + order] + previous_subset = permutations[permutation_id, :k] + for subset_ in powerset(subset, min_size=0): + subset_eval = np.concatenate((previous_subset, subset_)).astype(int) + subsets[subset_index, subset_eval] = True + subset_index += 1 + + # evaluate all subsets on the game + game_values: np.ndarray[float] = game(subsets) + + # update the interaction scores by iterating over the permutations again + subset_index = 0 + for permutation_id in range(n_permutations): + for order in self._order_iterator: + for k in range(self.n - order + 1): + subset = permutations[permutation_id, k:k + order] + counts[order][tuple(subset)] += 1 + # update the discrete derivative given the subset + for subset_ in powerset(subset, min_size=0): + game_value = game_values[subset_index] + update = game_value * (-1) ** (order - len(subset_)) + result[order][tuple(subset)] += update + subset_index += 1 + + # compute mean of interactions + for s in self._order_iterator: + result[s] = np.divide(result[s], counts[s], out=result[s], where=counts[s] != 0) + + return self._finalize_result(result) diff --git a/shapiq/approximator/permutation/sti.py b/shapiq/approximator/permutation/sti.py new file mode 100644 index 00000000..eed460be --- /dev/null +++ b/shapiq/approximator/permutation/sti.py @@ -0,0 +1,81 @@ +from typing import Callable, Union + +import numpy as np +from scipy.special import binom + +from approximator._base import InteractionValues +from approximator.permutation import PermutationSampling +from utils import powerset + + +class PermutationSamplingSTI(PermutationSampling): + + def __init__(self, n: int, max_order: int, top_order: bool): + super().__init__(n, max_order, "STI", top_order) + self._iteration_cost: int = self._compute_iteration_cost() + + def _compute_iteration_cost(self) -> int: + """Computes the cost of performing a single iteration of the permutation sampling given + the order, the number of players, and the STI index. + + Returns: + int: The cost of a single iteration. + """ + iteration_cost = int(binom(self.n, self.max_order) * 2 ** self.max_order) + return iteration_cost + + def _compute_lower_order_sti( + self, + game: Callable[[Union[set, tuple]], float], + result: dict[int, np.ndarray] + ) -> dict[int, np.ndarray]: + """Computes all lower order interactions for the STI index up to order max_order - 1. + + Args: + game: The game function as a callable that takes a set of players and returns the value. + result: The result dictionary. + + Returns: + The result dictionary. + """ + # run the game on the whole powerset of players up to order max_order - 1 + game_evaluations = {subset: game(subset) + for subset in powerset(self.N, max_size=self.max_order - 1, min_size=1)} + # inspect all parts of the subsets contained in the powerset and attribute their + # contribution to the corresponding interactions and order + for subset in powerset(self.N, max_size=self.max_order - 1, min_size=1): + subset = tuple(subset) + subset_size = len(subset) + for subset_part in powerset(subset): + subset_part_size = len(subset_part) + update = (-1) ** (subset_size - subset_part_size) * game_evaluations[subset_part] + result[subset_size][subset] += update + return result + + def approximate( + self, + budget: int, + game: Callable[[Union[set, tuple]], float] + ) -> InteractionValues: + raise NotImplementedError + result = self._init_result() + counts = self._init_result(dtype=int) + value_empty = game(set()) + value_full = game(self.N) + + # compute all lower order interactions if budget allows it + lower_order_cost = sum(int(binom(self.n, s)) for s in range(self.min_order, self.max_order)) + if self.max_order > 1 and budget >= lower_order_cost: + budget -= lower_order_cost + result = self._compute_lower_order_sti(game, result) + + # main permutation sampling loop + n_permutations = 0 + while budget >= self._iteration_cost: + budget -= self._iteration_cost + values = np.zeros(self.n + 1) # create array for the current permutation + values[0], values[-1] = value_empty, value_full # init values on the edges + permutation = np.random.permutation(self.n) # create random permutation + # TODO finish this + + return self._finalize_result(result) diff --git a/shapiq/explainer/__init__.py b/shapiq/explainer/__init__.py new file mode 100644 index 00000000..5eb1d8b9 --- /dev/null +++ b/shapiq/explainer/__init__.py @@ -0,0 +1 @@ +"""This module contains the explainer for the shapiq package.""" diff --git a/shapiq/explainer/_base.py b/shapiq/explainer/_base.py new file mode 100644 index 00000000..6b16806b --- /dev/null +++ b/shapiq/explainer/_base.py @@ -0,0 +1,50 @@ +"""This module contains the base explainer classes for the shapiq package.""" +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np + + +@dataclass +class Explanation: + """ This class contains the explanation of the model. + + Attributes: + interaction_values: The interaction values of the model. Mapping from order to the + interaction values. + explanation_type: The type of the explanation. Available types are 'SII', 'nSII', 'STI', + and 'FSI'. + order: The maximum order of the explanation. + """ + interaction_values: dict[int, np.ndarray] + explanation_type: str + order: int + + def __post_init__(self) -> None: + """Checks if the explanation type is valid.""" + if self.explanation_type not in ["SII", "nSII", "STI", "FSI"]: + raise ValueError( + f"Explanation type {self.explanation_type} is not valid. " + f"Available types are 'SII', 'nSII', 'STI', and 'FSI'." + ) + + +class Explainer(ABC): + + def __init__( + self, + model: Any, + X: np.ndarray, + y: np.ndarray + ) -> None: + """Initializes the Explainer class. + + Args: + model: The model to be explained. + X: The input data. + y: The output data. + """ + self.model = model + self.X = X + self.y = y diff --git a/shapiq/games/__init__.py b/shapiq/games/__init__.py new file mode 100644 index 00000000..99e5dc4f --- /dev/null +++ b/shapiq/games/__init__.py @@ -0,0 +1,6 @@ +"""This module contains sample game functions for the shapiq package.""" +from games.dummy import DummyGame + +__all__ = [ + "DummyGame", +] diff --git a/shapiq/games/dummy.py b/shapiq/games/dummy.py new file mode 100644 index 00000000..93b7e709 --- /dev/null +++ b/shapiq/games/dummy.py @@ -0,0 +1,82 @@ +"""This module contains the DummyGame class. The DummyGame class is mainly used for testing +purposes. It returns the size of the coalition divided by the number of players plus an additional +interaction term.""" +from typing import Union + +import numpy as np + + +class DummyGame: + + """Dummy game for testing purposes. When called, it returns the size of the coalition divided by + the number of players plus an additional interaction term. + + Args: + n: The number of players. + interaction: The interaction of the game as a tuple of player indices. Defaults to an empty + tuple. + + Attributes: + n: The number of players. + N: The set of players (starting from 0 to n - 1). + interaction: The interaction of the game as a tuple of player indices. + access_counter: The number of times the game has been called. + """ + + def __init__(self, n: int, interaction: Union[set, tuple] = ()): + self.n = n + self.N = set(range(self.n)) + self.interaction: tuple = tuple(sorted(interaction)) + self.access_counter = 0 + + def __call__( + self, + coalition: np.ndarray + ) -> np.ndarray[float]: + """Returns the size of the coalition divided by the number of players plus the interaction + term. + + Args: + coalition: The coalition as a binary vector of shape (n,) or (batch_size, n). + + Returns: + The worth of the coalition. + """ + if coalition.ndim == 1: + coalition = coalition.reshape((1, self.n)) + worth = np.sum(coalition, axis=1) / self.n + if len(self.interaction) > 0: + interaction = coalition[:, self.interaction] + worth += np.prod(interaction, axis=1) + # update access counter given rows in coalition + self.access_counter += coalition.shape[0] + return worth + + def __repr__(self): + return f"DummyGame(n={self.n}, interaction={self.interaction})" + + def __str__(self): + return f"DummyGame(n={self.n}, interaction={self.interaction})" + + def __eq__(self, other): + return self.n == other.n and self.interaction == other.interaction + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.n, self.interaction)) + + def __copy__(self): + return DummyGame(n=self.n, interaction=self.interaction) + + def __deepcopy__(self, memo): + return DummyGame(n=self.n, interaction=self.interaction) + + def __getstate__(self): + return {'n': self.n, 'interaction': self.interaction} + + def __setstate__(self, state): + self.n = state['n'] + self.interaction = state['interaction'] + self.N = set(range(self.n)) diff --git a/shapiq/utils/game_theory.py b/shapiq/utils/game_theory.py index 1dd90cec..09eb9957 100644 --- a/shapiq/utils/game_theory.py +++ b/shapiq/utils/game_theory.py @@ -1,7 +1,7 @@ """This module contains utility functions for dealing with sets, coalitions and game theory.""" from itertools import chain, combinations -from typing import Iterable, Any, Optional, Union, Callable, TypeVar, Tuple +from typing import Iterable, Any, Optional __all__ = [ @@ -13,7 +13,7 @@ def powerset( iterable: Iterable[Any], min_size: Optional[int] = 0, max_size: Optional[int] = None -) -> Iterable[tuple[Any, ...]]: +) -> Iterable[tuple]: """Return a powerset of an iterable as tuples with optional size limits. Args: diff --git a/tests/test_approximator_permutation_sii.py b/tests/test_approximator_permutation_sii.py new file mode 100644 index 00000000..d32672e1 --- /dev/null +++ b/tests/test_approximator_permutation_sii.py @@ -0,0 +1,68 @@ +"""This test module contains all tests regarding the SII permutation sampling approximator.""" +import numpy as np +import pytest + +from approximator._base import InteractionValues +from approximator.permutation import PermutationSamplingSII +from games import DummyGame + + +@pytest.mark.parametrize( + "n, max_order, top_order, expected", + [ + (3, 1, True, 6), + (3, 1, False, 6), + (3, 2, True, 8), + (3, 2, False, 14), + (10, 3, False, 120), + ], +) +def test_initialization(n, max_order, top_order, expected): + """Tests the initialization of the PermutationSamplingSII approximator.""" + approximator = PermutationSamplingSII(n, max_order, top_order) + assert approximator.n == n + assert approximator.max_order == max_order + assert approximator.top_order == top_order + assert approximator.min_order == (max_order if top_order else 1) + assert approximator._iteration_cost == expected + + +@pytest.mark.parametrize( + "n, max_order, top_order, budget, batch_size, expected", + [ + (7, 2, False, 380, 10, {"iteration_cost": 38, "access_counter": 380}), + (7, 2, False, 500, 10, {"iteration_cost": 38, "access_counter": 494}), + ] +) +def test_approximate(n, max_order, top_order, budget, batch_size, expected): + """Tests the approximation of the PermutationSamplingSII approximator.""" + interaction = (1, 2) + game = DummyGame(n, interaction) + approximator = PermutationSamplingSII(n, max_order, top_order) + assert approximator._iteration_cost == expected["iteration_cost"] + sii_estimates = approximator.approximate(budget, game, batch_size=batch_size) + assert isinstance(sii_estimates, InteractionValues) + assert len(sii_estimates.values) == max_order + + # check that the budget is respected + assert game.access_counter <= budget + assert game.access_counter == expected["access_counter"] + + # check that the estimates are correct + # for order 1 player 1 and 2 are the most important the rest should be somewhat equal + # for order 2 the interaction between player 1 and 2 is the most important + first_order: np.ndarray = sii_estimates.values[1] + # sort the players by their importance + sorted_players: np.ndarray = np.argsort(first_order)[::-1] + assert (sorted_players[0] == 1 or sorted_players[0] == 2) and \ + (sorted_players[1] == 1 or sorted_players[1] == 2) + for player_one in sorted_players[2:]: + for player_two in sorted_players[2:]: + pytest.approx(player_one, player_two, 0.1) + + second_order: np.ndarray = sii_estimates.values[2] + pytest.approx(second_order[interaction], 1.0, 0.0001) + + # check efficiency + efficiency = np.sum(first_order) + pytest.approx(efficiency, 2.0, 0.001) diff --git a/tests/test_games_dummy.py b/tests/test_games_dummy.py new file mode 100644 index 00000000..c705df9a --- /dev/null +++ b/tests/test_games_dummy.py @@ -0,0 +1,40 @@ +"""This test module contains the tests for the DummyGame class.""" +import numpy as np +import pytest + +from games import DummyGame + + +@pytest.mark.parametrize( + "n, interaction, expected", + [ + (3, (1, 2), {(): 0, (0,): 1 / 3, (1,): 1 / 3, (2,): 1 / 3, + (0, 1): 2 / 3, (0, 2): 2 / 3, (1, 2): 2 / 3 + 1, (0, 1, 2): 3 / 3 + 1}), + (4, (1, 2), {(): 0, (0,): 1 / 4, (1,): 1 / 4, (2,): 1 / 4, (3,): 1 / 4, + (0, 1): 2 / 4, (1, 2): 2 / 4 + 1, (2, 3): 2 / 4, + (0, 1, 2): 3 / 4 + 1, (1, 2, 3): 3 / 4 + 1, (0, 1, 2, 3): 4 / 4 + 1}), + ] +) +def test_dummy_game(n, interaction, expected): + """Test the DummyGame class.""" + game = DummyGame(n=n, interaction=interaction) + for coalition in expected.keys(): + x_input = np.zeros(shape=(n,), dtype=bool) + x_input[list(coalition)] = True + assert game(x_input)[0] == expected[coalition] + + +def test_dummy_game_access_counts(): + """Test how often the game was called.""" + game = DummyGame(n=10, interaction=(1, 2)) + assert game.access_counter == 0 + game(np.asarray([True, False, False, False, False, False, False, False, False, False])) + assert game.access_counter == 1 + game(np.asarray([True, False, False, False, False, False, False, False, False, False])) + assert game.access_counter == 2 + game(np.asarray([ + [True, False, False, False, False, False, False, False, False, False], + [False, True, False, False, False, False, False, False, False, False], + [False, False, True, False, False, False, False, False, False, False], + ])) + assert game.access_counter == 5 diff --git a/tests/test_utils_game_theory.py b/tests/test_utils_game_theory.py new file mode 100644 index 00000000..911942bc --- /dev/null +++ b/tests/test_utils_game_theory.py @@ -0,0 +1,18 @@ +"""This test module contains the test cases for the utils game theory module.""" + +import pytest + +from utils.game_theory import powerset + + +@pytest.mark.parametrize( + "iterable, min_size, max_size, expected", + [ + ([1, 2, 3], 0, None, [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]), + ([1, 2, 3], 1, None, [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]), + ([1, 2, 3], 0, 2, [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3)]), + ], +) +def test_powerset(iterable, min_size, max_size, expected): + """Tests the powerset function.""" + assert list(powerset(iterable, min_size, max_size)) == expected