mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-20 09:10:20 -07:00
update cog predict
This commit is contained in:
parent
af7569775d
commit
8d2447a2d9
@ -27,6 +27,7 @@ except Exception:
|
|||||||
class Predictor(BasePredictor):
|
class Predictor(BasePredictor):
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
os.makedirs('output', exist_ok=True)
|
||||||
# download weights
|
# download weights
|
||||||
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
|
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
|
||||||
os.system(
|
os.system(
|
||||||
@ -69,9 +70,13 @@ class Predictor(BasePredictor):
|
|||||||
) -> Path:
|
) -> Path:
|
||||||
print(img, version, scale)
|
print(img, version, scale)
|
||||||
try:
|
try:
|
||||||
|
extension = os.path.splitext(os.path.basename(str(img)))[1]
|
||||||
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
|
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
|
||||||
if len(img.shape) == 3 and img.shape[2] == 4:
|
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||||
img_mode = 'RGBA'
|
img_mode = 'RGBA'
|
||||||
|
elif len(img.shape) == 2:
|
||||||
|
img_mode = None
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
else:
|
else:
|
||||||
img_mode = None
|
img_mode = None
|
||||||
|
|
||||||
@ -120,16 +125,15 @@ class Predictor(BasePredictor):
|
|||||||
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print('wrong scale input.', error)
|
print('wrong scale input.', error)
|
||||||
|
|
||||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
extension = 'png'
|
extension = 'png'
|
||||||
else:
|
# save_path = f'output/out.{extension}'
|
||||||
extension = 'jpg'
|
# cv2.imwrite(save_path, output)
|
||||||
save_path = f'output/out.{extension}'
|
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
|
||||||
cv2.imwrite(save_path, output)
|
|
||||||
out_path = Path(tempfile.mkdtemp()) / 'output.png'
|
|
||||||
cv2.imwrite(str(out_path), output)
|
cv2.imwrite(str(out_path), output)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print('global exception', error)
|
print('global exception: ', error)
|
||||||
finally:
|
finally:
|
||||||
clean_folder('output')
|
clean_folder('output')
|
||||||
return out_path
|
return out_path
|
||||||
|
Loading…
x
Reference in New Issue
Block a user