From 072f34c07286020ec3ca7edfc551b7638c7424a1 Mon Sep 17 00:00:00 2001 From: Peter Tatkowski <tatp22@gmail.com> Date: Fri, 31 Jul 2020 11:22:18 +0200 Subject: [PATCH] Added convolution and updated README --- README.md | 30 +++++++++++++++++++-- examples/example_conv.py | 24 +++++++++++++++++ linformer_pytorch/linformer_pytorch.py | 37 +++++++++++++++----------- 3 files changed, 73 insertions(+), 18 deletions(-) create mode 100644 examples/example_conv.py diff --git a/README.md b/README.md index 9a60436..2684846 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ model = LinformerLM( w_o_intermediate_dim=None, # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels` emb_dim=128, # If you want the embedding dimension to be different than the channels for the Linformer causal=False, # If you want this to be a causal Linformer, where the upper right of the P_bar matrix is masked out. + convolution=False, # Instead of a linear layer, perform 1d convolution instead, with a stride and kernel size of n/k ).cuda() x = torch.randint(1,10000,(1,512)).cuda() y = model(x) @@ -265,6 +266,30 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like ) ``` +## Encoder Decoder Module +Similar to the [Reformer](https://github.com/lucidrains/reformer-pytorch#reformer-encoder-decoder-architecture), I will be attempting to make a Encoder/Decoder Module, so that training can be simplified. This works like 2 `LinformerLM` classes. Params can be adjusted individually for each one, with the encoder having the `enc_` prefix for all of the hyperparams, and the decoder having the `dec_` prefix in a similar fashion. So far, what is implemented is: + +```python3 +import torch +from linformer_pytorch import LinformerEncDec + +encdec = LinformerEncDec( + enc_num_tokens=10000, + enc_input_size=512, + enc_channels=16, + dec_num_tokens=10000, + dec_input_size=512, + dec_channels=16, +) + +x = torch.randint(1,10000,(1,512)) +y = torch.randint(1,10000,(1,512)) + +output = encdec(x,y) +``` + +I am planning to have a way to generate text sequence for this. + ## Practical Tips * Note that the Linformer has O(nk) time and space complexity. So, while it may be linear in n, make sure that your k is not too large as well. These are editable with `input_size` and `dim_k`, respectively. * Speaking about k, the authors found that empirical evidence supports the fact that "the performance of Linformer model is mainly determined by the projected dimension k instead of the ratio n/k". Therefore, even when increasing sequence lengths, it may be fine to keep a relatively low, constant k (the authors showed with k=256, that it still performed almost as good as a vanilla transformer). @@ -273,8 +298,9 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like * In practice, I found that the memory and time requirements are more on the order of O(nkd), with n=`input_size`, k=`dim_k`, and d=`dim_d`. ## Future work -* Run some benchmark tests to see what the performance is -* Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k. +* Run some benchmark tests to see what the performance is (Doing that now) +* Complete the `LinformerEncDec` class +* ~~Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.~~ ## Disclaimer This is the first time that I am reproducing a result from a paper, so some things may be wrong. If you see a problem, please open up an issue, and I will attempt to work on it. diff --git a/examples/example_conv.py b/examples/example_conv.py new file mode 100644 index 0000000..07ded42 --- /dev/null +++ b/examples/example_conv.py @@ -0,0 +1,24 @@ +import sys +import torch + +sys.path.insert(0, "../") +from linformer_pytorch import Linformer + +model = Linformer( + input_size=510, + channels=21, + dim_d=26, + dim_k=61, + dim_ff=32, + nhead=4, + depth=3, + activation="relu", + checkpoint_level="C1", + parameter_sharing="none", + k_reduce_by_layer=1, + include_ff=True, + convolution=True, + ) +x = torch.randn(1, 510, 21) +y = model(x) +print(y) # (1, 512, 16) diff --git a/linformer_pytorch/linformer_pytorch.py b/linformer_pytorch/linformer_pytorch.py index a70a52b..55ba64e 100644 --- a/linformer_pytorch/linformer_pytorch.py +++ b/linformer_pytorch/linformer_pytorch.py @@ -23,11 +23,15 @@ def gen_causal_mask(input_size, dim_k, full_attention=False): return (torch.triu(torch.ones(input_size, input_size))==1).transpose(0,1) return (torch.triu(torch.ones(dim_k, input_size))==1).transpose(0,1) -def get_EF(input_size, dim, bias=True): +def get_EF(input_size, dim, convolution=False, head_dim=None, bias=True): """ Retuns the E or F matrix, initialized via xavier initialization. This is the recommended way to do it according to the authors of the paper. + Also, includes an option for convolution, as proposed by the authors. """ + if convolution: + conv = nn.Conv1d(head_dim, head_dim, kernel_size=int(input_size/dim), stride=int(input_size/dim)) + return conv lin = nn.Linear(input_size, dim, bias) torch.nn.init.xavier_normal_(lin.weight) return lin @@ -149,6 +153,7 @@ def forward(self, Q, K, V, **kwargs): if not self.full_attention: K = self.E(K) Q = torch.matmul(Q, K) + print(Q.shape) P_bar = Q/torch.sqrt(torch.tensor(self.dim).type(Q.type())) if self.causal_mask is not None: @@ -176,7 +181,7 @@ class MHAttention(nn.Module): This feeds directly into a feed forward head """ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, checkpoint_level, - parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False): + parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False, convolution=False): super(MHAttention, self).__init__() self.heads = nn.ModuleList() self.input_size = input_size @@ -186,8 +191,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, self.checkpoint_level = checkpoint_level self.w_o_intermediate_dim = w_o_intermediate_dim if parameter_sharing != "layerwise": - E_proj = get_EF(input_size, dim_k) - F_proj = get_EF(input_size, dim_k) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj + E_proj = get_EF(input_size, dim_k, convolution, dim) + F_proj = get_EF(input_size, dim_k, convolution, dim) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj self.decoder_mode = decoder_mode self.to_q = nn.ModuleList() @@ -196,8 +201,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, for _ in range(nhead): if parameter_sharing == "none": - E_proj = get_EF(input_size, dim_k) - F_proj = get_EF(input_size, dim_k) + E_proj = get_EF(input_size, dim_k, convolution, dim) + F_proj = get_EF(input_size, dim_k, convolution, dim) attn = LinearAttentionHead(dim, dropout, E_proj, F_proj, causal_mask, full_attention) self.heads.append(attn) self.to_q.append(nn.Linear(channels, dim, bias=False)) @@ -238,7 +243,7 @@ class Linformer(nn.Module): My attempt at reproducing the Linformer Paper https://arxiv.org/pdf/2006.04768.pdf """ - def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False): + def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False, convolution=False): super(Linformer, self).__init__() assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now" assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2." @@ -255,13 +260,13 @@ def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ head_dim = channels // nhead if dim_d is None else dim_d - E_proj = get_EF(input_size, dim_k) + E_proj = get_EF(input_size, dim_k, convolution, head_dim) causal_mask = gen_causal_mask(input_size, dim_k, full_attention) if causal else None # If we want causal but only with the encoder causal_enc = gen_causal_mask(input_size, dim_k, full_attention) if (causal and not decoder_mode) else None - get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False) - get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True) + get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False, convolution=convolution) + get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True, convolution=convolution) get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff) for index in range(depth): @@ -317,7 +322,7 @@ def __init__(self, num_tokens, input_size, channels, dropout=0.05, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, emb_dim=None, - return_emb=False, decoder_mode=False, causal=False): + return_emb=False, decoder_mode=False, causal=False, convolution=False): super(LinformerLM, self).__init__() emb_dim = channels if emb_dim is None else emb_dim @@ -330,7 +335,7 @@ def __init__(self, num_tokens, input_size, channels, nhead=nhead, depth=depth, dropout=dropout, activation=activation, checkpoint_level=checkpoint_level, parameter_sharing=parameter_sharing, k_reduce_by_layer=k_reduce_by_layer, full_attention=full_attention, include_ff=include_ff, - w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal) + w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal, convolution=False) if emb_dim != channels: self.linformer = ProjectInOut(self.linformer, emb_dim, channels) @@ -354,20 +359,20 @@ class LinformerEncDec(nn.Module): def __init__(self, enc_num_tokens, enc_input_size, enc_channels, dec_num_tokens, dec_input_size, dec_channels, enc_dim_k=64, enc_dim_ff=1024, enc_dim_d=None, enc_dropout_ff=0.1, enc_nhead=4, enc_depth=2, enc_dropout=0.05, enc_parameter_sharing="layerwise", enc_k_reduce_by_layer=0, - enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None, + enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None, enc_convolution=False, dec_dim_k=64, dec_dim_ff=1024, dec_dim_d=None, dec_dropout_ff=0.1, dec_nhead=4, dec_depth=2, dec_dropout=0.05, dec_parameter_sharing="layerwise", dec_k_reduce_by_layer=0, dec_full_attention=False, dec_include_ff=True, - dec_w_o_intermediate_dim=None, dec_emb_dim=None, activation="gelu", checkpoint_level="C0"): + dec_w_o_intermediate_dim=None, dec_emb_dim=None, dec_convolution=False, activation="gelu", checkpoint_level="C0"): super(LinformerEncDec, self).__init__() self.encoder = LinformerLM(num_tokens=enc_num_tokens, input_size=enc_input_size, channels=enc_channels, dim_d=enc_dim_d, dim_ff=enc_dim_ff, dim_k=enc_dim_k, dropout_ff=enc_dropout_ff, nhead=enc_nhead, depth=enc_depth, dropout=enc_dropout, parameter_sharing=enc_parameter_sharing, k_reduce_by_layer=enc_k_reduce_by_layer, full_attention=enc_full_attention, include_ff=enc_include_ff, w_o_intermediate_dim=enc_w_o_intermediate_dim, - emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level) + emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level, convolution=enc_convolution) self.decoder = LinformerLM(num_tokens=dec_num_tokens, input_size=dec_input_size, channels=dec_channels, dim_d=dec_dim_d, dim_ff=dec_dim_ff, dim_k=dec_dim_k, dropout_ff=dec_dropout_ff, nhead=dec_nhead, depth=dec_depth, dropout=dec_dropout, - parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer, + parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer, convolution=dec_convolution, full_attention=dec_full_attention, include_ff=dec_include_ff, w_o_intermediate_dim=dec_w_o_intermediate_dim, emb_dim=dec_emb_dim, decoder_mode=True, causal=True, activation=activation, checkpoint_level=checkpoint_level)