1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 14:50:11 -07:00

update replicate (#248)

* update util

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* merge replicate update
This commit is contained in:
Xintao 2022-09-04 20:12:31 +08:00 committed by GitHub
parent 3e27784b1b
commit 7272e45887
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 21 deletions

View File

@ -80,7 +80,8 @@ class GFPGANer():
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=self.device)
device=self.device,
model_rootpath='gfpgan/weights')
if model_path.startswith('https://'):
model_path = load_file_from_url(

View File

@ -1,5 +1,9 @@
# flake8: noqa
# This file is used for deploying replicate models
# running: cog predict -i img=@inputs/whole_imgs/10045.png -i version='v1.4' -i scale=2
# push: cog push r8.im/tencentarc/gfpgan
# push (backup): cog push r8.im/xinntao/gfpgan
import os
os.system('python setup.py develop')
@ -10,6 +14,7 @@ import shutil
import tempfile
import torch
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from pathlib import Path
from gfpgan import GFPGANer
@ -24,33 +29,46 @@ class Predictor(BasePredictor):
def setup(self):
# download weights
if not os.path.exists('realesr-general-x4v3.pth'):
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .')
if not os.path.exists('GFPGANv1.2.pth'):
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .')
if not os.path.exists('GFPGANv1.3.pth'):
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .')
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./gfpgan/weights'
)
if not os.path.exists('gfpgan/weights/GFPGANv1.2.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.3.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
# background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
model_path = 'gfpgan/weights/realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
# Use GFPGAN for face enhancement
self.face_enhancer_v3 = GFPGANer(
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
self.face_enhancer_v2 = GFPGANer(
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
os.makedirs('output', exist_ok=True)
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
def predict(
self,
img: Path = Input(description='Input'),
version: str = Input(description='GFPGAN version', choices=['v1.2', 'v1.3'], default='v1.3'),
version: str = Input(
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
choices=['v1.2', 'v1.3', 'v1.4'],
default='v1.4'),
scale: float = Input(description='Rescaling factor', default=2)
) -> Path:
print(img, version, scale)
try:
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
@ -62,12 +80,35 @@ class Predictor(BasePredictor):
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if version == 'v1.2':
face_enhancer = self.face_enhancer_v2
else:
face_enhancer = self.face_enhancer_v3
if self.current_version != version:
if version == 'v1.2':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.2.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.2'
elif version == 'v1.3':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.3.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.3'
elif version == 'v1.4':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
try:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
_, _, output = self.face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as error:
print('Error', error)
else:
@ -86,7 +127,7 @@ class Predictor(BasePredictor):
extension = 'jpg'
save_path = f'output/out.{extension}'
cv2.imwrite(save_path, output)
out_path = os.path.join(tempfile.mkdtemp(), 'output.png')
out_path = Path(tempfile.mkdtemp()) / 'output.png'
cv2.imwrite(str(out_path), output)
except Exception as error:
print('global exception', error)