1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-20 09:10:20 -07:00

update cog predict

This commit is contained in:
Xintao 2022-09-04 23:27:02 +08:00
parent af7569775d
commit 8d2447a2d9

View File

@ -27,6 +27,7 @@ except Exception:
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
os.makedirs('output', exist_ok=True)
# download weights # download weights
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'): if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
os.system( os.system(
@ -69,9 +70,13 @@ class Predictor(BasePredictor):
) -> Path: ) -> Path:
print(img, version, scale) print(img, version, scale)
try: try:
extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED) img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4: if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA' img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else: else:
img_mode = None img_mode = None
@ -120,16 +125,15 @@ class Predictor(BasePredictor):
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error: except Exception as error:
print('wrong scale input.', error) print('wrong scale input.', error)
if img_mode == 'RGBA': # RGBA images should be saved in png format if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png' extension = 'png'
else: # save_path = f'output/out.{extension}'
extension = 'jpg' # cv2.imwrite(save_path, output)
save_path = f'output/out.{extension}' out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / 'output.png'
cv2.imwrite(str(out_path), output) cv2.imwrite(str(out_path), output)
except Exception as error: except Exception as error:
print('global exception', error) print('global exception: ', error)
finally: finally:
clean_folder('output') clean_folder('output')
return out_path return out_path