-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·122 lines (106 loc) · 5.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import time
from torch import nn
from models import load_model
from utils.hp import load_hps
from datasets.dataset_torch import Dataset
from utils.plotting import plot
import torch
from utils.callback import Model_checkpoint, EarlyStopping
from adabelief_pytorch import AdaBelief
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def train(ds_train, model, criterion, optimizer, device):
model.train()
loss_ = 0
train_acc = 0
num_image = 0
for x, y_true in ds_train:
optimizer.zero_grad()
X = x.to(device)
Y = y_true.to(device)
logit = model(X)
loss = criterion(logit.squeeze(1), Y.long())
loss_ += loss.item() * x.size(0)
Max, num = torch.max(logit, 1)
train_acc += torch.sum(num == Y)
num_image += x.size(0)
loss.backward()
optimizer.step()
total_loss_train = loss_ / num_image
total_acc_train = train_acc / num_image
return model, total_loss_train, total_acc_train.item()
def valid(ds_valid, model, criterion, device):
model.eval()
loss_ = 0
valid_acc = 0
num_image = 0
for x, y_true in ds_valid:
X = x.to(device)
Y = y_true.to(device)
logit = model(X)
loss = criterion(logit.squeeze(1), Y.long())
loss_ += loss.item() * x.size(0)
Max, num = torch.max(logit, 1)
valid_acc += torch.sum(num == Y)
num_image += x.size(0)
total_loss_valid = loss_ / num_image
total_acc_valid = valid_acc / num_image
return model, total_loss_valid, total_acc_valid.item()
def training(model, ds_train, ds_valid, criterion, optimizer, scheduler, device, epochs):
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []
history = {}
for epoch in range(epochs):
since = time.time()
print('\nEpoch {}/{}'.format(epoch + 1, epochs))
print('-' * 10)
model, total_loss_train, total_acc_train = train(ds_train, model, criterion, optimizer, device)
train_losses.append(total_loss_train)
train_accs.append(total_acc_train)
with torch.no_grad():
model, total_loss_valid, total_acc_valid = valid(ds_valid, model, criterion, device)
valid_losses.append(total_loss_valid)
valid_accs.append(total_acc_valid)
scheduler.step(total_loss_valid)
metrics = {'train_loss': train_losses, 'train_acc': train_accs, 'val_loss': valid_losses,
'val_acc': valid_accs}
Model_checkpoint(path='./', metrics=metrics, model=model,
monitor='val_acc', verbose=True,
file_name="best_acc.pth")
Model_checkpoint(path='./', metrics=metrics, model=model,
monitor='val_loss', verbose=True,
file_name="best_loss.pth")
Early_Stopping = EarlyStopping()
early_stop = Early_Stopping.Early_Stopping(metrics, 20, monitor='val_loss', verbose=True)
history = {'epoch': epochs, 'accuracy': train_accs, 'loss': train_losses, 'val_accuracy': valid_accs,
'val_loss': valid_losses, 'LR': optimizer.param_groups[0]['lr']}
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print("Epoch:", epoch + 1, "- Train Loss:", total_loss_train, "- Train Accuracy:", total_acc_train,
"- Validation Loss:", total_loss_valid, "- Validation Accuracy:", total_acc_valid)
if Early_Stopping.Early_Stopping(metrics, 20, monitor='val_loss', verbose=True) is True:
return model, history
return model, history
def train_torch():
hps = load_hps(dataset_dir="./fake_real-faces/", model_name='regnet', n_epochs=150, batch_size=16,
learning_rate=0.001,
lr_reducer_factor=0.2,
lr_reducer_patience=8, img_size=299, framework='pytorch')
model = load_model(model_name=hps['model_name'])
if hps['framework'] == 'pytorch':
train_loader, val_loader, test_loader = Dataset.pytorch_preprocess(dataset_dir=hps['dataset_dir'],
img_size=hps['img_size'],
batch_size=hps['batch_size'],
split_size=0.3, augment=True)
model.to(device)
criterion = nn.CrossEntropyLoss
optimizer = AdaBelief(model.parameters(), lr=hps['learning_rate'])
reduce_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=hps['lr_reducer_factor'],
patience=hps['lr_reducer_patience'],
verbose=True)
model, history, early_stop = training(model, train_loader, val_loader, criterion, optimizer,
reduce_on_plateau, device, hps['n_epochs'])
plot(history)
if __name__ == '__main__':
train_torch()