Skip to content

Commit

Permalink
Add UniPELT implementation (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Sep 9, 2022
1 parent a38e1f1 commit 8440ab4
Show file tree
Hide file tree
Showing 49 changed files with 480 additions and 59 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ Currently, adapter-transformers integrates all architectures and methods listed
| AdapterDrop | [Rücklé et al. (2021)](https://arxiv.org/pdf/2010.11918.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/05_Adapter_Drop_Training.ipynb) |
| MAD-X 2.0,<br> Embedding training | [Pfeiffer et al. (2021)](https://arxiv.org/pdf/2012.15562.pdf) | [Docs: Embeddings](https://docs.adapterhub.ml/embeddings.html), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/08_NER_Wikiann.ipynb) |
| Prefix Tuning | [Li and Liang (2021)](https://arxiv.org/pdf/2101.00190.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#prefix-tuning) |
| Parallel adapters,<br> Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#combinations-mix-and-match-adapters) |
| Parallel adapters,<br> Mix-and-Match adapters | [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#mix-and-match-adapters) |
| Compacter | [Mahabadi et al. (2021)](https://arxiv.org/pdf/2106.04647.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#compacter) |
| LoRA | [Hu et al. (2021)](https://arxiv.org/pdf/2106.09685.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#lora) |
| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia3) |
| (IA)^3 | [Liu et al. (2022)](https://arxiv.org/pdf/2205.05638.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#ia-3) |
| UniPELT | [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) | [Docs](https://docs.adapterhub.ml/overview.html#unipelt) |

## Supported Models

Expand Down
3 changes: 3 additions & 0 deletions adapter_docs/classes/adapter_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ Combined configurations
.. autoclass:: transformers.MAMConfig
:members:

.. autoclass:: transformers.UniPELTConfig
:members:

Adapter Fusion
~~~~~~~~~~~~~~~

Expand Down
73 changes: 72 additions & 1 deletion adapter_docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ model.reset_adapter("ia3_adapter")
_Papers:_
- [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022)

## Combinations - Mix-and-Match Adapters
## Method Combinations

_Configuration class_: [`ConfigUnion`](transformers.ConfigUnion)

Expand All @@ -290,6 +290,10 @@ config = ConfigUnion(
model.add_adapter("union_adapter", config=config)
```

### Mix-and-Match Adapters

_Configuration class_: [`MAMConfig`](transformers.MAMConfig)

[He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) study various variants and combinations of efficient fine-tuning methods.
Among others, they propose _Mix-and-Match Adapters_ as a combination of Prefix Tuning and parallel bottleneck adapters.
This configuration is supported by adapter-transformers out-of-the-box:
Expand All @@ -315,3 +319,70 @@ model.add_adapter("mam_adapter", config=config)

_Papers:_
- [Towards a Unified View of Parameter-Efficient Transfer Learning](https://arxiv.org/pdf/2110.04366.pdf) (He et al., 2021)

### UniPELT

_Configuration class_: [`UniPELTConfig`](transformers.UniPELTConfig)

An approach similar to the work of [He et al. (2021)](https://arxiv.org/pdf/2110.04366.pdf) is taken by [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) in their _UniPELT_ framework.
They, too, combine multiple efficient fine-tuning methods, namely LoRA, Prefix Tuning and bottleneck adapters, in a single unified setup.
_UniPELT_ additionally introduces a gating mechanism that controls the activation of the different submodules.

Concretely, for each adapted module $m$, UniPELT adds a trainable gating value $\mathcal{G}_m \in (0, 1)$ that is computed via a feed-forward network ($W_{\mathcal{G}_m}$) and sigmoid activation ($\sigma$) from the Transformer layer input states ($x$):

$$\mathcal{G}_m \leftarrow \sigma(W_{\mathcal{G}_m} \cdot x)$$

These gating values are then used to scale the output activations of the injected adapter modules, e.g. for a LoRA layer:

$$
h \leftarrow W_0 x + \mathcal{G}_{LoRA} B A x
$$

In the configuration classes of `adapter-transformers`, these gating mechanisms can be activated via `use_gating=True`.
The full UniPELT setup can be instantiated using `UniPELTConfig`[^unipelt]:

[^unipelt]: Note that the implementation of UniPELT in `adapter-transformers` follows the implementation in the original code, which is slighlty different from the description in the paper. See [here](https://github.com/morningmoni/UniPELT/issues/1) for more.

```python
from transformers.adapters import UniPELTConfig

config = UniPELTConfig()
model.add_adapter("unipelt", config=config)
```

which is identical to the following `ConfigUnion`:

```python
from transformers.adapters import ConfigUnion, LoRAConfig, PrefixTuningConfig, PfeifferConfig

config = ConfigUnion(
LoRAConfig(r=8, use_gating=True),
PrefixTuningConfig(prefix_length=10, use_gating=True),
PfeifferConfig(reduction_factor=16, use_gating=True),
)
model.add_adapter("unipelt", config=config)
```

Finally, as the gating values for each adapter module might provide interesting insights for analysis, `adapter-transformers` comes with an integrated mechanism of returning all gating values computed during a model forward pass via the `output_adapter_gating_scores` parameter:

```python
outputs = model(**inputs, output_adapter_gating_scores=True)
gating_scores = outputs.adapter_gating_scores
```
Note that this parameter is only available to base model classes and [AdapterModel classes](prediction_heads.md#adaptermodel-classes).
In the example, `gating_scores` holds a dictionary of the following form:
```
{
'<adapter_name>': {
<layer_id>: {
'<module_location>': np.array([...]),
...
},
...
},
...
}
```

_Papers:_
- [UNIPELT: A Unified Framework for Parameter-Efficient Language Model Tuning](https://arxiv.org/pdf/2110.07577.pdf) (Mao et al., 2022)
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,7 @@
"StaticAdapterFusionConfig",
"T5AdapterModel",
"T5ModelWithHeads",
"UniPELTConfig",
"ViTAdapterModel",
"XLMRobertaAdapterModel",
"XLMRobertaModelWithHeads",
Expand Down Expand Up @@ -4600,6 +4601,7 @@
StaticAdapterFusionConfig,
T5AdapterModel,
T5ModelWithHeads,
UniPELTConfig,
ViTAdapterModel,
XLMRobertaAdapterModel,
XLMRobertaModelWithHeads,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"PfeifferInvConfig",
"PrefixTuningConfig",
"StaticAdapterFusionConfig",
"UniPELTConfig",
],
"context": [
"AdapterSetup",
Expand Down Expand Up @@ -175,6 +176,7 @@
PfeifferInvConfig,
PrefixTuningConfig,
StaticAdapterFusionConfig,
UniPELTConfig,
)
from .context import AdapterSetup, ForwardContext
from .heads import (
Expand Down
39 changes: 39 additions & 0 deletions src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class AdapterConfig(AdapterConfigBase):
Scaling factor to use for scaled addition of adapter outputs as done by He et al. (2021). Can bei either a
constant factor (float) or the string "learned", in which case the scaling factor is learned. Defaults to
1.0.
use_gating (:obj:`bool`, optional):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False.
residual_before_ln (:obj:`bool`, optional):
If True, take the residual connection around the adapter bottleneck before the layer normalization. Only
applicable if :obj:`original_ln_before` is True.
Expand Down Expand Up @@ -224,6 +227,7 @@ class AdapterConfig(AdapterConfigBase):
init_weights: str = "bert"
is_parallel: bool = False
scaling: Union[float, str] = 1.0
use_gating: bool = False
residual_before_ln: bool = True
adapter_residual_before_ln: bool = False
inv_adapter: Optional[str] = None
Expand Down Expand Up @@ -362,6 +366,12 @@ class PrefixTuningConfig(AdapterConfigBase):
non_linearity (str): If flat=False, the non-linearity used in the bottleneck MLP.
dropout (float): The dropout rate used in the prefix tuning layer.
leave_out (List[int]): The IDs of the layers (starting at 0) where NO prefix should be added.
use_gating (:obj:`bool`, optional):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False.
shared_gating (:
obj:`bool`, optional): Whether to use a shared gate for the prefixes of all attention matrices. Only
applicable if `use_gating=True`. Defaults to True.
"""

architecture: Optional[str] = "prefix_tuning"
Expand All @@ -375,6 +385,8 @@ class PrefixTuningConfig(AdapterConfigBase):
bottleneck_size: int = 512
non_linearity: str = "tanh"
dropout: float = 0.0
use_gating: bool = False
shared_gating: bool = True


@dataclass(eq=False)
Expand Down Expand Up @@ -402,6 +414,10 @@ class LoRAConfig(AdapterConfigBase):
(IA)^3). "scale" can only be used together with r=1. Defaults to "add".
init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules.
Currently, this can be either "lora" (default) or "bert".
use_gating (:obj:`bool`, optional):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
`merge_adapter()`.
"""

architecture: Optional[str] = "lora"
Expand All @@ -416,6 +432,7 @@ class LoRAConfig(AdapterConfigBase):
attn_matrices: List[str] = field(default_factory=lambda: ["q", "v"])
composition_mode: str = "add"
init_weights: str = "lora"
use_gating: bool = False


@dataclass(eq=False)
Expand All @@ -436,6 +453,7 @@ class IA3Config(LoRAConfig):
attn_matrices: List[str] = field(default_factory=lambda: ["k", "v"])
composition_mode: str = "scale"
init_weights: str = "ia3"
use_gating: bool = False


class ConfigUnion(AdapterConfigBase):
Expand Down Expand Up @@ -548,6 +566,26 @@ def adapter(self):
return self[1]


class UniPELTConfig(ConfigUnion):
"""
The UniPELT adapter architecture proposed by Mao et al. (2022). See https://arxiv.org/pdf/2110.07577.pdf.
"""

def __init__(
self,
prefix_tuning: Optional[PrefixTuningConfig] = None,
adapter: Optional[AdapterConfig] = None,
lora: Optional[LoRAConfig] = None,
):
components = [
prefix_tuning or PrefixTuningConfig(prefix_length=10),
adapter or PfeifferConfig(reduction_factor=16),
lora or LoRAConfig(r=8),
]

super().__init__(*[c.replace(use_gating=True) for c in components])


ADAPTER_CONFIG_MAP = {
"pfeiffer": PfeifferConfig(),
"houlsby": HoulsbyConfig(),
Expand All @@ -562,6 +600,7 @@ def adapter(self):
"lora": LoRAConfig(),
"ia3": IA3Config(),
"mam": MAMConfig(),
"unipelt": UniPELTConfig(),
}

DEFAULT_ADAPTER_CONFIG = "pfeiffer"
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = ["adapter_gating_scores"]

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
Expand All @@ -99,8 +101,21 @@ def wrap(cls, f):
@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.config.adapters is not None:
with cls(self, *args, **kwargs):
with cls(self, *args, **kwargs) as ctx:
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
return results
else:
return f(self, *args, **kwargs)
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from ...utils import ModelOutput
from ..composition import AdapterCompositionBlock, BatchSplit, Parallel, parse_heads_from_composition
from ..context import AdapterSetup
from ..context import AdapterSetup, ForwardContext
from ..model_mixin import ModelWithHeadsAdaptersMixin
from ..modeling import Activation_Function_Class

Expand Down Expand Up @@ -790,7 +790,7 @@ def _get_head_input(outputs, cls_out, batch):
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif self.has_parallel_adapters or isinstance(self.active_head, Parallel):
if len(self.active_head) != self.config.adapters.active_setup.parallel_channels:
raise ValueError("The number of parallel adapters and the number of active heads must match.")
Expand All @@ -807,16 +807,22 @@ def _get_head_input(outputs, cls_out, batch):
if all("loss" in out and out["loss"] is not None for out in head_outputs)
else None
)
return MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
return_output = MultiHeadOutput(head_outputs=head_outputs, loss=combined_loss)
elif len(used_heads) > 1:
head_outputs = []
for head in used_heads:
head_module = self.heads[head]
head_outputs.append(head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs))
return head_outputs
return_output = MultiHeadOutput(head_outputs=head_outputs)
else:
head_module = self.heads[used_heads[0]]
return head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)
return_output = head_module(all_outputs, cls_output, attention_mask, return_dict, **kwargs)

if isinstance(return_output, ModelOutput):
for attr in ForwardContext.context_attributes:
if attr not in return_output and attr in all_outputs:
return_output[attr] = all_outputs[attr]
return return_output

def get_labels_dict(self, head_name=None):
"""
Expand Down
Loading

0 comments on commit 8440ab4

Please sign in to comment.