1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 23:00:12 -07:00
GFPGAN/predict.py
Chenxi a9a2e3ae15
Add Docker environment & web demo (#67)
* enable cog

* Update README.md

* Update README.md

* refactor

* fix temp input dir bug

Co-authored-by: CJWBW <70536672+CJWBW@users.noreply.github.com>
Co-authored-by: Chenxi <chenxi@Chenxis-MacBook-Pro-2.local>
Co-authored-by: Xintao <wxt1994@126.com>
2022-08-29 17:28:16 +08:00

145 lines
5.3 KiB
Python

import subprocess
subprocess.call(["sh", "./run_setup.sh"])
import warnings
import tempfile
import os
from pathlib import Path
import argparse
import glob
import shutil
from basicsr.utils import imwrite
import torch
import cv2
import cog
from realesrgan import RealESRGANer
from gfpgan import GFPGANer
class Predictor(cog.Predictor):
def setup(self):
parser = argparse.ArgumentParser()
parser.add_argument("--upscale", type=int, default=2)
parser.add_argument("--arch", type=str, default="clean")
parser.add_argument("--channel", type=int, default=2)
parser.add_argument(
"--model_path",
type=str,
default="experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth",
)
parser.add_argument("--bg_upsampler", type=str, default="realesrgan")
parser.add_argument("--bg_tile", type=int, default=400)
parser.add_argument("--test_path", type=str, default="inputs/whole_imgs")
parser.add_argument(
"--suffix", type=str, default=None, help="Suffix of the restored faces"
)
parser.add_argument("--only_center_face", action="store_true")
parser.add_argument("--aligned", action="store_true")
parser.add_argument("--paste_back", action="store_false")
parser.add_argument("--save_root", type=str, default="results")
self.args = parser.parse_args(
["--upscale", "2", "--test_path", "cog_temp", "--save_root", "results"]
)
os.makedirs(self.args.test_path, exist_ok=True)
# background upsampler
if self.args.bg_upsampler == "realesrgan":
if not torch.cuda.is_available(): # CPU
warnings.warn(
"The unoptimized RealESRGAN is very slow on CPU. We do not use it. "
"If you really want to use it, please modify the corresponding codes."
)
bg_upsampler = None
else:
bg_upsampler = RealESRGANer(
scale=2,
model_path="https://github.com/xinntao/Real-ESRGAN/releases"
"/download/v0.2.1/RealESRGAN_x2plus.pth",
tile=self.args.bg_tile,
tile_pad=10,
pre_pad=0,
half=True,
) # need to set False in CPU mode
else:
bg_upsampler = None
# set up GFPGAN restorer
self.restorer = GFPGANer(
model_path=self.args.model_path,
upscale=self.args.upscale,
arch=self.args.arch,
channel_multiplier=self.args.channel,
bg_upsampler=bg_upsampler,
)
@cog.input("image", type=Path, help="input image")
def predict(self, image):
try:
input_dir = self.args.test_path
input_path = os.path.join(input_dir, os.path.basename(image))
shutil.copy(str(image), input_path)
os.makedirs(self.args.save_root, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(input_dir, "*")))
out_path = Path(tempfile.mkdtemp()) / "output.png"
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
print(f"Processing {img_name} ...")
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
cropped_faces, restored_faces, restored_img = self.restorer.enhance(
input_img,
has_aligned=self.args.aligned,
only_center_face=self.args.only_center_face,
paste_back=self.args.paste_back,
)
imwrite(restored_img, str(out_path))
clean_folder(self.args.test_path)
# save faces
for idx, (cropped_face, restored_face) in enumerate(
zip(cropped_faces, restored_faces)
):
# save cropped face
save_crop_path = os.path.join(
self.args.save_root, "cropped_faces", f"{basename}_{idx:02d}.png"
)
imwrite(cropped_face, save_crop_path)
# save restored face
if self.args.suffix is not None:
save_face_name = f"{basename}_{idx:02d}_{self.args.suffix}.png"
else:
save_face_name = f"{basename}_{idx:02d}.png"
save_restore_path = os.path.join(
self.args.save_root, "restored_faces", save_face_name
)
imwrite(restored_face, save_restore_path)
imwrite(restored_img, str(out_path))
finally:
clean_folder(self.args.test_path)
return out_path
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e))