Skip to content

Commit

Permalink
Merge pull request #322 from WenjieDu/(refactor)simplify_mha
Browse files Browse the repository at this point in the history
Simplify MultiHeadAttention
  • Loading branch information
WenjieDu authored Mar 28, 2024
2 parents bfbfcec + 71c5dda commit 0e727af
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 25 deletions.
23 changes: 7 additions & 16 deletions pypots/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ class MultiHeadAttention(nn.Module):
d_v:
The dimension of the value tensor.
dropout:
The dropout rate.
attn_dropout:
The dropout rate for the attention map.
attn_temperature:
The temperature for scaling. Default is None, which means d_k**0.5 will be applied.
"""

def __init__(
Expand All @@ -120,11 +120,13 @@ def __init__(
d_model: int,
d_k: int,
d_v: int,
dropout: float,
attn_dropout: float,
attn_temperature: float = None,
):
super().__init__()

attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature

self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
Expand All @@ -133,12 +135,9 @@ def __init__(
self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)

self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout)
self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout)
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -177,7 +176,6 @@ def forward(

# keep useful variables
batch_size, n_steps = q.size(0), q.size(1)
residual = v

# now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
Expand All @@ -198,11 +196,4 @@ def forward(
v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1)
v = self.fc(v)

# apply dropout and residual connection
v = self.dropout(v)
v += residual

# apply layer-norm
v = self.layer_norm(v)

return v, attn_weights
22 changes: 13 additions & 9 deletions pypots/nn/modules/transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)

def forward(
Expand Down Expand Up @@ -138,6 +138,14 @@ def forward(
enc_input,
attn_mask=src_mask,
)

# apply dropout and residual connection
enc_output = self.dropout(enc_output)
enc_output += enc_input

# apply layer-norm
enc_output = self.layer_norm(enc_output)

enc_output = self.pos_ffn(enc_output)
return enc_output, attn_weights

Expand Down Expand Up @@ -181,12 +189,8 @@ def __init__(
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.enc_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)

def forward(
Expand Down

0 comments on commit 0e727af

Please sign in to comment.