Skip to content

Commit

Permalink
Rebase, added option to plot MHAttention heads
Browse files Browse the repository at this point in the history
Rebased the code so it looks better, and added the option to plot the
MHAttention module as well as the Linformer module
  • Loading branch information
tatp22 committed Jun 29, 2020
1 parent 7c5c3a0 commit 547b7c9
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 38 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ model = Linformer(
parameter_sharing="layerwise", # What level of parameter sharing to use. For more information, see below.
k_reduce_by_layer=0, # Going down `depth`, how much to reduce `dim_k` by, for the `E` and `F` matrices. Will have a minimum value of 1.
full_attention=False, # Use full attention instead, for O(n^2) time and space complexity. Included here just for comparison
include_ff=True, # Whether or not to include the Feed Forward layer
).cuda()
x = torch.randn(1, 262144, 64).cuda()
y = model(x)
Expand Down
1 change: 1 addition & 0 deletions examples/example_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
checkpoint_level="C1",
parameter_sharing="none",
k_reduce_by_layer=1,
include_ff=True,
)
x = torch.randn(1, 512, 16)
y = model(x)
Expand Down
4 changes: 2 additions & 2 deletions examples/example_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
y = model(x, visualize=True)
vis = Visualizer(model)
vis.plot_all_heads(title="All P_bar matrices",
show=False,
save_file="../head_vis.png",
show=True,
save_file=None,
figsize=(8,6),
n_limit=256)
print(y) # (1, 512, 16)
2 changes: 1 addition & 1 deletion linformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from linformer_pytorch.linformer_pytorch import Linformer
from linformer_pytorch.linformer_pytorch import Linformer, MHAttention, LinearAttentionHead, FeedForward, PositionalEmbedding, get_linear
from linformer_pytorch.padder import Padder
from linformer_pytorch.visualizer import Visualizer
42 changes: 14 additions & 28 deletions linformer_pytorch/linformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ class FeedForward(nn.Module):
"""
Standard Feed Forward Layer
"""
def __init__(self, channels, ff_dim, dropout=0.0, activation="gelu"):
def __init__(self, channels, ff_dim, dropout, activation="gelu"):
super(FeedForward, self).__init__()
self.w_1 = get_linear(channels, ff_dim)
self.w_2 = get_linear(ff_dim, channels)
self.activation = get_act(activation)
self.dropout = nn.Dropout(dropout)

def forward(self, tensor):
def forward(self, tensor, **kwargs):
tensor = self.w_1(tensor)
if self.activation is not None:
tensor = self.activation(tensor)
Expand Down Expand Up @@ -149,7 +149,7 @@ class Linformer(nn.Module):
My attempt at reproducing the Linformer Paper
https://arxiv.org/pdf/2006.04768.pdf
"""
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False):
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False, include_ff=True):
super(Linformer, self).__init__()
assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
Expand All @@ -171,14 +171,12 @@ def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=No

get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F, full_attention)
get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)
norm_attn = lambda: nn.LayerNorm(channels)
norm_ff = lambda: nn.LayerNorm(channels)
get_norm = lambda: nn.LayerNorm(channels)

for index in range(depth):
self.layers.append(nn.ModuleList([get_attn(max(1, dim_k - index*k_reduce_by_layer)),
norm_attn(),
get_ff(),
norm_ff()]))
self.layers.append(nn.ModuleList([get_attn(max(1, dim_k - index*k_reduce_by_layer)), get_norm()]))
if include_ff:
self.layers.append(nn.ModuleList([get_ff(), get_norm()]))

def forward(self, tensor, **kwargs):
"""
Expand All @@ -193,27 +191,15 @@ def forward(self, tensor, **kwargs):
tensor += self.pos_emb(tensor).type(tensor.type())

for layer in self.layers:
attn = layer[0]
attn_norm = layer[1]
ff = layer[2]
ff_norm = layer[3]

# Run attention
before_attn_tensor = tensor.clone().detach()
if self.checkpoint_level == "C1" or self.checkpoint_level == "C2":
tensor = checkpoint(attn, tensor)
else:
tensor = attn(tensor, **kwargs)
tensor += before_attn_tensor
tensor = attn_norm(tensor)
module = layer[0]
norm = layer[1]

# Run ff
before_ff_tensor = tensor.clone().detach()
before_module_tensor = tensor.clone().detach()
if self.checkpoint_level == "C1" or self.checkpoint_level == "C2":
tensor = checkpoint(ff, tensor)
tensor = checkpoint(module, tensor)
else:
tensor = ff(tensor)
tensor += before_ff_tensor
tensor = ff_norm(tensor)
tensor = module(tensor, **kwargs)
tensor += before_module_tensor
tensor = norm(tensor)

return tensor
22 changes: 15 additions & 7 deletions linformer_pytorch/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
import matplotlib.colors as col
import matplotlib.pyplot as plt

from linformer_pytorch import Linformer
from linformer_pytorch import Linformer, MHAttention

class Visualizer():
"""
A way to visualize the attention heads for each layer
"""
def __init__(self, net):
assert isinstance(net, Linformer), "Only the Linformer is supported"
assert isinstance(net, (Linformer, MHAttention)), "Only the Linformer and MHAttention is supported"
self.net = net

def get_head_visualization(self, depth_no, max_depth, head_no, n_limit, axs):
"""
Returns the visualization for one head in the Linformer
Returns the visualization for one head in the Linformer or MHAttention
"""
curr_mh_attn = self.net.layers[depth_no][0] # First one is mh attn
curr_head = curr_mh_attn.heads[head_no]
if isinstance(self.net, Linformer):
depth_to_use = 2*depth_no if 2*(max_depth+1) == len(self.net.layers) else depth_no
curr_mh_attn = self.net.layers[depth_to_use][0] # First one is attn module
curr_head = curr_mh_attn.heads[head_no]
else:
curr_head = self.net.heads[head_no]

arr = curr_head.P_bar[0].detach().cpu().numpy()
assert arr is not None, "Cannot visualize a None matrix!"
Expand All @@ -47,8 +51,12 @@ def plot_all_heads(self, title="Visualization of Attention Heads", show=True, sa
which turns out to be an NxK matrix for each of them.
"""

self.depth = self.net.depth
self.heads = self.net.nhead
if isinstance(self.net, Linformer):
self.depth = self.net.depth
self.heads = self.net.nhead
else:
self.depth = 1
self.heads = len(self.net.heads)

fig, axs = plt.subplots(self.depth, self.heads, figsize=figsize)
axs = axs.reshape((self.depth, self.heads)) # In case depth or nheads are 1, bug i think
Expand Down

0 comments on commit 547b7c9

Please sign in to comment.