Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter support for Hubert #551

Open
wants to merge 4 commits into
base: legacy
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
"GPT2ModelWithHeads",
],
"models.gptj": ["GPTJAdapterModel"],
"models.hubert": [
"HubertAdapterModel",
"HubertModelWithHeads",
],
"models.mbart": [
"MBartAdapterModel",
"MBartModelWithHeads",
Expand Down Expand Up @@ -218,6 +222,7 @@
from .models.distilbert import DistilBertAdapterModel, DistilBertModelWithHeads
from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads
from .models.gptj import GPTJAdapterModel
from .models.hubert import HubertAdapterModel, HubertModelWithHeads
from .models.mbart import MBartAdapterModel, MBartModelWithHeads
from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads
from .models.t5 import T5AdapterModel, T5ModelWithHeads
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/adapters/mixins/hubert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin


class HubertEncoderLayerAdaptersMixin:
"""Adds adapters to the Encoder Layer module of Hubert."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class HubertEncoderLayerStableLayerNormAdaptersMixin:
"""Adds adapters to the Encoder Layer Stable Layer Norm module of Hubert."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class HubertModelAdaptersMixin(ModelAdaptersMixin):
"""Adds adapters to the Hubert module."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.transformer.layer):
yield i, layer


class HubertModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
pass
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
("mbart", "MBartAdapterModel"),
("gpt2", "GPT2AdapterModel"),
("gptj", "GPTJAdapterModel"),
("hubert", "HubertAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
]
Expand All @@ -34,6 +35,7 @@
("bart", "BartModelWithHeads"),
("mbart", "MBartModelWithHeads"),
("gpt2", "GPT2ModelWithHeads"),
("hubert", "HubertModelWithHeads"),
("t5", "T5ModelWithHeads"),
]
)
Expand Down
42 changes: 42 additions & 0 deletions src/transformers/adapters/models/hubert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": [
"HubertAdapterModel",
"HubertModelWithHeads",
],
}


if TYPE_CHECKING:
from .adapter_model import HubertAdapterModel, HubertModelWithHeads

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
208 changes: 208 additions & 0 deletions src/transformers/adapters/models/hubert/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import warnings

import torch.nn as nn

from ....models.hubert.modeling_hubert import (
HUBERT_INPUTS_DOCSTRING,
HUBERT_START_DOCSTRING,
HubertModel,
HubertPreTrainedModel,
)
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
MultipleChoiceHead,
)


@add_start_docstrings(
"""Hubert Model with the option to add multiple flexible heads on top.""",
HUBERT_START_DOCSTRING,
)
class HubertAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.hubert = HubertModel(config)

self._init_head_modules()

self.init_weights()

def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.hubert.get_position_embeddings()

def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.

Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.hubert.resize_position_embeddings(new_num_position_embeddings)

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)

hubert_output = self.hubert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

outputs = self.forward_head(
hubert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs
)

return outputs

# Copied from RobertaForCausalLM
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
}

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"multiple_choice": MultipleChoiceHead,
}

def add_classification_head(
self,
head_name,
num_labels=2,
layers=2,
activation_function="tanh",
overwrite_ok=False,
multilabel=False,
id2label=None,
use_pooler=False,
):
"""
Adds a sequence classification head on top of the model.

Args:
head_name (str): The name of the head.
num_labels (int, optional): Number of classification labels. Defaults to 2.
layers (int, optional): Number of layers. Defaults to 2.
activation_function (str, optional): Activation function. Defaults to 'tanh'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
multilabel (bool, optional): Enable multilabel classification setup. Defaults to False.
"""

if multilabel:
head = MultiLabelClassificationHead(
self, head_name, num_labels, layers, activation_function, id2label, use_pooler
)
else:
head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label, use_pooler)
self.add_prediction_head(head, overwrite_ok)

def add_multiple_choice_head(
self,
head_name,
num_choices=2,
layers=2,
activation_function="tanh",
overwrite_ok=False,
id2label=None,
use_pooler=False,
):
"""
Adds a multiple choice head on top of the model.

Args:
head_name (str): The name of the head.
num_choices (int, optional): Number of choices. Defaults to 2.
layers (int, optional): Number of layers. Defaults to 2.
activation_function (str, optional): Activation function. Defaults to 'tanh'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = MultipleChoiceHead(self, head_name, num_choices, layers, activation_function, id2label, use_pooler)
self.add_prediction_head(head, overwrite_ok)


class HubertModelWithHeads(HubertAdapterModel):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
self.__class__.__bases__[0].__name__
),
FutureWarning,
)
super().__init__(*args, **kwargs)

@classmethod
def from_config(cls, config):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_config(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
1 change: 1 addition & 0 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"hidden_dropout_prob": "resid_pdrop",
"attention_probs_dropout_prob": "attn_pdrop",
},
"hubert": {},
"mbart": {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
Expand Down
Loading