mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-18 08:10:16 -07:00
add models and archs
This commit is contained in:
parent
043dc22027
commit
110be40ff4
213
ffhq_degradation_dataset.py
Normal file
213
ffhq_degradation_dataset.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os.path as osp
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
||||||
|
normalize)
|
||||||
|
|
||||||
|
from basicsr.data import degradations as degradations
|
||||||
|
from basicsr.data.data_util import paths_from_folder
|
||||||
|
from basicsr.data.transforms import augment
|
||||||
|
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||||
|
from basicsr.utils.registry import DATASET_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class FFHQDegradationDataset(data.Dataset):
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
super(FFHQDegradationDataset, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
# file client (io backend)
|
||||||
|
self.file_client = None
|
||||||
|
self.io_backend_opt = opt['io_backend']
|
||||||
|
|
||||||
|
self.gt_folder = opt['dataroot_gt']
|
||||||
|
self.mean = opt['mean']
|
||||||
|
self.std = opt['std']
|
||||||
|
self.out_size = opt['512']
|
||||||
|
|
||||||
|
self.crop_components = opt.get('crop_components', False) # facial components
|
||||||
|
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
|
||||||
|
|
||||||
|
if self.crop_components:
|
||||||
|
self.components_list = torch.load(opt.get('component_path'))
|
||||||
|
|
||||||
|
if self.io_backend_opt['type'] == 'lmdb':
|
||||||
|
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||||
|
if not self.gt_folder.endswith('.lmdb'):
|
||||||
|
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||||
|
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||||
|
self.paths = [line.split('.')[0] for line in fin]
|
||||||
|
else:
|
||||||
|
self.paths = paths_from_folder(self.gt_folder)
|
||||||
|
|
||||||
|
# degradations
|
||||||
|
self.blur_kernel_size = opt['blur_kernel_size']
|
||||||
|
self.kernel_list = opt['kernel_list']
|
||||||
|
self.kernel_prob = opt['kernel_prob']
|
||||||
|
self.blur_sigma = opt['blur_sigma']
|
||||||
|
self.downsample_range = opt['downsample_range']
|
||||||
|
self.noise_range = opt['noise_range']
|
||||||
|
self.jpeg_range = opt['jpeg_range']
|
||||||
|
|
||||||
|
# color jitter
|
||||||
|
self.color_jitter_prob = opt.get('color_jitter_prob')
|
||||||
|
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
||||||
|
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
||||||
|
# to gray
|
||||||
|
self.gray_prob = opt.get('gray_prob')
|
||||||
|
|
||||||
|
logger = get_root_logger()
|
||||||
|
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
|
||||||
|
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||||
|
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
||||||
|
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
||||||
|
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
||||||
|
|
||||||
|
if self.color_jitter_prob is not None:
|
||||||
|
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
|
||||||
|
f'shift: {self.color_jitter_shift}')
|
||||||
|
if self.gray_prob is not None:
|
||||||
|
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
||||||
|
|
||||||
|
self.color_jitter_shift /= 255.
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def color_jitter(img, shift):
|
||||||
|
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
||||||
|
img = img + jitter_val
|
||||||
|
img = np.clip(img, 0, 1)
|
||||||
|
return img
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
||||||
|
fn_idx = torch.randperm(4)
|
||||||
|
for fn_id in fn_idx:
|
||||||
|
if fn_id == 0 and brightness is not None:
|
||||||
|
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
||||||
|
img = adjust_brightness(img, brightness_factor)
|
||||||
|
|
||||||
|
if fn_id == 1 and contrast is not None:
|
||||||
|
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
||||||
|
img = adjust_contrast(img, contrast_factor)
|
||||||
|
|
||||||
|
if fn_id == 2 and saturation is not None:
|
||||||
|
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
||||||
|
img = adjust_saturation(img, saturation_factor)
|
||||||
|
|
||||||
|
if fn_id == 3 and hue is not None:
|
||||||
|
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
||||||
|
img = adjust_hue(img, hue_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_component_coordinates(self, index, status):
|
||||||
|
components_bbox = self.components_list[f'{index:08d}']
|
||||||
|
if status[0]: # hflip
|
||||||
|
# exchange right and left eye
|
||||||
|
tmp = components_bbox['left_eye']
|
||||||
|
components_bbox['left_eye'] = components_bbox['right_eye']
|
||||||
|
components_bbox['right_eye'] = tmp
|
||||||
|
# modify the width coordinate
|
||||||
|
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
||||||
|
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
||||||
|
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
||||||
|
|
||||||
|
# get coordinates
|
||||||
|
locations = []
|
||||||
|
for part in ['left_eye', 'right_eye', 'mouth']:
|
||||||
|
mean = components_bbox[part][0:2]
|
||||||
|
half_len = components_bbox[part][2]
|
||||||
|
if 'eye' in part:
|
||||||
|
half_len *= self.eye_enlarge_ratio
|
||||||
|
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
||||||
|
loc = torch.from_numpy(loc).float()
|
||||||
|
locations.append(loc)
|
||||||
|
return locations
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.file_client is None:
|
||||||
|
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||||
|
|
||||||
|
# load gt image
|
||||||
|
gt_path = self.paths[index]
|
||||||
|
img_bytes = self.file_client.get(gt_path)
|
||||||
|
img_gt = imfrombytes(img_bytes, float32=True)
|
||||||
|
|
||||||
|
# random horizontal flip
|
||||||
|
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
||||||
|
h, w, _ = img_gt.shape
|
||||||
|
|
||||||
|
if self.crop_components:
|
||||||
|
locations = self.get_component_coordinates(index, status)
|
||||||
|
loc_left_eye, loc_right_eye, loc_mouth = locations
|
||||||
|
|
||||||
|
# ------------------------ generate lq image ------------------------ #
|
||||||
|
# blur
|
||||||
|
kernel = degradations.random_mixed_kernels(
|
||||||
|
self.kernel_list,
|
||||||
|
self.kernel_prob,
|
||||||
|
self.blur_kernel_size,
|
||||||
|
self.blur_sigma,
|
||||||
|
self.blur_sigma, [-math.pi, math.pi],
|
||||||
|
noise_range=None)
|
||||||
|
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
||||||
|
# downsample
|
||||||
|
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
||||||
|
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
# noise
|
||||||
|
if self.noise_range is not None:
|
||||||
|
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
||||||
|
# jpeg compression
|
||||||
|
if self.jpeg_range is not None:
|
||||||
|
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
||||||
|
|
||||||
|
# resize to original size
|
||||||
|
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# random color jitter (only for lq)
|
||||||
|
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
||||||
|
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
||||||
|
# random to gray (only for lq)
|
||||||
|
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
||||||
|
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
||||||
|
if self.opt.get('gt_gray'):
|
||||||
|
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
||||||
|
|
||||||
|
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||||
|
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||||
|
|
||||||
|
# random color jitter (pytorch version) (only for lq)
|
||||||
|
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
||||||
|
brightness = self.opt.get('brightness', (0.5, 1.5))
|
||||||
|
contrast = self.opt.get('contrast', (0.5, 1.5))
|
||||||
|
saturation = self.opt.get('saturation', (0, 1.5))
|
||||||
|
hue = self.opt.get('hue', (-0.1, 0.1))
|
||||||
|
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
||||||
|
|
||||||
|
# round and clip
|
||||||
|
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||||
|
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||||
|
|
||||||
|
if self.crop_components:
|
||||||
|
return_dict = {
|
||||||
|
'lq': img_lq,
|
||||||
|
'gt': img_gt,
|
||||||
|
'gt_path': gt_path,
|
||||||
|
'loc_left_eye': loc_left_eye,
|
||||||
|
'loc_right_eye': loc_right_eye,
|
||||||
|
'loc_mouth': loc_mouth
|
||||||
|
}
|
||||||
|
return return_dict
|
||||||
|
else:
|
||||||
|
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
@ -12,8 +12,10 @@ from basicsr.losses.losses import r1_penalty
|
|||||||
from basicsr.metrics import calculate_metric
|
from basicsr.metrics import calculate_metric
|
||||||
from basicsr.models.base_model import BaseModel
|
from basicsr.models.base_model import BaseModel
|
||||||
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
||||||
|
from basicsr.utils.registry import MODEL_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@MODEL_REGISTRY.register()
|
||||||
class GFPGANModel(BaseModel):
|
class GFPGANModel(BaseModel):
|
||||||
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
||||||
|
|
||||||
|
@ -7,9 +7,10 @@ from torch.nn import functional as F
|
|||||||
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||||
StyleGAN2Generator)
|
StyleGAN2Generator)
|
||||||
from basicsr.ops.fused_act import FusedLeakyReLU
|
from basicsr.ops.fused_act import FusedLeakyReLU
|
||||||
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
class StyleGAN2GeneratorSFTV1(StyleGAN2Generator):
|
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||||
"""StyleGAN2 Generator.
|
"""StyleGAN2 Generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -33,7 +34,7 @@ class StyleGAN2GeneratorSFTV1(StyleGAN2Generator):
|
|||||||
lr_mlp=0.01,
|
lr_mlp=0.01,
|
||||||
narrow=1,
|
narrow=1,
|
||||||
sft_half=False):
|
sft_half=False):
|
||||||
super(StyleGAN2GeneratorSFTV1, self).__init__(
|
super(StyleGAN2GeneratorSFT, self).__init__(
|
||||||
out_size,
|
out_size,
|
||||||
num_style_feat=num_style_feat,
|
num_style_feat=num_style_feat,
|
||||||
num_mlp=num_mlp,
|
num_mlp=num_mlp,
|
||||||
@ -221,6 +222,7 @@ class ResUpBlock(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ARCH_REGISTRY.register()
|
||||||
class GFPGANv1(nn.Module):
|
class GFPGANv1(nn.Module):
|
||||||
"""Unet + StyleGAN2 decoder with SFT."""
|
"""Unet + StyleGAN2 decoder with SFT."""
|
||||||
|
|
||||||
@ -294,7 +296,7 @@ class GFPGANv1(nn.Module):
|
|||||||
self.final_linear = EqualLinear(
|
self.final_linear = EqualLinear(
|
||||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||||
|
|
||||||
self.stylegan_decoder = StyleGAN2GeneratorSFTV1(
|
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
||||||
out_size=out_size,
|
out_size=out_size,
|
||||||
num_style_feat=num_style_feat,
|
num_style_feat=num_style_feat,
|
||||||
num_mlp=num_mlp,
|
num_mlp=num_mlp,
|
||||||
@ -384,3 +386,33 @@ class GFPGANv1(nn.Module):
|
|||||||
randomize_noise=randomize_noise)
|
randomize_noise=randomize_noise)
|
||||||
|
|
||||||
return image, out_rgbs
|
return image, out_rgbs
|
||||||
|
|
||||||
|
|
||||||
|
@ARCH_REGISTRY.register()
|
||||||
|
class FacialComponentDiscriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(FacialComponentDiscriminator, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||||
|
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||||
|
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||||
|
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||||
|
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||||
|
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
||||||
|
|
||||||
|
def forward(self, x, return_feats=False):
|
||||||
|
feat = self.conv1(x)
|
||||||
|
feat = self.conv3(self.conv2(feat))
|
||||||
|
rlt_feats = []
|
||||||
|
if return_feats:
|
||||||
|
rlt_feats.append(feat.clone())
|
||||||
|
feat = self.conv5(self.conv4(feat))
|
||||||
|
if return_feats:
|
||||||
|
rlt_feats.append(feat.clone())
|
||||||
|
out = self.final_conv(feat)
|
||||||
|
|
||||||
|
if return_feats:
|
||||||
|
return out, rlt_feats
|
||||||
|
else:
|
||||||
|
return out, None
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
import glob
|
import glob
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
@ -55,6 +56,10 @@ def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, onl
|
|||||||
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
|
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
|
||||||
imwrite(restored_face, save_restore_path)
|
imwrite(restored_face, save_restore_path)
|
||||||
|
|
||||||
|
# save cmp image
|
||||||
|
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
||||||
|
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -99,4 +104,10 @@ if __name__ == '__main__':
|
|||||||
img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
|
img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')))
|
||||||
for img_path in img_list:
|
for img_path in img_list:
|
||||||
restoration(
|
restoration(
|
||||||
gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=args.suffix)
|
gfpgan,
|
||||||
|
face_helper,
|
||||||
|
img_path,
|
||||||
|
save_root,
|
||||||
|
has_aligned=False,
|
||||||
|
only_center_face=args.only_center_face,
|
||||||
|
suffix=args.suffix)
|
||||||
|
@ -17,6 +17,6 @@ line_length = 120
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = pkg_resources,setuptools
|
known_standard_library = pkg_resources,setuptools
|
||||||
known_first_party = basicsr
|
known_first_party = basicsr
|
||||||
known_third_party = cv2,facexlib,torch,torchvision,tqdm
|
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
10
train.py
Normal file
10
train.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import ffhq_degradation_dataset # noqa: F401
|
||||||
|
import gfpgan_model # noqa: F401
|
||||||
|
import gfpganv1_arch # noqa: F401
|
||||||
|
from basicsr.train import train_pipeline
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||||
|
train_pipeline(root_path)
|
199
train_gfpgan_v1.yml
Normal file
199
train_gfpgan_v1.yml
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# general settings
|
||||||
|
name: debug_train_GFPGANv1_512
|
||||||
|
model_type: GFPGANModel
|
||||||
|
num_gpu: 4
|
||||||
|
manual_seed: 0
|
||||||
|
|
||||||
|
# dataset and data loader settings
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: FFHQ
|
||||||
|
type: FFHQDegradationDataset
|
||||||
|
dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
||||||
|
io_backend:
|
||||||
|
type: lmdb
|
||||||
|
|
||||||
|
use_hflip: true
|
||||||
|
mean: [0.5, 0.5, 0.5]
|
||||||
|
std: [0.5, 0.5, 0.5]
|
||||||
|
out_size: 512
|
||||||
|
|
||||||
|
blur_kernel_size: 41
|
||||||
|
kernel_list: ['iso', 'aniso']
|
||||||
|
kernel_prob: [0.5, 0.5]
|
||||||
|
blur_sigma: [0.1, 10]
|
||||||
|
downsample_range: [0.8, 8]
|
||||||
|
noise_range: [0, 20]
|
||||||
|
jpeg_range: [60, 100]
|
||||||
|
|
||||||
|
# color jitter and gray
|
||||||
|
color_jitter_prob: 0.3
|
||||||
|
color_jitter_shift: 20
|
||||||
|
color_jitter_pt_prob: 0.3
|
||||||
|
gray_prob: 0.01
|
||||||
|
|
||||||
|
crop_components: true
|
||||||
|
component_path: models/FFHQ_eye_mouth_landmarks_512.pth
|
||||||
|
eye_enlarge_ratio: 1.4
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
use_shuffle: true
|
||||||
|
num_worker_per_gpu: 6
|
||||||
|
batch_size_per_gpu: 3
|
||||||
|
dataset_enlarge_ratio: 100
|
||||||
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
val:
|
||||||
|
name: validation0930real_512
|
||||||
|
type: PairedImageDataset
|
||||||
|
dataroot_lq: datasets/faces/validation0930real_512/input # TODO
|
||||||
|
dataroot_gt: datasets/faces/validation0930real_512/input
|
||||||
|
io_backend:
|
||||||
|
type: disk
|
||||||
|
mean: [0.5, 0.5, 0.5]
|
||||||
|
std: [0.5, 0.5, 0.5]
|
||||||
|
scale: 1
|
||||||
|
|
||||||
|
# network structures
|
||||||
|
network_g:
|
||||||
|
type: GFPGANv1
|
||||||
|
out_size: 512
|
||||||
|
num_style_feat: 512
|
||||||
|
channel_multiplier: 1
|
||||||
|
resample_kernel: [1, 3, 3, 1]
|
||||||
|
decoder_load_path: models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
|
||||||
|
fix_decoder: true
|
||||||
|
num_mlp: 8
|
||||||
|
lr_mlp: 0.01
|
||||||
|
input_is_latent: true
|
||||||
|
different_w: true
|
||||||
|
narrow: 1
|
||||||
|
sft_half: true
|
||||||
|
|
||||||
|
network_d:
|
||||||
|
type: StyleGAN2Discriminator
|
||||||
|
out_size: 512
|
||||||
|
channel_multiplier: 1
|
||||||
|
resample_kernel: [1, 3, 3, 1]
|
||||||
|
|
||||||
|
network_d_left_eye:
|
||||||
|
type: FacialComponentDiscriminator
|
||||||
|
|
||||||
|
network_d_right_eye:
|
||||||
|
type: FacialComponentDiscriminator
|
||||||
|
|
||||||
|
network_d_mouth:
|
||||||
|
type: FacialComponentDiscriminator
|
||||||
|
|
||||||
|
network_identity:
|
||||||
|
type: ResNetArcFace
|
||||||
|
block: IRBlock
|
||||||
|
layers: [2, 2, 2, 2]
|
||||||
|
use_se: False
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
pretrain_network_g: ~
|
||||||
|
param_key_g: params_ema
|
||||||
|
strict_load_g: ~
|
||||||
|
pretrain_network_d: ~
|
||||||
|
|
||||||
|
resume_state: ~
|
||||||
|
pretrain_network_d_left_eye: ~
|
||||||
|
pretrain_network_d_right_eye: ~
|
||||||
|
pretrain_network_d_mouth: ~
|
||||||
|
pretrain_network_arcface: models/arcface_resnet18.pth
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
train:
|
||||||
|
optim_g:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 2e-3
|
||||||
|
optim_d:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 2e-3
|
||||||
|
optim_component:
|
||||||
|
type: Adam
|
||||||
|
lr: !!float 2e-3
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
type: MultiStepLR
|
||||||
|
milestones: [600000, 700000]
|
||||||
|
gamma: 0.5
|
||||||
|
|
||||||
|
total_iter: 800000
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
|
||||||
|
# losses
|
||||||
|
# pixel loss
|
||||||
|
pixel_opt:
|
||||||
|
type: L1Loss
|
||||||
|
loss_weight: !!float 1e-1
|
||||||
|
reduction: mean
|
||||||
|
# image pyramid loss
|
||||||
|
pyramid_loss_weight: 0
|
||||||
|
remove_pyramid_loss: 50000
|
||||||
|
# perceptual loss (content and style losses)
|
||||||
|
perceptual_opt:
|
||||||
|
type: PerceptualLoss
|
||||||
|
layer_weights:
|
||||||
|
# before relu
|
||||||
|
'conv1_2': 0.1
|
||||||
|
'conv2_2': 0.1
|
||||||
|
'conv3_4': 1
|
||||||
|
'conv4_4': 1
|
||||||
|
'conv5_4': 1
|
||||||
|
vgg_type: vgg19
|
||||||
|
use_input_norm: true
|
||||||
|
perceptual_weight: !!float 1
|
||||||
|
style_weight: 50
|
||||||
|
range_norm: true
|
||||||
|
criterion: l1
|
||||||
|
# gan loss
|
||||||
|
gan_opt:
|
||||||
|
type: GANLoss
|
||||||
|
gan_type: wgan_softplus
|
||||||
|
loss_weight: !!float 1e-1
|
||||||
|
# r1 regularization for discriminator
|
||||||
|
r1_reg_weight: 10
|
||||||
|
# facial component loss
|
||||||
|
gan_component_opt:
|
||||||
|
type: GANLoss
|
||||||
|
gan_type: vanilla
|
||||||
|
real_label_val: 1.0
|
||||||
|
fake_label_val: 0.0
|
||||||
|
loss_weight: !!float 1
|
||||||
|
comp_style_weight: 200
|
||||||
|
# identity loss
|
||||||
|
identity_weight: 10
|
||||||
|
|
||||||
|
net_d_iters: 1
|
||||||
|
net_d_init_iters: 0
|
||||||
|
net_d_reg_every: 16
|
||||||
|
|
||||||
|
# validation settings
|
||||||
|
val:
|
||||||
|
val_freq: !!float 5e3
|
||||||
|
save_img: true
|
||||||
|
|
||||||
|
metrics:
|
||||||
|
psnr: # metric name, can be arbitrary
|
||||||
|
type: calculate_psnr
|
||||||
|
crop_border: 0
|
||||||
|
test_y_channel: false
|
||||||
|
|
||||||
|
# logging settings
|
||||||
|
logger:
|
||||||
|
print_freq: 100
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb:
|
||||||
|
project: ~
|
||||||
|
resume_id: ~
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
||||||
|
|
||||||
|
find_unused_parameters: true
|
Loading…
x
Reference in New Issue
Block a user