-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathpredict.py
91 lines (78 loc) · 3.28 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Prediction interface for Cog ⚙️
# https://cog.run/python
import shutil, os
from omegaconf import OmegaConf
from cog import BasePredictor, Input, Path
import numpy as np
from utils import util_common
from sampler_invsr import InvSamplerSR
from basicsr.utils.download_util import load_file_from_url
class Predictor(BasePredictor):
def setup(self) -> None:
self.configs = OmegaConf.load("./configs/sample-sd-turbo.yaml")
def set_configs(self, num_steps=1, chopping_size=128, seed=12345):
if num_steps == 1:
self.configs.timesteps = [200,]
elif num_steps == 2:
self.configs.timesteps = [200, 100]
elif num_steps == 3:
self.configs.timesteps = [200, 100, 50]
elif num_steps == 4:
self.configs.timesteps = [200, 150, 100, 50]
elif num_steps == 5:
self.configs.timesteps = [250, 200, 150, 100, 50]
else:
assert num_steps <= 250
self.configs.timesteps = np.linspace(
start=250, stop=0, num=num_steps, endpoint=False, dtype=np.int64()
).tolist()
print(f'Setting timesteps for inference: {self.configs.timesteps}')
# path to save Stable Diffusion
sd_path = "./weights"
util_common.mkdir(sd_path, delete=False, parents=True)
self.configs.sd_pipe.params.cache_dir = sd_path
# path to save noise predictor
started_ckpt_name = "noise_predictor_sd_turbo_v5.pth"
started_ckpt_dir = "./weights"
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
started_ckpt_path = os.path.join(started_ckpt_dir, started_ckpt_name)
if not os.path.exists(started_ckpt_path):
load_file_from_url(
url="https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
model_dir=started_ckpt_dir,
progress=True,
file_name=started_ckpt_name,
)
self.configs.model_start.ckpt_path = started_ckpt_path
self.configs.bs = 1
self.configs.seed = 12345
self.configs.basesr.chopping.pch_size = chopping_size
if chopping_size == 128:
self.configs.basesr.chopping.extra_bs = 4
elif chopping_size == 256:
self.configs.basesr.chopping.extra_bs = 2
else:
self.configs.basesr.chopping.extra_bs = 1
def predict(
self,
in_path: Path = Input(description="Input low-quality image"),
num_steps: int = Input(
choices=[1,2,3,4,5], description="Number of sampling steps.", default=1
),
chopping_size: int = Input(
choices=[128, 256, 512], description="Chopping resolution", default=128
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed.", default=12345
),
) -> Path:
# setting configurations
self.set_configs(num_steps, chopping_size, seed)
sampler = InvSamplerSR(self.configs)
out_dir = 'invsr_output'
if os.path.exists(out_dir):
shutil.rmtree(out_dir)
sampler.inference(in_path, out_path=out_dir, bs=1)
out = "/tmp/out.png"
shutil.copy(os.path.join(out_dir, os.listdir(out_dir)[0]), out)
return Path(out)