mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-20 09:10:20 -07:00
add inference
This commit is contained in:
parent
6ddfed7bde
commit
043dc22027
29
.github/workflows/pylint.yml
vendored
Normal file
29
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
name: Python Lint
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install flake8 yapf isort
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: |
|
||||||
|
flake8 .
|
||||||
|
isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py
|
||||||
|
yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py
|
120
.gitignore
vendored
Normal file
120
.gitignore
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
.vscode
|
||||||
|
|
||||||
|
# ignored files
|
||||||
|
version.py
|
||||||
|
|
||||||
|
# ignored files with suffix
|
||||||
|
*.html
|
||||||
|
*.png
|
||||||
|
*.jpeg
|
||||||
|
*.jpg
|
||||||
|
*.gif
|
||||||
|
*.pth
|
||||||
|
*.zip
|
||||||
|
|
||||||
|
# template
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
40
.pre-commit-config.yaml
Normal file
40
.pre-commit-config.yaml
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
repos:
|
||||||
|
# flake8
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 3.8.3
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: ["--config=setup.cfg", "--ignore=W504, W503"]
|
||||||
|
|
||||||
|
# modify known_third_party
|
||||||
|
- repo: https://github.com/asottile/seed-isort-config
|
||||||
|
rev: v2.2.0
|
||||||
|
hooks:
|
||||||
|
- id: seed-isort-config
|
||||||
|
|
||||||
|
# isort
|
||||||
|
- repo: https://github.com/timothycrosley/isort
|
||||||
|
rev: 5.2.2
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
|
||||||
|
# yapf
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
|
rev: v0.30.0
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
|
||||||
|
# pre-commit-hooks
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v3.2.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace # Trim trailing whitespace
|
||||||
|
- id: check-yaml # Attempt to load all yaml files to verify syntax
|
||||||
|
- id: check-merge-conflict # Check for files that contain merge conflict strings
|
||||||
|
- id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
|
||||||
|
- id: end-of-file-fixer # Make sure files end in a newline and only a newline
|
||||||
|
- id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
|
||||||
|
- id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
|
||||||
|
args: ["--remove"]
|
||||||
|
- id: mixed-line-ending # Replace or check mixed line ending
|
||||||
|
args: ["--fix=lf"]
|
543
gfpgan_model.py
Normal file
543
gfpgan_model.py
Normal file
@ -0,0 +1,543 @@
|
|||||||
|
import math
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
from collections import OrderedDict
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchvision.ops import roi_align
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from basicsr.archs import build_network
|
||||||
|
from basicsr.losses import build_loss
|
||||||
|
from basicsr.losses.losses import r1_penalty
|
||||||
|
from basicsr.metrics import calculate_metric
|
||||||
|
from basicsr.models.base_model import BaseModel
|
||||||
|
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
||||||
|
|
||||||
|
|
||||||
|
class GFPGANModel(BaseModel):
|
||||||
|
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(GFPGANModel, self).__init__(opt)
|
||||||
|
|
||||||
|
# define network
|
||||||
|
self.net_g = build_network(opt['network_g'])
|
||||||
|
self.net_g = self.model_to_device(self.net_g)
|
||||||
|
self.print_network(self.net_g)
|
||||||
|
|
||||||
|
# load pretrained model
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||||
|
if load_path is not None:
|
||||||
|
param_key = self.opt['path'].get('param_key_g', 'params')
|
||||||
|
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
||||||
|
|
||||||
|
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
self.init_training_settings()
|
||||||
|
|
||||||
|
def init_training_settings(self):
|
||||||
|
train_opt = self.opt['train']
|
||||||
|
|
||||||
|
# ----------- define net_d ----------- #
|
||||||
|
self.net_d = build_network(self.opt['network_d'])
|
||||||
|
self.net_d = self.model_to_device(self.net_d)
|
||||||
|
self.print_network(self.net_d)
|
||||||
|
# load pretrained model
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_d', None)
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
||||||
|
|
||||||
|
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
||||||
|
# net_g_ema only used for testing on one GPU and saving
|
||||||
|
# There is no need to wrap with DistributedDataParallel
|
||||||
|
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
||||||
|
# load pretrained model
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
||||||
|
else:
|
||||||
|
self.model_ema(0) # copy net_g weight
|
||||||
|
|
||||||
|
self.net_g.train()
|
||||||
|
self.net_d.train()
|
||||||
|
self.net_g_ema.eval()
|
||||||
|
|
||||||
|
# ----------- facial components networks ----------- #
|
||||||
|
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
||||||
|
self.use_facial_disc = True
|
||||||
|
else:
|
||||||
|
self.use_facial_disc = False
|
||||||
|
|
||||||
|
if self.use_facial_disc:
|
||||||
|
# left eye
|
||||||
|
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
||||||
|
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
||||||
|
self.print_network(self.net_d_left_eye)
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
||||||
|
# right eye
|
||||||
|
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
||||||
|
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
||||||
|
self.print_network(self.net_d_right_eye)
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
||||||
|
# mouth
|
||||||
|
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
||||||
|
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
||||||
|
self.print_network(self.net_d_mouth)
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
||||||
|
|
||||||
|
self.net_d_left_eye.train()
|
||||||
|
self.net_d_right_eye.train()
|
||||||
|
self.net_d_mouth.train()
|
||||||
|
|
||||||
|
# ----------- define facial component gan loss ----------- #
|
||||||
|
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
||||||
|
|
||||||
|
# ----------- define losses ----------- #
|
||||||
|
if train_opt.get('pixel_opt'):
|
||||||
|
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
||||||
|
else:
|
||||||
|
self.cri_pix = None
|
||||||
|
|
||||||
|
if train_opt.get('perceptual_opt'):
|
||||||
|
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
||||||
|
else:
|
||||||
|
self.cri_perceptual = None
|
||||||
|
|
||||||
|
# gan loss (wgan)
|
||||||
|
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
||||||
|
|
||||||
|
# ----------- define identity loss ----------- #
|
||||||
|
if 'network_identity' in self.opt:
|
||||||
|
self.use_identity = True
|
||||||
|
else:
|
||||||
|
self.use_identity = False
|
||||||
|
|
||||||
|
if self.use_identity:
|
||||||
|
# define identity network
|
||||||
|
self.network_identity = build_network(self.opt['network_identity'])
|
||||||
|
self.network_identity = self.model_to_device(self.network_identity)
|
||||||
|
self.print_network(self.network_identity)
|
||||||
|
load_path = self.opt['path'].get('pretrain_network_identity')
|
||||||
|
if load_path is not None:
|
||||||
|
self.load_network(self.network_identity, load_path, True, None)
|
||||||
|
self.network_identity.eval()
|
||||||
|
for param in self.network_identity.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# regularization weights
|
||||||
|
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
||||||
|
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
||||||
|
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
||||||
|
self.net_d_reg_every = train_opt['net_d_reg_every']
|
||||||
|
|
||||||
|
# set up optimizers and schedulers
|
||||||
|
self.setup_optimizers()
|
||||||
|
self.setup_schedulers()
|
||||||
|
|
||||||
|
def setup_optimizers(self):
|
||||||
|
train_opt = self.opt['train']
|
||||||
|
|
||||||
|
# ----------- optimizer g ----------- #
|
||||||
|
net_g_reg_ratio = 1
|
||||||
|
normal_params = []
|
||||||
|
for _, param in self.net_g.named_parameters():
|
||||||
|
normal_params.append(param)
|
||||||
|
optim_params_g = [{ # add normal params first
|
||||||
|
'params': normal_params,
|
||||||
|
'lr': train_opt['optim_g']['lr']
|
||||||
|
}]
|
||||||
|
optim_type = train_opt['optim_g'].pop('type')
|
||||||
|
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
||||||
|
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
||||||
|
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
||||||
|
self.optimizers.append(self.optimizer_g)
|
||||||
|
|
||||||
|
# ----------- optimizer d ----------- #
|
||||||
|
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
||||||
|
normal_params = []
|
||||||
|
for _, param in self.net_d.named_parameters():
|
||||||
|
normal_params.append(param)
|
||||||
|
optim_params_d = [{ # add normal params first
|
||||||
|
'params': normal_params,
|
||||||
|
'lr': train_opt['optim_d']['lr']
|
||||||
|
}]
|
||||||
|
optim_type = train_opt['optim_d'].pop('type')
|
||||||
|
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
||||||
|
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
||||||
|
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
||||||
|
self.optimizers.append(self.optimizer_d)
|
||||||
|
|
||||||
|
if self.use_facial_disc:
|
||||||
|
# setup optimizers for facial component discriminators
|
||||||
|
optim_type = train_opt['optim_component'].pop('type')
|
||||||
|
lr = train_opt['optim_component']['lr']
|
||||||
|
# left eye
|
||||||
|
self.optimizer_d_left_eye = self.get_optimizer(
|
||||||
|
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
||||||
|
self.optimizers.append(self.optimizer_d_left_eye)
|
||||||
|
# right eye
|
||||||
|
self.optimizer_d_right_eye = self.get_optimizer(
|
||||||
|
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
||||||
|
self.optimizers.append(self.optimizer_d_right_eye)
|
||||||
|
# mouth
|
||||||
|
self.optimizer_d_mouth = self.get_optimizer(
|
||||||
|
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
||||||
|
self.optimizers.append(self.optimizer_d_mouth)
|
||||||
|
|
||||||
|
def feed_data(self, data):
|
||||||
|
self.lq = data['lq'].to(self.device)
|
||||||
|
if 'gt' in data:
|
||||||
|
self.gt = data['gt'].to(self.device)
|
||||||
|
|
||||||
|
if self.use_facial_disc:
|
||||||
|
# get facial component locations, shape (batch, 4)
|
||||||
|
self.loc_left_eyes = data['loc_left_eye']
|
||||||
|
self.loc_right_eyes = data['loc_right_eye']
|
||||||
|
self.loc_mouths = data['loc_mouth']
|
||||||
|
|
||||||
|
def construct_img_pyramid(self):
|
||||||
|
pyramid_gt = [self.gt]
|
||||||
|
down_img = self.gt
|
||||||
|
for _ in range(0, self.log_size - 3):
|
||||||
|
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||||
|
pyramid_gt.insert(0, down_img)
|
||||||
|
return pyramid_gt
|
||||||
|
|
||||||
|
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
||||||
|
# hard code
|
||||||
|
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
||||||
|
eye_out_size *= face_ratio
|
||||||
|
mouth_out_size *= face_ratio
|
||||||
|
|
||||||
|
rois_eyes = []
|
||||||
|
rois_mouths = []
|
||||||
|
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
||||||
|
# left eye and right eye
|
||||||
|
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
||||||
|
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
||||||
|
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
||||||
|
rois_eyes.append(rois)
|
||||||
|
# mouse
|
||||||
|
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
||||||
|
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
||||||
|
rois_mouths.append(rois)
|
||||||
|
|
||||||
|
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
||||||
|
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
||||||
|
|
||||||
|
# real images
|
||||||
|
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
||||||
|
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
||||||
|
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
||||||
|
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
||||||
|
# output
|
||||||
|
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
||||||
|
self.left_eyes = all_eyes[0::2, :, :, :]
|
||||||
|
self.right_eyes = all_eyes[1::2, :, :, :]
|
||||||
|
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
||||||
|
|
||||||
|
def _gram_mat(self, x):
|
||||||
|
"""Calculate Gram matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Gram matrix.
|
||||||
|
"""
|
||||||
|
n, c, h, w = x.size()
|
||||||
|
features = x.view(n, c, w * h)
|
||||||
|
features_t = features.transpose(1, 2)
|
||||||
|
gram = features.bmm(features_t) / (c * h * w)
|
||||||
|
return gram
|
||||||
|
|
||||||
|
def gray_resize_for_identity(self, out, size=128):
|
||||||
|
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
||||||
|
out_gray = out_gray.unsqueeze(1)
|
||||||
|
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
||||||
|
return out_gray
|
||||||
|
|
||||||
|
def optimize_parameters(self, current_iter):
|
||||||
|
# optimize net_g
|
||||||
|
for p in self.net_d.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
self.optimizer_g.zero_grad()
|
||||||
|
|
||||||
|
if self.use_facial_disc:
|
||||||
|
for p in self.net_d_left_eye.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in self.net_d_right_eye.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
for p in self.net_d_mouth.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# image pyramid loss weight
|
||||||
|
if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
||||||
|
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
|
||||||
|
else:
|
||||||
|
pyramid_loss_weight = 1e-12 # very small loss
|
||||||
|
if pyramid_loss_weight > 0:
|
||||||
|
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
||||||
|
pyramid_gt = self.construct_img_pyramid()
|
||||||
|
else:
|
||||||
|
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
||||||
|
|
||||||
|
# get roi-align regions
|
||||||
|
if self.use_facial_disc:
|
||||||
|
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
||||||
|
|
||||||
|
l_g_total = 0
|
||||||
|
loss_dict = OrderedDict()
|
||||||
|
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
||||||
|
# pixel loss
|
||||||
|
if self.cri_pix:
|
||||||
|
l_g_pix = self.cri_pix(self.output, self.gt)
|
||||||
|
l_g_total += l_g_pix
|
||||||
|
loss_dict['l_g_pix'] = l_g_pix
|
||||||
|
|
||||||
|
# image pyramid loss
|
||||||
|
if pyramid_loss_weight > 0:
|
||||||
|
for i in range(0, self.log_size - 2):
|
||||||
|
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
||||||
|
l_g_total += l_pyramid
|
||||||
|
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
||||||
|
|
||||||
|
# perceptual loss
|
||||||
|
if self.cri_perceptual:
|
||||||
|
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
||||||
|
if l_g_percep is not None:
|
||||||
|
l_g_total += l_g_percep
|
||||||
|
loss_dict['l_g_percep'] = l_g_percep
|
||||||
|
if l_g_style is not None:
|
||||||
|
l_g_total += l_g_style
|
||||||
|
loss_dict['l_g_style'] = l_g_style
|
||||||
|
|
||||||
|
# gan loss
|
||||||
|
fake_g_pred = self.net_d(self.output)
|
||||||
|
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
loss_dict['l_g_gan'] = l_g_gan
|
||||||
|
|
||||||
|
# facial component loss
|
||||||
|
if self.use_facial_disc:
|
||||||
|
# left eye
|
||||||
|
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
||||||
|
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
||||||
|
# right eye
|
||||||
|
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
||||||
|
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
||||||
|
# mouth
|
||||||
|
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
||||||
|
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
||||||
|
l_g_total += l_g_gan
|
||||||
|
loss_dict['l_g_gan_mouth'] = l_g_gan
|
||||||
|
|
||||||
|
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
||||||
|
# get gt feat
|
||||||
|
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
||||||
|
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
||||||
|
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
||||||
|
|
||||||
|
def _comp_style(feat, feat_gt, criterion):
|
||||||
|
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
||||||
|
feat_gt[0].detach())) * 0.5 + criterion(
|
||||||
|
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
||||||
|
|
||||||
|
# facial component style loss
|
||||||
|
comp_style_loss = 0
|
||||||
|
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
||||||
|
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
||||||
|
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
||||||
|
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
||||||
|
l_g_total += comp_style_loss
|
||||||
|
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
||||||
|
|
||||||
|
# identity loss
|
||||||
|
if self.use_identity:
|
||||||
|
identity_weight = self.opt['train']['identity_weight']
|
||||||
|
# get gray images and resize
|
||||||
|
out_gray = self.gray_resize_for_identity(self.output)
|
||||||
|
gt_gray = self.gray_resize_for_identity(self.gt)
|
||||||
|
|
||||||
|
identity_gt = self.network_identity(gt_gray).detach()
|
||||||
|
identity_out = self.network_identity(out_gray)
|
||||||
|
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
||||||
|
l_g_total += l_identity
|
||||||
|
loss_dict['l_identity'] = l_identity
|
||||||
|
|
||||||
|
l_g_total.backward()
|
||||||
|
self.optimizer_g.step()
|
||||||
|
|
||||||
|
# EMA
|
||||||
|
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
||||||
|
|
||||||
|
# ----------- optimize net_d ----------- #
|
||||||
|
for p in self.net_d.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
self.optimizer_d.zero_grad()
|
||||||
|
if self.use_facial_disc:
|
||||||
|
for p in self.net_d_left_eye.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
for p in self.net_d_right_eye.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
for p in self.net_d_mouth.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
self.optimizer_d_left_eye.zero_grad()
|
||||||
|
self.optimizer_d_right_eye.zero_grad()
|
||||||
|
self.optimizer_d_mouth.zero_grad()
|
||||||
|
|
||||||
|
fake_d_pred = self.net_d(self.output.detach())
|
||||||
|
real_d_pred = self.net_d(self.gt)
|
||||||
|
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||||
|
loss_dict['l_d'] = l_d
|
||||||
|
# In wgan, real_score should be positive and fake_score should benegative
|
||||||
|
loss_dict['real_score'] = real_d_pred.detach().mean()
|
||||||
|
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
||||||
|
l_d.backward()
|
||||||
|
|
||||||
|
if current_iter % self.net_d_reg_every == 0:
|
||||||
|
self.gt.requires_grad = True
|
||||||
|
real_pred = self.net_d(self.gt)
|
||||||
|
l_d_r1 = r1_penalty(real_pred, self.gt)
|
||||||
|
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
||||||
|
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
||||||
|
l_d_r1.backward()
|
||||||
|
|
||||||
|
self.optimizer_d.step()
|
||||||
|
|
||||||
|
if self.use_facial_disc:
|
||||||
|
# lefe eye
|
||||||
|
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
||||||
|
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
||||||
|
l_d_left_eye = self.cri_component(
|
||||||
|
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||||
|
fake_d_pred, False, is_disc=True)
|
||||||
|
loss_dict['l_d_left_eye'] = l_d_left_eye
|
||||||
|
l_d_left_eye.backward()
|
||||||
|
# right eye
|
||||||
|
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
||||||
|
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
||||||
|
l_d_right_eye = self.cri_component(
|
||||||
|
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||||
|
fake_d_pred, False, is_disc=True)
|
||||||
|
loss_dict['l_d_right_eye'] = l_d_right_eye
|
||||||
|
l_d_right_eye.backward()
|
||||||
|
# mouth
|
||||||
|
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
||||||
|
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
||||||
|
l_d_mouth = self.cri_component(
|
||||||
|
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||||
|
fake_d_pred, False, is_disc=True)
|
||||||
|
loss_dict['l_d_mouth'] = l_d_mouth
|
||||||
|
l_d_mouth.backward()
|
||||||
|
|
||||||
|
self.optimizer_d_left_eye.step()
|
||||||
|
self.optimizer_d_right_eye.step()
|
||||||
|
self.optimizer_d_mouth.step()
|
||||||
|
|
||||||
|
self.log_dict = self.reduce_loss_dict(loss_dict)
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
if hasattr(self, 'net_g_ema'):
|
||||||
|
self.net_g_ema.eval()
|
||||||
|
self.output, _ = self.net_g_ema(self.lq)
|
||||||
|
else:
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
||||||
|
self.net_g.eval()
|
||||||
|
self.output, _ = self.net_g(self.lq)
|
||||||
|
self.net_g.train()
|
||||||
|
|
||||||
|
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||||
|
if self.opt['rank'] == 0:
|
||||||
|
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||||
|
|
||||||
|
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||||
|
dataset_name = dataloader.dataset.opt['name']
|
||||||
|
with_metrics = self.opt['val'].get('metrics') is not None
|
||||||
|
if with_metrics:
|
||||||
|
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||||
|
pbar = tqdm(total=len(dataloader), unit='image')
|
||||||
|
|
||||||
|
for idx, val_data in enumerate(dataloader):
|
||||||
|
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
||||||
|
self.feed_data(val_data)
|
||||||
|
self.test()
|
||||||
|
|
||||||
|
visuals = self.get_current_visuals()
|
||||||
|
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
|
||||||
|
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||||
|
|
||||||
|
if 'gt' in visuals:
|
||||||
|
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||||
|
del self.gt
|
||||||
|
# tentative for out of GPU memory
|
||||||
|
del self.lq
|
||||||
|
del self.output
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if save_img:
|
||||||
|
if self.opt['is_train']:
|
||||||
|
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
||||||
|
f'{img_name}_{current_iter}.png')
|
||||||
|
else:
|
||||||
|
if self.opt['val']['suffix']:
|
||||||
|
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
||||||
|
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
||||||
|
else:
|
||||||
|
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
||||||
|
f'{img_name}_{self.opt["name"]}.png')
|
||||||
|
imwrite(sr_img, save_img_path)
|
||||||
|
|
||||||
|
if with_metrics:
|
||||||
|
# calculate metrics
|
||||||
|
for name, opt_ in self.opt['val']['metrics'].items():
|
||||||
|
metric_data = dict(img1=sr_img, img2=gt_img)
|
||||||
|
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f'Test {img_name}')
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
if with_metrics:
|
||||||
|
for metric in self.metric_results.keys():
|
||||||
|
self.metric_results[metric] /= (idx + 1)
|
||||||
|
|
||||||
|
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
||||||
|
|
||||||
|
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
||||||
|
log_str = f'Validation {dataset_name}\n'
|
||||||
|
for metric, value in self.metric_results.items():
|
||||||
|
log_str += f'\t # {metric}: {value:.4f}\n'
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.info(log_str)
|
||||||
|
if tb_logger:
|
||||||
|
for metric, value in self.metric_results.items():
|
||||||
|
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
||||||
|
|
||||||
|
def get_current_visuals(self):
|
||||||
|
out_dict = OrderedDict()
|
||||||
|
out_dict['gt'] = self.gt.detach().cpu()
|
||||||
|
out_dict['sr'] = self.output.detach().cpu()
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def save(self, epoch, current_iter):
|
||||||
|
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
||||||
|
self.save_network(self.net_d, 'net_d', current_iter)
|
||||||
|
# save component discriminators
|
||||||
|
if self.use_facial_disc:
|
||||||
|
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
||||||
|
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
||||||
|
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
||||||
|
self.save_training_state(epoch, current_iter)
|
386
gfpganv1_arch.py
Normal file
386
gfpganv1_arch.py
Normal file
@ -0,0 +1,386 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||||
|
StyleGAN2Generator)
|
||||||
|
from basicsr.ops.fused_act import FusedLeakyReLU
|
||||||
|
|
||||||
|
|
||||||
|
class StyleGAN2GeneratorSFTV1(StyleGAN2Generator):
|
||||||
|
"""StyleGAN2 Generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_size (int): The spatial size of outputs.
|
||||||
|
num_style_feat (int): Channel number of style features. Default: 512.
|
||||||
|
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||||
|
channel_multiplier (int): Channel multiplier for large networks of
|
||||||
|
StyleGAN2. Default: 2.
|
||||||
|
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
||||||
|
magnitude. A cross production will be applied to extent 1D resample
|
||||||
|
kenrel to 2D resample kernel. Default: [1, 3, 3, 1].
|
||||||
|
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
out_size,
|
||||||
|
num_style_feat=512,
|
||||||
|
num_mlp=8,
|
||||||
|
channel_multiplier=2,
|
||||||
|
resample_kernel=[1, 3, 3, 1],
|
||||||
|
lr_mlp=0.01,
|
||||||
|
narrow=1,
|
||||||
|
sft_half=False):
|
||||||
|
super(StyleGAN2GeneratorSFTV1, self).__init__(
|
||||||
|
out_size,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
num_mlp=num_mlp,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
resample_kernel=resample_kernel,
|
||||||
|
lr_mlp=lr_mlp,
|
||||||
|
narrow=narrow)
|
||||||
|
self.sft_half = sft_half
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
styles,
|
||||||
|
conditions,
|
||||||
|
input_is_latent=False,
|
||||||
|
noise=None,
|
||||||
|
randomize_noise=True,
|
||||||
|
truncation=1,
|
||||||
|
truncation_latent=None,
|
||||||
|
inject_index=None,
|
||||||
|
return_latents=False):
|
||||||
|
"""Forward function for StyleGAN2Generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
styles (list[Tensor]): Sample codes of styles.
|
||||||
|
input_is_latent (bool): Whether input is latent style.
|
||||||
|
Default: False.
|
||||||
|
noise (Tensor | None): Input noise or None. Default: None.
|
||||||
|
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||||
|
False. Default: True.
|
||||||
|
truncation (float): TODO. Default: 1.
|
||||||
|
truncation_latent (Tensor | None): TODO. Default: None.
|
||||||
|
inject_index (int | None): The injection index for mixing noise.
|
||||||
|
Default: None.
|
||||||
|
return_latents (bool): Whether to return style latents.
|
||||||
|
Default: False.
|
||||||
|
"""
|
||||||
|
# style codes -> latents with Style MLP layer
|
||||||
|
if not input_is_latent:
|
||||||
|
styles = [self.style_mlp(s) for s in styles]
|
||||||
|
# noises
|
||||||
|
if noise is None:
|
||||||
|
if randomize_noise:
|
||||||
|
noise = [None] * self.num_layers # for each style conv layer
|
||||||
|
else: # use the stored noise
|
||||||
|
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||||
|
# style truncation
|
||||||
|
if truncation < 1:
|
||||||
|
style_truncation = []
|
||||||
|
for style in styles:
|
||||||
|
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||||
|
styles = style_truncation
|
||||||
|
# get style latent with injection
|
||||||
|
if len(styles) == 1:
|
||||||
|
inject_index = self.num_latent
|
||||||
|
|
||||||
|
if styles[0].ndim < 3:
|
||||||
|
# repeat latent code for all the layers
|
||||||
|
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
else: # used for encoder with different latent code for each layer
|
||||||
|
latent = styles[0]
|
||||||
|
elif len(styles) == 2: # mixing noises
|
||||||
|
if inject_index is None:
|
||||||
|
inject_index = random.randint(1, self.num_latent - 1)
|
||||||
|
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||||
|
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||||
|
latent = torch.cat([latent1, latent2], 1)
|
||||||
|
|
||||||
|
# main generation
|
||||||
|
out = self.constant_input(latent.shape[0])
|
||||||
|
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||||
|
skip = self.to_rgb1(out, latent[:, 1])
|
||||||
|
|
||||||
|
i = 1
|
||||||
|
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||||
|
noise[2::2], self.to_rgbs):
|
||||||
|
out = conv1(out, latent[:, i], noise=noise1)
|
||||||
|
|
||||||
|
# the conditions may have fewer levels
|
||||||
|
if i < len(conditions):
|
||||||
|
# SFT part to combine the conditions
|
||||||
|
if self.sft_half:
|
||||||
|
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||||
|
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||||
|
out = torch.cat([out_same, out_sft], dim=1)
|
||||||
|
else:
|
||||||
|
out = out * conditions[i - 1] + conditions[i]
|
||||||
|
|
||||||
|
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||||
|
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||||
|
i += 2
|
||||||
|
|
||||||
|
image = skip
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return image, latent
|
||||||
|
else:
|
||||||
|
return image, None
|
||||||
|
|
||||||
|
|
||||||
|
class ConvUpLayer(nn.Module):
|
||||||
|
"""Conv Up Layer. Bilinear upsample + Conv.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of the input.
|
||||||
|
out_channels (int): Channel number of the output.
|
||||||
|
kernel_size (int): Size of the convolving kernel.
|
||||||
|
stride (int): Stride of the convolution. Default: 1
|
||||||
|
padding (int): Zero-padding added to both sides of the input.
|
||||||
|
Default: 0.
|
||||||
|
bias (bool): If ``True``, adds a learnable bias to the output.
|
||||||
|
Default: ``True``.
|
||||||
|
bias_init_val (float): Bias initialized value. Default: 0.
|
||||||
|
activate (bool): Whether use activateion. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
bias_init_val=0,
|
||||||
|
activate=True):
|
||||||
|
super(ConvUpLayer, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||||
|
|
||||||
|
if bias and not activate:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
# activation
|
||||||
|
if activate:
|
||||||
|
if bias:
|
||||||
|
self.activation = FusedLeakyReLU(out_channels)
|
||||||
|
else:
|
||||||
|
self.activation = ScaledLeakyReLU(0.2)
|
||||||
|
else:
|
||||||
|
self.activation = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# bilinear upsample
|
||||||
|
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
# conv
|
||||||
|
out = F.conv2d(
|
||||||
|
out,
|
||||||
|
self.weight * self.scale,
|
||||||
|
bias=self.bias,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
# activation
|
||||||
|
if self.activation is not None:
|
||||||
|
out = self.activation(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResUpBlock(nn.Module):
|
||||||
|
"""Residual block with upsampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Channel number of the input.
|
||||||
|
out_channels (int): Channel number of the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(ResUpBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
||||||
|
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
||||||
|
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
skip = self.skip(x)
|
||||||
|
out = (out + skip) / math.sqrt(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GFPGANv1(nn.Module):
|
||||||
|
"""Unet + StyleGAN2 decoder with SFT."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_size,
|
||||||
|
num_style_feat=512,
|
||||||
|
channel_multiplier=1,
|
||||||
|
resample_kernel=[1, 3, 3, 1],
|
||||||
|
decoder_load_path=None,
|
||||||
|
fix_decoder=True,
|
||||||
|
# for stylegan decoder
|
||||||
|
num_mlp=8,
|
||||||
|
lr_mlp=0.01,
|
||||||
|
input_is_latent=False,
|
||||||
|
different_w=False,
|
||||||
|
narrow=1,
|
||||||
|
sft_half=False):
|
||||||
|
|
||||||
|
super(GFPGANv1, self).__init__()
|
||||||
|
self.input_is_latent = input_is_latent
|
||||||
|
self.different_w = different_w
|
||||||
|
self.num_style_feat = num_style_feat
|
||||||
|
|
||||||
|
unet_narrow = narrow * 0.5
|
||||||
|
channels = {
|
||||||
|
'4': int(512 * unet_narrow),
|
||||||
|
'8': int(512 * unet_narrow),
|
||||||
|
'16': int(512 * unet_narrow),
|
||||||
|
'32': int(512 * unet_narrow),
|
||||||
|
'64': int(256 * channel_multiplier * unet_narrow),
|
||||||
|
'128': int(128 * channel_multiplier * unet_narrow),
|
||||||
|
'256': int(64 * channel_multiplier * unet_narrow),
|
||||||
|
'512': int(32 * channel_multiplier * unet_narrow),
|
||||||
|
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.log_size = int(math.log(out_size, 2))
|
||||||
|
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||||
|
|
||||||
|
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
in_channels = channels[f'{first_out_size}']
|
||||||
|
self.conv_body_down = nn.ModuleList()
|
||||||
|
for i in range(self.log_size, 2, -1):
|
||||||
|
out_channels = channels[f'{2**(i - 1)}']
|
||||||
|
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
||||||
|
|
||||||
|
# upsample
|
||||||
|
in_channels = channels['4']
|
||||||
|
self.conv_body_up = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
out_channels = channels[f'{2**i}']
|
||||||
|
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
||||||
|
in_channels = out_channels
|
||||||
|
|
||||||
|
# to RGB
|
||||||
|
self.toRGB = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
||||||
|
|
||||||
|
if different_w:
|
||||||
|
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||||
|
else:
|
||||||
|
linear_out_channel = num_style_feat
|
||||||
|
|
||||||
|
self.final_linear = EqualLinear(
|
||||||
|
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||||
|
|
||||||
|
self.stylegan_decoder = StyleGAN2GeneratorSFTV1(
|
||||||
|
out_size=out_size,
|
||||||
|
num_style_feat=num_style_feat,
|
||||||
|
num_mlp=num_mlp,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
resample_kernel=resample_kernel,
|
||||||
|
lr_mlp=lr_mlp,
|
||||||
|
narrow=narrow,
|
||||||
|
sft_half=sft_half)
|
||||||
|
|
||||||
|
if decoder_load_path:
|
||||||
|
self.stylegan_decoder.load_state_dict(
|
||||||
|
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||||
|
if fix_decoder:
|
||||||
|
for name, param in self.stylegan_decoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# for SFT
|
||||||
|
self.condition_scale = nn.ModuleList()
|
||||||
|
self.condition_shift = nn.ModuleList()
|
||||||
|
for i in range(3, self.log_size + 1):
|
||||||
|
out_channels = channels[f'{2**i}']
|
||||||
|
if sft_half:
|
||||||
|
sft_out_channels = out_channels
|
||||||
|
else:
|
||||||
|
sft_out_channels = out_channels * 2
|
||||||
|
self.condition_scale.append(
|
||||||
|
nn.Sequential(
|
||||||
|
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||||
|
ScaledLeakyReLU(0.2),
|
||||||
|
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
||||||
|
self.condition_shift.append(
|
||||||
|
nn.Sequential(
|
||||||
|
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||||
|
ScaledLeakyReLU(0.2),
|
||||||
|
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
return_latents=False,
|
||||||
|
save_feat_path=None,
|
||||||
|
load_feat_path=None,
|
||||||
|
return_rgb=True,
|
||||||
|
randomize_noise=True):
|
||||||
|
conditions = []
|
||||||
|
unet_skips = []
|
||||||
|
out_rgbs = []
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
feat = self.conv_body_first(x)
|
||||||
|
for i in range(self.log_size - 2):
|
||||||
|
feat = self.conv_body_down[i](feat)
|
||||||
|
unet_skips.insert(0, feat)
|
||||||
|
|
||||||
|
feat = self.final_conv(feat)
|
||||||
|
|
||||||
|
# style code
|
||||||
|
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||||
|
if self.different_w:
|
||||||
|
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||||
|
|
||||||
|
# decode
|
||||||
|
for i in range(self.log_size - 2):
|
||||||
|
# add unet skip
|
||||||
|
feat = feat + unet_skips[i]
|
||||||
|
# ResUpLayer
|
||||||
|
feat = self.conv_body_up[i](feat)
|
||||||
|
# generate scale and shift for SFT layer
|
||||||
|
scale = self.condition_scale[i](feat)
|
||||||
|
conditions.append(scale.clone())
|
||||||
|
shift = self.condition_shift[i](feat)
|
||||||
|
conditions.append(shift.clone())
|
||||||
|
# generate rgb images
|
||||||
|
if return_rgb:
|
||||||
|
out_rgbs.append(self.toRGB[i](feat))
|
||||||
|
|
||||||
|
if save_feat_path is not None:
|
||||||
|
torch.save(conditions, save_feat_path)
|
||||||
|
if load_feat_path is not None:
|
||||||
|
conditions = torch.load(load_feat_path)
|
||||||
|
conditions = [v.cuda() for v in conditions]
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
image, _ = self.stylegan_decoder([style_code],
|
||||||
|
conditions,
|
||||||
|
return_latents=return_latents,
|
||||||
|
input_is_latent=self.input_is_latent,
|
||||||
|
randomize_noise=randomize_noise)
|
||||||
|
|
||||||
|
return image, out_rgbs
|
102
inference_gfpgan_full.py
Normal file
102
inference_gfpgan_full.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import argparse
|
||||||
|
import cv2
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
|
from basicsr.utils import img2tensor, imwrite, tensor2img
|
||||||
|
from gfpganv1_arch import GFPGANv1
|
||||||
|
|
||||||
|
|
||||||
|
def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=None):
|
||||||
|
# read image
|
||||||
|
img_name = os.path.basename(img_path)
|
||||||
|
print(f'Processing {img_name} ...')
|
||||||
|
basename, _ = os.path.splitext(img_name)
|
||||||
|
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||||
|
face_helper.clean_all()
|
||||||
|
|
||||||
|
if has_aligned:
|
||||||
|
input_img = cv2.resize(input_img, (512, 512))
|
||||||
|
face_helper.cropped_faces = [input_img]
|
||||||
|
else:
|
||||||
|
face_helper.read_image(input_img)
|
||||||
|
# get face landmarks for each face
|
||||||
|
face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
|
||||||
|
# align and warp each face
|
||||||
|
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
|
||||||
|
face_helper.align_warp_face(save_crop_path)
|
||||||
|
|
||||||
|
# face restoration
|
||||||
|
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||||
|
# prepare data
|
||||||
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = gfpgan(cropped_face_t, return_rgb=False)[0]
|
||||||
|
# convert to image
|
||||||
|
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||||
|
except RuntimeError as error:
|
||||||
|
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||||
|
restored_face = cropped_face
|
||||||
|
|
||||||
|
restored_face = restored_face.astype('uint8')
|
||||||
|
face_helper.add_restored_face(restored_face)
|
||||||
|
|
||||||
|
if suffix is not None:
|
||||||
|
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
|
||||||
|
else:
|
||||||
|
save_face_name = f'{basename}_{idx:02d}.png'
|
||||||
|
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
|
||||||
|
imwrite(restored_face, save_restore_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--upscale_factor', type=int, default=1)
|
||||||
|
parser.add_argument('--model_path', type=str, default='models/GFPGANv1.pth')
|
||||||
|
parser.add_argument('--test_path', type=str, default='inputs')
|
||||||
|
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||||
|
parser.add_argument('--only_center_face', action='store_true')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.test_path.endswith('/'):
|
||||||
|
args.test_path = args.test_path[:-1]
|
||||||
|
save_root = 'results/'
|
||||||
|
os.makedirs(save_root, exist_ok=True)
|
||||||
|
|
||||||
|
# initialize the GFP-GAN
|
||||||
|
gfpgan = GFPGANv1(
|
||||||
|
out_size=512,
|
||||||
|
num_style_feat=512,
|
||||||
|
channel_multiplier=1,
|
||||||
|
decoder_load_path=None,
|
||||||
|
fix_decoder=True,
|
||||||
|
# for stylegan decoder
|
||||||
|
num_mlp=8,
|
||||||
|
input_is_latent=True,
|
||||||
|
different_w=True,
|
||||||
|
narrow=1,
|
||||||
|
sft_half=True)
|
||||||
|
|
||||||
|
gfpgan.to(device)
|
||||||
|
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
|
||||||
|
gfpgan.load_state_dict(checkpoint['params_ema'])
|
||||||
|
gfpgan.eval()
|
||||||
|
|
||||||
|
# initialize face helper
|
||||||
|
face_helper = FaceRestoreHelper(
|
||||||
|
upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
|
||||||
|
|
||||||
|
# scan all the jpg and png images
|
||||||
|
img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
|
||||||
|
for img_path in img_list:
|
||||||
|
restoration(
|
||||||
|
gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=args.suffix)
|
22
setup.cfg
Normal file
22
setup.cfg
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
[flake8]
|
||||||
|
ignore =
|
||||||
|
# line break before binary operator (W503)
|
||||||
|
W503,
|
||||||
|
# line break after binary operator (W504)
|
||||||
|
W504,
|
||||||
|
max-line-length=120
|
||||||
|
|
||||||
|
[yapf]
|
||||||
|
based_on_style = pep8
|
||||||
|
column_limit = 120
|
||||||
|
blank_line_before_nested_class_or_def = true
|
||||||
|
split_before_expression_after_opening_paren = true
|
||||||
|
|
||||||
|
[isort]
|
||||||
|
line_length = 120
|
||||||
|
multi_line_output = 0
|
||||||
|
known_standard_library = pkg_resources,setuptools
|
||||||
|
known_first_party = basicsr
|
||||||
|
known_third_party = cv2,facexlib,torch,torchvision,tqdm
|
||||||
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
|
default_section = THIRDPARTY
|
Loading…
x
Reference in New Issue
Block a user