From 072f34c07286020ec3ca7edfc551b7638c7424a1 Mon Sep 17 00:00:00 2001
From: Peter Tatkowski <tatp22@gmail.com>
Date: Fri, 31 Jul 2020 11:22:18 +0200
Subject: [PATCH] Added convolution and updated README

---
 README.md                              | 30 +++++++++++++++++++--
 examples/example_conv.py               | 24 +++++++++++++++++
 linformer_pytorch/linformer_pytorch.py | 37 +++++++++++++++-----------
 3 files changed, 73 insertions(+), 18 deletions(-)
 create mode 100644 examples/example_conv.py

diff --git a/README.md b/README.md
index 9a60436..2684846 100644
--- a/README.md
+++ b/README.md
@@ -55,6 +55,7 @@ model = LinformerLM(
         w_o_intermediate_dim=None, # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels`
         emb_dim=128, # If you want the embedding dimension to be different than the channels for the Linformer
         causal=False, # If you want this to be a causal Linformer, where the upper right of the P_bar matrix is masked out.
+        convolution=False, # Instead of a linear layer, perform 1d convolution instead, with a stride and kernel size of n/k
         ).cuda()
 x = torch.randint(1,10000,(1,512)).cuda()
 y = model(x)
@@ -265,6 +266,30 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like
                    )
 ```
 
+## Encoder Decoder Module
+Similar to the [Reformer](https://github.com/lucidrains/reformer-pytorch#reformer-encoder-decoder-architecture), I will be attempting to make a Encoder/Decoder Module, so that training can be simplified. This works like 2 `LinformerLM` classes. Params can be adjusted individually for each one, with the encoder having the `enc_` prefix for all of the hyperparams, and the decoder having the `dec_` prefix in a similar fashion. So far, what is implemented is:
+
+```python3
+import torch
+from linformer_pytorch import LinformerEncDec
+
+encdec = LinformerEncDec(
+    enc_num_tokens=10000,
+    enc_input_size=512,
+    enc_channels=16,
+    dec_num_tokens=10000,
+    dec_input_size=512,
+    dec_channels=16,
+)
+
+x = torch.randint(1,10000,(1,512))
+y = torch.randint(1,10000,(1,512))
+
+output = encdec(x,y)
+```
+
+I am planning to have a way to generate text sequence for this.
+
 ## Practical Tips
 * Note that the Linformer has O(nk) time and space complexity. So, while it may be linear in n, make sure that your k is not too large as well. These are editable with `input_size` and `dim_k`, respectively.
 * Speaking about k, the authors found that empirical evidence supports the fact that "the performance of Linformer model is mainly determined by the projected dimension k instead of the ratio n/k". Therefore, even when increasing sequence lengths, it may be fine to keep a relatively low, constant k (the authors showed with k=256, that it still performed almost as good as a vanilla transformer).
@@ -273,8 +298,9 @@ vis.plot_all_heads(title="All P_bar matrices", # Change the title if you'd like
 * In practice, I found that the memory and time requirements are more on the order of O(nkd), with n=`input_size`, k=`dim_k`, and d=`dim_d`.
 
 ## Future work
-* Run some benchmark tests to see what the performance is
-* Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.
+* Run some benchmark tests to see what the performance is (Doing that now)
+* Complete the `LinformerEncDec` class
+* ~~Instead of matrix multiplication to bring the dimensions down to k (With EKW and FVW), try to do convolution, as mentioned in the paper, with a stride length and kernel size of n/k.~~
 
 ## Disclaimer
 This is the first time that I am reproducing a result from a paper, so some things may be wrong. If you see a problem, please open up an issue, and I will attempt to work on it.
diff --git a/examples/example_conv.py b/examples/example_conv.py
new file mode 100644
index 0000000..07ded42
--- /dev/null
+++ b/examples/example_conv.py
@@ -0,0 +1,24 @@
+import sys
+import torch
+
+sys.path.insert(0, "../")
+from linformer_pytorch import Linformer
+
+model = Linformer(
+        input_size=510,
+        channels=21,
+        dim_d=26,
+        dim_k=61,
+        dim_ff=32,
+        nhead=4,
+        depth=3,
+        activation="relu",
+        checkpoint_level="C1",
+        parameter_sharing="none",
+        k_reduce_by_layer=1,
+        include_ff=True,
+        convolution=True,
+        )
+x = torch.randn(1, 510, 21)
+y = model(x)
+print(y) # (1, 512, 16)
diff --git a/linformer_pytorch/linformer_pytorch.py b/linformer_pytorch/linformer_pytorch.py
index a70a52b..55ba64e 100644
--- a/linformer_pytorch/linformer_pytorch.py
+++ b/linformer_pytorch/linformer_pytorch.py
@@ -23,11 +23,15 @@ def gen_causal_mask(input_size, dim_k, full_attention=False):
         return (torch.triu(torch.ones(input_size, input_size))==1).transpose(0,1)
     return (torch.triu(torch.ones(dim_k, input_size))==1).transpose(0,1)
 
-def get_EF(input_size, dim, bias=True):
+def get_EF(input_size, dim, convolution=False, head_dim=None, bias=True):
     """
     Retuns the E or F matrix, initialized via xavier initialization.
     This is the recommended way to do it according to the authors of the paper.
+    Also, includes an option for convolution, as proposed by the authors.
     """
+    if convolution:
+        conv = nn.Conv1d(head_dim, head_dim, kernel_size=int(input_size/dim), stride=int(input_size/dim))
+        return conv
     lin = nn.Linear(input_size, dim, bias)
     torch.nn.init.xavier_normal_(lin.weight)
     return lin
@@ -149,6 +153,7 @@ def forward(self, Q, K, V, **kwargs):
         if not self.full_attention:
             K = self.E(K)
         Q = torch.matmul(Q, K)
+        print(Q.shape)
 
         P_bar = Q/torch.sqrt(torch.tensor(self.dim).type(Q.type()))
         if self.causal_mask is not None:
@@ -176,7 +181,7 @@ class MHAttention(nn.Module):
     This feeds directly into a feed forward head
     """
     def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, checkpoint_level,
-                       parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False):
+                parameter_sharing, E_proj, F_proj, full_attention, causal_mask, w_o_intermediate_dim=None, decoder_mode=False, convolution=False):
         super(MHAttention, self).__init__()
         self.heads = nn.ModuleList()
         self.input_size = input_size
@@ -186,8 +191,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
         self.checkpoint_level = checkpoint_level
         self.w_o_intermediate_dim = w_o_intermediate_dim
         if parameter_sharing != "layerwise":
-            E_proj = get_EF(input_size, dim_k)
-            F_proj = get_EF(input_size, dim_k) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj
+            E_proj = get_EF(input_size, dim_k, convolution, dim)
+            F_proj = get_EF(input_size, dim_k, convolution, dim) if parameter_sharing == "none" or parameter_sharing == "headwise" else E_proj
 
         self.decoder_mode = decoder_mode
         self.to_q = nn.ModuleList()
@@ -196,8 +201,8 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
 
         for _ in range(nhead):
             if parameter_sharing == "none":
-                E_proj = get_EF(input_size, dim_k)
-                F_proj = get_EF(input_size, dim_k)
+                E_proj = get_EF(input_size, dim_k, convolution, dim)
+                F_proj = get_EF(input_size, dim_k, convolution, dim)
             attn = LinearAttentionHead(dim, dropout, E_proj, F_proj, causal_mask, full_attention)
             self.heads.append(attn)
             self.to_q.append(nn.Linear(channels, dim, bias=False))
@@ -238,7 +243,7 @@ class Linformer(nn.Module):
     My attempt at reproducing the Linformer Paper
     https://arxiv.org/pdf/2006.04768.pdf
     """
-    def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False):
+    def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True, w_o_intermediate_dim=None, decoder_mode=False, causal=False, convolution=False):
         super(Linformer, self).__init__()
         assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
         assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
@@ -255,13 +260,13 @@ def __init__(self, input_size, channels, dim_k, dim_ff=256, dim_d=None, dropout_
 
         head_dim = channels // nhead if dim_d is None else dim_d
 
-        E_proj = get_EF(input_size, dim_k)
+        E_proj = get_EF(input_size, dim_k, convolution, head_dim)
         causal_mask = gen_causal_mask(input_size, dim_k, full_attention) if causal else None
         # If we want causal but only with the encoder
         causal_enc = gen_causal_mask(input_size, dim_k, full_attention) if (causal and not decoder_mode) else None
 
-        get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False)
-        get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True)
+        get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_enc, w_o_intermediate_dim, decoder_mode=False, convolution=convolution)
+        get_attn_context = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, E_proj, full_attention, causal_mask, w_o_intermediate_dim, decoder_mode=True, convolution=convolution)
         get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)
 
         for index in range(depth):
@@ -317,7 +322,7 @@ def __init__(self, num_tokens, input_size, channels,
                        dropout=0.05, activation="gelu", checkpoint_level="C0",
                        parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False,
                        include_ff=True, w_o_intermediate_dim=None, emb_dim=None,
-                       return_emb=False, decoder_mode=False, causal=False):
+                       return_emb=False, decoder_mode=False, causal=False, convolution=False):
         super(LinformerLM, self).__init__()
         emb_dim = channels if emb_dim is None else emb_dim
 
@@ -330,7 +335,7 @@ def __init__(self, num_tokens, input_size, channels,
                                    nhead=nhead, depth=depth, dropout=dropout,
                                    activation=activation, checkpoint_level=checkpoint_level, parameter_sharing=parameter_sharing,
                                    k_reduce_by_layer=k_reduce_by_layer, full_attention=full_attention, include_ff=include_ff,
-                                   w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal)
+                                   w_o_intermediate_dim=w_o_intermediate_dim, decoder_mode=decoder_mode, causal=causal, convolution=False)
 
         if emb_dim != channels:
             self.linformer = ProjectInOut(self.linformer, emb_dim, channels)
@@ -354,20 +359,20 @@ class LinformerEncDec(nn.Module):
     def __init__(self, enc_num_tokens, enc_input_size, enc_channels, dec_num_tokens, dec_input_size, dec_channels,
                        enc_dim_k=64, enc_dim_ff=1024, enc_dim_d=None,
                        enc_dropout_ff=0.1, enc_nhead=4, enc_depth=2, enc_dropout=0.05, enc_parameter_sharing="layerwise", enc_k_reduce_by_layer=0,
-                       enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None,
+                       enc_full_attention=False, enc_include_ff=True, enc_w_o_intermediate_dim=None, enc_emb_dim=None, enc_convolution=False,
                        dec_dim_k=64, dec_dim_ff=1024, dec_dim_d=None, dec_dropout_ff=0.1, dec_nhead=4, dec_depth=2, dec_dropout=0.05,
                        dec_parameter_sharing="layerwise", dec_k_reduce_by_layer=0, dec_full_attention=False, dec_include_ff=True,
-                       dec_w_o_intermediate_dim=None, dec_emb_dim=None, activation="gelu", checkpoint_level="C0"):
+                       dec_w_o_intermediate_dim=None, dec_emb_dim=None, dec_convolution=False, activation="gelu", checkpoint_level="C0"):
 
         super(LinformerEncDec, self).__init__()
         self.encoder = LinformerLM(num_tokens=enc_num_tokens, input_size=enc_input_size, channels=enc_channels, dim_d=enc_dim_d, dim_ff=enc_dim_ff,
                                    dim_k=enc_dim_k, dropout_ff=enc_dropout_ff, nhead=enc_nhead, depth=enc_depth, dropout=enc_dropout,
                                    parameter_sharing=enc_parameter_sharing, k_reduce_by_layer=enc_k_reduce_by_layer,
                                    full_attention=enc_full_attention, include_ff=enc_include_ff, w_o_intermediate_dim=enc_w_o_intermediate_dim,
-                                   emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level)
+                                   emb_dim=enc_emb_dim, return_emb=True, activation=activation, checkpoint_level=checkpoint_level, convolution=enc_convolution)
         self.decoder = LinformerLM(num_tokens=dec_num_tokens, input_size=dec_input_size, channels=dec_channels, dim_d=dec_dim_d, dim_ff=dec_dim_ff,
                                    dim_k=dec_dim_k, dropout_ff=dec_dropout_ff, nhead=dec_nhead, depth=dec_depth, dropout=dec_dropout,
-                                   parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer,
+                                   parameter_sharing=dec_parameter_sharing, k_reduce_by_layer=dec_k_reduce_by_layer, convolution=dec_convolution,
                                    full_attention=dec_full_attention, include_ff=dec_include_ff, w_o_intermediate_dim=dec_w_o_intermediate_dim,
                                    emb_dim=dec_emb_dim, decoder_mode=True, causal=True, activation=activation, checkpoint_level=checkpoint_level)