Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BertGeneration #480

Merged
merged 26 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions adapter_docs/classes/models/bert-generation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

BertGeneration
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The BertGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using
EncoderDecoderModel as proposed in `Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.

The abstract from the paper is the following:

*Unsupervised pretraining of large neural models has recently revolutionized Natural Language Processing. By
warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple
benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language
Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We
developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT,
GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both
encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation,
Text Summarization, Sentence Splitting, and Sentence Fusion.*


BertGenerationAdapterModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.adapters.BertGenerationAdapterModel
:members:
:inherited-members: BertGenerationPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/bart
classes/models/beit
classes/models/bert
classes/models/bert-generation
classes/models/clip
classes/models/deberta
classes/models/deberta_v2
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The table below further shows which model architectures support which adaptation
| --------------------------------------- | -| - | - | - | - | - | - |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | |
| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,8 @@
"BartModelWithHeads",
"BeitAdapterModel",
"BertAdapterModel",
"BertGenerationAdapterModel",
"BertModelWithHeads",
"BertModelWithHeads",
"CompacterConfig",
"CompacterPlusPlusConfig",
Expand Down Expand Up @@ -5038,6 +5040,7 @@
BartModelWithHeads,
BeitAdapterModel,
BertAdapterModel,
BertGenerationAdapterModel,
BertModelWithHeads,
CompacterConfig,
CompacterPlusPlusConfig,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
"BertAdapterModel",
"BertModelWithHeads",
],
"models.bert_generation": ["BertGenerationAdapterModel"],
"models.deberta": ["DebertaAdapterModel"],
"models.debertaV2": ["DebertaV2AdapterModel"],
"models.distilbert": [
Expand Down Expand Up @@ -209,6 +210,7 @@
from .models.bart import BartAdapterModel, BartModelWithHeads
from .models.beit import BeitAdapterModel
from .models.bert import BertAdapterModel, BertModelWithHeads
from .models.bert_generation import BertGenerationAdapterModel
from .models.deberta import DebertaAdapterModel
from .models.debertaV2 import DebertaV2AdapterModel
from .models.distilbert import DistilBertAdapterModel, DistilBertModelWithHeads
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b
"t5",
"vit",
"xlm-roberta",
"bert-generation",
],
}

Expand Down
12 changes: 12 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@
"cls.predictions.decoder",
],
},
# BertGeneration
"BertGenerationDecoder": {
"config": {
"head_type": "causal_lm",
"layers": 1,
"activation_function": None,
"bias": True,
},
"layers": [
"lm_head.decoder",
],
},
# RoBERTa
"RobertaForSequenceClassification": {
"config": {
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def set_active_embeddings(self, name):
"""
self.loaded_embeddings[self.active_embeddings] = self.get_input_embeddings()
self.set_input_embeddings(self.loaded_embeddings[name])
self.config.vocab_size = self.loaded_embeddings[name].num_embeddings
self._active_embedding = name

@property
Expand Down Expand Up @@ -487,9 +488,11 @@ def _add_adapter_weights(self, adapter_name: str):
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# PHM Layer
if self.config.adapters.match(adapter_name, AdapterConfig, location_key="phm_layer"):
self.base_model.shared_parameters[adapter_name] = (
list(self.get_adapter(adapter_name)[0].values())[0].adapter_down[0].init_shared_parameters()
)
adapter_module = list(self.get_adapter(adapter_name)[0].values())[0]
# if multiple adapters with same location key exist they are returned as a modulelist
if isinstance(adapter_module, nn.ModuleList):
adapter_module = adapter_module[0]
self.base_model.shared_parameters[adapter_name] = adapter_module.adapter_down[0].init_shared_parameters()
# Prefix Tuning
for module in self.modules():
if isinstance(module, PrefixTuningPool):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
("roberta", "RobertaAdapterModel"),
("beit", "BeitAdapterModel"),
("bert", "BertAdapterModel"),
("bert-generation", "BertGenerationAdapterModel"),
("distilbert", "DistilBertAdapterModel"),
("deberta-v2", "DebertaV2AdapterModel"),
("deberta", "DebertaAdapterModel"),
Expand Down
39 changes: 39 additions & 0 deletions src/transformers/adapters/models/bert_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": ["BertGenerationAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import BertGenerationAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
161 changes: 161 additions & 0 deletions src/transformers/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from ....models.bert_generation.modeling_bert_generation import (
BERT_GENERATION_INPUTS_DOCSTRING,
BERT_GENERATION_START_DOCSTRING,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
)
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...context import AdapterSetup
from ...heads import (
BertStyleMaskedLMHead,
BiaffineParsingHead,
CausalLMHead,
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
QuestionAnsweringHead,
TaggingHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"""Bert Model transformer with the option to add multiple flexible heads on top.""",
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationAdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertGenerationPreTrainedModel
):
_keys_to_ignore_on_load_unexpected = [r"lm_head.bias"]

def __init__(self, config):
super().__init__(config)

self.bert = BertGenerationEncoder(config)

self._init_head_modules()

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
else:
head_inputs = outputs

if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
head_inputs,
head_name=head,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)
return head_outputs
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

# Copied from BertLMHeadModel
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"tagging": TaggingHead,
"multiple_choice": MultipleChoiceHead,
"question_answering": QuestionAnsweringHead,
"dependency_parsing": BiaffineParsingHead,
"masked_lm": BertStyleMaskedLMHead,
"causal_lm": CausalLMHead,
}
hSterz marked this conversation as resolved.
Show resolved Hide resolved

def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False):
"""
Adds a masked language modeling head on top of the model.

Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

def add_causal_lm_head(self, head_name, activation_function=None, overwrite_ok=False):
"""
Adds a causal language modeling head on top of the model.

Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = CausalLMHead(
self, head_name, layers=1, activation_function=activation_function, layer_norm=True, bias=True
)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)
Loading