-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
112 lines (96 loc) · 3.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import math
import time
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from torch.autograd import Variable
from utils import *
from datetime import datetime
import dataset
import models
from zero_shot import zero_shot_eval
def main():
global args, input_resolution, test_datasets
#* set random seed
if args.seed > 0:
set_seed(args.seed)
else:
cudnn.benchmark = True
test_datasets = args.test_dataset
args.global_rank = int(os.environ['RANK'])
#* distribute init
dist.init_process_group(backend='nccl',
world_size=int(os.environ['WORLD_SIZE']),
rank=int(os.environ['RANK']))
torch.cuda.set_device(args.local_rank)
world_size = dist.get_world_size()
if args.global_rank == 0:
print(f'world_size: {world_size}')
device = torch.device('cuda', args.local_rank)
#* create model
model = models.build_model(args.visual_model).cuda()
input_resolution = model.visual.input_resolution
#* sync bn
if args.global_rank != -1 and args.use_bn_sync:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
#* set up criterion
criterion = nn.CrossEntropyLoss()
#* evaluate
if args.evaluate:
if not os.path.isfile(args.evaluate):
print('invalid checkpoint: {}'.format(args.evaluate))
return
else:
checkpoint = torch.load(args.evaluate, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
if args.local_rank == 0:
print("loaded checkpoint '{}' (epoch {})".format(
args.evaluate, checkpoint.get('epoch', -1)))
#* DDP
if args.global_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[int(os.environ['LOCAL_RANK'])],
output_device=args.local_rank,
find_unused_parameters=True
)
if args.evaluate:
test_dataset = list(args.test_dataset.split('+'))
for idx in range(len(test_dataset)):
args.test_dataset = test_dataset[idx]
data_loaders = dataset.load_data(args, input_resolution=input_resolution, _type='test')
test_loader = data_loaders['test_loader'].dataloader
with torch.no_grad():
test_prec1, test_prec5 = test(
test_loader, model, criterion, 0)
return
def forward_test(data_loader, model, criterion, epoch, training=False):
device = torch.device('cuda', args.local_rank)
if args.test_dataset == 'imagenet':
zero_shot_metrics = zero_shot_eval(model, data_loader, epoch, args)
top1, top5 = zero_shot_metrics['top1'], zero_shot_metrics['top5']
if args.local_rank == 0:
print('{phase}\t'
'Prec@1/5 {top1:.2f}/{top5:.2f} \t'
.format(phase='ImageNet Zeroshot', top1=top1, top5=top5))
elif args.test_dataset in ['dtd','flowers','cifar10','cifar100','car','pet','caltech','aircraft','food','sun','sat']:
zero_shot_metrics = zero_shot_eval(model, data_loader, epoch, args)
top1, top5 = zero_shot_metrics['top1'], zero_shot_metrics['top5']
if args.local_rank == 0:
print('{phase}\t'
'Prec@1/5 {top1:.2f}/{top5:.2f} \t'
.format(phase='{} Zeroshot'.format(args.test_dataset), top1=top1, top5=top5))
return top1, top5
def test(data_loader, model, criterion, epoch):
model.eval()
return forward_test(data_loader, model, criterion, epoch,
training=False)
if __name__ == '__main__':
main()