diff --git a/.github/workflows/publish-pip.yml b/.github/workflows/publish-pip.yml new file mode 100644 index 0000000..06047f7 --- /dev/null +++ b/.github/workflows/publish-pip.yml @@ -0,0 +1,30 @@ +name: PyPI Publish + +on: push + +jobs: + build-n-publish: + runs-on: ubuntu-latest + if: startsWith(github.event.ref, 'refs/tags') + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Upgrade pip + run: pip install pip --upgrade + - name: Install PyTorch (cpu) + run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install dependencies + run: pip install -r requirements.txt + - name: Build and install + run: rm -rf .eggs && pip install -e . + - name: Build for distribution + # remove bdist_wheel for pip installation with compiling cuda extensions + run: python setup.py sdist + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@master + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index baa370a..c408d46 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -25,5 +25,5 @@ jobs: - name: Lint run: | flake8 . - isort --check-only --diff data/ archs/ models/ train.py inference_gfpgan_full.py - yapf -r -d data/ archs/ models/ train.py inference_gfpgan_full.py + isort --check-only --diff gfpgan/ scripts/ inference_gfpgan.py setup.py + yapf -r -d gfpgan/ scripts/ inference_gfpgan.py setup.py diff --git a/.gitignore b/.gitignore index cfe4bf1..0200900 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,13 @@ -.vscode +# ignored folders datasets/* experiments/* +results/* tb_logger/* +wandb/* +tmp/* -# ignored files version.py - -# ignored files with suffix -*.html -*.png -*.jpeg -*.jpg -*.gif -*.pth -*.zip - -# template +.vscode # Byte-compiled / optimized / DLL files __pycache__/ @@ -39,6 +31,8 @@ parts/ sdist/ var/ wheels/ +pip-wheel-metadata/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -57,12 +51,14 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover +*.py,cover .hypothesis/ .pytest_cache/ @@ -74,6 +70,7 @@ coverage.xml *.log local_settings.py db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -91,11 +88,26 @@ target/ # Jupyter Notebook .ipynb_checkpoints +# IPython +profile_default/ +ipython_config.py + # pyenv .python-version -# celery beat schedule file +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid # SageMath parsed files *.sage.py @@ -121,3 +133,8 @@ venv.bak/ # mypy .mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..bcaa717 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,8 @@ +include assets/* +include inputs/* +include scripts/*.py +include inference_gfpgan.py +include VERSION +include LICENSE +include requirements.txt +include gfpgan/weights/README.md diff --git a/PaperModel.md b/PaperModel.md index 7131a79..aec81d3 100644 --- a/PaperModel.md +++ b/PaperModel.md @@ -27,6 +27,7 @@ If you want want to use the original model in our paper, please follow the instr pip install facexlib pip install -r requirements.txt + python setup.py develop # remember to set BASICSR_JIT=True before your running commands ``` @@ -45,6 +46,7 @@ If you want want to use the original model in our paper, please follow the instr pip install facexlib pip install -r requirements.txt + python setup.py develop ``` ## :zap: Quick Inference @@ -58,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth - Option 1: Load extensions just-in-time(JIT) ```bash - BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 + BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 # for aligned images - BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned + BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned ``` - Option 2: Have successfully compiled extensions during installation ```bash - python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 + python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 # for aligned images - python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned + python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned ``` diff --git a/README.md b/README.md index dca1557..7d29f06 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ # GFPGAN (CVPR 2021) [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases) +[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/) [![Open issue](https://isitmaintained.com/badge/open/TencentARC/GFPGAN.svg)](https://github.com/TencentARC/GFPGAN/issues) [![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE) [![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml) +[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml) 1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN google colab logo; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model) 1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. @@ -59,6 +61,7 @@ If you want want to use the original model in our paper, please see [PaperModel. pip install facexlib pip install -r requirements.txt + python setup.py develop ``` ## :zap: Quick Inference @@ -72,7 +75,7 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1 **Inference!** ```bash -python inference_gfpgan_full.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results +python inference_gfpgan.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results ``` ## :european_castle: Model Zoo @@ -90,10 +93,9 @@ You could improve it according to your own needs. 1. More high quality faces can improve the restoration quality. 2. You may need to perform some pre-processing, such as beauty makeup. - **Procedures** -(You can try a simple version ( `train_gfpgan_v1_simple.yml`) that does not require face component landmarks.) +(You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.) 1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset) @@ -102,11 +104,11 @@ You could improve it according to your own needs. 1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth) 1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth) -1. Modify the configuration file `train_gfpgan_v1.yml` accordingly. +1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly. 1. Training -> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 train.py -opt train_gfpgan_v1.yml --launcher pytorch +> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch ## :scroll: License and Acknowledgement diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..0c62199 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.2.1 diff --git a/gfpgan/__init__.py b/gfpgan/__init__.py new file mode 100644 index 0000000..4ccac57 --- /dev/null +++ b/gfpgan/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/archs/__init__.py b/gfpgan/archs/__init__.py similarity index 63% rename from archs/__init__.py rename to gfpgan/archs/__init__.py index b7a5b1c..bec5f17 100644 --- a/archs/__init__.py +++ b/gfpgan/archs/__init__.py @@ -1,12 +1,10 @@ import importlib +from basicsr.utils import scandir from os import path as osp -from basicsr.utils import scandir - # automatically scan and import arch modules for registry -# scan all the files under the 'archs' folder and collect files ending with -# '_arch.py' +# scan all the files that end with '_arch.py' under the archs folder arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules -_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] +_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/archs/arcface_arch.py b/gfpgan/archs/arcface_arch.py similarity index 99% rename from archs/arcface_arch.py rename to gfpgan/archs/arcface_arch.py index 9411fcf..2623621 100644 --- a/archs/arcface_arch.py +++ b/gfpgan/archs/arcface_arch.py @@ -1,5 +1,4 @@ import torch.nn as nn - from basicsr.utils.registry import ARCH_REGISTRY diff --git a/archs/gfpganv1_arch.py b/gfpgan/archs/gfpganv1_arch.py similarity index 99% rename from archs/gfpganv1_arch.py rename to gfpgan/archs/gfpganv1_arch.py index 38a450f..98966ab 100644 --- a/archs/gfpganv1_arch.py +++ b/gfpgan/archs/gfpganv1_arch.py @@ -1,13 +1,12 @@ import math import random import torch -from torch import nn -from torch.nn import functional as F - from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, StyleGAN2Generator) from basicsr.ops.fused_act import FusedLeakyReLU from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F class StyleGAN2GeneratorSFT(StyleGAN2Generator): diff --git a/archs/gfpganv1_clean_arch.py b/gfpgan/archs/gfpganv1_clean_arch.py similarity index 100% rename from archs/gfpganv1_clean_arch.py rename to gfpgan/archs/gfpganv1_clean_arch.py diff --git a/archs/stylegan2_clean_arch.py b/gfpgan/archs/stylegan2_clean_arch.py similarity index 99% rename from archs/stylegan2_clean_arch.py rename to gfpgan/archs/stylegan2_clean_arch.py index 73ab854..4afff18 100644 --- a/archs/stylegan2_clean_arch.py +++ b/gfpgan/archs/stylegan2_clean_arch.py @@ -1,11 +1,10 @@ import math import random import torch -from torch import nn -from torch.nn import functional as F - from basicsr.archs.arch_util import default_init_weights from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F class NormStyleCode(nn.Module): diff --git a/data/__init__.py b/gfpgan/data/__init__.py similarity index 65% rename from data/__init__.py rename to gfpgan/data/__init__.py index ccb0981..69fd9f9 100644 --- a/data/__init__.py +++ b/gfpgan/data/__init__.py @@ -1,11 +1,10 @@ import importlib +from basicsr.utils import scandir from os import path as osp -from basicsr.utils import scandir - # automatically scan and import dataset modules for registry -# scan all the files under the data folder with '_dataset' in file names +# scan all the files that end with '_dataset.py' under the data folder data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules -_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] +_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/data/ffhq_degradation_dataset.py b/gfpgan/data/ffhq_degradation_dataset.py similarity index 99% rename from data/ffhq_degradation_dataset.py rename to gfpgan/data/ffhq_degradation_dataset.py index da66ff7..db22665 100644 --- a/data/ffhq_degradation_dataset.py +++ b/gfpgan/data/ffhq_degradation_dataset.py @@ -4,14 +4,13 @@ import numpy as np import os.path as osp import torch import torch.utils.data as data -from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, - normalize) - from basicsr.data import degradations as degradations from basicsr.data.data_util import paths_from_folder from basicsr.data.transforms import augment from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) @DATASET_REGISTRY.register() diff --git a/models/__init__.py b/gfpgan/models/__init__.py similarity index 63% rename from models/__init__.py rename to gfpgan/models/__init__.py index 904bd99..6afad57 100644 --- a/models/__init__.py +++ b/gfpgan/models/__init__.py @@ -1,12 +1,10 @@ import importlib +from basicsr.utils import scandir from os import path as osp -from basicsr.utils import scandir - # automatically scan and import model modules for registry -# scan all the files under the 'models' folder and collect files ending with -# '_model.py' +# scan all the files that end with '_model.py' under the model folder model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules -_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] +_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] diff --git a/models/gfpgan_model.py b/gfpgan/models/gfpgan_model.py similarity index 99% rename from models/gfpgan_model.py rename to gfpgan/models/gfpgan_model.py index 4716be4..378f2a9 100644 --- a/models/gfpgan_model.py +++ b/gfpgan/models/gfpgan_model.py @@ -1,11 +1,6 @@ import math import os.path as osp import torch -from collections import OrderedDict -from torch.nn import functional as F -from torchvision.ops import roi_align -from tqdm import tqdm - from basicsr.archs import build_network from basicsr.losses import build_loss from basicsr.losses.losses import r1_penalty @@ -13,6 +8,10 @@ from basicsr.metrics import calculate_metric from basicsr.models.base_model import BaseModel from basicsr.utils import get_root_logger, imwrite, tensor2img from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F +from torchvision.ops import roi_align +from tqdm import tqdm @MODEL_REGISTRY.register() diff --git a/gfpgan/train.py b/gfpgan/train.py new file mode 100644 index 0000000..fe5f1f9 --- /dev/null +++ b/gfpgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import gfpgan.archs +import gfpgan.data +import gfpgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/gfpgan/utils.py b/gfpgan/utils.py new file mode 100644 index 0000000..02b0d40 --- /dev/null +++ b/gfpgan/utils.py @@ -0,0 +1,134 @@ +import cv2 +import os +import torch +from basicsr.utils import img2tensor, tensor2img +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torch.hub import download_url_to_file, get_dir +from torchvision.transforms.functional import normalize +from urllib.parse import urlparse + +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class GFPGANer(): + + def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # initialize the GFP-GAN + if arch == 'clean': + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + else: + self.gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=self.device) + + if model_path.startswith('https://'): + model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) + + @torch.no_grad() + def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True): + self.face_helper.clean_all() + + if has_aligned: + img = cv2.resize(img, (512, 512)) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + self.face_helper.get_face_landmarks_5(only_center_face=only_center_face) + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration + for cropped_face in self.face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + output = self.gfpgan(cropped_face_t, return_rgb=False)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + if not has_aligned and paste_back: + + if self.bg_upsampler is not None: + # Now only support RealESRGAN + bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] + else: + bg_img = None + + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) + return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img + else: + return self.face_helper.cropped_faces, self.face_helper.restored_faces, None + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/gfpgan/weights/README.md b/gfpgan/weights/README.md new file mode 100644 index 0000000..4d7b7e6 --- /dev/null +++ b/gfpgan/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded weights to this folder. diff --git a/inference_gfpgan.py b/inference_gfpgan.py new file mode 100644 index 0000000..a646c48 --- /dev/null +++ b/inference_gfpgan.py @@ -0,0 +1,96 @@ +import argparse +import cv2 +import glob +import numpy as np +import os +import torch +from basicsr.utils import imwrite + +from gfpgan import GFPGANer + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--upscale', type=int, default=2) + parser.add_argument('--arch', type=str, default='clean') + parser.add_argument('--channel', type=int, default=2) + parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') + parser.add_argument('--bg_upsampler', type=str, default='realesrgan') + parser.add_argument('--bg_tile', type=int, default=0) + parser.add_argument('--test_path', type=str, default='inputs/whole_imgs') + parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') + parser.add_argument('--only_center_face', action='store_true') + parser.add_argument('--aligned', action='store_true') + parser.add_argument('--paste_back', action='store_false') + parser.add_argument('--save_root', type=str, default='results') + + args = parser.parse_args() + if args.test_path.endswith('/'): + args.test_path = args.test_path[:-1] + os.makedirs(args.save_root, exist_ok=True) + + # background upsampler + if args.bg_upsampler == 'realesrgan': + if not torch.cuda.is_available(): # CPU + import warnings + warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. ' + 'If you really want to use it, please modify the corresponding codes.') + bg_upsampler = None + else: + from realesrgan import RealESRGANer + bg_upsampler = RealESRGANer( + scale=2, + model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + tile=args.bg_tile, + tile_pad=10, + pre_pad=0, + half=True) # need to set False in CPU mode + else: + bg_upsampler = None + # set up GFPGAN restorer + restorer = GFPGANer( + model_path=args.model_path, + upscale=args.upscale, + arch=args.arch, + channel_multiplier=args.channel, + bg_upsampler=bg_upsampler) + + img_list = sorted(glob.glob(os.path.join(args.test_path, '*'))) + for img_path in img_list: + # read image + img_name = os.path.basename(img_path) + print(f'Processing {img_name} ...') + basename, ext = os.path.splitext(img_name) + input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + + cropped_faces, restored_faces, restored_img = restorer.enhance( + input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back) + + # save faces + for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): + # save cropped face + save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png') + imwrite(restored_face, save_crop_path) + # save restored face + if args.suffix is not None: + save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png' + else: + save_face_name = f'{basename}_{idx:02d}.png' + save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name) + imwrite(restored_face, save_restore_path) + # save cmp image + cmp_img = np.concatenate((cropped_face, restored_face), axis=1) + imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png')) + # save restored img + if args.suffix is not None: + save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}_{args.suffix}{ext}') + else: + save_restore_path = os.path.join(args.save_root, 'restored_imgs', img_name) + imwrite(restored_img, save_restore_path) + + print(f'Results are in the [{args.save_root}] folder.') + + +if __name__ == '__main__': + main() diff --git a/inference_gfpgan_full.py b/inference_gfpgan_full.py deleted file mode 100644 index 07b3d5a..0000000 --- a/inference_gfpgan_full.py +++ /dev/null @@ -1,153 +0,0 @@ -import argparse -import cv2 -import glob -import numpy as np -import os -import torch -from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from torchvision.transforms.functional import normalize - -from archs.gfpganv1_arch import GFPGANv1 -from archs.gfpganv1_clean_arch import GFPGANv1Clean -from basicsr.utils import img2tensor, imwrite, tensor2img - - -def restoration(gfpgan, - face_helper, - img_path, - save_root, - has_aligned=False, - only_center_face=True, - suffix=None, - paste_back=False, - device='cuda'): - # read image - img_name = os.path.basename(img_path) - print(f'Processing {img_name} ...') - basename, _ = os.path.splitext(img_name) - input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) - face_helper.clean_all() - - if has_aligned: - input_img = cv2.resize(input_img, (512, 512)) - face_helper.cropped_faces = [input_img] - else: - face_helper.read_image(input_img) - # get face landmarks for each face - face_helper.get_face_landmarks_5(only_center_face=only_center_face) - # align and warp each face - save_crop_path = os.path.join(save_root, 'cropped_faces', img_name) - face_helper.align_warp_face(save_crop_path) - - # face restoration - for idx, cropped_face in enumerate(face_helper.cropped_faces): - # prepare data - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - - try: - with torch.no_grad(): - output = gfpgan(cropped_face_t, return_rgb=False)[0] - # convert to image - restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) - except RuntimeError as error: - print(f'\tFailed inference for GFPGAN: {error}.') - restored_face = cropped_face - - restored_face = restored_face.astype('uint8') - face_helper.add_restored_face(restored_face) - - if suffix is not None: - save_face_name = f'{basename}_{idx:02d}_{suffix}.png' - else: - save_face_name = f'{basename}_{idx:02d}.png' - save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name) - imwrite(restored_face, save_restore_path) - - # save cmp image - cmp_img = np.concatenate((cropped_face, restored_face), axis=1) - imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png')) - - if not has_aligned and paste_back: - face_helper.get_inverse_affine(None) - save_restore_path = os.path.join(save_root, 'restored_imgs', img_name) - # paste each restored face to the input image - face_helper.paste_faces_to_input_image(save_restore_path) - - -if __name__ == '__main__': - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - parser = argparse.ArgumentParser() - - parser.add_argument('--upscale_factor', type=int, default=2) - parser.add_argument('--arch', type=str, default='clean') - parser.add_argument('--channel', type=int, default=2) - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') - parser.add_argument('--test_path', type=str, default='inputs/whole_imgs') - parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') - parser.add_argument('--only_center_face', action='store_true') - parser.add_argument('--aligned', action='store_true') - parser.add_argument('--paste_back', action='store_false') - parser.add_argument('--save_root', type=str, default='results') - - args = parser.parse_args() - if args.test_path.endswith('/'): - args.test_path = args.test_path[:-1] - os.makedirs(args.save_root, exist_ok=True) - - # initialize the GFP-GAN - if args.arch == 'clean': - gfpgan = GFPGANv1Clean( - out_size=512, - num_style_feat=512, - channel_multiplier=args.channel, - decoder_load_path=None, - fix_decoder=False, - # for stylegan decoder - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - else: - gfpgan = GFPGANv1( - out_size=512, - num_style_feat=512, - channel_multiplier=args.channel, - decoder_load_path=None, - fix_decoder=True, - # for stylegan decoder - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - - gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema']) - gfpgan.to(device).eval() - - # initialize face helper - face_helper = FaceRestoreHelper( - args.upscale_factor, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - device=device) - - img_list = sorted(glob.glob(os.path.join(args.test_path, '*'))) - for img_path in img_list: - restoration( - gfpgan, - face_helper, - img_path, - args.save_root, - has_aligned=args.aligned, - only_center_face=args.only_center_face, - suffix=args.suffix, - paste_back=args.paste_back, - device=device) - - print(f'Results are in the [{args.save_root}] folder.') diff --git a/train_gfpgan_v1.yml b/options/train_gfpgan_v1.yml similarity index 100% rename from train_gfpgan_v1.yml rename to options/train_gfpgan_v1.yml diff --git a/train_gfpgan_v1_simple.yml b/options/train_gfpgan_v1_simple.yml similarity index 100% rename from train_gfpgan_v1_simple.yml rename to options/train_gfpgan_v1_simple.yml diff --git a/scripts/parse_landmark.py b/scripts/parse_landmark.py index 7ca457e..c6ca4a5 100644 --- a/scripts/parse_landmark.py +++ b/scripts/parse_landmark.py @@ -2,9 +2,8 @@ import cv2 import json import numpy as np import torch -from collections import OrderedDict - from basicsr.utils import FileClient, imfrombytes +from collections import OrderedDict print('Load JSON metadata...') # use the json file in FFHQ dataset diff --git a/setup.cfg b/setup.cfg index 014da11..628b31c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools -known_first_party = basicsr -known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm +known_first_party = gfpgan +known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2a72f8f --- /dev/null +++ b/setup.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import time + +version_file = 'gfpgan/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from facexlib.version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def get_requirements(filename='requirements.txt'): + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + write_version_py() + setup( + name='gfpgan', + version=get_version(), + description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan', + url='https://github.com/TencentARC/GFPGAN', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License Version 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + zip_safe=False) diff --git a/train.py b/train.py deleted file mode 100644 index b03e0f3..0000000 --- a/train.py +++ /dev/null @@ -1,10 +0,0 @@ -import os.path as osp - -import archs # noqa: F401 -import data # noqa: F401 -import models # noqa: F401 -from basicsr.train import train_pipeline - -if __name__ == '__main__': - root_path = osp.abspath(osp.join(__file__, osp.pardir)) - train_pipeline(root_path)