1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 23:00:12 -07:00

Major revision: Support Pypi (#37)

* reorganize

* update inference

* update inference

* format
This commit is contained in:
Xintao 2021-08-09 01:28:10 +08:00 committed by GitHub
parent 77dc85b882
commit 996d1e3df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 470 additions and 221 deletions

30
.github/workflows/publish-pip.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: PyPI Publish
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch (cpu)
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Build for distribution
# remove bdist_wheel for pip installation with compiling cuda extensions
run: python setup.py sdist
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -25,5 +25,5 @@ jobs:
- name: Lint - name: Lint
run: | run: |
flake8 . flake8 .
isort --check-only --diff data/ archs/ models/ train.py inference_gfpgan_full.py isort --check-only --diff gfpgan/ scripts/ inference_gfpgan.py setup.py
yapf -r -d data/ archs/ models/ train.py inference_gfpgan_full.py yapf -r -d gfpgan/ scripts/ inference_gfpgan.py setup.py

45
.gitignore vendored
View File

@ -1,21 +1,13 @@
.vscode # ignored folders
datasets/* datasets/*
experiments/* experiments/*
results/*
tb_logger/* tb_logger/*
wandb/*
tmp/*
# ignored files
version.py version.py
.vscode
# ignored files with suffix
*.html
*.png
*.jpeg
*.jpg
*.gif
*.pth
*.zip
# template
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
@ -39,6 +31,8 @@ parts/
sdist/ sdist/
var/ var/
wheels/ wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/ *.egg-info/
.installed.cfg .installed.cfg
*.egg *.egg
@ -57,12 +51,14 @@ pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
htmlcov/ htmlcov/
.tox/ .tox/
.nox/
.coverage .coverage
.coverage.* .coverage.*
.cache .cache
nosetests.xml nosetests.xml
coverage.xml coverage.xml
*.cover *.cover
*.py,cover
.hypothesis/ .hypothesis/
.pytest_cache/ .pytest_cache/
@ -74,6 +70,7 @@ coverage.xml
*.log *.log
local_settings.py local_settings.py
db.sqlite3 db.sqlite3
db.sqlite3-journal
# Flask stuff: # Flask stuff:
instance/ instance/
@ -91,11 +88,26 @@ target/
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv # pyenv
.python-version .python-version
# celery beat schedule file # pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule celerybeat-schedule
celerybeat.pid
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
@ -121,3 +133,8 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

8
MANIFEST.in Normal file
View File

@ -0,0 +1,8 @@
include assets/*
include inputs/*
include scripts/*.py
include inference_gfpgan.py
include VERSION
include LICENSE
include requirements.txt
include gfpgan/weights/README.md

View File

@ -27,6 +27,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib pip install facexlib
pip install -r requirements.txt pip install -r requirements.txt
python setup.py develop
# remember to set BASICSR_JIT=True before your running commands # remember to set BASICSR_JIT=True before your running commands
``` ```
@ -45,6 +46,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib pip install facexlib
pip install -r requirements.txt pip install -r requirements.txt
python setup.py develop
``` ```
## :zap: Quick Inference ## :zap: Quick Inference
@ -58,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
- Option 1: Load extensions just-in-time(JIT) - Option 1: Load extensions just-in-time(JIT)
```bash ```bash
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images # for aligned images
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
``` ```
- Option 2: Have successfully compiled extensions during installation - Option 2: Have successfully compiled extensions during installation
```bash ```bash
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images # for aligned images
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
``` ```

View File

@ -1,9 +1,11 @@
# GFPGAN (CVPR 2021) # GFPGAN (CVPR 2021)
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases) [![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/)
[![Open issue](https://isitmaintained.com/badge/open/TencentARC/GFPGAN.svg)](https://github.com/TencentARC/GFPGAN/issues) [![Open issue](https://isitmaintained.com/badge/open/TencentARC/GFPGAN.svg)](https://github.com/TencentARC/GFPGAN/issues)
[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE) [![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml) [![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model) 1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. 1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
@ -59,6 +61,7 @@ If you want want to use the original model in our paper, please see [PaperModel.
pip install facexlib pip install facexlib
pip install -r requirements.txt pip install -r requirements.txt
python setup.py develop
``` ```
## :zap: Quick Inference ## :zap: Quick Inference
@ -72,7 +75,7 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1
**Inference!** **Inference!**
```bash ```bash
python inference_gfpgan_full.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results python inference_gfpgan.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
``` ```
## :european_castle: Model Zoo ## :european_castle: Model Zoo
@ -90,10 +93,9 @@ You could improve it according to your own needs.
1. More high quality faces can improve the restoration quality. 1. More high quality faces can improve the restoration quality.
2. You may need to perform some pre-processing, such as beauty makeup. 2. You may need to perform some pre-processing, such as beauty makeup.
**Procedures** **Procedures**
(You can try a simple version ( `train_gfpgan_v1_simple.yml`) that does not require face component landmarks.) (You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset) 1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
@ -102,11 +104,11 @@ You could improve it according to your own needs.
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth) 1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth) 1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
1. Modify the configuration file `train_gfpgan_v1.yml` accordingly. 1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly.
1. Training 1. Training
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 train.py -opt train_gfpgan_v1.yml --launcher pytorch > python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch
## :scroll: License and Acknowledgement ## :scroll: License and Acknowledgement

1
VERSION Normal file
View File

@ -0,0 +1 @@
0.2.1

6
gfpgan/__init__.py Normal file
View File

@ -0,0 +1,6 @@
# flake8: noqa
from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__

View File

@ -1,12 +1,10 @@
import importlib import importlib
from basicsr.utils import scandir
from os import path as osp from os import path as osp
from basicsr.utils import scandir
# automatically scan and import arch modules for registry # automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with # scan all the files that end with '_arch.py' under the archs folder
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__)) arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules # import all the arch modules
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]

View File

@ -1,5 +1,4 @@
import torch.nn as nn import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY

View File

@ -1,13 +1,12 @@
import math import math
import random import random
import torch import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator) StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator): class StyleGAN2GeneratorSFT(StyleGAN2Generator):

View File

@ -1,11 +1,10 @@
import math import math
import random import random
import torch import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.arch_util import default_init_weights from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module): class NormStyleCode(nn.Module):

View File

@ -1,11 +1,10 @@
import importlib import importlib
from basicsr.utils import scandir
from os import path as osp from os import path as osp
from basicsr.utils import scandir
# automatically scan and import dataset modules for registry # automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names # scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__)) data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules # import all the dataset modules
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]

View File

@ -4,14 +4,13 @@ import numpy as np
import os.path as osp import os.path as osp
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
from basicsr.data import degradations as degradations from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY from basicsr.utils.registry import DATASET_REGISTRY
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
@DATASET_REGISTRY.register() @DATASET_REGISTRY.register()

View File

@ -1,12 +1,10 @@
import importlib import importlib
from basicsr.utils import scandir
from os import path as osp from os import path as osp
from basicsr.utils import scandir
# automatically scan and import model modules for registry # automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with # scan all the files that end with '_model.py' under the model folder
# '_model.py'
model_folder = osp.dirname(osp.abspath(__file__)) model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules # import all the model modules
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]

View File

@ -1,11 +1,6 @@
import math import math
import os.path as osp import os.path as osp
import torch import torch
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
from basicsr.archs import build_network from basicsr.archs import build_network
from basicsr.losses import build_loss from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty from basicsr.losses.losses import r1_penalty
@ -13,6 +8,10 @@ from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
@MODEL_REGISTRY.register() @MODEL_REGISTRY.register()

11
gfpgan/train.py Normal file
View File

@ -0,0 +1,11 @@
# flake8: noqa
import os.path as osp
from basicsr.train import train_pipeline
import gfpgan.archs
import gfpgan.data
import gfpgan.models
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)

134
gfpgan/utils.py Normal file
View File

@ -0,0 +1,134 @@
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned:
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
if self.bg_upsampler is not None:
# Now only support RealESRGAN
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

3
gfpgan/weights/README.md Normal file
View File

@ -0,0 +1,3 @@
# Weights
Put the downloaded weights to this folder.

96
inference_gfpgan.py Normal file
View File

@ -0,0 +1,96 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.utils import imwrite
from gfpgan import GFPGANer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
parser.add_argument('--bg_tile', type=int, default=0)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# background upsampler
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from realesrgan import RealESRGANer
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
tile=args.bg_tile,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# set up GFPGAN restorer
restorer = GFPGANer(
model_path=args.model_path,
upscale=args.upscale,
arch=args.arch,
channel_multiplier=args.channel,
bg_upsampler=bg_upsampler)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# save cropped face
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(restored_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if args.suffix is not None:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}_{args.suffix}{ext}')
else:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', img_name)
imwrite(restored_img, save_restore_path)
print(f'Results are in the [{args.save_root}] folder.')
if __name__ == '__main__':
main()

View File

@ -1,153 +0,0 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from archs.gfpganv1_arch import GFPGANv1
from archs.gfpganv1_clean_arch import GFPGANv1Clean
from basicsr.utils import img2tensor, imwrite, tensor2img
def restoration(gfpgan,
face_helper,
img_path,
save_root,
has_aligned=False,
only_center_face=True,
suffix=None,
paste_back=False,
device='cuda'):
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, _ = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
face_helper.clean_all()
if has_aligned:
input_img = cv2.resize(input_img, (512, 512))
face_helper.cropped_faces = [input_img]
else:
face_helper.read_image(input_img)
# get face landmarks for each face
face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
face_helper.align_warp_face(save_crop_path)
# face restoration
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
if suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
if not has_aligned and paste_back:
face_helper.get_inverse_affine(None)
save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
# paste each restored face to the input image
face_helper.paste_faces_to_input_image(save_restore_path)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--upscale_factor', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# initialize the GFP-GAN
if args.arch == 'clean':
gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema'])
gfpgan.to(device).eval()
# initialize face helper
face_helper = FaceRestoreHelper(
args.upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=device)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
restoration(
gfpgan,
face_helper,
img_path,
args.save_root,
has_aligned=args.aligned,
only_center_face=args.only_center_face,
suffix=args.suffix,
paste_back=args.paste_back,
device=device)
print(f'Results are in the [{args.save_root}] folder.')

View File

@ -2,9 +2,8 @@ import cv2
import json import json
import numpy as np import numpy as np
import torch import torch
from collections import OrderedDict
from basicsr.utils import FileClient, imfrombytes from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
print('Load JSON metadata...') print('Load JSON metadata...')
# use the json file in FFHQ dataset # use the json file in FFHQ dataset

View File

@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
line_length = 120 line_length = 120
multi_line_output = 0 multi_line_output = 0
known_standard_library = pkg_resources,setuptools known_standard_library = pkg_resources,setuptools
known_first_party = basicsr known_first_party = gfpgan
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY

113
setup.py Normal file
View File

@ -0,0 +1,113 @@
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import time
version_file = 'gfpgan/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from facexlib.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
write_version_py()
setup(
name='gfpgan',
version=get_version(),
description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
url='https://github.com/TencentARC/GFPGAN',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License Version 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
zip_safe=False)

View File

@ -1,10 +0,0 @@
import os.path as osp
import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir))
train_pipeline(root_path)