-
Notifications
You must be signed in to change notification settings - Fork 2
/
extract.py
147 lines (114 loc) · 5.44 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
"""
@date: 2023/8/20 下午12:17
@file: extract.py
@author: zj
@description:
Usage - Extract Features:
$ python extract.py --arch resnet18 --data toy.yaml
Usage - Reduce dimension:
$ python extract.py --arch resnet18 --data toy.yaml --enhance PCA --reduce 50 --learn-pca
"""
import os
import sys
import argparse
from argparse import Namespace
from pathlib import Path
import torch
from torch.utils.data import DataLoader
import torchvision.models as models
from torchvision.models.resnet import ResNet, resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models.mobilenet import MobileNetV2, MobileNetV3, mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small
from simpleir.utils.logger import LOGGER
from simpleir.utils.misc import print_args, colorstr
from simpleir.data.build import build_data
from simpleir.extract.helper import ExtractHelper, AggregateType, EnhanceType
from simpleir.utils.fileutil import increment_path, check_yaml, yaml_load
from simpleir.utils.general import save_features
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # SimpleIR root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
def parse_opt():
model_names = sorted(name for name in models.resnet.__dict__ if name.islower() and name.startswith('resnet'))
model_names += sorted(
name for name in models.mobilenet.__dict__ if name.islower() and name.startswith('mobilenet_'))
# print(model_names)
aggregate_types = [e.value for e in AggregateType]
enhance_types = [e.value for e in EnhanceType]
parser = argparse.ArgumentParser()
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('--data', type=str, default=ROOT / 'configs/data/toy.yaml', help='dataset.yaml path')
parser.add_argument('--aggregate', type=str, default='IDENTITY', choices=aggregate_types,
help='aggregate type: ' +
' | '.join(aggregate_types) +
' (default: IDENTITY)')
parser.add_argument('--enhance', type=str, default='IDENTITY', choices=enhance_types,
help='enhance type: ' +
' | '.join(enhance_types) +
' (default: IDENTITY)')
parser.add_argument('--reduce', type=int, default=512, help='reduce dimension')
parser.add_argument('--learn-pca', action='store_true', default=False,
help='whether to perform PCA learning')
parser.add_argument('--pca-path', type=str, default=None, help='load the learned PCA model')
parser.add_argument('--project', default=ROOT / 'runs/extract', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name')
opt = parser.parse_args()
return opt
def do_extract(data_loader: DataLoader, extract_helper: ExtractHelper, save_dir: str, is_gallery: bool = False):
image_name_list, label_list, feat_tensor_list = extract_helper.run(data_loader, is_gallery=is_gallery)
# Save
if is_gallery:
feat_dir = os.path.join(save_dir, 'gallery')
info_path = os.path.join(save_dir, "gallery.pkl")
else:
feat_dir = os.path.join(save_dir, 'query')
info_path = os.path.join(save_dir, "query.pkl")
assert hasattr(data_loader.dataset, 'classes')
classes = data_loader.dataset.classes
save_features(classes, image_name_list, label_list, feat_tensor_list, feat_dir, info_path)
def main(opt: Namespace):
# Config
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=False))
print_args(vars(opt))
if not os.path.exists(opt.save_dir):
os.makedirs(opt.save_dir)
opt.data = check_yaml(opt.data)
opt.data = yaml_load(opt.data)
# Data
gallery_loader = build_data(opt.data, is_gallery=True)
query_loader = build_data(opt.data, is_gallery=False)
# Model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = eval(opt.arch)(pretrained=True).to(device)
model.eval()
if isinstance(model, ResNet):
target_layer = model.fc
elif isinstance(model, MobileNetV2):
target_layer = model.classifier[-1]
else:
assert isinstance(model, MobileNetV3)
target_layer = model.classifier[-1]
# Extract
extract_helper = ExtractHelper(model=model,
target_layer=target_layer,
device=device,
aggregate_type=opt.aggregate,
enhance_type=opt.enhance,
reduce_dimension=opt.reduce,
learn_pca=opt.learn_pca,
pca_path=opt.pca_path,
save_dir=opt.save_dir
)
LOGGER.info("Extract Gallery")
do_extract(gallery_loader, extract_helper, opt.save_dir, is_gallery=True)
LOGGER.info("Extract Query")
do_extract(query_loader, extract_helper, opt.save_dir, is_gallery=False)
LOGGER.info(f"Save to {colorstr(opt.save_dir)}")
if __name__ == "__main__":
opt = parse_opt()
main(opt)