forked from fofr/cog-face-to-many
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
265 lines (227 loc) · 9.78 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os
import shutil
import random
import json
from PIL import Image, ExifTags
from typing import List
from cog import BasePredictor, Input, Path
from helpers.comfyui import ComfyUI
OUTPUT_DIR = "/tmp/outputs"
INPUT_DIR = "/tmp/inputs"
COMFYUI_TEMP_OUTPUT_DIR = "ComfyUI/temp"
with open("face-to-many-api.json", "r") as file:
workflow_json = file.read()
LORA_WEIGHTS_MAPPING = {
"3D": "artificialguybr/3DRedmond-3DRenderStyle-3DRenderAF.safetensors",
"Emoji": "fofr/emoji.safetensors",
"Video game": "artificialguybr/PS1Redmond-PS1Game-Playstation1Graphics.safetensors",
"Pixels": "artificialguybr/PixelArtRedmond-Lite64.safetensors",
"Clay": "artificialguybr/ClayAnimationRedm.safetensors",
"Toy": "artificialguybr/ToyRedmond-FnkRedmAF.safetensors",
}
LORA_TYPES = list(LORA_WEIGHTS_MAPPING.keys())
class Predictor(BasePredictor):
def setup(self):
self.comfyUI = ComfyUI("127.0.0.1:8188")
self.comfyUI.start_server(OUTPUT_DIR, INPUT_DIR)
self.comfyUI.load_workflow(workflow_json, check_inputs=False)
self.download_loras()
def parse_custom_lora_url(self, url: str):
if "pbxt.replicate" in url:
parts_after_pbxt = url.split("/pbxt.replicate.delivery/")[1]
else:
parts_after_pbxt = url.split("/pbxt/")[1]
return parts_after_pbxt.split("/trained_model.tar")[0]
def add_to_lora_map(self, lora_url: str):
uuid = self.parse_custom_lora_url(lora_url)
self.comfyUI.weights_downloader.download_lora_from_replicate_url(uuid, lora_url)
def download_loras(self):
for weight in LORA_WEIGHTS_MAPPING.values():
self.comfyUI.weights_downloader.download_weights(weight)
def cleanup(self):
self.comfyUI.clear_queue()
for directory in [OUTPUT_DIR, INPUT_DIR, COMFYUI_TEMP_OUTPUT_DIR]:
if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory)
def handle_input_file(self, input_file: Path):
file_extension = os.path.splitext(input_file)[1].lower()
if file_extension in [".jpg", ".jpeg"]:
filename = "input.png"
image = Image.open(input_file)
try:
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
exif = dict(image._getexif().items())
if exif[orientation] == 3:
image = image.rotate(180, expand=True)
elif exif[orientation] == 6:
image = image.rotate(270, expand=True)
elif exif[orientation] == 8:
image = image.rotate(90, expand=True)
except (KeyError, AttributeError):
# EXIF data does not have orientation
# Do not rotate
pass
image.save(os.path.join(INPUT_DIR, filename))
elif file_extension in [".png", ".webp"]:
filename = f"input{file_extension}"
shutil.copy(input_file, os.path.join(INPUT_DIR, filename))
else:
raise ValueError(f"Unsupported file type: {file_extension}")
return filename
def log_and_collect_files(self, directory, prefix=""):
files = []
for f in os.listdir(directory):
if f == "__MACOSX":
continue
path = os.path.join(directory, f)
if os.path.isfile(path):
print(f"{prefix}{f}")
files.append(Path(path))
elif os.path.isdir(path):
print(f"{prefix}{f}/")
files.extend(self.log_and_collect_files(path, prefix=f"{prefix}{f}/"))
return files
def update_workflow(self, workflow, **kwargs):
style = kwargs["style"]
prompt = kwargs["prompt"]
negative_prompt = kwargs["negative_prompt"]
custom_style = kwargs["lora_url"]
if custom_style:
uuid = self.parse_custom_lora_url(custom_style)
lora_name = f"{uuid}/{uuid}.safetensors"
else:
lora_name = LORA_WEIGHTS_MAPPING[style]
prompt = self.style_to_prompt(style, prompt)
negative_prompt = self.style_to_negative_prompt(style, negative_prompt)
load_image = workflow["22"]["inputs"]
load_image["image"] = kwargs["filename"]
loader = workflow["2"]["inputs"]
loader["positive"] = prompt
loader["negative"] = negative_prompt
controlnet = workflow["28"]["inputs"]
controlnet["strength"] = kwargs["control_depth_strength"]
lora_loader = workflow["3"]["inputs"]
lora_loader["lora_name_1"] = lora_name
lora_loader["lora_wt_1"] = kwargs["lora_scale"]
instant_id = workflow["41"]["inputs"]
instant_id["weight"] = kwargs["instant_id_strength"]
sampler = workflow["4"]["inputs"]
sampler["denoise"] = kwargs["denoising_strength"]
sampler["seed"] = kwargs["seed"]
sampler["cfg"] = kwargs["prompt_strength"]
def style_to_prompt(self, style, prompt):
style_prompts = {
"3D": f"3D Render Style, 3DRenderAF, {prompt}",
"Emoji": f"memoji, emoji, {prompt}, 3d render, sharp",
"Video game": f"Playstation 1 Graphics, PS1 Game, {prompt}, Video game screenshot",
"Pixels": f"Pixel Art, PixArFK, {prompt}",
"Clay": f"Clay Animation, Clay, {prompt}",
"Toy": f"FnkRedmAF, {prompt}, toy, miniature",
}
return style_prompts[style]
def style_to_negative_prompt(self, style, negative_prompt=""):
if negative_prompt:
negative_prompt = f"{negative_prompt}, "
start_base_negative = "nsfw, nude, oversaturated, "
end_base_negative = "ugly, broken, watermark"
specifics = {
"3D": "photo, photography, ",
"Emoji": "photo, photography, blurry, soft, ",
"Video game": "text, photo, ",
"Pixels": "photo, photography, ",
"Clay": "",
"Toy": "",
}
return f"{specifics[style]}{start_base_negative}{negative_prompt}{end_base_negative}"
def predict(
self,
image: Path = Input(
description="An image of a person to be converted",
default=None,
),
style: str = Input(
default="3D",
choices=LORA_TYPES,
description="Style to convert to",
),
prompt: str = Input(default="a person"),
negative_prompt: str = Input(
default="",
description="Things you do not want in the image",
),
denoising_strength: float = Input(
default=0.65,
ge=0,
le=1,
description="How much of the original image to keep. 1 is the complete destruction of the original image, 0 is the original image",
),
prompt_strength: float = Input(
default=4.5,
ge=0,
le=20,
description="Strength of the prompt. This is the CFG scale, higher numbers lead to stronger prompt, lower numbers will keep more of a likeness to the original.",
),
control_depth_strength: float = Input(
default=0.8,
ge=0,
le=1,
description="Strength of depth controlnet. The bigger this is, the more controlnet affects the output.",
),
instant_id_strength: float = Input(
default=1, description="How strong the InstantID will be.", ge=0, le=1
),
seed: int = Input(
default=None, description="Fix the random seed for reproducibility"
),
custom_lora_url: str = Input(
default=None,
description="URL to a Replicate custom LoRA. Must be in the format https://replicate.delivery/pbxt/[id]/trained_model.tar or https://pbxt.replicate.delivery/[id]/trained_model.tar",
),
lora_scale: float = Input(
default=1, description="How strong the LoRA will be", ge=0, le=1
),
) -> List[Path]:
"""Run a single prediction on the model"""
self.cleanup()
if image is None:
raise ValueError("No image provided")
filename = self.handle_input_file(image)
if custom_lora_url is not None:
if not (
"https://replicate.delivery/pbxt/" in custom_lora_url
or "https://pbxt.replicate.delivery/" in custom_lora_url
) or not custom_lora_url.endswith("/trained_model.tar"):
raise ValueError(
"Custom LoRA URL format is not supported. Must be in the format https://replicate.delivery/pbxt/[id]/trained_model.tar or https://pbxt.replicate.delivery/[id]/trained_model.tar"
)
self.add_to_lora_map(custom_lora_url)
if seed is None:
seed = random.randint(0, 2**32 - 1)
print(f"Random seed set to: {seed}")
workflow = json.loads(workflow_json)
self.update_workflow(
workflow,
filename=filename,
style=style,
denoising_strength=denoising_strength,
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
prompt_strength=prompt_strength,
instant_id_strength=instant_id_strength,
lora_url=custom_lora_url,
lora_scale=lora_scale,
control_depth_strength=control_depth_strength,
)
wf = self.comfyUI.load_workflow(workflow, check_weights=False)
self.comfyUI.connect()
self.comfyUI.run_workflow(wf)
files = []
output_directories = [OUTPUT_DIR]
for directory in output_directories:
print(f"Contents of {directory}:")
files.extend(self.log_and_collect_files(directory))
return files