From 45669b743ae9fc94d11d1a7d5ef6df4274bda5f0 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 09:46:14 +0100 Subject: [PATCH 01/10] add basic dunder methods to approximator and interaction values --- shapiq/approximator/_base.py | 177 +++++++++++++++++++++++++++++++---- 1 file changed, 158 insertions(+), 19 deletions(-) diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py index 49f4850e..8a61429e 100644 --- a/shapiq/approximator/_base.py +++ b/shapiq/approximator/_base.py @@ -12,7 +12,7 @@ @dataclass class InteractionValues: - """ This class contains the interaction values as estimated by an approximator. + """This class contains the interaction values as estimated by an approximator. Attributes: values: The interaction values of the model. Mapping from order to the interaction values. @@ -20,6 +20,7 @@ class InteractionValues: 'FSI'. order: The order of the approximation. """ + values: dict[int, np.ndarray] index: str order: int @@ -37,6 +38,78 @@ def __post_init__(self) -> None: f"Order should be a positive integer equal to the maximum key of the values." ) + def __repr__(self) -> str: + """Returns the representation of the InteractionValues object.""" + representation = f"InteractionValues(\n" + representation += f" index={self.index}, order={self.order}, values=" + "{" + for order, values in self.values.items(): + representation += "\n" + representation += f" {order}: " + string_values: str = str(np.round(values, 4)) + representation += string_values.replace("\n", "\n" + " " * 11) + representation += "})" + return representation + + def __str__(self) -> str: + """Returns the string representation of the InteractionValues object.""" + return self.__repr__() + + def __getitem__(self, item: int) -> np.ndarray: + """Returns the interaction values for the given order. + + Args: + item: The order of the interaction values. + + Returns: + The interaction values. + """ + return self.values[item] + + def __eq__(self, other: object) -> bool: + """Checks if two InteractionValues objects are equal. + + Args: + other: The other InteractionValues object. + + Returns: + True if the two objects are equal, False otherwise. + """ + if not isinstance(other, InteractionValues): + raise NotImplementedError("Cannot compare InteractionValues with other types.") + if self.index != other.index or self.order != other.order: + return False + for order, values in self.values.items(): + if not np.allclose(values, other.values[order]): + return False + return True + + def __ne__(self, other: object) -> bool: + """Checks if two InteractionValues objects are not equal. + + Args: + other: The other InteractionValues object. + + Returns: + True if the two objects are not equal, False otherwise. + """ + return not self.__eq__(other) + + def __hash__(self) -> int: + """Returns the hash of the InteractionValues object.""" + return hash((self.index, self.order, tuple(self.values.values()))) + + def __copy__(self) -> "InteractionValues": + """Returns a copy of the InteractionValues object.""" + return InteractionValues( + values=copy.deepcopy(self.values), index=self.index, order=self.order + ) + + def __deepcopy__(self, memo) -> "InteractionValues": + """Returns a deep copy of the InteractionValues object.""" + return InteractionValues( + values=copy.deepcopy(self.values), index=self.index, order=self.order + ) + class Approximator(ABC): """This class is the base class for all approximators. @@ -66,19 +139,18 @@ class Approximator(ABC): """ def __init__( - self, - n: int, - max_order: int, - index: str, - top_order: bool, - random_state: Optional[int] = None + self, + n: int, + max_order: int, + index: str, + top_order: bool, + random_state: Optional[int] = None, ) -> None: """Initializes the approximator.""" self.index: str = index if self.index not in AVAILABLE_INDICES: raise ValueError( - f"Index {self.index} is not valid. " - f"Available indices are {AVAILABLE_INDICES}." + f"Index {self.index} is not valid. " f"Available indices are {AVAILABLE_INDICES}." ) self.n: int = n self.N: set = set(range(self.n)) @@ -90,10 +162,7 @@ def __init__( @abstractmethod def approximate( - self, - budget: int, - game: Callable[[Union[set, tuple]], float], - *args, **kwargs + self, budget: int, game: Callable[[np.ndarray], np.ndarray], *args, **kwargs ) -> InteractionValues: """Approximates the interaction values. Abstract method that needs to be implemented for each approximator. @@ -108,7 +177,7 @@ def approximate( Raises: NotImplementedError: If the method is not implemented. """ - raise NotImplementedError + raise NotImplementedError("The approximate method needs to be implemented.") def _init_result(self, dtype=float) -> dict[int, np.ndarray]: """Initializes the result dictionary mapping from order to the interaction values. @@ -120,8 +189,7 @@ def _init_result(self, dtype=float) -> dict[int, np.ndarray]: Returns: The result dictionary. """ - result = {s: self._get_empty_array(self.n, s, dtype=dtype) - for s in self._order_iterator} + result = {s: self._get_empty_array(self.n, s, dtype=dtype) for s in self._order_iterator} return result @staticmethod @@ -136,7 +204,7 @@ def _get_empty_array(n: int, order: int, dtype=float) -> np.ndarray: Returns: The empty array. """ - return np.zeros(n ** order, dtype=dtype).reshape((n,) * order) + return np.zeros(n**order, dtype=dtype).reshape((n,) * order) @property def _order_iterator(self) -> range: @@ -160,8 +228,7 @@ def _finalize_result(self, result) -> InteractionValues: @staticmethod def _smooth_with_epsilon( - interaction_results: Union[dict, np.ndarray], - eps=0.00001 + interaction_results: Union[dict, np.ndarray], eps=0.00001 ) -> Union[dict, np.ndarray]: """Smooth the interaction results with a small epsilon to avoid numerical issues. @@ -180,3 +247,75 @@ def _smooth_with_epsilon( interaction_values[np.abs(interaction_values) < eps] = 0 interactions[interaction_order] = interaction_values return copy.deepcopy(interactions) + + def __repr__(self) -> str: + """Returns the representation of the Approximator object.""" + return ( + f"{self.__class__.__name__}(" + f" n={self.n},\n" + f" max_order={self.max_order},\n" + f" index={self.index},\n" + f" top_order={self.top_order},\n" + f" random_state={self._random_state}\n" + f")" + ) + + def __str__(self) -> str: + """Returns the string representation of the Approximator object.""" + return self.__repr__() + + def __eq__(self, other: object) -> bool: + """Checks if two Approximator objects are equal. + + Args: + other: The other Approximator object. + + Returns: + True if the two objects are equal, False otherwise. + """ + if not isinstance(other, Approximator): + raise NotImplementedError("Cannot compare Approximator with other types.") + if ( + self.n != other.n + or self.max_order != other.max_order + or self.index != other.index + or self.top_order != other.top_order + or self._random_state != other._random_state + ): + return False + return True + + def __ne__(self, other: object) -> bool: + """Checks if two Approximator objects are not equal. + + Args: + other: The other Approximator object. + + Returns: + True if the two objects are not equal, False otherwise. + """ + return not self.__eq__(other) + + def __hash__(self) -> int: + """Returns the hash of the Approximator object.""" + return hash((self.n, self.max_order, self.index, self.top_order, self._random_state)) + + def __copy__(self) -> "Approximator": + """Returns a copy of the Approximator object.""" + return self.__class__( + n=self.n, + max_order=self.max_order, + index=self.index, + top_order=self.top_order, + random_state=self._random_state, + ) + + def __deepcopy__(self, memo) -> "Approximator": + """Returns a deep copy of the Approximator object.""" + return self.__class__( + n=self.n, + max_order=self.max_order, + index=self.index, + top_order=self.top_order, + random_state=self._random_state, + ) From 383f36a26927b9552dbfcbac57aedff4943f421a Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 09:46:58 +0100 Subject: [PATCH 02/10] add black as code sanitizer --- pyproject.toml | 4 ++++ setup.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4aff98bb..dd759a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ [build-system] requires = ["setuptools>=61.0.0", "wheel"] build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 100 +target-version = ['py39'] \ No newline at end of file diff --git a/setup.py b/setup.py index 21e3da8c..83fb04a1 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,10 @@ ] dev_packages = [ + "build", + "black", "pytest" + ] setuptools.setup( From 10fa658232f98d0bce3523b2c959902b5192d87e Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 09:48:47 +0100 Subject: [PATCH 03/10] add docstring to SII permutation --- shapiq/approximator/permutation/_base.py | 23 +++++---- shapiq/approximator/permutation/sii.py | 61 +++++++++++++++++------- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/shapiq/approximator/permutation/_base.py b/shapiq/approximator/permutation/_base.py index f16a8ae3..d722451e 100644 --- a/shapiq/approximator/permutation/_base.py +++ b/shapiq/approximator/permutation/_base.py @@ -30,16 +30,15 @@ class PermutationSampling(Approximator): the specified order. min_order (int): The minimum order of the approximation. If top_order is True, min_order is equal to order. Otherwise, min_order is equal to 1. - """ def __init__( - self, - n: int, - max_order: int, - index: str, - top_order: bool, - random_state: Optional[int] = None + self, + n: int, + max_order: int, + index: str, + top_order: bool, + random_state: Optional[int] = None, ) -> None: if index not in AVAILABLE_INDICES_PERMUTATION: raise ValueError( @@ -47,11 +46,10 @@ def __init__( f"Available indices are {AVAILABLE_INDICES_PERMUTATION}." ) super().__init__(n, max_order, index, top_order, random_state) + self._iteration_cost: int = -1 def approximate( - self, - budget: int, - game: Callable[[Union[set, tuple]], float] + self, budget: int, game: Callable[[np.ndarray], np.ndarray] ) -> InteractionValues: """Approximates the interaction values.""" raise NotImplementedError @@ -76,3 +74,8 @@ def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tupl last_batch_size = remaining_budget // iteration_cost n_iterations += 1 return n_iterations, last_batch_size + + @property + def iteration_cost(self) -> int: + """The cost of a single iteration of the permutation sampling mechanism.""" + return self._iteration_cost diff --git a/shapiq/approximator/permutation/sii.py b/shapiq/approximator/permutation/sii.py index 27eaf743..28fc8716 100644 --- a/shapiq/approximator/permutation/sii.py +++ b/shapiq/approximator/permutation/sii.py @@ -2,7 +2,7 @@ # TODO add docstring """ -from typing import Optional, Callable, Union +from typing import Optional, Callable import numpy as np @@ -12,7 +12,7 @@ class PermutationSamplingSII(PermutationSampling): - """ Permutation Sampling approximator for the SII (and nSII) index. + """Permutation Sampling approximator for the SII (and nSII) index. Args: n: The number of players. @@ -20,16 +20,41 @@ class PermutationSamplingSII(PermutationSampling): top_order: Whether to approximate only the top order interactions (`True`) or all orders up to the specified order (`False`). random_state: The random state to use for the permutation sampling. Defaults to `None`. + + Attributes: + n: The number of players. + max_order: The interaction order of the approximation. + top_order: Whether to approximate only the top order interactions (`True`) or all orders up + to the specified order (`False`). + min_order: The minimum order to approximate. + iteration_cost: The cost of a single iteration of the permutation sampling. + + Example: + >>> from games import DummyGame + >>> from approximator import PermutationSamplingSII + >>> game = DummyGame(n=7, interaction=(0, 1)) + >>> approximator = PermutationSamplingSII(n=7, max_order=2, top_order=False) + >>> approximator.approximate(budget=1000, game=game) + InteractionValues( + index=SII, order=2, values={ + 1: [0.1429 0.6429 0.6429 0.1429 0.1429 0.1429 0.1429] + 2: [[ 0. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 1. 0. 0. 0. 0.] + [ 0. 1. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 0. -0. 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.]]}) """ def __init__( - self, - n: int, - max_order: int, - top_order: bool, - random_state: Optional[int] = None + self, + n: int, + max_order: int, + top_order: bool, + random_state: Optional[int] = None, ) -> None: - super().__init__(n, max_order, 'SII', top_order, random_state) + super().__init__(n, max_order, "SII", top_order, random_state) self._iteration_cost: int = self._compute_iteration_cost() def _compute_iteration_cost(self) -> int: @@ -41,21 +66,21 @@ def _compute_iteration_cost(self) -> int: """ iteration_cost: int = 0 for s in self._order_iterator: - iteration_cost += (self.n - s + 1) * 2 ** s + iteration_cost += (self.n - s + 1) * 2**s return iteration_cost def approximate( - self, - budget: int, - game: Callable[[np.ndarray], np.ndarray], - batch_size: Optional[int] = None, + self, + budget: int, + game: Callable[[np.ndarray], np.ndarray], + batch_size: Optional[int] = 5, ) -> InteractionValues: """Approximates the interaction values. Args: budget: The budget for the approximation. game: The game function as a callable that takes a set of players and returns the value. - batch_size: The size of the batch. If None, the batch size is set to 1. + batch_size: The size of the batch. If None, the batch size is set to 1. Defaults to 5. Returns: InteractionValues: The estimated interaction values. @@ -68,11 +93,11 @@ def approximate( # compute the number of iterations and size of the last batch (can be smaller than original) n_iterations, last_batch_size = self._get_n_iterations( - budget, batch_size, self._iteration_cost) + budget, batch_size, self._iteration_cost + ) # main permutation sampling loop for iteration in range(1, n_iterations + 1): - batch_size = batch_size if iteration != n_iterations else last_batch_size # create the permutations: a 2d matrix of shape (batch_size, n) where each row is a @@ -88,7 +113,7 @@ def approximate( for permutation_id in range(n_permutations): for order in self._order_iterator: for k in range(self.n - order + 1): - subset = permutations[permutation_id, k:k + order] + subset = permutations[permutation_id, k : k + order] previous_subset = permutations[permutation_id, :k] for subset_ in powerset(subset, min_size=0): subset_eval = np.concatenate((previous_subset, subset_)).astype(int) @@ -103,7 +128,7 @@ def approximate( for permutation_id in range(n_permutations): for order in self._order_iterator: for k in range(self.n - order + 1): - subset = permutations[permutation_id, k:k + order] + subset = permutations[permutation_id, k : k + order] counts[order][tuple(subset)] += 1 # update the discrete derivative given the subset for subset_ in powerset(subset, min_size=0): From 24be8afb294065ce83b4017de2dc908fc1d4c778 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 09:49:17 +0100 Subject: [PATCH 04/10] add check of correct index --- tests/test_approximator_permutation_sii.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_approximator_permutation_sii.py b/tests/test_approximator_permutation_sii.py index d32672e1..076a6acc 100644 --- a/tests/test_approximator_permutation_sii.py +++ b/tests/test_approximator_permutation_sii.py @@ -24,7 +24,8 @@ def test_initialization(n, max_order, top_order, expected): assert approximator.max_order == max_order assert approximator.top_order == top_order assert approximator.min_order == (max_order if top_order else 1) - assert approximator._iteration_cost == expected + assert approximator.iteration_cost == expected + assert approximator.index == "SII" @pytest.mark.parametrize( @@ -32,7 +33,7 @@ def test_initialization(n, max_order, top_order, expected): [ (7, 2, False, 380, 10, {"iteration_cost": 38, "access_counter": 380}), (7, 2, False, 500, 10, {"iteration_cost": 38, "access_counter": 494}), - ] + ], ) def test_approximate(n, max_order, top_order, budget, batch_size, expected): """Tests the approximation of the PermutationSamplingSII approximator.""" @@ -54,8 +55,9 @@ def test_approximate(n, max_order, top_order, budget, batch_size, expected): first_order: np.ndarray = sii_estimates.values[1] # sort the players by their importance sorted_players: np.ndarray = np.argsort(first_order)[::-1] - assert (sorted_players[0] == 1 or sorted_players[0] == 2) and \ - (sorted_players[1] == 1 or sorted_players[1] == 2) + assert (sorted_players[0] == 1 or sorted_players[0] == 2) and ( + sorted_players[1] == 1 or sorted_players[1] == 2 + ) for player_one in sorted_players[2:]: for player_two in sorted_players[2:]: pytest.approx(player_one, player_two, 0.1) From 550bab9354c7280fd9a6dfa65a4a9944ed81397e Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 10:04:03 +0100 Subject: [PATCH 05/10] add test of split_subsets_budget --- shapiq/utils/sets.py | 4 +++- tests/test_utils_sets.py | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/shapiq/utils/sets.py b/shapiq/utils/sets.py index c19ce84b..344b4db8 100644 --- a/shapiq/utils/sets.py +++ b/shapiq/utils/sets.py @@ -75,7 +75,9 @@ def pair_subset_sizes(order: int, n: int) -> tuple[list[tuple[int, int]], Option return paired_subsets, unpaired_subset -def split_subsets_budget(order: int, n: int, budget: int, q: np.ndarray) -> tuple[list, list, int]: +def split_subsets_budget( + order: int, n: int, budget: int, q: np.ndarray[float] +) -> tuple[list, list, int]: """Determines which subset sizes can be computed explicitly and which sizes need to be sampled. Given a computational budget, determines the complete subsets that can be computed explicitly diff --git a/tests/test_utils_sets.py b/tests/test_utils_sets.py index 362f2f4c..90dbcb60 100644 --- a/tests/test_utils_sets.py +++ b/tests/test_utils_sets.py @@ -1,8 +1,8 @@ """This test module contains the test cases for the utils sets module.""" - +import numpy as np import pytest -from utils.sets import powerset, pair_subset_sizes +from utils.sets import powerset, pair_subset_sizes, split_subsets_budget @pytest.mark.parametrize( @@ -32,3 +32,18 @@ def test_powerset(iterable, min_size, max_size, expected): def test_pairing(order, n, expected): """Tests the get_paired_subsets function.""" assert pair_subset_sizes(order, n) == expected + + +@pytest.mark.parametrize( + "order, n, budget, q, expected", + [ + (1, 6, 100, [0, 1, 1, 1, 1, 1, 0], ([1, 5, 2, 4, 3], [], 38)), + (1, 6, 60, [0, 1, 1, 1, 1, 1, 0], ([1, 5, 2, 4], [3], 18)), + (1, 6, 100, [0, 0, 0, 0, 0, 0, 0], ([], [1, 2, 3, 4, 5], 100)), + ], +) +def test_split_subsets_budget(budget, order, n, q, expected): + """Tests the split_subsets_budget function.""" + q_arr = np.asarray(q, dtype=float) + out = split_subsets_budget(order, n, budget, q_arr) + assert out == expected From 3489fa22480dfac4c09009496cfab439cbab56e0 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 10:08:11 +0100 Subject: [PATCH 06/10] rename q to sampling_weights --- shapiq/utils/sets.py | 12 ++++++------ tests/test_utils_sets.py | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/shapiq/utils/sets.py b/shapiq/utils/sets.py index 344b4db8..f9d2a433 100644 --- a/shapiq/utils/sets.py +++ b/shapiq/utils/sets.py @@ -76,7 +76,7 @@ def pair_subset_sizes(order: int, n: int) -> tuple[list[tuple[int, int]], Option def split_subsets_budget( - order: int, n: int, budget: int, q: np.ndarray[float] + order: int, n: int, budget: int, sampling_weights: np.ndarray[float] ) -> tuple[list, list, int]: """Determines which subset sizes can be computed explicitly and which sizes need to be sampled. @@ -87,20 +87,20 @@ def split_subsets_budget( order: interaction order. n: number of players. budget: total allowed budget for the computation. - q: weight vector of the sampling distribution in shape (n + 1,). The first and last element + sampling_weights: weight vector of the sampling distribution in shape (n + 1,). The first and last element constituting the empty and full subsets are not used. Returns: complete subsets, incomplete subsets, remaining budget Examples: - >>> split_subsets_budget(order=1, n=6, budget=100, q=np.ones(shape=(6,))) + >>> split_subsets_budget(order=1, n=6, budget=100, sampling_weights=np.ones(shape=(6,))) ([1, 5, 2, 4, 3], [], 38) - >>> split_subsets_budget(order=1, n=6, budget=60, q=np.ones(shape=(6,))) + >>> split_subsets_budget(order=1, n=6, budget=60, sampling_weights=np.ones(shape=(6,))) ([1, 5, 2, 4], [3], 18) - >>> split_subsets_budget(order=1, n=6, budget=100, q=np.zeros(shape=(6,))) + >>> split_subsets_budget(order=1, n=6, budget=100, sampling_weights=np.zeros(shape=(6,))) ([], [1, 2, 3, 4, 5], 100) """ # determine paired and unpaired subsets @@ -109,7 +109,7 @@ def split_subsets_budget( incomplete_subsets = list(range(order, n - order + 1)) # turn weight vector into probability vector - weight_vector = copy.copy(q) + weight_vector = copy.copy(sampling_weights) weight_vector[0], weight_vector[-1] = 0, 0 # zero out the empty and full subsets sum_weight_vector = np.sum(weight_vector) weight_vector = np.divide( diff --git a/tests/test_utils_sets.py b/tests/test_utils_sets.py index 90dbcb60..ebe8b017 100644 --- a/tests/test_utils_sets.py +++ b/tests/test_utils_sets.py @@ -44,6 +44,7 @@ def test_pairing(order, n, expected): ) def test_split_subsets_budget(budget, order, n, q, expected): """Tests the split_subsets_budget function.""" - q_arr = np.asarray(q, dtype=float) - out = split_subsets_budget(order, n, budget, q_arr) - assert out == expected + sampling_weights = np.asarray(q, dtype=float) + assert split_subsets_budget(order, n, budget, sampling_weights) == expected + assert (split_subsets_budget(order=order, n=n, budget=budget, sampling_weights=sampling_weights) + == expected) From 154f4714107b3fc265ee940115a60f18969209ae Mon Sep 17 00:00:00 2001 From: Maximilian Date: Tue, 21 Nov 2023 10:08:26 +0100 Subject: [PATCH 07/10] rename q to sampling_weights --- tests/test_utils_sets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_utils_sets.py b/tests/test_utils_sets.py index ebe8b017..4b97e633 100644 --- a/tests/test_utils_sets.py +++ b/tests/test_utils_sets.py @@ -46,5 +46,7 @@ def test_split_subsets_budget(budget, order, n, q, expected): """Tests the split_subsets_budget function.""" sampling_weights = np.asarray(q, dtype=float) assert split_subsets_budget(order, n, budget, sampling_weights) == expected - assert (split_subsets_budget(order=order, n=n, budget=budget, sampling_weights=sampling_weights) - == expected) + assert ( + split_subsets_budget(order=order, n=n, budget=budget, sampling_weights=sampling_weights) + == expected + ) From a0aab36059af382e2153b664935def7a4cdc511d Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 22 Nov 2023 19:04:20 +0100 Subject: [PATCH 08/10] extend base Approximator and InteractionValues classes --- shapiq/approximator/_base.py | 90 +++++++++++++++++++++--- shapiq/approximator/permutation/_base.py | 23 +----- shapiq/approximator/permutation/sii.py | 10 +-- 3 files changed, 88 insertions(+), 35 deletions(-) diff --git a/shapiq/approximator/_base.py b/shapiq/approximator/_base.py index 8a61429e..6b17c718 100644 --- a/shapiq/approximator/_base.py +++ b/shapiq/approximator/_base.py @@ -1,11 +1,12 @@ """This module contains the base approximator classes for the shapiq package.""" import copy +import itertools from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable, Union, Optional import numpy as np - +from scipy.special import binom AVAILABLE_INDICES = {"SII", "nSII", "STI", "FSI"} @@ -19,11 +20,15 @@ class InteractionValues: index: The interaction index estimated. Available indices are 'SII', 'nSII', 'STI', and 'FSI'. order: The order of the approximation. + estimated: Whether the interaction values are estimated or not. Defaults to True. + estimation_budget: The budget used for the estimation. Defaults to None. """ values: dict[int, np.ndarray] index: str order: int + estimated: bool = True + estimation_budget: Optional[int] = None def __post_init__(self) -> None: """Checks if the index is valid.""" @@ -41,13 +46,21 @@ def __post_init__(self) -> None: def __repr__(self) -> str: """Returns the representation of the InteractionValues object.""" representation = f"InteractionValues(\n" - representation += f" index={self.index}, order={self.order}, values=" + "{" + representation += ( + f" index={self.index}, order={self.order}, estimated={self.estimated}" + f", estimation_budget={self.estimation_budget},\n" + ) + " values={" for order, values in self.values.items(): representation += "\n" representation += f" {order}: " string_values: str = str(np.round(values, 4)) representation += string_values.replace("\n", "\n" + " " * 11) - representation += "})" + representation += "\n }" + # if self.approximator: + # representation += ",\n" + # string_approximator = str(self.approximator).replace("\n", "\n" + " " * 4) + # representation += f" approximator={string_approximator}" + representation += "\n)" return representation def __str__(self) -> str: @@ -101,13 +114,21 @@ def __hash__(self) -> int: def __copy__(self) -> "InteractionValues": """Returns a copy of the InteractionValues object.""" return InteractionValues( - values=copy.deepcopy(self.values), index=self.index, order=self.order + values=copy.deepcopy(self.values), + index=self.index, + order=self.order, + estimated=self.estimated, + estimation_budget=self.estimation_budget, ) def __deepcopy__(self, memo) -> "InteractionValues": """Returns a deep copy of the InteractionValues object.""" return InteractionValues( - values=copy.deepcopy(self.values), index=self.index, order=self.order + values=copy.deepcopy(self.values), + index=self.index, + order=self.order, + estimated=self.estimated, + estimation_budget=self.estimation_budget, ) @@ -130,6 +151,7 @@ class Approximator(ABC): Attributes: n: The number of players. N: The set of players (starting from 0 to n - 1). + N_arr: The array of players (starting from 0 to n). max_order: The interaction order of the approximation. index: The interaction index to be estimated. top_order: If True, the approximation is performed only for the top order interactions. If @@ -154,6 +176,7 @@ def __init__( ) self.n: int = n self.N: set = set(range(self.n)) + self.N_arr: np.ndarray[int] = np.arange(self.n + 1) self.max_order: int = max_order self.top_order: bool = top_order self.min_order: int = self.max_order if self.top_order else 1 @@ -215,16 +238,26 @@ def _order_iterator(self) -> range: """ return range(self.min_order, self.max_order + 1) - def _finalize_result(self, result) -> InteractionValues: + def _finalize_result( + self, result, estimated: bool = True, budget: Optional[int] = None + ) -> InteractionValues: """Finalizes the result dictionary. Args: result: The result dictionary. + estimated: Whether the interaction values are estimated or not. Defaults to True. + budget: The budget used for the estimation. Defaults to None. Returns: The interaction values. """ - return InteractionValues(result, self.index, self.max_order) + return InteractionValues( + values=result, + index=self.index, + order=self.max_order, + estimated=estimated, + estimation_budget=budget, + ) @staticmethod def _smooth_with_epsilon( @@ -248,10 +281,51 @@ def _smooth_with_epsilon( interactions[interaction_order] = interaction_values return copy.deepcopy(interactions) + @staticmethod + def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]: + """Computes the number of iterations and the size of the last batch given the batch size and + the budget. + + Args: + budget: The budget for the approximation. + batch_size: The size of the batch. + iteration_cost: The cost of a single iteration. + + Returns: + int, int: The number of iterations and the size of the last batch. + """ + n_iterations = budget // (iteration_cost * batch_size) + last_batch_size = batch_size + remaining_budget = budget - n_iterations * iteration_cost * batch_size + if remaining_budget > 0 and remaining_budget // iteration_cost > 0: + last_batch_size = remaining_budget // iteration_cost + n_iterations += 1 + return n_iterations, last_batch_size + + @staticmethod + def _get_explicit_subsets(n: int, subset_sizes: list[int]) -> np.ndarray[bool]: + """Enumerates all subsets of the given sizes and returns a one-hot matrix. + + Args: + n: number of players. + subset_sizes: list of subset sizes. + + Returns: + one-hot matrix of all subsets of certain sizes. + """ + total_subsets = int(sum(binom(n, size) for size in subset_sizes)) + subset_matrix = np.zeros(shape=(total_subsets, n), dtype=bool) + subset_index = 0 + for subset_size in subset_sizes: + for subset in itertools.combinations(range(n), subset_size): + subset_matrix[subset_index, subset] = True + subset_index += 1 + return subset_matrix + def __repr__(self) -> str: """Returns the representation of the Approximator object.""" return ( - f"{self.__class__.__name__}(" + f"{self.__class__.__name__}(\n" f" n={self.n},\n" f" max_order={self.max_order},\n" f" index={self.index},\n" diff --git a/shapiq/approximator/permutation/_base.py b/shapiq/approximator/permutation/_base.py index d722451e..89e4a7f9 100644 --- a/shapiq/approximator/permutation/_base.py +++ b/shapiq/approximator/permutation/_base.py @@ -1,5 +1,5 @@ """This module contains the base permutation sampling algorithms to estimate SII/nSII and STI.""" -from typing import Optional, Callable, Union +from typing import Optional, Callable import numpy as np @@ -54,27 +54,6 @@ def approximate( """Approximates the interaction values.""" raise NotImplementedError - @staticmethod - def _get_n_iterations(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]: - """Computes the number of iterations and the size of the last batch given the batch size and - the budget. - - Args: - budget: The budget for the approximation. - batch_size: The size of the batch. - iteration_cost: The cost of a single iteration. - - Returns: - int, int: The number of iterations and the size of the last batch. - """ - n_iterations = budget // (iteration_cost * batch_size) - last_batch_size = batch_size - remaining_budget = budget - n_iterations * iteration_cost * batch_size - if remaining_budget > 0 and remaining_budget // iteration_cost > 0: - last_batch_size = remaining_budget // iteration_cost - n_iterations += 1 - return n_iterations, last_batch_size - @property def iteration_cost(self) -> int: """The cost of a single iteration of the permutation sampling mechanism.""" diff --git a/shapiq/approximator/permutation/sii.py b/shapiq/approximator/permutation/sii.py index 28fc8716..eb88eb68 100644 --- a/shapiq/approximator/permutation/sii.py +++ b/shapiq/approximator/permutation/sii.py @@ -1,7 +1,4 @@ -"""This module implements the Permutation Sampling approximator for the SII (and nSII) index. - -# TODO add docstring -""" +"""This module implements the Permutation Sampling approximator for the SII (and nSII) index.""" from typing import Optional, Callable import numpy as np @@ -87,6 +84,7 @@ def approximate( """ batch_size = 1 if batch_size is None else batch_size + used_budget = 0 result = self._init_result() counts = self._init_result(dtype=int) @@ -137,8 +135,10 @@ def approximate( result[order][tuple(subset)] += update subset_index += 1 + used_budget += self._iteration_cost * batch_size + # compute mean of interactions for s in self._order_iterator: result[s] = np.divide(result[s], counts[s], out=result[s], where=counts[s] != 0) - return self._finalize_result(result) + return self._finalize_result(result, budget=used_budget) From 2ac2bb3b027e645160335a284b93f653b7461041 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 22 Nov 2023 19:06:18 +0100 Subject: [PATCH 09/10] add regression estimator for FSI and closes #3 --- shapiq/approximator/__init__.py | 3 + shapiq/approximator/regression/__init__.py | 6 + shapiq/approximator/regression/_base.py | 71 ++++++ shapiq/approximator/regression/fsi.py | 245 +++++++++++++++++++++ 4 files changed, 325 insertions(+) create mode 100644 shapiq/approximator/regression/__init__.py create mode 100644 shapiq/approximator/regression/_base.py create mode 100644 shapiq/approximator/regression/fsi.py diff --git a/shapiq/approximator/__init__.py b/shapiq/approximator/__init__.py index 7e6e221d..ade71522 100644 --- a/shapiq/approximator/__init__.py +++ b/shapiq/approximator/__init__.py @@ -1,7 +1,10 @@ """This module contains the approximators to estimate the Shapley interaction values.""" from .permutation.sii import PermutationSamplingSII from .permutation.sti import PermutationSamplingSTI +from .regression import RegressionFSI __all__ = [ "PermutationSamplingSII", + "PermutationSamplingSTI", + "RegressionFSI", ] diff --git a/shapiq/approximator/regression/__init__.py b/shapiq/approximator/regression/__init__.py new file mode 100644 index 00000000..4cc783b9 --- /dev/null +++ b/shapiq/approximator/regression/__init__.py @@ -0,0 +1,6 @@ +"""This module contains the regression-based approximators to estimate Shapley interaction values. +""" +from ._base import Regression +from .fsi import RegressionFSI + +__all__ = ["RegressionFSI", "Regression"] diff --git a/shapiq/approximator/regression/_base.py b/shapiq/approximator/regression/_base.py new file mode 100644 index 00000000..1f9bb85d --- /dev/null +++ b/shapiq/approximator/regression/_base.py @@ -0,0 +1,71 @@ +"""This module contains the base regression algorithms to estimate FSI scores.""" +from typing import Optional, Callable, Union + +import numpy as np +from scipy.special import binom + +from approximator._base import Approximator, InteractionValues + +AVAILABLE_INDICES_REGRESSION = {"FSI"} + + +class Regression(Approximator): + def __init__( + self, + n: int, + max_order: int, + index: str, + random_state: Optional[int] = None, + ) -> None: + if index not in AVAILABLE_INDICES_REGRESSION: + raise ValueError( + f"Index {index} is not valid. " + f"Available indices are {AVAILABLE_INDICES_REGRESSION}." + ) + super().__init__(n, max_order, index, False, random_state) + self._big_M = float(1_000_000) + + def approximate( + self, budget: int, game: Callable[[Union[set, tuple]], float], *args, **kwargs + ) -> InteractionValues: + """Approximates the interaction values.""" + raise NotImplementedError + + def _init_ksh_sampling_weights(self) -> np.ndarray[float]: + """Initializes the weights for sampling subsets. + + The sampling weights are of size n + 1 and indexed by the size of the subset. The edges + (the first, empty coalition, and the last element, full coalition) are set to 0. + + Returns: + The weights for sampling subsets. + """ + weight_vector = np.zeros(shape=self.n - 1, dtype=float) + for subset_size in range(1, self.n): + weight_vector[subset_size - 1] = (self.n - 1) / (subset_size * (self.n - subset_size)) + sampling_weight = (np.asarray([0] + [*weight_vector] + [0])) / sum(weight_vector) + return sampling_weight + + def _get_ksh_subset_weights(self, subsets: np.ndarray[bool]) -> np.ndarray[float]: + """Computes the KernelSHAP regression weights for the given subsets. + + The weights for the subsets of size s are set to ksh_weights[s] / binom(n, s). The weights + for the empty and full sets are set to a big number. + + Args: + subsets: one-hot matrix of subsets for which to compute the weights in shape + (n_subsets, n). + + Returns: + The KernelSHAP regression weights in shape (n_subsets,). + """ + # set the weights for each subset to ksh_weights[|S|] / binom(n, |S|) + ksh_weights = self._init_ksh_sampling_weights() # indexed by subset size + subset_sizes = np.sum(subsets, axis=1) + weights = ksh_weights[subset_sizes] # set the weights for each subset size + weights /= binom(self.n, subset_sizes) # divide by the number of subsets of the same size + + # set the weights for the empty and full sets to big M + weights[np.logical_not(subsets).all(axis=1)] = self._big_M + weights[subsets.all(axis=1)] = self._big_M + return weights diff --git a/shapiq/approximator/regression/fsi.py b/shapiq/approximator/regression/fsi.py new file mode 100644 index 00000000..40f71c0b --- /dev/null +++ b/shapiq/approximator/regression/fsi.py @@ -0,0 +1,245 @@ +"""This module contains the base regression algorithms to estimate FSI scores.""" +import copy +import itertools +from typing import Optional, Callable + +import numpy as np +from scipy.special import binom + +from approximator._base import InteractionValues +from utils import split_subsets_budget, powerset +from ._base import Regression + + +class RegressionFSI(Regression): + """Estimates the FSI values using the weighted least square approach. + + Args: + n: The number of players. + max_order: The interaction order of the approximation. + random_state: The random state of the estimator. Defaults to `None`. + + Attributes: + n: The number of players. + N: The set of players (starting from 0 to n - 1). + max_order: The interaction order of the approximation. + min_order: The minimum order of the approximation. For FSI, min_order is equal to 1. + + Example: + >>> from games import DummyGame + >>> from approximator import RegressionFSI + >>> game = DummyGame(n=7, interaction=(0, 1)) + >>> approximator = RegressionFSI(n=7, max_order=2) + >>> approximator.approximate(budget=100, game=game) + InteractionValues( + index=FSI, order=2, values={ + 1: [0.1429 0.1429 0.1429 0.1429 0.1429 0.1429 0.1429] + 2: [[ 0. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 1. -0. -0. 0. 0.] + [ 0. 1. 0. 0. 0. 0. -0.] + [ 0. -0. 0. 0. 0. -0. 0.] + [ 0. -0. 0. 0. 0. -0. -0.] + [ 0. 0. 0. -0. -0. 0. -0.] + [ 0. 0. -0. 0. -0. -0. 0.]]}) + """ + + def __init__( + self, + n: int, + max_order: int, + random_state: Optional[int] = None, + ) -> None: + super().__init__(n, max_order=max_order, index="FSI", random_state=random_state) + + def approximate( + self, + budget: int, + game: Callable[[np.ndarray], np.ndarray], + batch_size: Optional[int] = None, + replacement: bool = False, + pairing: bool = True, + ) -> InteractionValues: + """Approximates the interaction values. + + Args: + budget: The budget of the approximation. + game: The game to be approximated. + batch_size: The batch size for the approximation. Defaults to `None`. If `None` the + batch size is set to the budget. + replacement: Whether to sample subsets with replacement (`True`) or without replacement + (`False`). Defaults to `False`. + pairing: Whether to use the pairing sampling strategy or not. If paired sampling + (`True`) is used a subset is always paired with its complement subset and sampled + together. This may increase approximation quality. Defaults to `True`. + + Returns: + The interaction values. + """ + # validate input parameters + batch_size = budget + 2 if batch_size is None else batch_size + used_budget = 0 + n_iterations, last_batch_size = self._get_n_iterations( + budget + 2, batch_size, iteration_cost=1 + ) + + # create storage array for given budget + all_subsets: np.ndarray[bool] = np.zeros(shape=(budget, self.n), dtype=bool) + + # split the subset sizes into explicit and sampling parts + sampling_weights: np.ndarray[float] = self._init_ksh_sampling_weights() + explicit_sizes, sampling_sizes, remaining_budget = split_subsets_budget( + order=1, n=self.n, budget=budget, sampling_weights=sampling_weights + ) + + # enumerate all explicit subsets + explicit_subsets: np.ndarray[bool] = self._get_explicit_subsets(self.n, explicit_sizes) + all_subsets[: len(explicit_subsets)] = explicit_subsets + # zero out the sampling weights for the explicit sizes + sampling_weights[explicit_sizes] = 0.0 + + # sample the remaining subsets with the remaining budget + if len(sampling_sizes) > 0 and remaining_budget > 0: + sampling_subsets: np.ndarray[bool] = self._sample_subsets( + budget=remaining_budget, + sampling_weights=sampling_weights, + replacement=replacement, + pairing=pairing, + ) + all_subsets[len(explicit_subsets) :] = sampling_subsets + + # add empty and full set to all_subsets in the beginning + all_subsets = np.concatenate( + ( + np.zeros(shape=(1, self.n), dtype=bool), # empty set + np.ones(shape=(1, self.n), dtype=bool), # full set + all_subsets, # explicit and sampled subsets + ) + ) + n_subsets = all_subsets.shape[0] + + # get the fsi representation of the subsets + regression_subsets, num_players = self._get_fsi_subset_representation(all_subsets) # S, m + regression_weights = self._get_ksh_subset_weights(all_subsets) # W(|S|) + + # initialize the regression variables + game_values: np.ndarray[float] = np.zeros(shape=(n_subsets,), dtype=float) # \nu(S) + fsi_values: np.ndarray[float] = np.zeros(shape=(num_players,), dtype=float) + + # main regression loop computing the FSI values + for iteration in range(1, n_iterations + 1): + batch_size = batch_size if iteration != n_iterations else last_batch_size + batch_index = (iteration - 1) * batch_size + + # query the game for the batch of subsets + batch_subsets = all_subsets[0 : batch_index + batch_size] + game_values[batch_index : batch_index + batch_size] = game(batch_subsets) + + # compute the FSI values up to now + A = regression_subsets[0 : batch_index + batch_size] + B = game_values[0 : batch_index + batch_size] + W = regression_weights[0 : batch_index + batch_size] + W = np.sqrt(np.diag(W)) + Aw = np.dot(W, A) + Bw = np.dot(W, B) + fsi_values = np.linalg.lstsq(Aw, Bw, rcond=None)[0] # \phi_i + + used_budget += batch_size + + return self._finalize_fsi_result(fsi_values, budget=used_budget) + + def _sample_subsets( + self, + budget: int, + sampling_weights: np.ndarray[float], + replacement: bool = False, + pairing: bool = True, + ) -> np.ndarray[bool]: + """Samples subsets with the given budget. + + Args: + budget: budget for the sampling. + sampling_weights: weights for sampling subsets of certain sizes and indexed by the size. + The shape is expected to be (n + 1,). A size that is not to be sampled has weight 0. + pairing: whether to use pairing (`True`) sampling or not (`False`). Defaults to `False`. + + Returns: + sampled subsets. + """ + # sanitize input parameters + sampling_weights = copy.copy(sampling_weights) + sampling_weights /= np.sum(sampling_weights) + + # adjust budget for paired sampling + if pairing: + budget = budget - budget % 2 # must be even for pairing + budget = int(budget / 2) + + # create storage array for given budget + subset_matrix = np.zeros(shape=(budget, self.n), dtype=bool) + + # sample subsets + sampled_sizes = self._rng.choice(self.N_arr, size=budget, p=sampling_weights).astype(int) + if replacement: # sample subsets with replacement + permutations = np.tile(np.arange(self.n), (budget, 1)) + self._rng.permuted(permutations, axis=1, out=permutations) + for i, subset_size in enumerate(sampled_sizes): + subset = permutations[i, :subset_size] + subset_matrix[i, subset] = True + else: # sample subsets without replacement + sampled_subsets, n_sampled = set(), 0 # init sampling variables + while n_sampled < budget: + subset_size = sampled_sizes[n_sampled] + subset = tuple(sorted(self._rng.choice(np.arange(0, self.n), size=subset_size))) + sampled_subsets.add(subset) + if len(sampled_subsets) != n_sampled: # subset was not already sampled + subset_matrix[n_sampled, subset] = True + n_sampled += 1 # continue sampling + + if pairing: + subset_matrix = np.repeat(subset_matrix, repeats=2, axis=0) # extend the subset matrix + subset_matrix[1::2] = np.logical_not(subset_matrix[1::2]) # flip sign of paired subsets + + return subset_matrix + + def _get_fsi_subset_representation( + self, all_subsets: np.ndarray[bool] + ) -> tuple[np.ndarray[bool], int]: + """Transforms a subset matrix into the FSI representation. + + The FSI representation is a matrix of shape (n_subsets, num_players) where each interaction + up to the maximum order is an individual player. + + Args: + all_subsets: subset matrix in shape (n_subsets, n). + + Returns: + FSI representation of the subset matrix in shape (n_subsets, num_players). + """ + n_subsets = all_subsets.shape[0] + num_players = sum(int(binom(self.n, order)) for order in range(1, self.max_order + 1)) + regression_subsets = np.zeros(shape=(n_subsets, num_players), dtype=bool) + for interaction_index, interaction in enumerate( + powerset(self.N, min_size=1, max_size=self.max_order) + ): + regression_subsets[:, interaction_index] = all_subsets[:, interaction].all(axis=1) + return regression_subsets, num_players + + def _finalize_fsi_result( + self, fsi_values: np.ndarray[float], budget: Optional[int] = None + ) -> InteractionValues: + """Transforms the FSI values into the output interaction values. + + Args: + fsi_values: FSI values in shape (num_players,). + budget: The budget of the approximation. Defaults to `None`. + + Returns: + InteractionValues: The estimated interaction values. + """ + result = self._init_result() + fsi_index = 0 + for interaction in powerset(self.N, min_size=1, max_size=self.max_order): + for interaction_ordering in itertools.permutations(interaction): # all permutations + result[len(interaction)][interaction_ordering] = fsi_values[fsi_index] + fsi_index += 1 + return self._finalize_result(result, budget=budget) From dc663dac94667f2950bde83d0a9ed2ccc2fe2de7 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 22 Nov 2023 19:06:41 +0100 Subject: [PATCH 10/10] add permutation estimator for STI and closes #2 --- shapiq/approximator/permutation/sti.py | 137 +++++++++++++++++++------ 1 file changed, 105 insertions(+), 32 deletions(-) diff --git a/shapiq/approximator/permutation/sti.py b/shapiq/approximator/permutation/sti.py index eed460be..acf1a312 100644 --- a/shapiq/approximator/permutation/sti.py +++ b/shapiq/approximator/permutation/sti.py @@ -1,4 +1,5 @@ -from typing import Callable, Union +from typing import Callable +import warnings import numpy as np from scipy.special import binom @@ -9,7 +10,6 @@ class PermutationSamplingSTI(PermutationSampling): - def __init__(self, n: int, max_order: int, top_order: bool): super().__init__(n, max_order, "STI", top_order) self._iteration_cost: int = self._compute_iteration_cost() @@ -21,13 +21,11 @@ def _compute_iteration_cost(self) -> int: Returns: int: The cost of a single iteration. """ - iteration_cost = int(binom(self.n, self.max_order) * 2 ** self.max_order) + iteration_cost = int(binom(self.n, self.max_order) * 2**self.max_order) return iteration_cost def _compute_lower_order_sti( - self, - game: Callable[[Union[set, tuple]], float], - result: dict[int, np.ndarray] + self, game: Callable[[np.ndarray], np.ndarray], result: dict[int, np.ndarray] ) -> dict[int, np.ndarray]: """Computes all lower order interactions for the STI index up to order max_order - 1. @@ -38,44 +36,119 @@ def _compute_lower_order_sti( Returns: The result dictionary. """ - # run the game on the whole powerset of players up to order max_order - 1 - game_evaluations = {subset: game(subset) - for subset in powerset(self.N, max_size=self.max_order - 1, min_size=1)} - # inspect all parts of the subsets contained in the powerset and attribute their - # contribution to the corresponding interactions and order - for subset in powerset(self.N, max_size=self.max_order - 1, min_size=1): - subset = tuple(subset) - subset_size = len(subset) - for subset_part in powerset(subset): - subset_part_size = len(subset_part) - update = (-1) ** (subset_size - subset_part_size) * game_evaluations[subset_part] + # get all game values on the whole powerset of players up to order max_order - 1 + lower_order_sizes = list(range(0, self.max_order)) + subsets: np.ndarray[bool] = self._get_explicit_subsets(self.n, lower_order_sizes) + game_values = game(subsets) + game_values_lookup = { + tuple(np.where(subsets[index])[0]): float(game_values[index]) + for index in range(subsets.shape[0]) + } + + # compute the discrete derivatives of all subsets + for subset in powerset(self.N, min_size=1, max_size=self.max_order - 1): # S + subset_size = len(subset) # |S| + for subset_part in powerset(subset): # L + subset_part_size = len(subset_part) # |L| + game_value = game_values_lookup[subset_part] # \nu(L) + update = (-1) ** (subset_size - subset_part_size) * game_value result[subset_size][subset] += update return result def approximate( - self, - budget: int, - game: Callable[[Union[set, tuple]], float] + self, budget: int, game: Callable[[np.ndarray], np.ndarray], batch_size: int = 1 ) -> InteractionValues: - raise NotImplementedError + """Approximates the interaction values. + + Args: + budget: The budget for the approximation. + game: The game function as a callable that takes a set of players and returns the value. + batch_size: The size of the batch. If None, the batch size is set to 1. Defaults to 1. + + Returns: + InteractionValues: The estimated interaction values. + """ + batch_size = 1 if batch_size is None else batch_size + used_budget = 0 + result = self._init_result() counts = self._init_result(dtype=int) - value_empty = game(set()) - value_full = game(self.N) # compute all lower order interactions if budget allows it lower_order_cost = sum(int(binom(self.n, s)) for s in range(self.min_order, self.max_order)) if self.max_order > 1 and budget >= lower_order_cost: budget -= lower_order_cost + used_budget += lower_order_cost result = self._compute_lower_order_sti(game, result) + else: + warnings.warn( + message=f"The budget {budget} is too small to compute the lower order interactions " + f"of the STI index, which requires {lower_order_cost} evaluations. Consider " + f"increasing the budget.", + category=UserWarning, + ) + return self._finalize_result(result, budget=used_budget) + + # compute the number of iterations and size of the last batch (can be smaller than original) + n_iterations, last_batch_size = self._get_n_iterations( + budget, batch_size, self._iteration_cost + ) + + # warn the user if the budget is too small + if n_iterations == 0: + warnings.warn( + message=f"The budget {budget} is too small to perform a single iteration, which " + f"requires {self._iteration_cost + lower_order_cost} evaluations. Consider " + f"increasing the budget.", + category=UserWarning, + ) + return self._finalize_result(result, budget=used_budget) # main permutation sampling loop - n_permutations = 0 - while budget >= self._iteration_cost: - budget -= self._iteration_cost - values = np.zeros(self.n + 1) # create array for the current permutation - values[0], values[-1] = value_empty, value_full # init values on the edges - permutation = np.random.permutation(self.n) # create random permutation - # TODO finish this - - return self._finalize_result(result) + for iteration in range(1, n_iterations + 1): + batch_size = batch_size if iteration != n_iterations else last_batch_size + + # create the permutations: a 2d matrix of shape (batch_size, n) where each row is a + # permutation of the players + permutations = np.tile(np.arange(self.n), (batch_size, 1)) + self._rng.permuted(permutations, axis=1, out=permutations) + n_permutations = permutations.shape[0] + n_subsets = n_permutations * self._iteration_cost + + # get all subsets to evaluate per iteration + subsets = np.zeros(shape=(n_subsets, self.n), dtype=bool) + subset_index = 0 + for permutation_id in range(n_permutations): + for interaction in powerset(self.N, self.max_order, self.max_order): + idx = 0 + for i in permutations[permutation_id]: + if i in interaction: + break + else: + idx += 1 + subset = tuple(permutations[permutation_id][:idx]) + for L in powerset(interaction): + subsets[subset_index, tuple(subset + L)] = True + subset_index += 1 + + # evaluate all subsets on the game + game_values: np.ndarray[float] = game(subsets) + + # update the interaction scores by iterating over the permutations again + subset_index = 0 + for permutation_id in range(n_permutations): + for interaction in powerset(self.N, self.max_order, self.max_order): + counts[self.max_order][interaction] += 1 + for L in powerset(interaction): + game_value = game_values[subset_index] + update = game_value * (-1) ** (self.max_order - len(L)) + result[self.max_order][interaction] += update + subset_index += 1 + + used_budget += self._iteration_cost * batch_size + + # compute mean of interactions + for s in self._order_iterator: + result[s] = np.divide(result[s], counts[s], out=result[s], where=counts[s] != 0) + + return self._finalize_result(result, budget=used_budget)