-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy patheval_LFW.py
65 lines (57 loc) · 2.28 KB
/
eval_LFW.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.backends.cudnn as cudnn
from nets.arcface import Arcface
from utils.dataloader import LFWDataset
from utils.utils_metrics import test
if __name__ == "__main__":
#--------------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#--------------------------------------#
cuda = True
#--------------------------------------#
# 主干特征提取网络的选择
# mobilefacenet
# mobilenetv1
# iresnet18
# iresnet34
# iresnet50
# iresnet100
# iresnet200
#--------------------------------------#
backbone = "mobilefacenet"
#--------------------------------------#
# 输入图像大小
#--------------------------------------#
input_shape = [112, 112, 3]
#--------------------------------------#
# 训练好的权值文件
#--------------------------------------#
model_path = "model_data/arcface_mobilefacenet.pth"
#--------------------------------------#
# LFW评估数据集的文件路径
# 以及对应的txt文件
#--------------------------------------#
lfw_dir_path = "lfw"
lfw_pairs_path = "model_data/lfw_pair.txt"
#--------------------------------------#
# 评估的批次大小和记录间隔
#--------------------------------------#
batch_size = 256
log_interval = 1
#--------------------------------------#
# ROC图的保存路径
#--------------------------------------#
png_save_path = "model_data/roc_test.png"
test_loader = torch.utils.data.DataLoader(
LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, image_size=input_shape), batch_size=batch_size, shuffle=False)
model = Arcface(backbone=backbone, mode="predict")
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
model = model.eval()
if cuda:
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
model = model.cuda()
test(test_loader, model, png_save_path, log_interval, batch_size, cuda)