-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
111 lines (100 loc) · 3.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import logging
import os
def main(
space,
method,
data,
task,
gpath,
rpath,
cpath,
opath,
model,
prompt,
neg_prompt,
count,
):
all_task = task == "report"
output = []
if space == "metric":
if data == "image":
if method == "quality":
if all_task or task == "inception":
output.append(metric_handlers.inception_handler(gpath))
if all_task or task == "frechet":
output.append(metric_handlers.frechet_handler(gpath, rpath))
if all_task or task == "realism":
output.append(metric_handlers.realism_handler(gpath))
elif method == "diversity":
if all_task or task == "perceptual":
output.append(metric_handlers.perceptual_handler(gpath))
if all_task or task == "coverage":
output.append(metric_handlers.coverage_image_handler(gpath, rpath))
if all_task or task == "ssim":
output.append(metric_handlers.ssim_image_handler(gpath))
elif space == "genai":
if data == "text":
if method == "llm":
if task == "diverse":
generator_handlers.prompts_llm_handler(opath, model, prompt, count)
if method == "config":
if task == "llm-diversity":
generator_handlers.llm_diversity_handler()
elif space == "pipeline":
if data == "text":
if method == "full":
pipeline_handlers.full_generation(cpath, opath)
print("\n".join(output))
if __name__ == "__main__":
logging.getLogger("torch").setLevel(logging.CRITICAL)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import handlers.generators as generator_handlers
import handlers.metrics as metric_handlers
import handlers.pipelines as pipeline_handlers
parser = argparse.ArgumentParser(
description="Sharif ML-Lab Data Generation ToolKit"
)
parser.add_argument(
"-s", "--space", type=str, required=True, help="Space Name (e.g. metric, genai)"
)
parser.add_argument(
"-mt",
"--method",
type=str,
required=True,
help="Method Name (e.g. quality, diversity, sdm)",
)
parser.add_argument(
"-d", "--data", type=str, required=True, help="Kind of Data (e.g. image, text)"
)
parser.add_argument(
"-t",
"--task",
type=str,
required=False,
help="Task Name (e.g. inception, xlarge)",
)
parser.add_argument(
"-gp", "--gpath", type=str, required=False, help="Generated Data Path"
)
parser.add_argument(
"-cp", "--cpath", type=str, required=False, help="Caption Data Path"
)
parser.add_argument(
"-rp", "--rpath", type=str, required=False, help="Real Data Path"
)
parser.add_argument(
"-op", "--opath", type=str, required=False, help="Output Data Path"
)
parser.add_argument("-m", "--model", type=str, required=False, help="Model Name")
parser.add_argument("-p", "--prompt", type=str, required=False, help="Prompt")
parser.add_argument(
"-np", "--neg-prompt", type=str, required=False, help="Negative Prompt"
)
parser.add_argument(
"-n", "--count", type=int, required=False, help="Number of Images To Generate"
)
args = parser.parse_args()
main(**vars(args))