Skip to content

Commit

Permalink
feat: store all responsibility maps in Explanation object
Browse files Browse the repository at this point in the history
  • Loading branch information
liz-is committed Nov 29, 2024
1 parent 24a2f84 commit 71b90a7
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 30 deletions.
2 changes: 1 addition & 1 deletion rex_xai/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def update_database(
explanation.args,
target.classification,
target.confidence,
explanation.map,
explanation.maps.get(explanation.data.target.classification),
explanation.explanation.detach().cpu().numpy(), # type: ignore
time_taken,
total_passing,
Expand Down
2 changes: 1 addition & 1 deletion rex_xai/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def calculate_responsibility(data: Data, args: CausalArgs, prediction_func):
'avg_box_size': avg_box_size
}

exp = Explanation(maps, prediction_func, data.target, data, args, run_stats)
exp = Explanation(maps, prediction_func, data.target, data, args, run_stats, keep_all_maps=False)

return exp

Expand Down
23 changes: 15 additions & 8 deletions rex_xai/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@
class Explanation:
def __init__(
self,
map,
maps,
prediction_func,
target: Prediction,
data: Data,
args: CausalArgs,
run_stats: dict,
keep_all_maps = False
) -> None:
self.map = map.get(target.classification)
if keep_all_maps:
self.maps = maps
else:
maps.subset(target.classification)
self.maps = maps

self.explanation: Optional[tt.Tensor] = None
self.final_mask = None
self.prediction_func = prediction_func
Expand Down Expand Up @@ -89,7 +95,7 @@ def set_to_true(self, coords, mask=None):

def __global(self, map=None, wipe=False):
if map is None:
map = self.map
map = self.maps.get(self.data.target.classification)
ranking = get_map_locations(map)

mutant = tt.zeros(
Expand Down Expand Up @@ -133,8 +139,9 @@ def __circle(self, centre, radius: int):

def __spatial(self, centre=None, expansion_limit=None) -> Optional[int]:
# we don't have a search location to start from, so we try to isolate one
map = self.maps.get(self.data.target.classification)
if centre is None:
centre = np.unravel_index(np.argmax(self.map), self.map.shape)
centre = np.unravel_index(np.argmax(map), map.shape)

start_radius = self.args.spatial_radius
mask = tt.zeros(
Expand All @@ -161,7 +168,7 @@ def __spatial(self, centre=None, expansion_limit=None) -> Optional[int]:
p = self.prediction_func(d)[0]
if p.classification == self.target.classification:
return self.__global(
map=np.where(circle.detach().cpu().numpy(), self.map, 0)
map=np.where(circle.detach().cpu().numpy(), map, 0)
)
start_radius = int(start_radius * (1 + self.args.spatial_eta))
circle = self.__circle(centre, start_radius)
Expand All @@ -182,7 +189,7 @@ def save(self, path):
spectral_plot(
self.explanation,
self.data,
self.map,
self.maps.get(self.data.target.classification),
self.args.heatmap_colours,
path = path
)
Expand All @@ -196,7 +203,7 @@ def heatmap_plot(self, path=None):
if self.data.mode in ("RGB", "L"):
heatmap_plot(
self.data,
self.map,
self.maps.get(self.data.target.classification),
self.args.heatmap_colours,
self.target,
path=path,
Expand All @@ -208,7 +215,7 @@ def surface_plot(self, path=None):
if self.data.mode in ("RGB", "L"):
surface_plot(
self.args,
self.map,
self.maps.get(self.data.target.classification),
self.target,
path=path,
)
Expand Down
6 changes: 6 additions & 0 deletions rex_xai/resp_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ def update_maps(
)
section = 0.001
self.maps[k] = resp_map

def subset(self, id):
m = self.maps.get(id)
c = self.counts.get(id)
self.maps = {id: m}
self.counts = {id: c}
8 changes: 1 addition & 7 deletions tests/__snapshots__/_explanation_onnx_test.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,7 @@
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]]]),
map=array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
maps={209: 3},
mask_func=0,
run_stats=dict({
'avg_box_size': 12544.0,
Expand Down
14 changes: 1 addition & 13 deletions tests/__snapshots__/_explanation_test.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,7 @@
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]]),
map=array([[2. , 2. , 2. , ..., 4.6666665, 4.6666665,
4.6666665],
[2. , 2. , 2. , ..., 4.6666665, 4.6666665,
4.6666665],
[2. , 2. , 2. , ..., 4.6666665, 4.6666665,
4.6666665],
...,
[2. , 2. , 2. , ..., 3.4999998, 3.4999998,
3.4999998],
[2. , 2. , 2. , ..., 3.4999998, 3.4999998,
3.4999998],
[2. , 2. , 2. , ..., 3.4999998, 3.4999998,
3.4999998]], dtype=float32),
maps={207: 3},
mask_func=0,
run_stats=dict({
'avg_box_size': 216.0,
Expand Down

0 comments on commit 71b90a7

Please sign in to comment.