forked from Jo0o0Hyung/Dual-Attention-for-VAD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DNN_Model.py
26 lines (20 loc) · 829 Bytes
/
DNN_Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
# fully connected layer with 1 hidden layer.
# PostDNN determines the each time steps of LSTM's output.
class PostDNN(nn.Module):
def __init__(self, input_size, hidden_size, num_classes=1):
super(PostDNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_classes = num_classes
self.pDNN = nn.Sequential(
nn.Linear(in_features=self.input_size, out_features=self.hidden_size),
nn.ReLU(),
nn.Linear(in_features=self.hidden_size, out_features=self.num_classes)
)
self.sig = nn.Sigmoid()
def forward(self, x):
linear_out = self.pDNN(x)
linear_out = linear_out.squeeze(-1)
sigmoid_out = self.sig(linear_out)
return linear_out, sigmoid_out