From 2ead6c8597fda7b517f41a938a83a94a231c5012 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Mon, 12 Aug 2024 12:36:04 +0200 Subject: [PATCH] fix introspection methods --- hbw/ml/introspection.py | 5 ++--- hbw/ml/plotting.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hbw/ml/introspection.py b/hbw/ml/introspection.py index 8ac9fa66..ba2abb95 100644 --- a/hbw/ml/introspection.py +++ b/hbw/ml/introspection.py @@ -76,7 +76,7 @@ def sensitivity_analysis(model, inputs, output_node: int = 0, input_features: li """ Sensitivity analysis of *model* between batch-normalized *inputs* and pre-softmax output *output_node* """ - gradients = get_gradients(model, inputs, output_node) + gradients = get_gradients(model, inputs, output_node, skip_batch_norm=True) sum_gradients = np.sum(np.abs(gradients), axis=0) sum_gradients = dict(zip(input_features, sum_gradients)) @@ -89,7 +89,7 @@ def gradient_times_input(model, inputs, output_node: int = 0, input_features: li """ Gradient * Input of *model* between batch-normalized *inputs* and pre-softmax output *output_node* """ - gradients = get_gradients(model, inputs, output_node) + gradients = get_gradients(model, inputs, output_node, skip_batch_norm=True) inputs = get_input_post_batchnorm(model, inputs) # NOTE: remove np.abs? @@ -109,7 +109,6 @@ def shap_ranking(model, inputs, output_node: int = 0, input_features: list | Non # calculate shap values shap_values = explainer(inputs[:20]) shap_values.feature_names = list(input_features) - shap_values.shape shap_ranking = dict(zip(shap_values.feature_names, shap_values[:, :, output_node].abs.mean(axis=0).values)) shap_ranking = dict(sorted(shap_ranking.items(), key=lambda x: abs(x[1]), reverse=True)) diff --git a/hbw/ml/plotting.py b/hbw/ml/plotting.py index ccb18492..29e1ea35 100644 --- a/hbw/ml/plotting.py +++ b/hbw/ml/plotting.py @@ -31,8 +31,8 @@ def barplot_from_multidict(dict_of_rankings: dict[str, dict], normalize_weights: The first sub-directory is used for the sorting of variables. :param normalize_weights: whether to normalize the sum of weights per ranking to 1. """ - fig, ax = plt.subplots(figsize=(8, 10)) plt.style.use("seaborn-v0_8") + fig, ax = plt.subplots(figsize=(8, 10)) num_dicts = len(dict_of_rankings.keys()) num_labels = len(dict_of_rankings[list(dict_of_rankings.keys())[0]].keys())