Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify MultiHeadAttention #322

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions pypots/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ class MultiHeadAttention(nn.Module):
d_v:
The dimension of the value tensor.

dropout:
The dropout rate.

attn_dropout:
The dropout rate for the attention map.

attn_temperature:
The temperature for scaling. Default is None, which means d_k**0.5 will be applied.

"""

def __init__(
Expand All @@ -120,11 +120,13 @@ def __init__(
d_model: int,
d_k: int,
d_v: int,
dropout: float,
attn_dropout: float,
attn_temperature: float = None,
):
super().__init__()

attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature

self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
Expand All @@ -133,12 +135,9 @@ def __init__(
self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)

self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout)
self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout)
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

def forward(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -177,7 +176,6 @@ def forward(

# keep useful variables
batch_size, n_steps = q.size(0), q.size(1)
residual = v

# now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
Expand All @@ -198,11 +196,4 @@ def forward(
v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1)
v = self.fc(v)

# apply dropout and residual connection
v = self.dropout(v)
v += residual

# apply layer-norm
v = self.layer_norm(v)

return v, attn_weights
22 changes: 13 additions & 9 deletions pypots/nn/modules/transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)

def forward(
Expand Down Expand Up @@ -138,6 +138,14 @@ def forward(
enc_input,
attn_mask=src_mask,
)

# apply dropout and residual connection
enc_output = self.dropout(enc_output)
enc_output += enc_input

# apply layer-norm
enc_output = self.layer_norm(enc_output)

enc_output = self.pos_ffn(enc_output)
return enc_output, attn_weights

Expand Down Expand Up @@ -181,12 +189,8 @@ def __init__(
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.enc_attn = MultiHeadAttention(
n_heads, d_model, d_k, d_v, dropout, attn_dropout
)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)

def forward(
Expand Down
Loading