Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish unit test code #127

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bd3db82
add python wrapper for ls transformer
godweiyang Jul 19, 2021
a3a9a24
Merge branch 'master' into ls_transformer
godweiyang Jul 19, 2021
c693ced
add default activation_fn config
godweiyang Jul 19, 2021
f8e57c4
add demo example using pure ls layers
godweiyang Jul 19, 2021
d9e98e9
Merge branch 'master' into ls_transformer
godweiyang Jul 19, 2021
3d490aa
add new line
godweiyang Jul 19, 2021
6b69495
format demo.py
godweiyang Jul 19, 2021
54fc320
fix embedding block allocation bug
godweiyang Jul 20, 2021
d802233
polish unit test code
godweiyang Jul 20, 2021
0f83d7d
Merge branch 'master' into ls_transformer
godweiyang Jul 20, 2021
2eb6cf2
convert fp16 when dtype changes
godweiyang Jul 20, 2021
5dc5740
style format
godweiyang Jul 20, 2021
c6193d8
polish unit test code
godweiyang Jul 20, 2021
0682ec6
generate random config for unit test
godweiyang Jul 20, 2021
91b80f0
modify demo example using huggingface tokenizer
godweiyang Jul 20, 2021
0508e33
modify demo example using huggingface tokenizer
godweiyang Jul 20, 2021
5b706f0
use cache to accelerate inference in demo example
godweiyang Jul 21, 2021
c68a3e4
use cache to accelerate inference in demo example
godweiyang Jul 21, 2021
c7b77c8
refactor op test
Jul 21, 2021
ed3406c
Merge branch 'master' into ls_transformer
godweiyang Jul 21, 2021
e7c766f
Merge branch 'ls_transformer' of github.com:bytedance/lightseq into l…
godweiyang Jul 21, 2021
9edbdcb
fix test_decoder_bw bug
godweiyang Jul 22, 2021
33df768
add multiprocessing for different shapunit test
godweiyang Jul 22, 2021
0588e68
Merge branch 'master' into polish_test
godweiyang Jul 22, 2021
52d75ab
move layer creation o th begining
godweiyang Jul 22, 2021
96c190a
modify inference unit test
godweiyang Aug 9, 2021
0823b40
modify inference unit test
godweiyang Aug 9, 2021
e0962de
modify inference unit test
godweiyang Aug 9, 2021
51e77b3
modify unit test
godweiyang Aug 9, 2021
3daaef7
add fairseq training cli
godweiyang Aug 9, 2021
ab9dd43
modify unit test
godweiyang Aug 9, 2021
a4219e6
Merge branch 'master' into polish_test
godweiyang Aug 20, 2021
316300c
fix merge confict
godweiyang Sep 16, 2021
ed01a36
Merge branch 'master' into polish_test
Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 97 additions & 209 deletions tests/fairseq_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
We use layers from Facebook Fairseq as our baseline for unit test
"""

from typing import Dict, List, Optional, Callable
from typing import Dict, List, Optional
import math
from copy import deepcopy

import torch
import torch.nn as nn
Expand All @@ -18,7 +17,7 @@
from torch import Tensor


class TransformerEncoderLayer(nn.Module):
class FSTransformerEncoderLayer(nn.Module):
"""Encoder layer implemented by fairseq.
This version only removes the "args" parameter, no other changes

Expand Down Expand Up @@ -165,128 +164,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
return x


class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""

def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
export: bool = False,
q_noise: float = 0.0,
qn_block_size: int = 8,
init_fn: Callable = None,
) -> None:
super().__init__()

if init_fn is not None:
init_fn()

# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.activation_dropout_module = FairseqDropout(
activation_dropout, module_name=self.__class__.__name__
)

# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = self.build_self_attention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
q_noise=q_noise,
qn_block_size=qn_block_size,
)

# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)

self.fc1 = self.build_fc1(
self.embedding_dim,
ffn_embedding_dim,
q_noise=q_noise,
qn_block_size=qn_block_size,
)
self.fc2 = self.build_fc2(
ffn_embedding_dim,
self.embedding_dim,
q_noise=q_noise,
qn_block_size=qn_block_size,
)

# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)

def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)

def build_self_attention(
self,
embed_dim,
num_attention_heads,
dropout,
self_attention,
q_noise,
qn_block_size,
):
return MultiheadAttention(
embed_dim,
num_attention_heads,
dropout=dropout,
self_attention=True,
q_noise=q_noise,
qn_block_size=qn_block_size,
)

def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
residual = x
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = residual + x
x = self.self_attn_layer_norm(x)

residual = x
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn


class TransformerDecoderLayer(nn.Module):
class FSTransformerDecoderLayer(nn.Module):
"""Decoder layer implemented by fairseq.
This version only removes the "args" parameter, no other changes
"""
Expand Down Expand Up @@ -544,72 +422,6 @@ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn


def generate_enc_layer():
hidden_size = 1024
intermediate_size = 1024 * 4
heads = 16
hidden_dropout_ratio = 0.0
attn_dropout_ratio = 0.0
activation_dropout_ratio = 0.0
pre_layer_norm = True
layer = TransformerEncoderLayer(
hidden_size,
intermediate_size,
heads,
hidden_dropout_ratio,
attn_dropout_ratio,
activation_dropout_ratio,
pre_layer_norm,
activation_fn="relu",
)
layer.to(torch.device("cuda:0"), dtype=torch.half)
return layer


def generate_dec_layer():
hidden_size = 1024
intermediate_size = 1024 * 4
heads = 16
hidden_dropout_ratio = 0.0
attn_dropout_ratio = 0.0
activation_dropout_ratio = 0.0
pre_layer_norm = True
layer = TransformerDecoderLayer(
embed_dim=hidden_size,
ffn_embed_dim=intermediate_size,
nhead=heads,
encoder_embed_dim=hidden_size,
dropout=hidden_dropout_ratio,
attn_dropout=attn_dropout_ratio,
activation_dropout=activation_dropout_ratio,
normalize_before=pre_layer_norm,
activation_fn="relu",
)

layer.to(torch.device("cuda:0"), dtype=torch.half)
return layer


def generate_bert_enc_layer():
hidden_size = 1024
intermediate_size = 1024 * 4
heads = 16
hidden_dropout_ratio = 0.0
attn_dropout_ratio = 0.0
activation_dropout_ratio = 0.0
layer = TransformerSentenceEncoderLayer(
hidden_size,
intermediate_size,
heads,
hidden_dropout_ratio,
attn_dropout_ratio,
activation_dropout_ratio,
activation_fn="gelu",
)
layer.to(torch.device("cuda:0"))
return layer


class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.

Expand Down Expand Up @@ -674,7 +486,7 @@ def forward(
).detach()


class TransformerEmbeddingLayer(nn.Module):
class FSTransformerEmbeddingLayer(nn.Module):
def __init__(
self, vocab_size, embedding_dim, max_seq_len, padding_idx, dropout, fp16
):
Expand Down Expand Up @@ -703,21 +515,97 @@ def forward(self, input):
return x


def generate_emb_layer(ls_emb_config):
layer = TransformerEmbeddingLayer(
ls_emb_config.vocab_size,
ls_emb_config.embedding_dim,
ls_emb_config.max_seq_len,
ls_emb_config.padding_idx,
ls_emb_config.dropout,
ls_emb_config.fp16,
)
dtype = torch.float16 if ls_emb_config.fp16 else torch.float32
layer.to(torch.device("cuda:0"), dtype=dtype)

return layer

class FSCrossEntropyLayer(nn.Module):
def __init__(self, epsilon, ignore_index):
super().__init__()

if __name__ == "__main__":
generate_enc_layer()
generate_dec_layer()
self.epsilon = epsilon
self.ignore_index = ignore_index

def label_smoothed_nll_loss(self, lprobs, target, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if self.ignore_index is not None:
pad_mask = target.eq(self.ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - self.epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss

def forward(self, inputs, targets):
x = torch.nn.functional.log_softmax(inputs, dim=-1, dtype=torch.float32)
loss, nll_loss = self.label_smoothed_nll_loss(x, targets)
loss = loss.to(inputs)
nll_loss = nll_loss.to(inputs)

return loss, nll_loss


def get_fairseq_enc_params(fairseq_layer):
initial_weights = []
initial_biases = []

initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone())

initial_weights.append(fairseq_layer.fc1.weight.detach().clone())
initial_biases.append(fairseq_layer.fc1.bias.detach().clone())
initial_weights.append(fairseq_layer.fc2.weight.detach().clone())
initial_biases.append(fairseq_layer.fc2.bias.detach().clone())
initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone())
return initial_weights, initial_biases


def get_fairseq_dec_params(fairseq_layer):
initial_weights = []
initial_biases = []

initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn_layer_norm.bias.detach().clone())

initial_weights.append(fairseq_layer.encodec_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.encodec_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.encodec_attn.out_proj.bias.detach().clone())
initial_weights.append(
fairseq_layer.encodec_attn_layer_norm.weight.detach().clone()
)
initial_biases.append(fairseq_layer.encodec_attn_layer_norm.bias.detach().clone())

initial_weights.append(fairseq_layer.fc1.weight.detach().clone())
initial_biases.append(fairseq_layer.fc1.bias.detach().clone())
initial_weights.append(fairseq_layer.fc2.weight.detach().clone())
initial_biases.append(fairseq_layer.fc2.bias.detach().clone())
initial_weights.append(fairseq_layer.final_layer_norm.weight.detach().clone())
initial_biases.append(fairseq_layer.final_layer_norm.bias.detach().clone())
return initial_weights, initial_biases
Loading