From 3fd33abc47ba796c4ba9e37dbffa1d848fc10c03 Mon Sep 17 00:00:00 2001 From: Xintao Date: Mon, 12 Sep 2022 23:24:08 +0800 Subject: [PATCH] update cog predict --- cog_predict.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/cog_predict.py b/cog_predict.py index addfd8da..ff37ae3d 100644 --- a/cog_predict.py +++ b/cog_predict.py @@ -42,6 +42,13 @@ def setup(self): if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'): os.system( 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights') + if not os.path.exists('gfpgan/weights/RestoreFormer.pth'): + os.system( + 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights' + ) + if not os.path.exists('gfpgan/weights/CodeFormer.pth'): + os.system( + 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P ./gfpgan/weights') # background enhancer with RealESRGAN model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') @@ -64,11 +71,18 @@ def predict( img: Path = Input(description='Input'), version: str = Input( description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.', - choices=['v1.2', 'v1.3', 'v1.4'], + choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], default='v1.4'), - scale: float = Input(description='Rescaling factor', default=2) + scale: float = Input(description='Rescaling factor', default=2), + weight: float = Input( + description='Weight, only for CodeFormer. 0 for better quality, 1 for better identity', + default=0.5, + ge=0, + le=1.0) ) -> Path: - print(img, version, scale) + if not isinstance(weight, (int, float)): + weight = 0.5 + print(img, version, scale, weight) try: extension = os.path.splitext(os.path.basename(str(img)))[1] img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED) @@ -109,14 +123,26 @@ def predict( channel_multiplier=2, bg_upsampler=self.upsampler) self.current_version = 'v1.4' + elif version == 'RestoreFormer': + self.face_enhancer = GFPGANer( + model_path='gfpgan/weights/RestoreFormer.pth', + upscale=2, + arch='RestoreFormer', + channel_multiplier=2, + bg_upsampler=self.upsampler) + elif version == 'CodeFormer': + self.face_enhancer = GFPGANer( + model_path='gfpgan/weights/CodeFormer.pth', + upscale=2, + arch='CodeFormer', + channel_multiplier=2, + bg_upsampler=self.upsampler) try: _, _, output = self.face_enhancer.enhance( - img, has_aligned=False, only_center_face=False, paste_back=True) + img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) except RuntimeError as error: print('Error', error) - else: - extension = 'png' try: if scale != 2: