Skip to content

Commit

Permalink
Optional bias Term for TiedTextOutputAdapter
Browse files Browse the repository at this point in the history
- closes #40
  • Loading branch information
krasserm committed Feb 27, 2023
1 parent 409ea33 commit cbc7842
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion perceiver/model/text/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CausalLanguageModelConfig(PerceiverARConfig):
max_latents: int = 512
num_channels: int = 512
output_norm: bool = False
output_bias: bool = True
init_scale: float = 0.02

@classmethod
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, config: CausalLanguageModelConfig):
if config.output_norm:
self.out_norm = nn.LayerNorm(config.num_channels)

self.output_adapter = common.TiedTextOutputAdapter(vocab_size=config.vocab_size)
self.output_adapter = common.TiedTextOutputAdapter(vocab_size=config.vocab_size, emb_bias=config.output_bias)
self._init_parameters(config.init_scale)

def _init_parameters(self, init_scale: float):
Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(
cross_attention_dropout: float = 0.5,
post_attention_dropout: float = 0.0,
output_norm: bool = False,
output_bias: bool = True,
init_scale: float = 0.02,
activation_checkpointing=False,
activation_offloading=False,
Expand Down
12 changes: 9 additions & 3 deletions perceiver/model/text/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,18 @@ def forward(self, x, abs_pos=None):


class TiedTextOutputAdapter(OutputAdapter):
def __init__(self, vocab_size: int):
def __init__(self, vocab_size: int, emb_bias: bool = True):
super().__init__()
self.bias = nn.Parameter(torch.zeros(vocab_size))
self._emb_bias = emb_bias
if emb_bias:
self.bias = nn.Parameter(torch.zeros(vocab_size))

def forward(self, x, txt_embedding: nn.Embedding):
return torch.matmul(x, txt_embedding.weight.T) + self.bias
result = torch.matmul(x, txt_embedding.weight.T)
if self._emb_bias:
return result + self.bias
else:
return result


class TextEncoder(PerceiverEncoder):
Expand Down

0 comments on commit cbc7842

Please sign in to comment.