forked from d8ahazard/sd_dreambooth_extension
-
Notifications
You must be signed in to change notification settings - Fork 0
/
postinstall.py
247 lines (200 loc) · 8.58 KB
/
postinstall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import filecmp
import importlib.util
import json
import os
import shutil
import subprocess
import sys
import sysconfig
from dataclasses import dataclass
import git
from packaging.version import Version
from dreambooth import shared
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
def actual_install():
if os.environ.get("PUBLIC_KEY", None):
print("Docker, returning.")
shared.launch_error = None
return
base_dir = os.path.dirname(os.path.realpath(__file__))
try:
repo = git.Repo(base_dir)
revision = repo.rev_parse("HEAD")
except:
revision = ""
print("If submitting an issue on github, please provide the full startup log for debugging purposes.")
print("")
print("Initializing Dreambooth")
print(f"Dreambooth revision: {revision}")
install_requirements()
check_xformers()
check_bitsandbytes()
check_versions()
check_torch_unsafe_load()
def pip_install(*args):
output = subprocess.check_output(
[sys.executable, "-m", "pip", "install"] + list(args),
stderr=subprocess.STDOUT,
)
for line in output.decode().split("\n"):
if "Successfully installed" in line:
print(line)
def install_requirements():
dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False)
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
req_file_startup_arg = os.environ.get("REQS_FILE", "requirements_versions.txt")
if dreambooth_skip_install or req_file == req_file_startup_arg:
return
has_diffusers = importlib.util.find_spec("diffusers") is not None
has_tqdm = importlib.util.find_spec("tqdm") is not None
transformers_version = importlib_metadata.version("transformers")
try:
pip_install("-r", req_file)
if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
print()
print("Does your project take forever to startup?")
print("Repetitive dependency installation may be the reason.")
print("Automatic1111's base project sets strict requirements on outdated dependencies.")
print("If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.")
print()
except subprocess.CalledProcessError as grepexc:
error_msg = grepexc.stdout.decode()
print_requirement_installation_error(error_msg)
raise grepexc
def check_xformers():
"""
Install xformers 0.0.17 if necessary
"""
try:
xformers_version = importlib_metadata.version("xformers")
xformers_outdated = Version(xformers_version) < Version("0.0.17.dev")
if xformers_outdated:
try:
torch_version = importlib_metadata.version("torch")
is_torch_1 = Version(torch_version) < Version("2")
if is_torch_1:
pip_install("xformers==0.0.17.dev476")
else:
pip_install("xformers", "--pre")
except subprocess.CalledProcessError as grepexc:
error_msg = grepexc.stdout.decode()
print_xformers_installation_error(error_msg)
except:
pass
def check_bitsandbytes():
"""
Check for "different" B&B Files and copy only if necessary
"""
if os.name == "nt":
try:
bnb_src = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bitsandbytes_windows")
bnb_dest = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
filecmp.clear_cache()
for file in os.listdir(bnb_src):
src_file = os.path.join(bnb_src, file)
if file == "main.py" or file == "paths.py":
dest = os.path.join(bnb_dest, "cuda_setup")
else:
dest = bnb_dest
shutil.copy2(src_file, dest)
except:
pass
@dataclass
class Dependency:
module: str
version: str
version_comparison: str = "min"
required: bool = True
def check_versions():
dependencies = [
Dependency(module="xformers", version="0.0.17.dev", required=False),
Dependency(module="torch", version="1.13.1+cu116"),
Dependency(module="torchvision", version="0.14.1+cu116"),
Dependency(module="accelerate", version="0.17.1"),
Dependency(module="diffusers", version="0.14.0"),
Dependency(module="transformers", version="4.25.1"),
Dependency(module="bitsandbytes", version="0.35.4", version_comparison="exact"),
]
launch_errors = []
for dependency in dependencies:
module = dependency.module
has_module = importlib.util.find_spec(module) is not None
installed_ver = str(importlib_metadata.version(module)) if has_module else None
if not installed_ver:
if module != "xformers":
launch_errors.append(f"{module} not installed.")
print(f"[!] {module} NOT installed.")
continue
required_version = dependency.version
required_comparison = dependency.version_comparison
if required_comparison == "min" and Version(installed_ver) < Version(required_version):
if dependency.required:
launch_errors.append(f"{module} is below the required {required_version} version.")
print(f"[!] {module} version {installed_ver} installed.")
elif required_comparison == "exact" and Version(installed_ver) != Version(required_version):
launch_errors.append(f"{module} is not the required {required_version} version.")
print(f"[!] {module} version {installed_ver} installed.")
else:
print(f"[+] {module} version {installed_ver} installed.")
try:
if len(launch_errors):
print_launch_errors(launch_errors)
os.environ["ERRORS"] = json.dumps(launch_errors)
else:
os.environ["ERRORS"] = ""
except Exception as e:
print(e)
def print_requirement_installation_error(err):
print("# Requirement installation exception:")
for line in err.split('\n'):
line = line.strip()
if line:
print(line)
def print_xformers_installation_error(err):
torch_ver = importlib_metadata.version("torch")
print()
print("#######################################################################################################")
print("# XFORMERS ISSUE DETECTED #")
print("#######################################################################################################")
print("#")
print(f"# Dreambooth could not find a compatible version of xformers (>= 0.0.17.dev built with torch {torch_ver})")
print("# xformers will not be available for Dreambooth. Consider upgrading to Torch 2.")
print("#")
print("# Xformers installation exception:")
for line in err.split('\n'):
line = line.strip()
if line:
print(line)
print("#")
print("#######################################################################################################")
def print_launch_errors(launch_errors):
print()
print("#######################################################################################################")
print("# LIBRARY ISSUE DETECTED #")
print("#######################################################################################################")
print("#")
print("# " + "\n# ".join(launch_errors))
print("#")
print("# Dreambooth may not work properly.")
print("#")
print("# TROUBLESHOOTING")
print("# 1. Fully restart your project (not just the webpage)")
print("# 2. Update your A1111 project and extensions")
print("# 3. Dreambooth requirements should have installed automatically, but you can manually install them")
print("# by running the following 4 commands from the A1111 project root:")
print("cd venv/Scripts")
print("activate")
print("cd ../..")
print("pip install -r ./extensions/sd_dreambooth_extension/requirements.txt")
print("#######################################################################################################")
def check_torch_unsafe_load():
try:
from modules import safe
safe.load = safe.unsafe_torch_load
import torch
torch.load = safe.unsafe_torch_load
except:
pass