Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Compute attributions w.r.t the predicted logit, not the predicted loss #4882

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed typo with `LabelField` string representation: removed trailing apostrophe.

- Gradient attribution in AllenNLP Interpret now computed as a function of the predicted class' logit, not its loss.

## [v1.3.0](https://github.com/allenai/allennlp/releases/tag/v1.3.0) - 2020-12-15

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SimpleGradient(SaliencyInterpreter):

def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
"""
Interprets the model's prediction for inputs. Gets the gradients of the loss with respect
Interprets the model's prediction for inputs. Gets the gradients of the logits with respect
to the input and returns those gradients normalized and sanitized.
"""
labeled_instances = self.predictor.json_to_labeled_instances(inputs)
Expand Down
8 changes: 4 additions & 4 deletions allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:

def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Gets the gradients of the loss with respect to the model inputs.
Gets the gradients of the logits with respect to the model inputs.

# Parameters

Expand All @@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
Takes a `JsonDict` representing the inputs of the model and converts
them to [`Instances`](../data/instance.md)), sends these through
the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
layer of the model. Calls `backward` on the loss and then removes the
layer of the model. Calls `backward` on the logits and then removes the
hooks.
"""
# set requires_grad to true for all parameters, but save original values to
Expand All @@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
self._model.forward(**dataset_tensor_dict) # type: ignore
)

loss = outputs["loss"]
predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs["probs"]))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble with doing it this way is that it hard-codes assumptions about the model's outputs which may not be true. The test failure you're getting is because of this. This method has to be generic enough to work for any model. This is ok when we query the loss key, because that key is already required by the Trainer. Nothing else is guaranteed to be in the output, so we can't hard-code anything else.

Maybe a better way of accomplishing what you want is to allow the caller to specify the output key, with a default value of "loss". Then it would be the model's responsibility make sure that the value in the key is a single number on which we can call .backward(). E.g., you could imagine adding a target_logit key in your model class, and then use that key when calling get_gradients().

We could get by with less model modification if we add a second flag that says whether to take an argmax of the values in that key, but that gets a bit messy, because then you're always getting gradients of the model's prediction, completely ignoring whatever label was given in the input instance. This breaks a lot of assumptions in other methods in the code (which I think is what you were referring to when you said this breaks hotflip), so I don't really like this option.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I agree that using a key is straightforward. I'll refactor.

# Zero gradients.
# NOTE: this is actually more efficient than calling `self._model.zero_grad()`
# because it avoids a read op when the gradients are first updated below.
for p in self._model.parameters():
p.grad = None
loss.backward()
predicted_logit.backward()

for hook in hooks:
hook.remove()
Expand Down