-
Notifications
You must be signed in to change notification settings - Fork 4
/
layers.py
47 lines (35 loc) · 1.87 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-
import torch.nn as nn
from sub_layers import MultiHeadAttention, PositionwiseFeedForward
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_inner, n_head, dropout=0.1):
super(EncoderLayer, self).__init__()
d_k = int(d_model / n_head)
d_v = int(d_model / n_head)
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None, encoding_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask, encoding_mask=encoding_mask)
enc_output *= non_pad_mask
enc_output = self.pos_ffn(enc_output)
enc_output *= non_pad_mask
return enc_output, enc_slf_attn
class DecoderLayer(nn.Module):
def __init__(self, d_model, d_inner, n_head, dropout=0.1):
super(DecoderLayer, self).__init__()
d_k = int(d_model / n_head)
d_v = int(d_model / n_head)
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
dec_output *= non_pad_mask
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output *= non_pad_mask
dec_output = self.pos_ffn(dec_output)
dec_output *= non_pad_mask
return dec_output, dec_slf_attn, dec_enc_attn