forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 10
/
run.py
90 lines (73 loc) · 2.68 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import sys
from typing import Union, OrderedDict
from dotenv import load_dotenv
# Load the .env file if it exists
load_dotenv()
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports
# import toolkit.cuda_malloc
# turn off diffusers telemetry until I can figure out how to make it opt-in
os.environ['DISABLE_TELEMETRY'] = 'YES'
# check if we have DEBUG_TOOLKIT in env
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
# set torch to trace mode
import torch
torch.autograd.set_detect_anomaly(True)
import argparse
from toolkit.job import get_job
def print_end_message(jobs_completed, jobs_failed):
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
print("")
print("========================================")
print("Result:")
if len(completed_string) > 0:
print(f" - {completed_string}")
if len(failure_string) > 0:
print(f" - {failure_string}")
print("========================================")
def main():
parser = argparse.ArgumentParser()
# require at lease one config file
parser.add_argument(
'config_file_list',
nargs='+',
type=str,
help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
)
# flag to continue if failed job
parser.add_argument(
'-r', '--recover',
action='store_true',
help='Continue running additional jobs even if a job fails'
)
# flag to continue if failed job
parser.add_argument(
'-n', '--name',
type=str,
default=None,
help='Name to replace [name] tag in config file, useful for shared config file'
)
args = parser.parse_args()
config_file_list = args.config_file_list
if len(config_file_list) == 0:
raise Exception("You must provide at least one config file")
jobs_completed = 0
jobs_failed = 0
print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
for config_file in config_file_list:
try:
job = get_job(config_file, args.name)
job.run()
job.cleanup()
jobs_completed += 1
except Exception as e:
print(f"Error running job: {e}")
jobs_failed += 1
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
if __name__ == '__main__':
main()