-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
109 lines (95 loc) · 4.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
from pathlib import Path
import torchaudio
from encoder.AutoEncoders import AutoEncoders
from encoder.huffman import Huffman
from model.save_load_model import *
from train_and_test.test import evaluate_model
from train_and_test.train import train_model
from train_and_test.train_autoencoders import train_autoencoders
def get_encoder(encoder_type, encoder_path):
if encoder_type == 'huffman':
print('Huffman Encoder is being used!')
return Huffman()
elif encoder_type == 'autoencoder':
print('AutoEncoder is being used!')
return AutoEncoders(encoder_path)
else:
return None
def create_folder(path):
directory = Path(path)
if not directory.exists() or not directory.is_dir():
directory.mkdir(parents=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parameters for split model inference')
parser.add_argument('-split_mode', action='store_true', default=False, required=False, help='Mode - split or full')
parser.add_argument('-host', metavar='host', action='store',
default="node0.grp19-cs744-3.uwmadison744-f20-pg0.wisc.cloudlab.us", required=False,
help='Hostname to connect')
parser.add_argument('-port', metavar='Batch Size', action='store', default=60009, required=False,
help='Port to be used')
parser.add_argument('-test', action='store_true', default=False, required=False, help='Test mode')
parser.add_argument('-path', metavar='base-path', action='store', default="./", required=False,
help='The base path for the project')
parser.add_argument('-batch', metavar='Batch Size', action='store', default=10, required=False,
help='Batch size to be used in training set')
parser.add_argument('-epochs', metavar='Epochs', action='store', default=10, required=False,
help='No of Epochs for training')
parser.add_argument('-savefile', metavar='Save File', action='store', default='model.pth', required=False,
help='File for saving the checkpoint')
parser.add_argument('-encoder', metavar='Encoder type', action='store', default='huffman', required=False,
help='Encoder to be used encoding in split model inference')
parser.add_argument('-encoderpath', metavar='Path of saved autoencoder model', action='store',
default='autoencoder.pth', required=False, help='Path of the saved models of autoencoder and '
'decoder')
parser.add_argument('-rank', metavar='Rank of node', action='store', default=0, required=False,
help='Rank of the node')
args = parser.parse_args()
port = int(args.port)
host = args.host
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
"n_class": 29,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
"learning_rate": 5e-4,
"batch_size": int(args.batch),
"epochs": int(args.epochs),
"input_layers": 512,
"hidden_layers": 128,
"output_layers": 32,
"leaky_relu": 0.2
}
node_rank = int(args.rank)
if node_rank < 0 or node_rank > 1:
raise Exception('Rank is incorrect. It should be either 0 or 1!')
base_dataset_directory = "{}/dataset".format(args.path)
create_folder(base_dataset_directory)
train_dataset = None
if not args.test:
train_dataset = torchaudio.datasets.LIBRISPEECH(base_dataset_directory, url='train-clean-100',
download=True)
test_dataset = torchaudio.datasets.LIBRISPEECH(base_dataset_directory, url='test-clean', download=True)
save_filepath = '{}/{}'.format(args.path, args.savefile)
encoder_base_path = '{}/{}'.format(args.path, args.encoderpath)
if args.test:
model = load_model(save_filepath, hparams)
sp_model = load_split_model(save_filepath, hparams)
encoder = get_encoder(args.encoder, encoder_base_path)
if not bool(args.split_mode):
print('Evaluating complete model without any splitting')
evaluate_model(hparams, model, None, test_dataset, encoder, node_rank, host, port)
else:
print('Evaluating split model')
evaluate_model(hparams, None, sp_model, test_dataset, encoder, node_rank, host, port)
else:
if args.encoder == 'autoencoder':
sp_model = load_split_model(save_filepath, hparams)
model = train_autoencoders(sp_model, hparams, train_dataset)
save_model(model, encoder_base_path)
else:
model = train_model(hparams, train_dataset, test_dataset)
save_model(model, save_filepath)