From 21ae01812d212bc7bd1974d1f43bcddacdc48c80 Mon Sep 17 00:00:00 2001 From: Mert Cobanov Date: Tue, 24 May 2022 15:24:49 +0300 Subject: [PATCH] Added GPU selection feature to python inference (#321) * Added GPU selection feature to python inference * pylint pep8 fixes * pep8 fixes --- inference_realesrgan.py | 6 +++++- realesrgan/utils.py | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/inference_realesrgan.py b/inference_realesrgan.py index be3977156..cc5d618a4 100644 --- a/inference_realesrgan.py +++ b/inference_realesrgan.py @@ -39,6 +39,9 @@ def main(): type=str, default='auto', help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + parser.add_argument( + '-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu') + args = parser.parse_args() # determine models according to model names @@ -71,7 +74,8 @@ def main(): tile=args.tile, tile_pad=args.tile_pad, pre_pad=args.pre_pad, - half=not args.fp32) + half=not args.fp32, + gpu_id=args.gpu_id) if args.face_enhance: # Use GFPGAN for face enhancement from gfpgan import GFPGANer diff --git a/realesrgan/utils.py b/realesrgan/utils.py index 922779c86..24d5d9d98 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -26,7 +26,16 @@ class RealESRGANer(): half (float): Whether to use half precision during inference. Default: False. """ - def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None): + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): self.scale = scale self.tile_size = tile self.tile_pad = tile_pad @@ -35,7 +44,11 @@ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=1 self.half = half # initialize model - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device # if the model_path starts with https, it will first download models to the folder: realesrgan/weights if model_path.startswith('https://'): model_path = load_file_from_url(