diff --git a/CHANGELOG.md b/CHANGELOG.md index 06512f99dc1..efe2df80d79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed typo with `LabelField` string representation: removed trailing apostrophe. - +- Gradient attribution in AllenNLP Interpret now computed as a function of the predicted class' logit, not its loss. ## [v1.3.0](https://github.com/allenai/allennlp/releases/tag/v1.3.0) - 2020-12-15 diff --git a/allennlp/interpret/saliency_interpreters/simple_gradient.py b/allennlp/interpret/saliency_interpreters/simple_gradient.py index 639da42e824..e4711f942f6 100644 --- a/allennlp/interpret/saliency_interpreters/simple_gradient.py +++ b/allennlp/interpret/saliency_interpreters/simple_gradient.py @@ -17,7 +17,7 @@ class SimpleGradient(SaliencyInterpreter): def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict: """ - Interprets the model's prediction for inputs. Gets the gradients of the loss with respect + Interprets the model's prediction for inputs. Gets the gradients of the logits with respect to the input and returns those gradients normalized and sanitized. """ labeled_instances = self.predictor.json_to_labeled_instances(inputs) diff --git a/allennlp/predictors/predictor.py b/allennlp/predictors/predictor.py index 3ea94182edb..df269cbeb45 100644 --- a/allennlp/predictors/predictor.py +++ b/allennlp/predictors/predictor.py @@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]: def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ - Gets the gradients of the loss with respect to the model inputs. + Gets the gradients of the logits with respect to the model inputs. # Parameters @@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict Takes a `JsonDict` representing the inputs of the model and converts them to [`Instances`](../data/instance.md)), sends these through the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding - layer of the model. Calls `backward` on the loss and then removes the + layer of the model. Calls `backward` on the logits and then removes the hooks. """ # set requires_grad to true for all parameters, but save original values to @@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict self._model.forward(**dataset_tensor_dict) # type: ignore ) - loss = outputs["loss"] + predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs["probs"]))] # Zero gradients. # NOTE: this is actually more efficient than calling `self._model.zero_grad()` # because it avoids a read op when the gradients are first updated below. for p in self._model.parameters(): p.grad = None - loss.backward() + predicted_logit.backward() for hook in hooks: hook.remove()