From 71c5dda83c925733839e2774006dfdd9e78efe3e Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 27 Mar 2024 21:45:08 +0800 Subject: [PATCH] refactor: simplify MultiHeadAttention, put dropout, layernorm, and residual connection into EncoderLayer; --- pypots/nn/modules/transformer/attention.py | 23 +++++++--------------- pypots/nn/modules/transformer/layers.py | 22 ++++++++++++--------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/pypots/nn/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py index 6fc5e21f..89684473 100644 --- a/pypots/nn/modules/transformer/attention.py +++ b/pypots/nn/modules/transformer/attention.py @@ -106,12 +106,12 @@ class MultiHeadAttention(nn.Module): d_v: The dimension of the value tensor. - dropout: - The dropout rate. - attn_dropout: The dropout rate for the attention map. + attn_temperature: + The temperature for scaling. Default is None, which means d_k**0.5 will be applied. + """ def __init__( @@ -120,11 +120,13 @@ def __init__( d_model: int, d_k: int, d_v: int, - dropout: float, attn_dropout: float, + attn_temperature: float = None, ): super().__init__() + attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature + self.n_heads = n_heads self.d_k = d_k self.d_v = d_v @@ -133,12 +135,9 @@ def __init__( self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False) - self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout) + self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout) self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) - self.dropout = nn.Dropout(dropout) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) - def forward( self, q: torch.Tensor, @@ -177,7 +176,6 @@ def forward( # keep useful variables batch_size, n_steps = q.size(0), q.size(1) - residual = v # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v] q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k) @@ -198,11 +196,4 @@ def forward( v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1) v = self.fc(v) - # apply dropout and residual connection - v = self.dropout(v) - v += residual - - # apply layer-norm - v = self.layer_norm(v) - return v, attn_weights diff --git a/pypots/nn/modules/transformer/layers.py b/pypots/nn/modules/transformer/layers.py index 6fd1efd2..8f209a2d 100644 --- a/pypots/nn/modules/transformer/layers.py +++ b/pypots/nn/modules/transformer/layers.py @@ -103,9 +103,9 @@ def __init__( attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) def forward( @@ -138,6 +138,14 @@ def forward( enc_input, attn_mask=src_mask, ) + + # apply dropout and residual connection + enc_output = self.dropout(enc_output) + enc_output += enc_input + + # apply layer-norm + enc_output = self.layer_norm(enc_output) + enc_output = self.pos_ffn(enc_output) return enc_output, attn_weights @@ -181,12 +189,8 @@ def __init__( attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) - self.enc_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) def forward(