Skip to content

Commit

Permalink
Upscaler.load_model: don't return None, just use exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jun 13, 2023
1 parent e3a973a commit bf67a5d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 64 deletions.
13 changes: 5 additions & 8 deletions extensions-builtin/LDSR/scripts/ldsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,13 @@ def load_model(self, path: str):

yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")

try:
return LDSR(model, yaml)
except Exception:
errors.report("Error importing LDSR", exc_info=True)
return None
return LDSR(model, yaml)

def do_upscale(self, img, path):
ldsr = self.load_model(path)
if ldsr is None:
print("NO LDSR!")
try:
ldsr = self.load_model(path)
except Exception:
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
Expand Down
16 changes: 6 additions & 10 deletions extensions-builtin/ScuNET/scripts/scunet_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os.path
import sys

import PIL.Image
Expand All @@ -8,7 +7,7 @@

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
from scunet_model_arch import SCUNet

from modules.modelloader import load_file_from_url
from modules.shared import opts
Expand Down Expand Up @@ -88,9 +87,10 @@ def do_upscale(self, img: PIL.Image.Image, selected_file):

torch.cuda.empty_cache()

model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img

device = devices.get_device_for('scunet')
Expand Down Expand Up @@ -123,11 +123,7 @@ def load_model(self, path: str):
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None

model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
Expand Down
40 changes: 20 additions & 20 deletions extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import sys

import numpy as np
import torch
Expand All @@ -7,8 +7,8 @@

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData


Expand Down Expand Up @@ -36,8 +36,10 @@ def __init__(self, dirname):
self.scalers = scalers

def do_upscale(self, img, model_file):
model = self.load_model(model_file)
if model is None:
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
Expand All @@ -56,25 +58,23 @@ def load_model(self, path, scale=4):
)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
model = Swin2SR(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = net(
model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
Expand Down
14 changes: 6 additions & 8 deletions modules/esrgan_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import sys

import numpy as np
import torch
from PIL import Image

import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts

from modules.upscaler import Upscaler, UpscalerData


def mod2normal(state_dict):
Expand Down Expand Up @@ -142,8 +141,10 @@ def __init__(self, dirname):
self.scalers.append(scaler_data)

def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
try:
model = self.load_model(selected_model)
except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
Expand All @@ -159,9 +160,6 @@ def load_model(self, path: str):
)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None

state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

Expand Down
33 changes: 15 additions & 18 deletions modules/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from modules import modelloader, errors



class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
Expand Down Expand Up @@ -43,9 +42,10 @@ def do_upscale(self, img, path):
if not self.enable:
return img

info = self.load_model(path)
if not os.path.exists(info.local_data_path):
print(f"Unable to load RealESRGAN model: {info.name}")
try:
info = self.load_model(path)
except Exception:
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img

upsampler = RealESRGANer(
Expand All @@ -63,20 +63,17 @@ def do_upscale(self, img, path):
return image

def load_model(self, path):
try:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)

if info is None:
print(f"Unable to find model info: {path}")
return None

if info.local_data_path.startswith("http"):
info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path)

return info
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
for scaler in self.scalers:
if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
)
if not os.path.exists(scaler.local_data_path):
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
raise ValueError(f"Unable to find model info: {path}")

def load_models(self, _):
return get_realesrgan_models(self)
Expand Down

0 comments on commit bf67a5d

Please sign in to comment.