1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 23:00:12 -07:00

Major revision: Support Pypi (#37)

* reorganize

* update inference

* update inference

* format
This commit is contained in:
Xintao 2021-08-09 01:28:10 +08:00 committed by GitHub
parent 77dc85b882
commit 996d1e3df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 470 additions and 221 deletions

30
.github/workflows/publish-pip.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: PyPI Publish
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch (cpu)
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Build for distribution
# remove bdist_wheel for pip installation with compiling cuda extensions
run: python setup.py sdist
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -25,5 +25,5 @@ jobs:
- name: Lint
run: |
flake8 .
isort --check-only --diff data/ archs/ models/ train.py inference_gfpgan_full.py
yapf -r -d data/ archs/ models/ train.py inference_gfpgan_full.py
isort --check-only --diff gfpgan/ scripts/ inference_gfpgan.py setup.py
yapf -r -d gfpgan/ scripts/ inference_gfpgan.py setup.py

45
.gitignore vendored
View File

@ -1,21 +1,13 @@
.vscode
# ignored folders
datasets/*
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
# ignored files
version.py
# ignored files with suffix
*.html
*.png
*.jpeg
*.jpg
*.gif
*.pth
*.zip
# template
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
@ -39,6 +31,8 @@ parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
@ -57,12 +51,14 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
@ -74,6 +70,7 @@ coverage.xml
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
@ -91,11 +88,26 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
@ -121,3 +133,8 @@ venv.bak/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

8
MANIFEST.in Normal file
View File

@ -0,0 +1,8 @@
include assets/*
include inputs/*
include scripts/*.py
include inference_gfpgan.py
include VERSION
include LICENSE
include requirements.txt
include gfpgan/weights/README.md

View File

@ -27,6 +27,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib
pip install -r requirements.txt
python setup.py develop
# remember to set BASICSR_JIT=True before your running commands
```
@ -45,6 +46,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib
pip install -r requirements.txt
python setup.py develop
```
## :zap: Quick Inference
@ -58,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
- Option 1: Load extensions just-in-time(JIT)
```bash
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
```
- Option 2: Have successfully compiled extensions during installation
```bash
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
```

View File

@ -1,9 +1,11 @@
# GFPGAN (CVPR 2021)
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/)
[![Open issue](https://isitmaintained.com/badge/open/TencentARC/GFPGAN.svg)](https://github.com/TencentARC/GFPGAN/issues)
[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
@ -59,6 +61,7 @@ If you want want to use the original model in our paper, please see [PaperModel.
pip install facexlib
pip install -r requirements.txt
python setup.py develop
```
## :zap: Quick Inference
@ -72,7 +75,7 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1
**Inference!**
```bash
python inference_gfpgan_full.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
python inference_gfpgan.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
```
## :european_castle: Model Zoo
@ -90,10 +93,9 @@ You could improve it according to your own needs.
1. More high quality faces can improve the restoration quality.
2. You may need to perform some pre-processing, such as beauty makeup.
**Procedures**
(You can try a simple version ( `train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
(You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
@ -102,11 +104,11 @@ You could improve it according to your own needs.
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
1. Modify the configuration file `train_gfpgan_v1.yml` accordingly.
1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly.
1. Training
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 train.py -opt train_gfpgan_v1.yml --launcher pytorch
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch
## :scroll: License and Acknowledgement

1
VERSION Normal file
View File

@ -0,0 +1 @@
0.2.1

6
gfpgan/__init__.py Normal file
View File

@ -0,0 +1,6 @@
# flake8: noqa
from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__

View File

@ -1,12 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
# scan all the files that end with '_arch.py' under the archs folder
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]

View File

@ -1,5 +1,4 @@
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY

View File

@ -1,13 +1,12 @@
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator):

View File

@ -1,11 +1,10 @@
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):

View File

@ -1,11 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
# scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]

View File

@ -4,14 +4,13 @@ import numpy as np
import os.path as osp
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
@DATASET_REGISTRY.register()

View File

@ -1,12 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]

View File

@ -1,11 +1,6 @@
import math
import os.path as osp
import torch
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty
@ -13,6 +8,10 @@ from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
@MODEL_REGISTRY.register()

11
gfpgan/train.py Normal file
View File

@ -0,0 +1,11 @@
# flake8: noqa
import os.path as osp
from basicsr.train import train_pipeline
import gfpgan.archs
import gfpgan.data
import gfpgan.models
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)

134
gfpgan/utils.py Normal file
View File

@ -0,0 +1,134 @@
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned:
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
if self.bg_upsampler is not None:
# Now only support RealESRGAN
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

3
gfpgan/weights/README.md Normal file
View File

@ -0,0 +1,3 @@
# Weights
Put the downloaded weights to this folder.

96
inference_gfpgan.py Normal file
View File

@ -0,0 +1,96 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.utils import imwrite
from gfpgan import GFPGANer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
parser.add_argument('--bg_tile', type=int, default=0)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# background upsampler
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from realesrgan import RealESRGANer
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
tile=args.bg_tile,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# set up GFPGAN restorer
restorer = GFPGANer(
model_path=args.model_path,
upscale=args.upscale,
arch=args.arch,
channel_multiplier=args.channel,
bg_upsampler=bg_upsampler)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# save cropped face
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(restored_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if args.suffix is not None:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}_{args.suffix}{ext}')
else:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', img_name)
imwrite(restored_img, save_restore_path)
print(f'Results are in the [{args.save_root}] folder.')
if __name__ == '__main__':
main()

View File

@ -1,153 +0,0 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from archs.gfpganv1_arch import GFPGANv1
from archs.gfpganv1_clean_arch import GFPGANv1Clean
from basicsr.utils import img2tensor, imwrite, tensor2img
def restoration(gfpgan,
face_helper,
img_path,
save_root,
has_aligned=False,
only_center_face=True,
suffix=None,
paste_back=False,
device='cuda'):
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, _ = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
face_helper.clean_all()
if has_aligned:
input_img = cv2.resize(input_img, (512, 512))
face_helper.cropped_faces = [input_img]
else:
face_helper.read_image(input_img)
# get face landmarks for each face
face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
face_helper.align_warp_face(save_crop_path)
# face restoration
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
if suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
if not has_aligned and paste_back:
face_helper.get_inverse_affine(None)
save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
# paste each restored face to the input image
face_helper.paste_faces_to_input_image(save_restore_path)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--upscale_factor', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# initialize the GFP-GAN
if args.arch == 'clean':
gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema'])
gfpgan.to(device).eval()
# initialize face helper
face_helper = FaceRestoreHelper(
args.upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=device)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
restoration(
gfpgan,
face_helper,
img_path,
args.save_root,
has_aligned=args.aligned,
only_center_face=args.only_center_face,
suffix=args.suffix,
paste_back=args.paste_back,
device=device)
print(f'Results are in the [{args.save_root}] folder.')

View File

@ -2,9 +2,8 @@ import cv2
import json
import numpy as np
import torch
from collections import OrderedDict
from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
print('Load JSON metadata...')
# use the json file in FFHQ dataset

View File

@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
known_first_party = gfpgan
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

113
setup.py Normal file
View File

@ -0,0 +1,113 @@
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import time
version_file = 'gfpgan/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from facexlib.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
write_version_py()
setup(
name='gfpgan',
version=get_version(),
description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
url='https://github.com/TencentARC/GFPGAN',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License Version 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
zip_safe=False)

View File

@ -1,10 +0,0 @@
import os.path as osp
import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir))
train_pipeline(root_path)