diff --git a/archs/__init__.py b/archs/__init__.py new file mode 100644 index 0000000..b7a5b1c --- /dev/null +++ b/archs/__init__.py @@ -0,0 +1,12 @@ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] diff --git a/archs/arcface_arch.py b/archs/arcface_arch.py new file mode 100644 index 0000000..a2b41be --- /dev/null +++ b/archs/arcface_arch.py @@ -0,0 +1,198 @@ +import torch.nn as nn + +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x diff --git a/gfpganv1_arch.py b/archs/gfpganv1_arch.py similarity index 100% rename from gfpganv1_arch.py rename to archs/gfpganv1_arch.py diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..ccb0981 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,11 @@ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] diff --git a/ffhq_degradation_dataset.py b/data/ffhq_degradation_dataset.py similarity index 99% rename from ffhq_degradation_dataset.py rename to data/ffhq_degradation_dataset.py index c092cc9..da66ff7 100644 --- a/ffhq_degradation_dataset.py +++ b/data/ffhq_degradation_dataset.py @@ -27,7 +27,7 @@ class FFHQDegradationDataset(data.Dataset): self.gt_folder = opt['dataroot_gt'] self.mean = opt['mean'] self.std = opt['std'] - self.out_size = opt['512'] + self.out_size = opt['out_size'] self.crop_components = opt.get('crop_components', False) # facial components self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) diff --git a/inference_gfpgan_full.py b/inference_gfpgan_full.py index 54eb786..06b1d3f 100644 --- a/inference_gfpgan_full.py +++ b/inference_gfpgan_full.py @@ -7,8 +7,8 @@ import torch from facexlib.utils.face_restoration_helper import FaceRestoreHelper from torchvision.transforms.functional import normalize +from archs.gfpganv1_arch import GFPGANv1 from basicsr.utils import img2tensor, imwrite, tensor2img -from gfpganv1_arch import GFPGANv1 def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=None): @@ -66,7 +66,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--upscale_factor', type=int, default=1) - parser.add_argument('--model_path', type=str, default='models/GFPGANv1.pth') + parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth') parser.add_argument('--test_path', type=str, default='inputs') parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') parser.add_argument('--only_center_face', action='store_true') diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..904bd99 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,12 @@ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] diff --git a/gfpgan_model.py b/models/gfpgan_model.py similarity index 97% rename from gfpgan_model.py rename to models/gfpgan_model.py index f3c3551..4b23c58 100644 --- a/gfpgan_model.py +++ b/models/gfpgan_model.py @@ -21,6 +21,7 @@ class GFPGANModel(BaseModel): def __init__(self, opt): super(GFPGANModel, self).__init__(opt) + self.idx = 0 # define network self.net_g = build_network(opt['network_g']) @@ -112,6 +113,9 @@ class GFPGANModel(BaseModel): else: self.cri_perceptual = None + # L1 loss used in pyramid loss, component style loss and identity loss + self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) + # gan loss (wgan) self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) @@ -198,7 +202,18 @@ class GFPGANModel(BaseModel): if 'gt' in data: self.gt = data['gt'].to(self.device) - if self.use_facial_disc: + import torchvision + if self.opt['rank'] == 0: + import os + os.makedirs('tmp/gt', exist_ok=True) + os.makedirs('tmp/lq', exist_ok=True) + print(self.idx) + torchvision.utils.save_image( + self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + torchvision.utils.save_image( + self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + + if 'loc_left_eye' in data: # get facial component locations, shape (batch, 4) self.loc_left_eyes = data['loc_left_eye'] self.loc_right_eyes = data['loc_right_eye'] diff --git a/train.py b/train.py index e4d52e9..b03e0f3 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,10 @@ import os.path as osp -import ffhq_degradation_dataset # noqa: F401 -import gfpgan_model # noqa: F401 -import gfpganv1_arch # noqa: F401 +import archs # noqa: F401 +import data # noqa: F401 +import models # noqa: F401 from basicsr.train import train_pipeline if __name__ == '__main__': - root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + root_path = osp.abspath(osp.join(__file__, osp.pardir)) train_pipeline(root_path) diff --git a/train_gfpgan_v1.yml b/train_gfpgan_v1.yml index d430571..b34704c 100644 --- a/train_gfpgan_v1.yml +++ b/train_gfpgan_v1.yml @@ -33,7 +33,7 @@ datasets: gray_prob: 0.01 crop_components: true - component_path: models/FFHQ_eye_mouth_landmarks_512.pth + component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth eye_enlarge_ratio: 1.4 # data loader @@ -44,10 +44,10 @@ datasets: prefetch_mode: ~ val: - name: validation0930real_512 + name: validation1020_512 type: PairedImageDataset - dataroot_lq: datasets/faces/validation0930real_512/input # TODO - dataroot_gt: datasets/faces/validation0930real_512/input + dataroot_lq: datasets/faces/validation1020_512/input # TODO: modify before release + dataroot_gt: datasets/faces/validation1020_512/input io_backend: type: disk mean: [0.5, 0.5, 0.5] @@ -61,7 +61,7 @@ network_g: num_style_feat: 512 channel_multiplier: 1 resample_kernel: [1, 3, 3, 1] - decoder_load_path: models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth + decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth fix_decoder: true num_mlp: 8 lr_mlp: 0.01 @@ -102,7 +102,7 @@ path: pretrain_network_d_left_eye: ~ pretrain_network_d_right_eye: ~ pretrain_network_d_mouth: ~ - pretrain_network_arcface: models/arcface_resnet18.pth + pretrain_network_arcface: experiments/pretrained_models/arcface_resnet18.pth # training settings train: @@ -130,6 +130,12 @@ train: type: L1Loss loss_weight: !!float 1e-1 reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + # image pyramid loss pyramid_loss_weight: 0 remove_pyramid_loss: 50000