1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 23:00:12 -07:00

update cog predict

This commit is contained in:
Xintao 2022-09-12 23:24:08 +08:00
parent d226e86f6c
commit 3fd33abc47

View File

@ -42,6 +42,13 @@ class Predictor(BasePredictor):
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'): if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
os.system( os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights') 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/RestoreFormer.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights'
)
if not os.path.exists('gfpgan/weights/CodeFormer.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P ./gfpgan/weights')
# background enhancer with RealESRGAN # background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
@ -64,11 +71,18 @@ class Predictor(BasePredictor):
img: Path = Input(description='Input'), img: Path = Input(description='Input'),
version: str = Input( version: str = Input(
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.', description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
choices=['v1.2', 'v1.3', 'v1.4'], choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'],
default='v1.4'), default='v1.4'),
scale: float = Input(description='Rescaling factor', default=2) scale: float = Input(description='Rescaling factor', default=2),
weight: float = Input(
description='Weight, only for CodeFormer. 0 for better quality, 1 for better identity',
default=0.5,
ge=0,
le=1.0)
) -> Path: ) -> Path:
print(img, version, scale) if not isinstance(weight, (int, float)):
weight = 0.5
print(img, version, scale, weight)
try: try:
extension = os.path.splitext(os.path.basename(str(img)))[1] extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED) img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
@ -109,14 +123,26 @@ class Predictor(BasePredictor):
channel_multiplier=2, channel_multiplier=2,
bg_upsampler=self.upsampler) bg_upsampler=self.upsampler)
self.current_version = 'v1.4' self.current_version = 'v1.4'
elif version == 'RestoreFormer':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/RestoreFormer.pth',
upscale=2,
arch='RestoreFormer',
channel_multiplier=2,
bg_upsampler=self.upsampler)
elif version == 'CodeFormer':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/CodeFormer.pth',
upscale=2,
arch='CodeFormer',
channel_multiplier=2,
bg_upsampler=self.upsampler)
try: try:
_, _, output = self.face_enhancer.enhance( _, _, output = self.face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True) img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
except RuntimeError as error: except RuntimeError as error:
print('Error', error) print('Error', error)
else:
extension = 'png'
try: try:
if scale != 2: if scale != 2: