-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_imgs.py
executable file
·50 lines (46 loc) · 1.3 KB
/
eval_imgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import sys
import torch
import argparse
from parse_args import parser_eval
from base_class_ours import IDCLIPScoreCalculator
from clip_eval import IdCLIPEvaluator
# os.chdir('/root/CelebBasis/evaluation')
# sys.path.append('/root/CelebBasis/evaluation')
# 给src image, tgt image, prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--src_root",
type=str,
default="src")
parser.add_argument(
"--save_dir",
type=str,
default="result.csv")
parser.add_argument(
"--eval_folder",
type=str,
default="output")
parser.add_argument(
"--model_dir",
default="weights",
type=str)
parser = parser_eval(parser)
opt = parser.parse_args()
src_root = opt.src_root
id_clip_evaluator = IdCLIPEvaluator(
torch.device('cuda:0'),
# torch.device('cpu'),
clip_model='ViT-B/32',
model_dir=opt.model_dir
)
id_score_calculator = IDCLIPScoreCalculator(
opt.eval_folder,
id_clip_evaluator,
opt.save_dir,
prompt_dir=os.path.join(src_root, 'prompts.txt'),
src_img_dir=os.path.join(src_root, '200'),
src_img_id=os.path.join(src_root, 'image_id.txt'),
)
id_score_calculator.start_calc()