Skip to content

Commit

Permalink
fix introspection methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Aug 12, 2024
1 parent e4ff078 commit 2ead6c8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions hbw/ml/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def sensitivity_analysis(model, inputs, output_node: int = 0, input_features: li
"""
Sensitivity analysis of *model* between batch-normalized *inputs* and pre-softmax output *output_node*
"""
gradients = get_gradients(model, inputs, output_node)
gradients = get_gradients(model, inputs, output_node, skip_batch_norm=True)

sum_gradients = np.sum(np.abs(gradients), axis=0)
sum_gradients = dict(zip(input_features, sum_gradients))
Expand All @@ -89,7 +89,7 @@ def gradient_times_input(model, inputs, output_node: int = 0, input_features: li
"""
Gradient * Input of *model* between batch-normalized *inputs* and pre-softmax output *output_node*
"""
gradients = get_gradients(model, inputs, output_node)
gradients = get_gradients(model, inputs, output_node, skip_batch_norm=True)
inputs = get_input_post_batchnorm(model, inputs)

# NOTE: remove np.abs?
Expand All @@ -109,7 +109,6 @@ def shap_ranking(model, inputs, output_node: int = 0, input_features: list | Non
# calculate shap values
shap_values = explainer(inputs[:20])
shap_values.feature_names = list(input_features)
shap_values.shape

shap_ranking = dict(zip(shap_values.feature_names, shap_values[:, :, output_node].abs.mean(axis=0).values))
shap_ranking = dict(sorted(shap_ranking.items(), key=lambda x: abs(x[1]), reverse=True))
Expand Down
2 changes: 1 addition & 1 deletion hbw/ml/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def barplot_from_multidict(dict_of_rankings: dict[str, dict], normalize_weights:
The first sub-directory is used for the sorting of variables.
:param normalize_weights: whether to normalize the sum of weights per ranking to 1.
"""
fig, ax = plt.subplots(figsize=(8, 10))
plt.style.use("seaborn-v0_8")
fig, ax = plt.subplots(figsize=(8, 10))

num_dicts = len(dict_of_rankings.keys())
num_labels = len(dict_of_rankings[list(dict_of_rankings.keys())[0]].keys())
Expand Down

0 comments on commit 2ead6c8

Please sign in to comment.