From cbc7842c3b5b2a1e9f7ccba7221e1daf43bc0c7d Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Mon, 27 Feb 2023 06:31:41 +0100 Subject: [PATCH] Optional bias Term for `TiedTextOutputAdapter` - closes #40 --- perceiver/model/text/clm.py | 4 +++- perceiver/model/text/common.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/perceiver/model/text/clm.py b/perceiver/model/text/clm.py index 7cdc792..556a46f 100644 --- a/perceiver/model/text/clm.py +++ b/perceiver/model/text/clm.py @@ -22,6 +22,7 @@ class CausalLanguageModelConfig(PerceiverARConfig): max_latents: int = 512 num_channels: int = 512 output_norm: bool = False + output_bias: bool = True init_scale: float = 0.02 @classmethod @@ -61,7 +62,7 @@ def __init__(self, config: CausalLanguageModelConfig): if config.output_norm: self.out_norm = nn.LayerNorm(config.num_channels) - self.output_adapter = common.TiedTextOutputAdapter(vocab_size=config.vocab_size) + self.output_adapter = common.TiedTextOutputAdapter(vocab_size=config.vocab_size, emb_bias=config.output_bias) self._init_parameters(config.init_scale) def _init_parameters(self, init_scale: float): @@ -171,6 +172,7 @@ def __init__( cross_attention_dropout: float = 0.5, post_attention_dropout: float = 0.0, output_norm: bool = False, + output_bias: bool = True, init_scale: float = 0.02, activation_checkpointing=False, activation_offloading=False, diff --git a/perceiver/model/text/common.py b/perceiver/model/text/common.py index b649e33..22f0319 100644 --- a/perceiver/model/text/common.py +++ b/perceiver/model/text/common.py @@ -46,12 +46,18 @@ def forward(self, x, abs_pos=None): class TiedTextOutputAdapter(OutputAdapter): - def __init__(self, vocab_size: int): + def __init__(self, vocab_size: int, emb_bias: bool = True): super().__init__() - self.bias = nn.Parameter(torch.zeros(vocab_size)) + self._emb_bias = emb_bias + if emb_bias: + self.bias = nn.Parameter(torch.zeros(vocab_size)) def forward(self, x, txt_embedding: nn.Embedding): - return torch.matmul(x, txt_embedding.weight.T) + self.bias + result = torch.matmul(x, txt_embedding.weight.T) + if self._emb_bias: + return result + self.bias + else: + return result class TextEncoder(PerceiverEncoder):