-
Notifications
You must be signed in to change notification settings - Fork 46
/
train_classifier.py
103 lines (89 loc) · 4.55 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import argparse
import math
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm import tqdm
from model import *
from utils import setup_seed
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--max_device_batch_size', type=int, default=256)
parser.add_argument('--base_learning_rate', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.05)
parser.add_argument('--total_epoch', type=int, default=100)
parser.add_argument('--warmup_epoch', type=int, default=5)
parser.add_argument('--pretrained_model_path', type=str, default=None)
parser.add_argument('--output_model_path', type=str, default='vit-t-classifier-from_scratch.pt')
args = parser.parse_args()
setup_seed(args.seed)
batch_size = args.batch_size
load_batch_size = min(args.max_device_batch_size, batch_size)
assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size
train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.pretrained_model_path is not None:
model = torch.load(args.pretrained_model_path, map_location='cpu')
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'pretrain-cls'))
else:
model = MAE_ViT()
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'scratch-cls'))
model = ViT_Classifier(model.encoder, num_classes=10).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())
optim = torch.optim.AdamW(model.parameters(), lr=args.base_learning_rate * args.batch_size / 256, betas=(0.9, 0.999), weight_decay=args.weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (args.warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / args.total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)
best_val_acc = 0
step_count = 0
optim.zero_grad()
for e in range(args.total_epoch):
model.train()
losses = []
acces = []
for img, label in tqdm(iter(train_dataloader)):
step_count += 1
img = img.to(device)
label = label.to(device)
logits = model(img)
loss = loss_fn(logits, label)
acc = acc_fn(logits, label)
loss.backward()
if step_count % steps_per_update == 0:
optim.step()
optim.zero_grad()
losses.append(loss.item())
acces.append(acc.item())
lr_scheduler.step()
avg_train_loss = sum(losses) / len(losses)
avg_train_acc = sum(acces) / len(acces)
print(f'In epoch {e}, average training loss is {avg_train_loss}, average training acc is {avg_train_acc}.')
model.eval()
with torch.no_grad():
losses = []
acces = []
for img, label in tqdm(iter(val_dataloader)):
img = img.to(device)
label = label.to(device)
logits = model(img)
loss = loss_fn(logits, label)
acc = acc_fn(logits, label)
losses.append(loss.item())
acces.append(acc.item())
avg_val_loss = sum(losses) / len(losses)
avg_val_acc = sum(acces) / len(acces)
print(f'In epoch {e}, average validation loss is {avg_val_loss}, average validation acc is {avg_val_acc}.')
if avg_val_acc > best_val_acc:
best_val_acc = avg_val_acc
print(f'saving best model with acc {best_val_acc} at {e} epoch!')
torch.save(model, args.output_model_path)
writer.add_scalars('cls/loss', {'train' : avg_train_loss, 'val' : avg_val_loss}, global_step=e)
writer.add_scalars('cls/acc', {'train' : avg_train_acc, 'val' : avg_val_acc}, global_step=e)