-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
116 lines (101 loc) · 3.39 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
from diffusers import (
StableDiffusionXLInpaintPipeline,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import load_image
import torch
import os
import shutil
SDXL_MODEL_CACHE = "sdxl-cache"
class KarrasDPM:
def from_config(config):
return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)
SCHEDULERS = {
"DDIM": DDIMScheduler,
"DPMSolverMultistep": DPMSolverMultistepScheduler,
"HeunDiscrete": HeunDiscreteScheduler,
"KarrasDPM": KarrasDPM,
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
"K_EULER": EulerDiscreteScheduler,
"PNDM": PNDMScheduler,
}
def get_image(path):
shutil.copyfile(path, "/tmp/image.png")
return load_image("/tmp/image.png").convert("RGB")
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.__pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
SDXL_MODEL_CACHE,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
@torch.inference_mode()
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="An astronaut riding a rainbow unicorn",
),
negative_prompt: str = Input(
description="Input Negative Prompt",
default="",
),
image: Path = Input(
description="Input image for img2img or inpaint mode",
default=None,
),
mask: Path = Input(
description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.",
default=None,
),
num_inference_steps: int = Input(
description="Number of denoising steps",
ge=1,
le=500,
default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance",
ge=1,
le=50,
default=7.5
),
prompt_strength: float = Input(
description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image",
ge=0.0,
le=1.0,
default=0.8,
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed",
default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
generator = torch.Generator("cuda").manual_seed(0)
images = self.__pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=get_image(image),
mask_image=get_image(mask),
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
strength=prompt_strength,
generator=generator
).images
output_path = f"/tmp/out.png"
images[0].save(output_path)
print(output_path)
return Path(output_path)