-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain_on_logmel.py
139 lines (108 loc) · 4.71 KB
/
train_on_logmel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from core import *
from data_loader import *
from util import *
import torch.optim as optim
import torch.backends.cudnn as cudnn
import time
np.random.seed(1001)
def main():
train = pd.read_csv('../input/train.csv')
LABELS = list(train.label.unique())
label_idx = {label: i for i, label in enumerate(LABELS)}
train.set_index("fname")
train["label_idx"] = train.label.apply(lambda x: label_idx[x])
if DEBUG:
train = train[:500]
skf = StratifiedKFold(n_splits=config.n_folds)
for foldNum, (train_split, val_split) in enumerate(skf.split(train, train.label_idx)):
end = time.time()
# split the dataset for cross-validation
train_set = train.iloc[train_split]
train_set = train_set.reset_index(drop=True)
val_set = train.iloc[val_split]
val_set = val_set.reset_index(drop=True)
logging.info("Fold {0}, Train samples:{1}, val samples:{2}"
.format(foldNum, len(train_set), len(val_set)))
# define train loader and val loader
trainSet = Freesound_logmel(config=config, frame=train_set,
transform=transforms.Compose([ToTensor()]),
mode="train")
train_loader = DataLoader(trainSet, batch_size=config.batch_size, shuffle=True, num_workers=4)
valSet = Freesound_logmel(config=config, frame=val_set,
transform=transforms.Compose([ToTensor()]),
mode="train")
val_loader = DataLoader(valSet, batch_size=config.batch_size, shuffle=False, num_workers=4)
model = run_method_by_string(config.arch)(pretrained=config.pretrain)
if config.cuda:
model.cuda()
# define loss function (criterion) and optimizer
if config.mixup:
train_criterion = cross_entropy_onehot
else:
train_criterion = nn.CrossEntropyLoss().cuda()
val_criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
cudnn.benchmark = True
train_on_fold(model, train_criterion, val_criterion,
optimizer, train_loader, val_loader, config, foldNum)
time_on_fold = time.strftime('%Hh:%Mm:%Ss', time.gmtime(time.time()-end))
logging.info("--------------Time on fold {}: {}--------------\n"
.format(foldNum, time_on_fold))
# Train on the whole training set
# foldNum = config.n_folds
# end = time.time()
# logging.info("Fold {0}, Train samples:{1}."
# .format(foldNum, len(train)))
#
# # define train loader and val loader
# trainSet = Freesound_logmel(config=config, frame=train,
# transform=transforms.Compose([ToTensor()]),
# mode="train")
# train_loader = DataLoader(trainSet, batch_size=config.batch_size, shuffle=True, num_workers=4)
#
# model = run_method_by_string(config.arch)(pretrained=config.pretrain)
#
# if config.cuda:
# model.cuda()
#
# # define loss function (criterion) and optimizer
# if config.mixup:
# train_criterion = cross_entropy_onehot
# else:
# train_criterion = cross_entropy_onehot
#
# optimizer = optim.SGD(model.parameters(), lr=config.lr,
# momentum=config.momentum,
# weight_decay=config.weight_decay)
#
# cudnn.benchmark = True
#
# train_all_data(model, train_criterion, optimizer, train_loader, config, foldNum)
#
# time_on_fold = time.strftime('%Hh:%Mm:%Ss', time.gmtime(time.time() - end))
# logging.info("--------------Time on fold {}: {}--------------\n"
# .format(foldNum, time_on_fold))
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
DEBUG = False
config = Config(sampling_rate=22050,
audio_duration=1.5,
batch_size=128,
n_folds=5,
data_dir="../logmel+delta_w80_s10_m64",
model_dir='../model/logmel_delta_resnet101',
prediction_dir='../prediction/logmel_delta_resnet101',
arch='resnet101_mfcc',
lr=0.01,
pretrain=True,
mixup=False,
# epochs=100)
epochs=50)
# create log
logging = create_logging('../log', filemode='a')
logging.info(os.path.abspath(__file__))
attrs = '\n'.join('%s:%s' % item for item in vars(config).items())
logging.info(attrs)
main()