-
Notifications
You must be signed in to change notification settings - Fork 25
/
training_xvector.py
127 lines (103 loc) · 5.13 KB
/
training_xvector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May 30 20:22:26 2020
@author: krishna
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
from SpeechDataGenerator import SpeechDataGenerator
import torch.nn as nn
import os
import numpy as np
from torch import optim
import argparse
from models.x_vector_Indian_LID import X_vector
from sklearn.metrics import accuracy_score
from utils.utils import speech_collate
import torch.nn.functional as F
torch.multiprocessing.set_sharing_strategy('file_system')
########## Argument parser
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('-training_filepath',type=str,default='meta/training_feat.txt')
parser.add_argument('-testing_filepath',type=str, default='meta/testing_feat.txt')
parser.add_argument('-validation_filepath',type=str, default='meta/validation_feat.txt')
parser.add_argument('-input_dim', action="store_true", default=257)
parser.add_argument('-num_classes', action="store_true", default=8)
parser.add_argument('-lamda_val', action="store_true", default=0.1)
parser.add_argument('-batch_size', action="store_true", default=256)
parser.add_argument('-use_gpu', action="store_true", default=True)
parser.add_argument('-num_epochs', action="store_true", default=100)
args = parser.parse_args()
### Data related
dataset_train = SpeechDataGenerator(manifest=args.training_filepath,mode='train')
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate)
dataset_val = SpeechDataGenerator(manifest=args.validation_filepath,mode='train')
dataloader_val = DataLoader(dataset_train, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate)
dataset_test = SpeechDataGenerator(manifest=args.testing_filepath,mode='test')
dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate)
## Model related
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = X_vector(args.input_dim, args.num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0, betas=(0.9, 0.98), eps=1e-9)
loss_fun = nn.CrossEntropyLoss()
def train(dataloader_train,epoch):
train_loss_list=[]
full_preds=[]
full_gts=[]
model.train()
for i_batch, sample_batched in enumerate(dataloader_train):
features = torch.from_numpy(np.asarray([torch_tensor.numpy().T for torch_tensor in sample_batched[0]])).float()
labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]]))
features, labels = features.to(device),labels.to(device)
features.requires_grad = True
optimizer.zero_grad()
pred_logits,x_vec = model(features)
#### CE loss
loss = loss_fun(pred_logits,labels)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
#train_acc_list.append(accuracy)
#if i_batch%10==0:
# print('Loss {} after {} iteration'.format(np.mean(np.asarray(train_loss_list)),i_batch))
predictions = np.argmax(pred_logits.detach().cpu().numpy(),axis=1)
for pred in predictions:
full_preds.append(pred)
for lab in labels.detach().cpu().numpy():
full_gts.append(lab)
mean_acc = accuracy_score(full_gts,full_preds)
mean_loss = np.mean(np.asarray(train_loss_list))
print('Total training loss {} and training Accuracy {} after {} epochs'.format(mean_loss,mean_acc,epoch))
def validation(dataloader_val,epoch):
model.eval()
with torch.no_grad():
val_loss_list=[]
full_preds=[]
full_gts=[]
for i_batch, sample_batched in enumerate(dataloader_val):
features = torch.from_numpy(np.asarray([torch_tensor.numpy().T for torch_tensor in sample_batched[0]])).float()
labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]]))
features, labels = features.to(device),labels.to(device)
pred_logits,x_vec = model(features)
#### CE loss
loss = loss_fun(pred_logits,labels)
val_loss_list.append(loss.item())
#train_acc_list.append(accuracy)
predictions = np.argmax(pred_logits.detach().cpu().numpy(),axis=1)
for pred in predictions:
full_preds.append(pred)
for lab in labels.detach().cpu().numpy():
full_gts.append(lab)
mean_acc = accuracy_score(full_gts,full_preds)
mean_loss = np.mean(np.asarray(val_loss_list))
print('Total vlidation loss {} and Validation accuracy {} after {} epochs'.format(mean_loss,mean_acc,epoch))
model_save_path = os.path.join('save_model', 'best_check_point_'+str(epoch)+'_'+str(mean_loss))
state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
torch.save(state_dict, model_save_path)
if __name__ == '__main__':
for epoch in range(args.num_epochs):
train(dataloader_train,epoch)
validation(dataloader_val,epoch)