diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d05bf5845f..7c5afd089a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## Unreleased +## Unreleased (2.0 branch) + +### Added + +- A new high-performance default `DataLoader`: `MultiProcessDataLoading`. +- A `MultiTaskModel` and abstractions to use with it, including `Backbone` and `Head`. The + `MultiTaskModel` first runs its inputs through the `Backbone`, then passes the result (and +whatever other relevant inputs it got) to each `Head` that's in use. This is intended for +multi-task learning, but so far it is incomplete, as there are no corresponding dataset readers or +data loaders. Those are coming soon. + +### Changed + +- `DatasetReader`s are now always lazy. This means there is no `lazy` parameter in the base + class, and the `_read()` method should always be a generator. +- The `DataLoader` now decides whether to load instances lazily or not. + With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with + the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting. + + +## Unreleased (1.x branch) ### Added @@ -38,16 +58,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `Predictor.capture_model_internals()` now accepts a regex specifying which modules to capture -- A new high-performance default `DataLoader`: `MultiProcessDataLoading`. - -### Changed - -- `DatasetReader`s are now always lazy. This means there is no `lazy` parameter in the base - class, and the `_read()` method should always be a generator. -- The `DataLoader` now decides whether to load instances lazily or not. - With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with - the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting. - ## [v1.1.0rc4](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc4) - 2020-08-20 diff --git a/allennlp/models/__init__.py b/allennlp/models/__init__.py index a4301da88c1..14bac98337e 100644 --- a/allennlp/models/__init__.py +++ b/allennlp/models/__init__.py @@ -5,6 +5,7 @@ from allennlp.models.model import Model from allennlp.models.archival import archive_model, load_archive, Archive -from allennlp.models.simple_tagger import SimpleTagger from allennlp.models.basic_classifier import BasicClassifier +from allennlp.models.multitask import MultiTaskModel +from allennlp.models.simple_tagger import SimpleTagger from allennlp.models.vilbert import Nlvr2Vilbert diff --git a/allennlp/models/heads/__init__.py b/allennlp/models/heads/__init__.py new file mode 100644 index 00000000000..0108faf262f --- /dev/null +++ b/allennlp/models/heads/__init__.py @@ -0,0 +1,2 @@ +from allennlp.models.heads.head import Head +from allennlp.models.heads.classifier_head import ClassifierHead diff --git a/allennlp/models/heads/classifier_head.py b/allennlp/models/heads/classifier_head.py new file mode 100644 index 00000000000..39d5eefc44b --- /dev/null +++ b/allennlp/models/heads/classifier_head.py @@ -0,0 +1,136 @@ +from typing import Dict, Optional + +from overrides import overrides +import torch + +from allennlp.data import Vocabulary +from allennlp.models.heads.head import Head +from allennlp.modules import FeedForward, Seq2VecEncoder +from allennlp.modules.seq2vec_encoders import ClsPooler +from allennlp.training.metrics import CategoricalAccuracy + + +@Head.register("classifier") +class ClassifierHead(Head): + """ + A classification `Head`. Takes encoded text, gets a single vector out of it, runs an optional + feedforward layer on that vector, then classifies it into some label space. + + Registered as a `Head` with name "classifier". + + # Parameters + + vocab : `Vocabulary` + Used to get the number of labels, if `num_labels` is not provided, and to translate label + indices to strings in `make_output_human_readable`. + seq2vec_encoder : `Seq2VecEncoder`, optional (default = `ClsPooler`) + The input to this module is assumed to be a sequence of encoded vectors. We use a + `Seq2VecEncoder` to compress this into a single vector on which we can perform + classification. If nothing is provided, we will use a `ClsPooler`, which simply takes the + first element of the sequence as the single vector (and is the standard thing to do when you + are running a classifier on top of a transformer). + feedforward : `FeedForward`, optional, (default = `None`) + An optional feedforward layer to apply on the pooled output before performing the + classification. + input_dim : `int`, optional (default = `None`) + We need to know how many dimensions to use for the final classification weight matrix. If + you have provided either a `seq2vec_encoder` or a `feedforward` module, we can get the + correct size from those objects. If you use default values for both of those parameters, + then you must provide this parameter, so that we know the size of that encoding. + dropout : `float`, optional (default = `None`) + Dropout percentage to use. + num_labels : `int`, optional (default = `None`) + Number of labels to project to in classification layer. By default, the classification layer will + project to the size of the vocabulary namespace corresponding to labels. + label_namespace : `str`, optional (default = `"labels"`) + Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. + """ + + def __init__( + self, + vocab: Vocabulary, + seq2vec_encoder: Seq2VecEncoder = None, + feedforward: Optional[FeedForward] = None, + input_dim: int = None, + dropout: float = None, + num_labels: int = None, + label_namespace: str = "labels", + ) -> None: + + super().__init__(vocab) + self._seq2vec_encoder = seq2vec_encoder or ClsPooler() + self._feedforward = feedforward + if feedforward is not None: + self._classifier_input_dim = self._feedforward.get_output_dim() + else: + self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() or input_dim + + if self._classifier_input_dim is None: + raise ValueError("No input dimension given!") + + if dropout: + self._dropout = torch.nn.Dropout(dropout) + else: + self._dropout = None + self._label_namespace = label_namespace + + if num_labels: + self._num_labels = num_labels + else: + self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace) + self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) + self._accuracy = CategoricalAccuracy() + self._loss = torch.nn.CrossEntropyLoss() + + def forward( # type: ignore + self, + encoded_text: torch.FloatTensor, + encoded_text_mask: torch.BoolTensor, + label: torch.IntTensor = None, + ) -> Dict[str, torch.Tensor]: + encoding = self._seq2vec_encoder(encoded_text, mask=encoded_text_mask) + + if self._dropout: + encoding = self._dropout(encoding) + + if self._feedforward is not None: + encoding = self._feedforward(encoding) + + logits = self._classification_layer(encoding) + probs = torch.nn.functional.softmax(logits, dim=-1) + + output_dict = {"logits": logits, "probs": probs} + if label is not None: + loss = self._loss(logits, label.long().view(-1)) + output_dict["loss"] = loss + self._accuracy(logits, label) + + return output_dict + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Does a simple argmax over the probabilities, converts index to string label, and + add `"label"` key to the dictionary with the result. + """ + if "probs" in output_dict: + predictions = output_dict["probs"] + if predictions.dim() == 2: + predictions_list = [predictions[i] for i in range(predictions.shape[0])] + else: + predictions_list = [predictions] + classes = [] + for prediction in predictions_list: + label_idx = prediction.argmax(dim=-1).item() + label_str = self.vocab.get_index_to_token_vocabulary(self._label_namespace).get( + label_idx, str(label_idx) + ) + classes.append(label_str) + output_dict["label"] = classes + return output_dict + + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics = {"accuracy": self._accuracy.get_metric(reset)} + return metrics diff --git a/allennlp/models/heads/head.py b/allennlp/models/heads/head.py new file mode 100644 index 00000000000..6fe45342aeb --- /dev/null +++ b/allennlp/models/heads/head.py @@ -0,0 +1,21 @@ +from allennlp.models.model import Model + + +class Head(Model): + """ + A `Head` is a `Model` that takes _already encoded input_ and typically does simple computation + before returning a loss. + + There isn't currently any difference in API between a `Model` and a `Head`, but we have this + separate type as both a signaling mechanism for what to expect when looking at a `Head` class, + and so that we can use this as a more informative type annotation when building models that use + `Heads` as inputs. + + One additional consideration in a `Head` is that `make_output_human_readable` needs to account + for the case where it gets called without first having `forward` be called on the head. This is + because at the point where we call `make_output_human_readable`, we don't know which heads were + used in `forward`, and trying to save the state is messy. So just make sure that you always + have conditional logic in `make_output_human_readable` when you implement a `Head`. + """ + + pass diff --git a/allennlp/models/multitask.py b/allennlp/models/multitask.py new file mode 100644 index 00000000000..201d800d9aa --- /dev/null +++ b/allennlp/models/multitask.py @@ -0,0 +1,198 @@ +from collections import defaultdict +import inspect +from typing import Any, Dict, List, Set + +from overrides import overrides +import torch + +from allennlp.data import Vocabulary +from allennlp.modules import Backbone +from allennlp.models.model import Model +from allennlp.models.heads import Head +from allennlp.nn import InitializerApplicator + + +def get_forward_arguments(module: torch.nn.Module) -> Set[str]: + signature = inspect.signature(module.forward) + return set([arg for arg in signature.parameters if arg != "self"]) + + +@Model.register("multitask") +class MultiTaskModel(Model): + """ + A `MultiTaskModel` consists of a `Backbone` that encodes its inputs in some way, then a + collection of `Heads` that make predictions from the backbone-encoded inputs. The predictions + of each `Head` are combined to compute a joint loss, which is then used for training. + + This model works by taking `**kwargs` in `forward`, and passing the right arguments from that to + the backbone and to each head. By default, we use `inspect` to try to figure out getting the + right arguments to the right modules, but we allow you to specify these arguments yourself in + case our inference code gets it wrong. + + It is the caller's responsibility to make sure that the backbone and all heads are compatible with + each other, and with the input data that comes from a `MultiTaskDatasetReader`. We give some + arguments in this class and in `MultiTaskDatasetReader` to help with plumbing the arguments in + complex cases (e.g., you can change argument names so that they match what the backbone and + heads expect). + + # Parameters + + vocab: `Vocab` + backbone: `Backbone` + heads: `Dict[str, Head]` + loss_weights: `Dict[str, float]`, optional (default = `equal weighting`) + If you want, you can specify a weight for each head, which we will multiply the loss by when + aggregating across heads. This is equivalent in many cases to specifying a separate + learning rate per head, and just putting a weighting on the loss is much easier than + figuring out the right way to specify that in the optimizer. + arg_name_mapping: `Dict[str, Dict[str, str]]`, optional (default = `identity mapping`) + The mapping changes the names in the `**kwargs` dictionary passed to `forward` before + passing on the arguments to the backbone and heads. This is keyed by component, and the + top-level keys must match the keys passed in the `heads` parameter, plus a "backbone" key + for the backbone. If you are using dataset readers that use dataset-specific names for + their keys, this lets you change them to be consistent. For example, this dictionary might + end up looking like this: `{"backbone": {"question": "text", "review": "text"}, + "classifier1": {"sentiment": "label"}, "classifier2": {"topic": "label"}}`. + Though in this particular example, we have two different inputs mapping to the same key in + the backbone; this will work, as long are you are careful that you don't give both of those + inputs in the same batch. If we see overlapping keys, we will crash. If you want to be able + to do this kind of mixed training in the same batch, you need to handle that in your data + code, not here; we won't handle complex batching inside this model. + allowed_arguments: `Dict[str, Set[str]]`, optional (default = `inferred`) + The list of arguments that should be passed from `**kwargs` to the `forward` method for the + backbone and each head. If you provide this, the keys in here should match the keys given + in the `heads` parameter, plus a "backbone" key for the backbone arguments. If not given, + we will use the `inspect` module to figure this out. The only time that this inference + might fail is if you have optional arguments that you want to be ignored, or + something. You very likely don't need to worry about this argument. + initializer: `InitializerApplicator`, optional (default=`InitializerApplicator()`) + If provided, will be used to initialize the model parameters. + """ + + def __init__( + self, + vocab: Vocabulary, + backbone: Backbone, + heads: Dict[str, Head], + *, + loss_weights: Dict[str, float] = None, + arg_name_mapping: Dict[str, Dict[str, str]] = None, + allowed_arguments: Dict[str, Set[str]] = None, + initializer: InitializerApplicator = InitializerApplicator(), + **kwargs, + ): + super().__init__(vocab, **kwargs) + self._backbone = backbone + self._heads = torch.nn.ModuleDict(heads) + self._arg_name_mapping = arg_name_mapping or defaultdict(dict) + + self._allowed_arguments = allowed_arguments or { + "backbone": get_forward_arguments(backbone), + **{key: get_forward_arguments(heads[key]) for key in heads}, + } + self._loss_weights = loss_weights or defaultdict(lambda: 1.0) + self._active_heads: List[str] = None + initializer(self) + + def set_active_heads(self, active_heads: List[str]) -> None: + """ + By default, the `MultiTaskModel` will try to infer which heads to run from the arguments + passed to `forward`. During training, we will only run a head if we have all of its + arguments, including optional arguments, which typically means the argument is the + prediction target; if we don't have it, we typically can't compute a loss, so running during + training is pointless. During evaluation, we will run all heads. + + If you want to limit which heads are run during evaluation, or if the inference for which + task to run during training is incorrect (e.g., if your head has multiple optional + arguments, and only some are actually required to compute a loss), then you can use this + method to override our inference and force the use of whatever set of heads you want. + + To get back to the default mode of operation, call this method with `None` as an argument. + """ + self._active_heads = active_heads + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore + backbone_arguments = self._get_arguments(kwargs, "backbone") + backbone_outputs = self._backbone(**backbone_arguments) + + outputs = {**backbone_outputs} + loss = None + for head_name in self._heads: + if self._active_heads is not None and head_name not in self._active_heads: + continue + + combined_arguments = {**backbone_outputs, **kwargs} + head_arguments = self._get_arguments(combined_arguments, head_name) + + if ( + self._active_heads is None + and self.training + and head_arguments.keys() != self._allowed_arguments[head_name] + ): + continue + + head_outputs = self._heads[head_name](**head_arguments) + for key in head_outputs: + outputs[f"{head_name}_{key}"] = head_outputs[key] + + if "loss" in head_outputs: + head_loss = self._loss_weights[head_name] * head_outputs["loss"] + if loss is None: + loss = head_loss + else: + loss += head_loss + + if loss is not None: + outputs["loss"] = loss + + return outputs + + def _get_arguments(self, available_args: Dict[str, Any], component: str) -> Dict[str, Any]: + """ + Given a list of things we might want to pass to a component (where "component" is either the + backbone or a head), this method figures out which things we should actually pass, by + mapping names and looking at allowed arguments. + """ + allowed_args = self._allowed_arguments[component] + name_mapping = self._arg_name_mapping[component] + kept_arguments = {} + for key, value in available_args.items(): + new_key = name_mapping.get(key, key) + if new_key in allowed_args: + if new_key in kept_arguments: + raise ValueError( + f"Got duplicate argument {new_key} for {component}. This likely means that" + " you mapped multiple inputs to the same name. This is generally ok for" + " the backbone, but you have to be sure each batch only gets one of those" + " inputs. This is typically not ok for heads, and means something is not" + " set up right." + ) + kept_arguments[new_key] = value + return kept_arguments + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics = {} + for head_name in self._heads: + if self._active_heads is not None and head_name not in self._active_heads: + continue + for key, value in self._heads[head_name].get_metrics(reset).items(): + metrics[f"{head_name}_{key}"] = value + return metrics + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + output_dict = self._backbone.make_output_human_readable(output_dict) + for head_name in self._heads: + if self._active_heads is not None and head_name not in self._active_heads: + continue + head_outputs = {} + for key, value in output_dict.items(): + if key.startswith(head_name): + head_outputs[key.replace(f"{head_name}_", "")] = value + readable_head_outputs = self._heads[head_name].make_output_human_readable(head_outputs) + for key, value in readable_head_outputs.items(): + output_dict[f"{head_name}_{key}"] = value + return output_dict diff --git a/allennlp/modules/__init__.py b/allennlp/modules/__init__.py index 2292ceabd73..0e47f36d0f6 100644 --- a/allennlp/modules/__init__.py +++ b/allennlp/modules/__init__.py @@ -5,6 +5,7 @@ """ from allennlp.modules.attention import Attention +from allennlp.modules.backbones import Backbone from allennlp.modules.bimpm_matching import BiMpmMatching from allennlp.modules.conditional_random_field import ConditionalRandomField from allennlp.modules.elmo import Elmo diff --git a/allennlp/modules/backbones/__init__.py b/allennlp/modules/backbones/__init__.py new file mode 100644 index 00000000000..5b8c98a55bf --- /dev/null +++ b/allennlp/modules/backbones/__init__.py @@ -0,0 +1,2 @@ +from allennlp.modules.backbones.backbone import Backbone +from allennlp.modules.backbones.pretrained_transformer_backbone import PretrainedTransformerBackbone diff --git a/allennlp/modules/backbones/backbone.py b/allennlp/modules/backbones/backbone.py new file mode 100644 index 00000000000..e4bb14f605b --- /dev/null +++ b/allennlp/modules/backbones/backbone.py @@ -0,0 +1,41 @@ +from typing import Dict + +import torch + +from allennlp.common import Registrable + + +class Backbone(Registrable, torch.nn.Module): + """ + A `Backbone` operates on basic model inputs and produces some encoding of those inputs that will + be shared among one or more `Heads` in a multi-task setting. For plain text inputs, this is + often a transformer. + + The main purpose of this class is to give us a `Registrable` class that we can use as a type + annotation on `Model` classes that want to use a backbone. The expectation is that this will + take the same inputs as a typical model, but return intermediate representations. These should + generally be returned as a dictionary, from which the caller will have to pull out what they + want and use as desired. As a convention that these modules should generally follow, their + outputs should have the same name as the given input, prepended with `encoded_`. So, a backbone + that encodes a `text` input should return an output called `encoded_text`. This convention + allows easier exchangeability of these backbone modules. + + Additionally, as downstream `Heads` will typically need mask information, but after encoding + have no way of computing it, a `Backbone` should also return a mask for each of its outputs, + with the same name as the output but with `_mask` appended. So in our example of `text` as + input, the output should have an entry called `encoded_text_mask`. + + Because a `Backbone` handles model inputs, if you want to make those inputs human readable + (e.g., for displaying them in a demo), then it's typically only the `Backbone` object that knows + how to do that. So we also implement the `make_output_human_readable` function from the `Model` + class. The implementation in the base class does nothing, but concrete classes should generally + convert whatever input indices are saved to the output into text. + """ + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + return output_dict diff --git a/allennlp/modules/backbones/pretrained_transformer_backbone.py b/allennlp/modules/backbones/pretrained_transformer_backbone.py new file mode 100644 index 00000000000..f55c3521302 --- /dev/null +++ b/allennlp/modules/backbones/pretrained_transformer_backbone.py @@ -0,0 +1,116 @@ +from typing import Dict, Optional + +from overrides import overrides +import torch + +from allennlp.data import TextFieldTensors, Vocabulary +from allennlp.modules.backbones.backbone import Backbone +from allennlp.modules.token_embedders.pretrained_transformer_embedder import ( + PretrainedTransformerEmbedder, +) +from allennlp.nn import util + + +@Backbone.register("pretrained_transformer") +class PretrainedTransformerBackbone(Backbone): + """ + Uses a pretrained model from `transformers` as a `Backbone`. + + This class passes most of its arguments to a `PretrainedTransformerEmbedder`, which it uses to + implement the underlying encoding logic (we duplicate the arguments here instead of taking an + `Embedder` as a constructor argument just to simplify the user-facing API). + + Registered as a `Backbone` with name "pretrained_transformer". + + # Parameters + + vocab : `Vocabulary` + Necessary for converting input ids to strings in `make_output_human_readable`. If you set + `output_token_strings` to `False`, or if you never call `make_output_human_readable`, then + this will not be used and can be safely set to `None`. + model_name : `str` + The name of the `transformers` model to use. Should be the same as the corresponding + `PretrainedTransformerIndexer`. + max_length : `int`, optional (default = `None`) + If positive, folds input token IDs into multiple segments of this length, pass them + through the transformer model independently, and concatenate the final representations. + Should be set to the same value as the `max_length` option on the + `PretrainedTransformerIndexer`. + sub_module: `str`, optional (default = `None`) + The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act + as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just + want to use the encoder. + train_parameters: `bool`, optional (default = `True`) + If this is `True`, the transformer weights get updated during training. + last_layer_only: `bool`, optional (default = `True`) + When `True` (the default), only the final layer of the pretrained transformer is taken + for the embeddings. But if set to `False`, a scalar mix of all of the layers + is used. + output_token_strings : `bool`, optional (default = `True`) + If `True`, we will add the input token ids to the output dictionary in `forward` (with key + "token_ids"), and convert them to strings in `make_output_human_readable` (with key + "tokens"). This is necessary for certain demo functionality, and it adds only a trivial + amount of computation if you are not using a demo. + vocab_namespace : `str`, optional (default = `"tags"`) + The namespace to use in conjunction with the `Vocabulary` above. We use a somewhat + confusing default of "tags" here, to match what is done in `PretrainedTransformerIndexer`. + """ + + def __init__( + self, + vocab: Vocabulary, + model_name: str, + *, + max_length: int = None, + sub_module: str = None, + train_parameters: bool = True, + last_layer_only: bool = True, + override_weights_file: Optional[str] = None, + override_weights_strip_prefix: Optional[str] = None, + output_token_strings: bool = True, + vocab_namespace: str = "tags", + ) -> None: + super().__init__() + self._vocab = vocab + self._namespace = vocab_namespace + self._embedder = PretrainedTransformerEmbedder( + model_name=model_name, + max_length=max_length, + sub_module=sub_module, + train_parameters=train_parameters, + last_layer_only=last_layer_only, + override_weights_file=override_weights_file, + override_weights_strip_prefix=override_weights_strip_prefix, + ) + self._output_token_strings = output_token_strings + + def forward(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]: # type: ignore + if len(text) != 1: + raise ValueError( + "PretrainedTransformerBackbone is only compatible with using a single TokenIndexer" + ) + text_inputs = next(iter(text.values())) + mask = util.get_text_field_mask(text) + encoded_text = self._embedder(**text_inputs) + outputs = {"encoded_text": encoded_text, "encoded_text_mask": mask} + if self._output_token_strings: + outputs["token_ids"] = util.get_token_ids_from_text_field_tensors(text) + return outputs + + @overrides + def make_output_human_readable( + self, output_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + if not self._output_token_strings: + return output_dict + + tokens = [] + for instance_tokens in output_dict["token_ids"]: + tokens.append( + [ + self._vocab.get_token_from_index(token_id.item(), namespace=self._namespace) + for token_id in instance_tokens + ] + ) + output_dict["tokens"] = tokens + return output_dict diff --git a/tests/models/multitask_test.py b/tests/models/multitask_test.py new file mode 100644 index 00000000000..36157be5661 --- /dev/null +++ b/tests/models/multitask_test.py @@ -0,0 +1,109 @@ +import pytest + +from allennlp.common.testing import ModelTestCase +from allennlp.data import Instance, Vocabulary +from allennlp.data.fields import LabelField, TextField +from allennlp.data.token_indexers import PretrainedTransformerIndexer +from allennlp.data.tokenizers import PretrainedTransformerTokenizer +from allennlp.models.heads import ClassifierHead +from allennlp.models import MultiTaskModel +from allennlp.modules.backbones import PretrainedTransformerBackbone + + +class TestMultiTaskModel(ModelTestCase): + def test_forward_works(self): + # Setting up the model. + transformer_name = "epwalsh/bert-xsmall-dummy" + vocab = Vocabulary() + backbone = PretrainedTransformerBackbone(vocab, transformer_name) + head1 = ClassifierHead(vocab, input_dim=20, num_labels=3) + head2 = ClassifierHead(vocab, input_dim=20, num_labels=4) + # We'll start with one head, and add another later. + model = MultiTaskModel(vocab, backbone, {"cls": head1}) + + # Setting up the data. + tokenizer = PretrainedTransformerTokenizer(model_name=transformer_name) + token_indexers = PretrainedTransformerIndexer(model_name=transformer_name) + tokens = tokenizer.tokenize("This is a test") + text_field = TextField(tokens, {"tokens": token_indexers}) + label_field1 = LabelField(1, skip_indexing=True) + label_field2 = LabelField(3, skip_indexing=True) + instance = Instance({"text": text_field, "label": label_field1}) + + # Now we run some tests. First, the default. + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "cls_logits" in outputs + assert "loss" in outputs + assert "cls_loss" in outputs + + # When we force the model not to use a head, even when we have all of its inputs. + model.set_active_heads([]) + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "loss" not in outputs + assert "cls_logits" not in outputs + model.set_active_heads(None) + + # When we don't have all of the inputs for a head. + instance = Instance({"text": text_field}) + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "cls_logits" not in outputs + assert "loss" not in outputs + + # When we don't have all of the inputs for a head, but we run it anyway. We should run it + # anyway in two scenarios: (1) when active_heads is set, and when we're in eval mode. + model.set_active_heads(["cls"]) + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "loss" not in outputs # no loss because we have no labels + assert "cls_logits" in outputs # but we can compute logits + model.set_active_heads(None) + + model.eval() + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "loss" not in outputs # no loss because we have no labels + assert "cls_logits" in outputs # but we can compute logits + model.train() + + # Now for two headed and other more complex tests. + model = MultiTaskModel( + vocab, + backbone, + {"cls1": head1, "cls2": head2}, + arg_name_mapping={ + "cls1": {"label1": "label"}, + "cls2": {"label2": "label"}, + "backbone": {"question": "text"}, + }, + ) + + # Basic case where things should work, with two heads that both need label inputs. + instance = Instance({"text": text_field, "label1": label_field1, "label2": label_field2}) + outputs = model.forward_on_instance(instance) + assert "encoded_text" in outputs + assert "cls1_logits" in outputs + assert "cls1_loss" in outputs + assert "cls2_logits" in outputs + assert "cls2_loss" in outputs + assert "loss" in outputs + combined_loss = outputs["cls1_loss"].item() + outputs["cls2_loss"].item() + assert abs(outputs["loss"].item() - combined_loss) <= 1e-6 + + # This should fail, because we are using the same label field for both heads, but it's the + # wrong label for cls1, and the sizes don't match. This shows up as an IndexError in this + # case. It'd be nice to catch this kind of error more cleanly in the model class, but I'm + # not sure how. + instance = Instance({"text": text_field, "label": label_field2}) + with pytest.raises(IndexError): + outputs = model.forward_on_instance(instance) + + # This one should fail because we now have two things that map to "text" in the backbone, + # and they would clobber each other. The name mapping that we have in the model is ok, as + # long as our data loader is set up such that we don't batch instances that have both of + # these fields at the same time. + instance = Instance({"question": text_field, "text": text_field}) + with pytest.raises(ValueError, match="duplicate argument text"): + outputs = model.forward_on_instance(instance)