Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to spandrel v0.2.1 #2487

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 4 additions & 39 deletions backend/src/nodes/impl/pytorch/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from spandrel import ImageModelDescriptor, SizeRequirements
from spandrel import ImageModelDescriptor

from ..upscale.auto_split import Split, Tiler, auto_split
from .utils import safe_cuda_cache_empty
Expand Down Expand Up @@ -46,26 +46,6 @@ def _rgb_to_bgr(t: torch.Tensor) -> torch.Tensor:
return t


def _pad(t: torch.Tensor, req: SizeRequirements):
_, _, h, w = t.shape

minimum = req.minimum
multiple_of = req.multiple_of

pad_h = (multiple_of - (h % multiple_of)) % multiple_of
if h + pad_h < minimum:
pad_h = minimum - h

pad_w = (multiple_of - (w % multiple_of)) % multiple_of
if w + pad_w < minimum:
pad_w = minimum - w

if pad_w or pad_h:
return True, torch.nn.functional.pad(t, (0, pad_w, 0, pad_h), "reflect")
else:
return False, t


@torch.inference_mode()
def pytorch_auto_split(
img: np.ndarray,
Expand All @@ -74,38 +54,23 @@ def pytorch_auto_split(
use_fp16: bool,
tiler: Tiler,
) -> np.ndarray:
model = model.to(device)
if use_fp16:
model.model.half()
else:
model.model.float()
dtype = torch.float16 if use_fp16 else torch.float32
model = model.to(device, dtype)

def upscale(img: np.ndarray, _: object):
input_tensor = None
try:
# convert to tensor
input_tensor = torch.from_numpy(np.ascontiguousarray(img)).to(device)
input_tensor = input_tensor.half() if use_fp16 else input_tensor.float()
input_tensor = torch.from_numpy(np.ascontiguousarray(img)).to(device, dtype)
input_tensor = _rgb_to_bgr(input_tensor)
input_tensor = _into_batched_form(input_tensor)

# pad to meat size requirements
_, _, org_h, org_w = input_tensor.shape
did_pad, input_tensor = _pad(input_tensor, model.size_requirements)

# inference
output_tensor = model(input_tensor)

if did_pad:
# crop to original (scaled) size
output_tensor = output_tensor[
:, :, : (org_h * model.scale), : (org_w * model.scale)
]

# convert back to numpy
output_tensor = _into_standard_image_form(output_tensor)
output_tensor = _rgb_to_bgr(output_tensor)
output_tensor = output_tensor.clip_(0, 1)
result = output_tensor.detach().cpu().detach().float().numpy()

return result
Expand Down
4 changes: 2 additions & 2 deletions backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def convert_to_onnx_impl(
raise ValueError(
f"Model of arch {model.architecture} does not support half precision."
)
model.model.half()
model.half()
dummy_input = dummy_input.half()
else:
model.model.float()
model.float()
dummy_input = dummy_input.float()

m = model.model
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_pytorch():
Dependency(
display_name="Spandrel",
pypi_name="spandrel",
version="0.1.7",
version="0.2.1",
size_estimate=287 * KB,
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,15 @@

import navi
from api import NodeContext
from nodes.impl.image_utils import as_3d
from nodes.impl.pytorch.utils import np2tensor, safe_cuda_cache_empty, tensor2np
from nodes.properties.inputs import ImageInput
from nodes.properties.inputs.pytorch_inputs import InpaintModelInput
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c

from ...settings import PyTorchSettings, get_settings
from .. import processing_group


def ceil_modulo(x: int, mod: int) -> int:
if x % mod == 0:
return x
return (x // mod + 1) * mod


def pad_img_to_modulo(
img: np.ndarray,
mod: int,
square: bool,
min_size: int,
):
img = as_3d(img)
h, w, _ = get_h_w_c(img)
out_h = ceil_modulo(max(h, min_size), mod)
out_w = ceil_modulo(max(w, min_size), mod)

if square:
max_size = max(out_h, out_w)
out_h = max_size
out_w = max_size

return np.pad(img, ((0, out_h - h), (0, out_w - w), (0, 0)), mode="symmetric")


def inpaint(
img: np.ndarray,
mask: np.ndarray,
Expand All @@ -53,38 +26,22 @@ def inpaint(
with torch.no_grad():
# TODO: use bfloat16 if RTX
use_fp16 = options.use_fp16 and model.supports_half
dtype = torch.float16 if use_fp16 else torch.float32
device = options.device

model = model.to(device)
model.model.half() if use_fp16 else model.model.float()

orig_height, orig_width, _ = get_h_w_c(img)

img = pad_img_to_modulo(
img,
model.size_requirements.multiple_of,
model.size_requirements.square,
model.size_requirements.minimum,
)
mask = pad_img_to_modulo(
mask,
model.size_requirements.multiple_of,
model.size_requirements.square,
model.size_requirements.minimum,
)
model = model.to(device, dtype)

img_tensor = np2tensor(img, change_range=True)
mask_tensor = np2tensor(mask, change_range=True)

d_img = None
d_mask = None
try:
d_img = img_tensor.to(device)
d_img = d_img.half() if use_fp16 else d_img.float()
d_img = img_tensor.to(device, dtype)

d_mask = mask_tensor.to(device)
d_mask = mask_tensor.to(device, dtype)
d_mask = (d_mask > 0.5) * 1
d_mask = d_mask.half() if use_fp16 else d_mask.float()
d_mask = d_mask.to(dtype)

result = model(d_img, d_mask)
result = tensor2np(
Expand All @@ -96,7 +53,7 @@ def inpaint(
del d_img
del d_mask

return result[0:orig_height, 0:orig_width]
return result
except RuntimeError:
# Collect garbage (clear VRAM)
if d_img is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def convert_to_onnx_node(
if fp16:
assert exec_options.use_fp16, "PyTorch fp16 mode must be supported and turned on in settings to convert model as fp16."

model.model.eval()
model = model.to(device)
model.eval().to(device)

use_half = fp16 and model.supports_half

Expand Down