Skip to content

Commit

Permalink
fix gradient*input
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Aug 5, 2024
1 parent e48a422 commit 69a5ea2
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions hbw/ml/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,56 @@
import numpy as np


def get_gradients(model: tf.keras.models.Model, inputs: np.array, output_node: int = 0) -> np.array:
def get_input_post_batchnorm(model: tf.keras.models.Model, inputs: np.array) -> np.array:
"""
Calculate gradients of *model* between batch-normalized *inputs* and pre-softmax output *output_node*
Get the input data after passing through the batch normalization layer of *model*
:param model: The Keras model for which the gradients are to be calculated. The first layer of
:param model: The Keras model for which the input data is to be retrieved. The first layer of
the model is expected to be a BatchNormalization layer.
:type model: keras.Model
:param inputs: The input data for the model. This should be a numpy array of the inputs to the
model, which will be batch-normalized before being passed through the model.
:type inputs: np.array
:param output_node: The index of the output node for which the gradients are to be calculated.
This refers to the index of the node in the final layer of the model, before the softmax activation. Defaults to 0.
:type output_node: int, optional
:return: A numpy array of the gradients of the model with respect to the inputs.
:return: A numpy array of the input data after passing through the batch normalization layer.
The shape of this array will be the same as the shape of the inputs.
:rtype: np.array
:raises Exception: If the first layer of the model is not a BatchNormalization layer, an exception is raised.
"""
batch_norm = model.layers[0]
if batch_norm.__class__.__name__ != "BatchNormalization":
raise Exception(f"First layer is expected to be BatchNormalization but is {batch_norm.__class__.__name__}")
inp = batch_norm(tf.convert_to_tensor(inputs, dtype=tf.float32))
return inp


def get_gradients(
model: tf.keras.models.Model,
inputs: np.array,
output_node: int = 0,
skip_batch_norm: bool = False,
) -> np.array:
"""
Calculate gradients of *model* between batch-normalized *inputs* and pre-softmax output *output_node*
:param model: The Keras model for which the gradients are to be calculated. The first layer of
the model is expected to be a BatchNormalization layer.
:param inputs: The input data for the model. This should be a numpy array of the inputs to the
model, which will be batch-normalized before being passed through the model.
:param output_node: The index of the output node for which the gradients are to be calculated.
This refers to the index of the node in the final layer of the model, before the softmax activation. Defaults to 0.
:param skip_batch_norm: If True, the input data is not passed through the batch normalization layer.
:return: A numpy array of the gradients of the model with respect to the inputs.
The shape of this array will be the same as the shape of the inputs.
"""
if skip_batch_norm:
inp = get_input_post_batchnorm(model, inputs)
layers = model.layers[1:]
else:
inp = tf.convert_to_tensor(inputs, dtype=tf.float32)
layers = model.layers
with tf.GradientTape() as tape:
tape.watch(inp)
outp = inp
for layer in model.layers[1:]:
for layer in layers:
layer = copy.copy(layer)
if layer.name == model.layers[-1].name:
if layer.name == layers[-1].name:
# for the final layer, copy the layer and remove the softmax activation
layer = copy.copy(layer)
layer.activation = None
Expand Down Expand Up @@ -68,9 +90,10 @@ def gradient_times_input(model, inputs, output_node: int = 0, input_features: li
Gradient * Input of *model* between batch-normalized *inputs* and pre-softmax output *output_node*
"""
gradients = get_gradients(model, inputs, output_node)
inputs = get_input_post_batchnorm(model, inputs)

# sum_gradients = np.sum(np.abs(gradients, axis=0))
sum_gradients = np.abs(np.sum(gradients, axis=0))
# NOTE: remove np.abs?
sum_gradients = np.abs(np.sum(gradients * inputs, axis=0))
sum_gradients = dict(zip(input_features, sum_gradients))
sum_gradients = dict(sorted(sum_gradients.items(), key=lambda x: abs(x[1]), reverse=True))

Expand Down

0 comments on commit 69a5ea2

Please sign in to comment.