1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-18 00:00:15 -07:00

add models and archs

This commit is contained in:
Xintao 2021-05-18 14:25:43 +08:00
parent 043dc22027
commit 110be40ff4
7 changed files with 472 additions and 5 deletions

213
ffhq_degradation_dataset.py Normal file
View File

@ -0,0 +1,213 @@
import cv2
import math
import numpy as np
import os.path as osp
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
self.out_size = opt['512']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
if self.crop_components:
self.components_list = torch.load(opt.get('component_path'))
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
self.paths = paths_from_folder(self.gt_folder)
# degradations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.blur_sigma = opt['blur_sigma']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob')
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
# to gray
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
f'shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_coordinates(self, index, status):
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
# get coordinates
locations = []
for part in ['left_eye', 'right_eye', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations.append(loc)
return locations
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
h, w, _ = img_gt.shape
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
# ------------------------ generate lq image ------------------------ #
# blur
kernel = degradations.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# random color jitter (only for lq)
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
# random to gray (only for lq)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'):
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
# round and clip
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
normalize(img_lq, self.mean, self.std, inplace=True)
if self.crop_components:
return_dict = {
'lq': img_lq,
'gt': img_gt,
'gt_path': gt_path,
'loc_left_eye': loc_left_eye,
'loc_right_eye': loc_right_eye,
'loc_mouth': loc_mouth
}
return return_dict
else:
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)

View File

@ -12,8 +12,10 @@ from basicsr.losses.losses import r1_penalty
from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""

View File

@ -7,9 +7,10 @@ from torch.nn import functional as F
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY
class StyleGAN2GeneratorSFTV1(StyleGAN2Generator):
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
"""StyleGAN2 Generator.
Args:
@ -33,7 +34,7 @@ class StyleGAN2GeneratorSFTV1(StyleGAN2Generator):
lr_mlp=0.01,
narrow=1,
sft_half=False):
super(StyleGAN2GeneratorSFTV1, self).__init__(
super(StyleGAN2GeneratorSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
@ -221,6 +222,7 @@ class ResUpBlock(nn.Module):
return out
@ARCH_REGISTRY.register()
class GFPGANv1(nn.Module):
"""Unet + StyleGAN2 decoder with SFT."""
@ -294,7 +296,7 @@ class GFPGANv1(nn.Module):
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
self.stylegan_decoder = StyleGAN2GeneratorSFTV1(
self.stylegan_decoder = StyleGAN2GeneratorSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
@ -384,3 +386,33 @@ class GFPGANv1(nn.Module):
randomize_noise=randomize_noise)
return image, out_rgbs
@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []
if return_feats:
rlt_feats.append(feat.clone())
feat = self.conv5(self.conv4(feat))
if return_feats:
rlt_feats.append(feat.clone())
out = self.final_conv(feat)
if return_feats:
return out, rlt_feats
else:
return out, None

View File

@ -1,6 +1,7 @@
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -55,6 +56,10 @@ def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, onl
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -99,4 +104,10 @@ if __name__ == '__main__':
img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
for img_path in img_list:
restoration(
gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=args.suffix)
gfpgan,
face_helper,
img_path,
save_root,
has_aligned=False,
only_center_face=args.only_center_face,
suffix=args.suffix)

View File

@ -17,6 +17,6 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr
known_third_party = cv2,facexlib,torch,torchvision,tqdm
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

10
train.py Normal file
View File

@ -0,0 +1,10 @@
import os.path as osp
import ffhq_degradation_dataset # noqa: F401
import gfpgan_model # noqa: F401
import gfpganv1_arch # noqa: F401
from basicsr.train import train_pipeline
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)

199
train_gfpgan_v1.yml Normal file
View File

@ -0,0 +1,199 @@
# general settings
name: debug_train_GFPGANv1_512
model_type: GFPGANModel
num_gpu: 4
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: FFHQ
type: FFHQDegradationDataset
dataroot_gt: datasets/ffhq/ffhq_512.lmdb
io_backend:
type: lmdb
use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
color_jitter_prob: 0.3
color_jitter_shift: 20
color_jitter_pt_prob: 0.3
gray_prob: 0.01
crop_components: true
component_path: models/FFHQ_eye_mouth_landmarks_512.pth
eye_enlarge_ratio: 1.4
# data loader
use_shuffle: true
num_worker_per_gpu: 6
batch_size_per_gpu: 3
dataset_enlarge_ratio: 100
prefetch_mode: ~
val:
name: validation0930real_512
type: PairedImageDataset
dataroot_lq: datasets/faces/validation0930real_512/input # TODO
dataroot_gt: datasets/faces/validation0930real_512/input
io_backend:
type: disk
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
scale: 1
# network structures
network_g:
type: GFPGANv1
out_size: 512
num_style_feat: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
decoder_load_path: models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
fix_decoder: true
num_mlp: 8
lr_mlp: 0.01
input_is_latent: true
different_w: true
narrow: 1
sft_half: true
network_d:
type: StyleGAN2Discriminator
out_size: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
network_d_left_eye:
type: FacialComponentDiscriminator
network_d_right_eye:
type: FacialComponentDiscriminator
network_d_mouth:
type: FacialComponentDiscriminator
network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
resume_state: ~
pretrain_network_d_left_eye: ~
pretrain_network_d_right_eye: ~
pretrain_network_d_mouth: ~
pretrain_network_arcface: models/arcface_resnet18.pth
# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-3
optim_d:
type: Adam
lr: !!float 2e-3
optim_component:
type: Adam
lr: !!float 2e-3
scheduler:
type: MultiStepLR
milestones: [600000, 700000]
gamma: 0.5
total_iter: 800000
warmup_iter: -1 # no warm up
# losses
# pixel loss
pixel_opt:
type: L1Loss
loss_weight: !!float 1e-1
reduction: mean
# image pyramid loss
pyramid_loss_weight: 0
remove_pyramid_loss: 50000
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1
style_weight: 50
range_norm: true
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: wgan_softplus
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
gan_component_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1
comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
net_d_reg_every: 16
# validation settings
val:
val_freq: !!float 5e3
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500
find_unused_parameters: true