diff --git a/ffhq_degradation_dataset.py b/ffhq_degradation_dataset.py new file mode 100644 index 0000000..c092cc9 --- /dev/null +++ b/ffhq_degradation_dataset.py @@ -0,0 +1,213 @@ +import cv2 +import math +import numpy as np +import os.path as osp +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) + +from basicsr.data import degradations as degradations +from basicsr.data.data_util import paths_from_folder +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class FFHQDegradationDataset(data.Dataset): + + def __init__(self, opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['512'] + + self.crop_components = opt.get('crop_components', False) # facial components + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) + + if self.crop_components: + self.components_list = torch.load(opt.get('component_path')) + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + self.paths = paths_from_folder(self.gt_folder) + + # degradations + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + # to gray + self.gray_prob = opt.get('gray_prob') + + logger = get_root_logger() + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, ' + f'sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, ' + f'shift: {self.color_jitter_shift}') + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] + + # get coordinates + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + h, w, _ = img_gt.shape + + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = degradations.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) + # noise + if self.noise_range is not None: + img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) + + # round and clip + img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. + + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_lq, self.mean, self.std, inplace=True) + + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/gfpgan_model.py b/gfpgan_model.py index c61d835..f3c3551 100644 --- a/gfpgan_model.py +++ b/gfpgan_model.py @@ -12,8 +12,10 @@ from basicsr.losses.losses import r1_penalty from basicsr.metrics import calculate_metric from basicsr.models.base_model import BaseModel from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +@MODEL_REGISTRY.register() class GFPGANModel(BaseModel): """GFPGAN model for """ diff --git a/gfpganv1_arch.py b/gfpganv1_arch.py index 270da91..48c4840 100644 --- a/gfpganv1_arch.py +++ b/gfpganv1_arch.py @@ -7,9 +7,10 @@ from torch.nn import functional as F from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, StyleGAN2Generator) from basicsr.ops.fused_act import FusedLeakyReLU +from basicsr.utils.registry import ARCH_REGISTRY -class StyleGAN2GeneratorSFTV1(StyleGAN2Generator): +class StyleGAN2GeneratorSFT(StyleGAN2Generator): """StyleGAN2 Generator. Args: @@ -33,7 +34,7 @@ class StyleGAN2GeneratorSFTV1(StyleGAN2Generator): lr_mlp=0.01, narrow=1, sft_half=False): - super(StyleGAN2GeneratorSFTV1, self).__init__( + super(StyleGAN2GeneratorSFT, self).__init__( out_size, num_style_feat=num_style_feat, num_mlp=num_mlp, @@ -221,6 +222,7 @@ class ResUpBlock(nn.Module): return out +@ARCH_REGISTRY.register() class GFPGANv1(nn.Module): """Unet + StyleGAN2 decoder with SFT.""" @@ -294,7 +296,7 @@ class GFPGANv1(nn.Module): self.final_linear = EqualLinear( channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) - self.stylegan_decoder = StyleGAN2GeneratorSFTV1( + self.stylegan_decoder = StyleGAN2GeneratorSFT( out_size=out_size, num_style_feat=num_style_feat, num_mlp=num_mlp, @@ -384,3 +386,33 @@ class GFPGANv1(nn.Module): randomize_noise=randomize_noise) return image, out_rgbs + + +@ARCH_REGISTRY.register() +class FacialComponentDiscriminator(nn.Module): + + def __init__(self): + super(FacialComponentDiscriminator, self).__init__() + + self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) + + def forward(self, x, return_feats=False): + feat = self.conv1(x) + feat = self.conv3(self.conv2(feat)) + rlt_feats = [] + if return_feats: + rlt_feats.append(feat.clone()) + feat = self.conv5(self.conv4(feat)) + if return_feats: + rlt_feats.append(feat.clone()) + out = self.final_conv(feat) + + if return_feats: + return out, rlt_feats + else: + return out, None diff --git a/inference_gfpgan_full.py b/inference_gfpgan_full.py index 9fef567..54eb786 100644 --- a/inference_gfpgan_full.py +++ b/inference_gfpgan_full.py @@ -1,6 +1,7 @@ import argparse import cv2 import glob +import numpy as np import os import torch from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -55,6 +56,10 @@ def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, onl save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name) imwrite(restored_face, save_restore_path) + # save cmp image + cmp_img = np.concatenate((cropped_face, restored_face), axis=1) + imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png')) + if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -99,4 +104,10 @@ if __name__ == '__main__': img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))) for img_path in img_list: restoration( - gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=args.suffix) + gfpgan, + face_helper, + img_path, + save_root, + has_aligned=False, + only_center_face=args.only_center_face, + suffix=args.suffix) diff --git a/setup.cfg b/setup.cfg index 9788236..014da11 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,6 @@ line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr -known_third_party = cv2,facexlib,torch,torchvision,tqdm +known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/train.py b/train.py new file mode 100644 index 0000000..e4d52e9 --- /dev/null +++ b/train.py @@ -0,0 +1,10 @@ +import os.path as osp + +import ffhq_degradation_dataset # noqa: F401 +import gfpgan_model # noqa: F401 +import gfpganv1_arch # noqa: F401 +from basicsr.train import train_pipeline + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/train_gfpgan_v1.yml b/train_gfpgan_v1.yml new file mode 100644 index 0000000..d430571 --- /dev/null +++ b/train_gfpgan_v1.yml @@ -0,0 +1,199 @@ +# general settings +name: debug_train_GFPGANv1_512 +model_type: GFPGANModel +num_gpu: 4 +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: FFHQ + type: FFHQDegradationDataset + dataroot_gt: datasets/ffhq/ffhq_512.lmdb + io_backend: + type: lmdb + + use_hflip: true + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + out_size: 512 + + blur_kernel_size: 41 + kernel_list: ['iso', 'aniso'] + kernel_prob: [0.5, 0.5] + blur_sigma: [0.1, 10] + downsample_range: [0.8, 8] + noise_range: [0, 20] + jpeg_range: [60, 100] + + # color jitter and gray + color_jitter_prob: 0.3 + color_jitter_shift: 20 + color_jitter_pt_prob: 0.3 + gray_prob: 0.01 + + crop_components: true + component_path: models/FFHQ_eye_mouth_landmarks_512.pth + eye_enlarge_ratio: 1.4 + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 3 + dataset_enlarge_ratio: 100 + prefetch_mode: ~ + + val: + name: validation0930real_512 + type: PairedImageDataset + dataroot_lq: datasets/faces/validation0930real_512/input # TODO + dataroot_gt: datasets/faces/validation0930real_512/input + io_backend: + type: disk + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + scale: 1 + +# network structures +network_g: + type: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 1 + sft_half: true + +network_d: + type: StyleGAN2Discriminator + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + +network_d_left_eye: + type: FacialComponentDiscriminator + +network_d_right_eye: + type: FacialComponentDiscriminator + +network_d_mouth: + type: FacialComponentDiscriminator + +network_identity: + type: ResNetArcFace + block: IRBlock + layers: [2, 2, 2, 2] + use_se: False + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + + resume_state: ~ + pretrain_network_d_left_eye: ~ + pretrain_network_d_right_eye: ~ + pretrain_network_d_mouth: ~ + pretrain_network_arcface: models/arcface_resnet18.pth + +# training settings +train: + optim_g: + type: Adam + lr: !!float 2e-3 + optim_d: + type: Adam + lr: !!float 2e-3 + optim_component: + type: Adam + lr: !!float 2e-3 + + scheduler: + type: MultiStepLR + milestones: [600000, 700000] + gamma: 0.5 + + total_iter: 800000 + warmup_iter: -1 # no warm up + + # losses + # pixel loss + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-1 + reduction: mean + # image pyramid loss + pyramid_loss_weight: 0 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + # facial component loss + gan_component_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1 + comp_style_weight: 200 + # identity loss + identity_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 16 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 + +find_unused_parameters: true