From 8d2447a2d918f8eba5a4a01463fd48e45126a379 Mon Sep 17 00:00:00 2001 From: Xintao Date: Sun, 4 Sep 2022 23:27:02 +0800 Subject: [PATCH] update cog predict --- cog_predict.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cog_predict.py b/cog_predict.py index 94ac227..addfd8d 100644 --- a/cog_predict.py +++ b/cog_predict.py @@ -27,6 +27,7 @@ except Exception: class Predictor(BasePredictor): def setup(self): + os.makedirs('output', exist_ok=True) # download weights if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'): os.system( @@ -69,9 +70,13 @@ class Predictor(BasePredictor): ) -> Path: print(img, version, scale) try: + extension = os.path.splitext(os.path.basename(str(img)))[1] img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED) if len(img.shape) == 3 and img.shape[2] == 4: img_mode = 'RGBA' + elif len(img.shape) == 2: + img_mode = None + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) else: img_mode = None @@ -120,16 +125,15 @@ class Predictor(BasePredictor): output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) except Exception as error: print('wrong scale input.', error) + if img_mode == 'RGBA': # RGBA images should be saved in png format extension = 'png' - else: - extension = 'jpg' - save_path = f'output/out.{extension}' - cv2.imwrite(save_path, output) - out_path = Path(tempfile.mkdtemp()) / 'output.png' + # save_path = f'output/out.{extension}' + # cv2.imwrite(save_path, output) + out_path = Path(tempfile.mkdtemp()) / f'out.{extension}' cv2.imwrite(str(out_path), output) except Exception as error: - print('global exception', error) + print('global exception: ', error) finally: clean_folder('output') return out_path