Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upscaler model loading cleanup #10823

Merged
merged 5 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions extensions-builtin/LDSR/scripts/ldsr_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from basicsr.utils.download_util import load_file_from_url

from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
Expand Down Expand Up @@ -43,20 +42,17 @@ def load_model(self, path: str):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")

yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")

try:
return LDSR(model, yaml)
except Exception:
errors.report("Error importing LDSR", exc_info=True)
return None
return LDSR(model, yaml)

def do_upscale(self, img, path):
ldsr = self.load_model(path)
if ldsr is None:
print("NO LDSR!")
try:
ldsr = self.load_model(path)
except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
Expand Down
26 changes: 11 additions & 15 deletions extensions-builtin/ScuNET/scripts/scunet_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os.path
import sys

import PIL.Image
import numpy as np
import torch
from tqdm import tqdm

from basicsr.utils.download_util import load_file_from_url

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
from scunet_model_arch import SCUNet

from modules.modelloader import load_file_from_url
from modules.shared import opts


Expand All @@ -28,7 +26,7 @@ def __init__(self, dirname):
scalers = []
add_model2 = True
for file in model_paths:
if "http" in file:
if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
Expand Down Expand Up @@ -89,9 +87,10 @@ def do_upscale(self, img: PIL.Image.Image, selected_file):

torch.cuda.empty_cache()

model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img

device = devices.get_device_for('scunet')
Expand Down Expand Up @@ -119,15 +118,12 @@ def do_upscale(self, img: PIL.Image.Image, selected_file):

def load_model(self, path: str):
device = devices.get_device_for('scunet')
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None

model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
Expand Down
57 changes: 29 additions & 28 deletions extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
import os
import sys

import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData

SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"

device_swinir = devices.get_device_for('swinir')


class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
if "http" in model:
if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
Expand All @@ -37,8 +35,10 @@ def __init__(self, dirname):
self.scalers = scalers

def do_upscale(self, img, model_file):
model = self.load_model(model_file)
if model is None:
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
Expand All @@ -49,30 +49,31 @@ def do_upscale(self, img, model_file):
return img

def load_model(self, path, scale=4):
if "http" in path:
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
if path.startswith("http"):
filename = modelloader.load_file_from_url(
url=path,
model_dir=self.model_download_path,
file_name=f"{self.model_name.replace(' ', '_')}.pth",
)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = net(
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
Expand Down
23 changes: 10 additions & 13 deletions modules/esrgan_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import sys

import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url

import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts

from modules.upscaler import Upscaler, UpscalerData


def mod2normal(state_dict):
Expand Down Expand Up @@ -134,7 +132,7 @@ def __init__(self, dirname):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
Expand All @@ -143,26 +141,25 @@ def __init__(self, dirname):
self.scalers.append(scaler_data)

def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
try:
model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img

def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
progress=True,
)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None

state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

Expand Down
2 changes: 1 addition & 1 deletion modules/gfpgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def gfpgann():
return None

models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]:
if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
Expand Down
31 changes: 27 additions & 4 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import shutil
import importlib
Expand All @@ -8,6 +10,29 @@
from modules.paths import script_path, models_path


def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.

Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress)
return cached_file


def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
Expand Down Expand Up @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None

if model_url is not None and len(output) == 0:
if download_name is not None:
from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, places[0], True, download_name)
output.append(dl)
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)

Expand All @@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None


def friendly_name(file: str):
if "http" in file:
if file.startswith("http"):
file = urlparse(file).path

file = os.path.basename(file)
Expand Down
33 changes: 15 additions & 18 deletions modules/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer

from modules.upscaler import Upscaler, UpscalerData
Expand Down Expand Up @@ -43,9 +42,10 @@ def do_upscale(self, img, path):
if not self.enable:
return img

info = self.load_model(path)
if not os.path.exists(info.local_data_path):
print(f"Unable to load RealESRGAN model: {info.name}")
try:
info = self.load_model(path)
except Exception:
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img

upsampler = RealESRGANer(
Expand All @@ -63,20 +63,17 @@ def do_upscale(self, img, path):
return image

def load_model(self, path):
try:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)

if info is None:
print(f"Unable to find model info: {path}")
return None

if info.local_data_path.startswith("http"):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)

return info
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
for scaler in self.scalers:
if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
)
if not os.path.exists(scaler.local_data_path):
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
raise ValueError(f"Unable to find model info: {path}")

def load_models(self, _):
return get_realesrgan_models(self)
Expand Down