diff --git a/gradio_apisr.py b/gradio_apisr.py index 7d5b2cf..1334e2d 100644 --- a/gradio_apisr.py +++ b/gradio_apisr.py @@ -1,5 +1,9 @@ +''' + Gradio demo (almost the same code as the one used in Huggingface space) +''' import os, sys import cv2 +import time import gradio as gr import torch import numpy as np @@ -20,6 +24,10 @@ def auto_download_if_needed(weight_path): if not os.path.exists("pretrained"): os.makedirs("pretrained") + if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth": + os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth") + os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained") + if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") @@ -28,6 +36,7 @@ def auto_download_if_needed(weight_path): os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") + def inference(img_path, model_name): @@ -41,22 +50,29 @@ def inference(img_path, model_name): auto_download_if_needed(weight_path) generator = load_grl(weight_path, scale=4) # Directly use default way now + elif model_name == "4xRRDB": + weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth" + auto_download_if_needed(weight_path) + generator = load_rrdb(weight_path, scale=4) # Directly use default way now + elif model_name == "2xRRDB": weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" auto_download_if_needed(weight_path) generator = load_rrdb(weight_path, scale=2) # Directly use default way now else: - raise gr.Error(error) + raise gr.Error("We don't support such Model") generator = generator.to(dtype=weight_dtype) # In default, we will automatically use crop to match 4x size - super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, crop_for_4x=True) - save_image(super_resolved_img, "SR_result.png") - outputs = cv2.imread("SR_result.png") + super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True) + store_name = str(time.time()) + ".png" + save_image(super_resolved_img, store_name) + outputs = cv2.imread(store_name) outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR) + os.remove(store_name) return outputs @@ -70,14 +86,18 @@ def inference(img_path, model_name): MARKDOWN = \ """ - ## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) - + ##

APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)

+ [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) - - If APISR is helpful for you, please help star the GitHub Repo. Thanks! + APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios. + + ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720 + ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight. + + If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! """ - block = gr.Blocks().queue() + block = gr.Blocks().queue(max_size=10) with block: with gr.Row(): gr.Markdown(MARKDOWN) @@ -87,6 +107,7 @@ def inference(img_path, model_name): model_name = gr.Dropdown( [ "2xRRDB", + "4xRRDB", "4xGRL" ], type="value", @@ -106,7 +127,7 @@ def inference(img_path, model_name): ["__assets__/lr_inputs/41.png"], ["__assets__/lr_inputs/f91.jpg"], ["__assets__/lr_inputs/image-00440.png"], - ["__assets__/lr_inputs/image-00164.png"], + ["__assets__/lr_inputs/image-00164.jpg"], ["__assets__/lr_inputs/img_eva.jpeg"], ["__assets__/lr_inputs/naruto.jpg"], ], @@ -115,4 +136,4 @@ def inference(img_path, model_name): run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image]) - block.launch() \ No newline at end of file + block.launch()