This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial design of the multi-task model * PR comments, more implementation * changelog and docs fix * More tests, and fixes for those tests * mypy and make test less flaky * Update allennlp/models/multitask.py * Update allennlp/models/multitask.py Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com> * Update allennlp/models/multitask.py Co-authored-by: James Barry <james.barry26@mail.dcu.ie> * respect active heads in get_metrics * Clean up changelog * black (apparently github UI doesn't add newlines?) Co-authored-by: Dirk Groeneveld <dirkg@allenai.org> Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com> Co-authored-by: James Barry <james.barry26@mail.dcu.ie>
- Loading branch information
1 parent
fa22f73
commit f1e46fd
Showing
11 changed files
with
649 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from allennlp.models.heads.head import Head | ||
from allennlp.models.heads.classifier_head import ClassifierHead |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Dict, Optional | ||
|
||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.data import Vocabulary | ||
from allennlp.models.heads.head import Head | ||
from allennlp.modules import FeedForward, Seq2VecEncoder | ||
from allennlp.modules.seq2vec_encoders import ClsPooler | ||
from allennlp.training.metrics import CategoricalAccuracy | ||
|
||
|
||
@Head.register("classifier") | ||
class ClassifierHead(Head): | ||
""" | ||
A classification `Head`. Takes encoded text, gets a single vector out of it, runs an optional | ||
feedforward layer on that vector, then classifies it into some label space. | ||
Registered as a `Head` with name "classifier". | ||
# Parameters | ||
vocab : `Vocabulary` | ||
Used to get the number of labels, if `num_labels` is not provided, and to translate label | ||
indices to strings in `make_output_human_readable`. | ||
seq2vec_encoder : `Seq2VecEncoder`, optional (default = `ClsPooler`) | ||
The input to this module is assumed to be a sequence of encoded vectors. We use a | ||
`Seq2VecEncoder` to compress this into a single vector on which we can perform | ||
classification. If nothing is provided, we will use a `ClsPooler`, which simply takes the | ||
first element of the sequence as the single vector (and is the standard thing to do when you | ||
are running a classifier on top of a transformer). | ||
feedforward : `FeedForward`, optional, (default = `None`) | ||
An optional feedforward layer to apply on the pooled output before performing the | ||
classification. | ||
input_dim : `int`, optional (default = `None`) | ||
We need to know how many dimensions to use for the final classification weight matrix. If | ||
you have provided either a `seq2vec_encoder` or a `feedforward` module, we can get the | ||
correct size from those objects. If you use default values for both of those parameters, | ||
then you must provide this parameter, so that we know the size of that encoding. | ||
dropout : `float`, optional (default = `None`) | ||
Dropout percentage to use. | ||
num_labels : `int`, optional (default = `None`) | ||
Number of labels to project to in classification layer. By default, the classification layer will | ||
project to the size of the vocabulary namespace corresponding to labels. | ||
label_namespace : `str`, optional (default = `"labels"`) | ||
Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab: Vocabulary, | ||
seq2vec_encoder: Seq2VecEncoder = None, | ||
feedforward: Optional[FeedForward] = None, | ||
input_dim: int = None, | ||
dropout: float = None, | ||
num_labels: int = None, | ||
label_namespace: str = "labels", | ||
) -> None: | ||
|
||
super().__init__(vocab) | ||
self._seq2vec_encoder = seq2vec_encoder or ClsPooler() | ||
self._feedforward = feedforward | ||
if feedforward is not None: | ||
self._classifier_input_dim = self._feedforward.get_output_dim() | ||
else: | ||
self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() or input_dim | ||
|
||
if self._classifier_input_dim is None: | ||
raise ValueError("No input dimension given!") | ||
|
||
if dropout: | ||
self._dropout = torch.nn.Dropout(dropout) | ||
else: | ||
self._dropout = None | ||
self._label_namespace = label_namespace | ||
|
||
if num_labels: | ||
self._num_labels = num_labels | ||
else: | ||
self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace) | ||
self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) | ||
self._accuracy = CategoricalAccuracy() | ||
self._loss = torch.nn.CrossEntropyLoss() | ||
|
||
def forward( # type: ignore | ||
self, | ||
encoded_text: torch.FloatTensor, | ||
encoded_text_mask: torch.BoolTensor, | ||
label: torch.IntTensor = None, | ||
) -> Dict[str, torch.Tensor]: | ||
encoding = self._seq2vec_encoder(encoded_text, mask=encoded_text_mask) | ||
|
||
if self._dropout: | ||
encoding = self._dropout(encoding) | ||
|
||
if self._feedforward is not None: | ||
encoding = self._feedforward(encoding) | ||
|
||
logits = self._classification_layer(encoding) | ||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
|
||
output_dict = {"logits": logits, "probs": probs} | ||
if label is not None: | ||
loss = self._loss(logits, label.long().view(-1)) | ||
output_dict["loss"] = loss | ||
self._accuracy(logits, label) | ||
|
||
return output_dict | ||
|
||
@overrides | ||
def make_output_human_readable( | ||
self, output_dict: Dict[str, torch.Tensor] | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Does a simple argmax over the probabilities, converts index to string label, and | ||
add `"label"` key to the dictionary with the result. | ||
""" | ||
if "probs" in output_dict: | ||
predictions = output_dict["probs"] | ||
if predictions.dim() == 2: | ||
predictions_list = [predictions[i] for i in range(predictions.shape[0])] | ||
else: | ||
predictions_list = [predictions] | ||
classes = [] | ||
for prediction in predictions_list: | ||
label_idx = prediction.argmax(dim=-1).item() | ||
label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get( | ||
label_idx, str(label_idx) | ||
) | ||
classes.append(label_str) | ||
output_dict["label"] = classes | ||
return output_dict | ||
|
||
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | ||
metrics = {"accuracy": self._accuracy.get_metric(reset)} | ||
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from allennlp.models.model import Model | ||
|
||
|
||
class Head(Model): | ||
""" | ||
A `Head` is a `Model` that takes _already encoded input_ and typically does simple computation | ||
before returning a loss. | ||
There isn't currently any difference in API between a `Model` and a `Head`, but we have this | ||
separate type as both a signaling mechanism for what to expect when looking at a `Head` class, | ||
and so that we can use this as a more informative type annotation when building models that use | ||
`Heads` as inputs. | ||
One additional consideration in a `Head` is that `make_output_human_readable` needs to account | ||
for the case where it gets called without first having `forward` be called on the head. This is | ||
because at the point where we call `make_output_human_readable`, we don't know which heads were | ||
used in `forward`, and trying to save the state is messy. So just make sure that you always | ||
have conditional logic in `make_output_human_readable` when you implement a `Head`. | ||
""" | ||
|
||
pass |
Oops, something went wrong.