Skip to content

Commit

Permalink
Added convolution and updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
tatp22 committed Jul 31, 2020
1 parent f1c944d commit 072f34c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 18 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ model = LinformerLM(
w_o_intermediate_dim=None, # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels`
emb_dim=128, # If you want the embedding dimension to be different than the channels for the Linformer
causal=False, # If you want this to be a causal Linformer, where the upper right of the P_bar matrix is masked out.
convolution=False, # Instead of a linear layer, perform 1d convolution instead, with a stride and kernel size of n/k
).cuda()
x = torch.randint(1,10000,(1,512)).cuda()
y = model(x)
Expand Down Expand Up @@ -265,6 +266,30 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like
)
```

## Encoder Decoder Module
Similar to the [Reformer](https://github.com/lucidrains/reformer-pytorch#reformer-encoder-decoder-architecture), I will be attempting to make a Encoder/Decoder Module, so that training can be simplified. This works like 2 `LinformerLM` classes. Params can be adjusted individually for each one, with the encoder having the `enc_` prefix for all of the hyperparams, and the decoder having the `dec_` prefix in a similar fashion. So far, what is implemented is:

```python3
import torch
from linformer_pytorch import LinformerEncDec

encdec = LinformerEncDec(
enc_num_tokens=10000,
enc_input_size=512,
enc_channels=16,
dec_num_tokens=10000,
dec_input_size=512,
dec_channels=16,
)

x = torch.randint(1,10000,(1,512))
y = torch.randint(1,10000,(1,512))

output = encdec(x,y)
```

I am planning to have a way to generate text sequence for this.

## Practical Tips
* Note that the Linformer has O(nk) time and space complexity. So, while it may be linear in n, make sure that your k is not too large as well. These are editable with `input_size` and `dim_k`, respectively.
* Speaking about k, the authors found that empirical evidence supports the fact that "the performance of Linformer model is mainly determined by the projected dimension k instead of the ratio n/k". Therefore, even when increasing sequence lengths, it may be fine to keep a relatively low, constant k (the authors showed with k=256, that it still performed almost as good as a vanilla transformer).
Expand All @@ -273,8 +298,9 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like
* In practice, I found that the memory and time requirements are more on the order of O(nkd), with n=`input_size`, k=`dim_k`, and d=`dim_d`.

## Future work
* Run some benchmark tests to see what the performance is
* Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.
* Run some benchmark tests to see what the performance is (Doing that now)
* Complete the `LinformerEncDec` class
* ~~Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.~~

## Disclaimer
This is the first time that I am reproducing a result from a paper, so some things may be wrong. If you see a problem, please open up an issue, and I will attempt to work on it.
Expand Down
24 changes: 24 additions & 0 deletions examples/example_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer

model = Linformer(
input_size=510,
channels=21,
dim_d=26,
dim_k=61,
dim_ff=32,
nhead=4,
depth=3,
activation="relu",
checkpoint_level="C1",
parameter_sharing="none",
k_reduce_by_layer=1,
include_ff=True,
convolution=True,
)
x = torch.randn(1, 510, 21)
y = model(x)
print(y) # (1, 512, 16)
37 changes: 21 additions & 16 deletions linformer_pytorch/linformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ def gen_causal_mask(input_size, dim_k, full_attention=False):
return (torch.triu(torch.ones(input_size, input_size))==1).transpose(0,1)
return (torch.triu(torch.ones(dim_k, input_size))==1).transpose(0,1)

def get_EF(input_size, dim, bias=True):
def get_EF(input_size, dim, convolution=False, head_dim=None, bias=True):
"""
Retuns the E or F matrix, initialized via xavier initialization.
This is the recommended way to do it according to the authors of the paper.
Also, includes an option for convolution, as proposed by the authors.
"""
if convolution:
conv = nn.Conv1d(head_dim, head_dim, kernel_size=int(input_size/dim), stride=int(input_size/dim))
return conv
lin = nn.Linear(input_size, dim, bias)
torch.nn.init.xavier_normal_(lin.weight)
return lin
Expand Down Expand Up @@ -149,6 +153,7 @@ def forward(self, Q, K, V, **kwargs):
if not self.full_attention:
K = self.E(K)
Q = torch.matmul(Q, K)
print(Q.shape)

P_bar = Q/torch.sqrt(torch.tensor(self.dim).type(Q.type()))
if self.causal_mask is not None:
Expand Down Expand Up @@ -176,7 +181,7 @@ class MHAttention(nn.Module):
This feeds directly into a feed forward head
"""
def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, checkpoint_level,
parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False):
parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False, convolution=False):
super(MHAttention, self).__init__()
self.heads = nn.ModuleList()
self.input_size = input_size
Expand All @@ -186,8 +191,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
self.checkpoint_level = checkpoint_level
self.w_o_intermediate_dim = w_o_intermediate_dim
if parameter_sharing != "layerwise":
E_proj = get_EF(input_size, dim_k)
F_proj = get_EF(input_size, dim_k) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj
E_proj = get_EF(input_size, dim_k, convolution, dim)
F_proj = get_EF(input_size, dim_k, convolution, dim) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj

self.decoder_mode = decoder_mode
self.to_q = nn.ModuleList()
Expand All @@ -196,8 +201,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,

for _ in range(nhead):
if parameter_sharing == "none":
E_proj = get_EF(input_size, dim_k)
F_proj = get_EF(input_size, dim_k)
E_proj = get_EF(input_size, dim_k, convolution, dim)
F_proj = get_EF(input_size, dim_k, convolution, dim)
attn = LinearAttentionHead(dim, dropout, E_proj, F_proj, causal_mask, full_attention)
self.heads.append(attn)
self.to_q.append(nn.Linear(channels, dim, bias=False))
Expand Down Expand Up @@ -238,7 +243,7 @@ class Linformer(nn.Module):
My attempt at reproducing the Linformer Paper
https://arxiv.org/pdf/2006.04768.pdf
"""
def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False):
def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False, convolution=False):
super(Linformer, self).__init__()
assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
Expand All @@ -255,13 +260,13 @@ def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_

head_dim = channels // nhead if dim_d is None else dim_d

E_proj = get_EF(input_size, dim_k)
E_proj = get_EF(input_size, dim_k, convolution, head_dim)
causal_mask = gen_causal_mask(input_size, dim_k, full_attention) if causal else None
# If we want causal but only with the encoder
causal_enc = gen_causal_mask(input_size, dim_k, full_attention) if (causal and not decoder_mode) else None

get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False)
get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True)
get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False, convolution=convolution)
get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True, convolution=convolution)
get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)

for index in range(depth):
Expand Down Expand Up @@ -317,7 +322,7 @@ def __init__(self, num_tokens, input_size, channels,
dropout=0.05, activation="gelu", checkpoint_level="C0",
parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False,
include_ff=True, w_o_intermediate_dim=None, emb_dim=None,
return_emb=False, decoder_mode=False, causal=False):
return_emb=False, decoder_mode=False, causal=False, convolution=False):
super(LinformerLM, self).__init__()
emb_dim = channels if emb_dim is None else emb_dim

Expand All @@ -330,7 +335,7 @@ def __init__(self, num_tokens, input_size, channels,
nhead=nhead, depth=depth, dropout=dropout,
activation=activation, checkpoint_level=checkpoint_level, parameter_sharing=parameter_sharing,
k_reduce_by_layer=k_reduce_by_layer, full_attention=full_attention, include_ff=include_ff,
w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal)
w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal, convolution=False)

if emb_dim != channels:
self.linformer = ProjectInOut(self.linformer, emb_dim, channels)
Expand All @@ -354,20 +359,20 @@ class LinformerEncDec(nn.Module):
def __init__(self, enc_num_tokens, enc_input_size, enc_channels, dec_num_tokens, dec_input_size, dec_channels,
enc_dim_k=64, enc_dim_ff=1024, enc_dim_d=None,
enc_dropout_ff=0.1, enc_nhead=4, enc_depth=2, enc_dropout=0.05, enc_parameter_sharing="layerwise", enc_k_reduce_by_layer=0,
enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None,
enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None, enc_convolution=False,
dec_dim_k=64, dec_dim_ff=1024, dec_dim_d=None, dec_dropout_ff=0.1, dec_nhead=4, dec_depth=2, dec_dropout=0.05,
dec_parameter_sharing="layerwise", dec_k_reduce_by_layer=0, dec_full_attention=False, dec_include_ff=True,
dec_w_o_intermediate_dim=None, dec_emb_dim=None, activation="gelu", checkpoint_level="C0"):
dec_w_o_intermediate_dim=None, dec_emb_dim=None, dec_convolution=False, activation="gelu", checkpoint_level="C0"):

super(LinformerEncDec, self).__init__()
self.encoder = LinformerLM(num_tokens=enc_num_tokens, input_size=enc_input_size, channels=enc_channels, dim_d=enc_dim_d, dim_ff=enc_dim_ff,
dim_k=enc_dim_k, dropout_ff=enc_dropout_ff, nhead=enc_nhead, depth=enc_depth, dropout=enc_dropout,
parameter_sharing=enc_parameter_sharing, k_reduce_by_layer=enc_k_reduce_by_layer,
full_attention=enc_full_attention, include_ff=enc_include_ff, w_o_intermediate_dim=enc_w_o_intermediate_dim,
emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level)
emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level, convolution=enc_convolution)
self.decoder = LinformerLM(num_tokens=dec_num_tokens, input_size=dec_input_size, channels=dec_channels, dim_d=dec_dim_d, dim_ff=dec_dim_ff,
dim_k=dec_dim_k, dropout_ff=dec_dropout_ff, nhead=dec_nhead, depth=dec_depth, dropout=dec_dropout,
parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer,
parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer, convolution=dec_convolution,
full_attention=dec_full_attention, include_ff=dec_include_ff, w_o_intermediate_dim=dec_w_o_intermediate_dim,
emb_dim=dec_emb_dim, decoder_mode=True, causal=True, activation=activation, checkpoint_level=checkpoint_level)

Expand Down

0 comments on commit 072f34c

Please sign in to comment.