From 043dc220279c476dadb827784f2263db0c414424 Mon Sep 17 00:00:00 2001 From: Xintao Date: Mon, 17 May 2021 23:32:41 +0800 Subject: [PATCH] add inference --- .github/workflows/pylint.yml | 29 ++ .gitignore | 120 ++++++++ .pre-commit-config.yaml | 40 +++ gfpgan_model.py | 543 +++++++++++++++++++++++++++++++++++ gfpganv1_arch.py | 386 +++++++++++++++++++++++++ inference_gfpgan_full.py | 102 +++++++ setup.cfg | 22 ++ 7 files changed, 1242 insertions(+) create mode 100644 .github/workflows/pylint.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 gfpgan_model.py create mode 100644 gfpganv1_arch.py create mode 100644 inference_gfpgan_full.py create mode 100644 setup.cfg diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..0b61a71 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,29 @@ +name: Python Lint + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 yapf isort + + - name: Lint + run: | + flake8 . + isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py + yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5abc87c --- /dev/null +++ b/.gitignore @@ -0,0 +1,120 @@ +.vscode + +# ignored files +version.py + +# ignored files with suffix +*.html +*.png +*.jpeg +*.jpg +*.gif +*.pth +*.zip + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..03cd47c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,40 @@ +repos: + # flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + args: ["--config=setup.cfg", "--ignore=W504, W503"] + + # modify known_third_party + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + + # isort + - repo: https://github.com/timothycrosley/isort + rev: 5.2.2 + hooks: + - id: isort + + # yapf + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + + # pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace # Trim trailing whitespace + - id: check-yaml # Attempt to load all yaml files to verify syntax + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings + - id: end-of-file-fixer # Make sure files end in a newline and only a newline + - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 + - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- + args: ["--remove"] + - id: mixed-line-ending # Replace or check mixed line ending + args: ["--fix=lf"] diff --git a/gfpgan_model.py b/gfpgan_model.py new file mode 100644 index 0000000..c61d835 --- /dev/null +++ b/gfpgan_model.py @@ -0,0 +1,543 @@ +import math +import os.path as osp +import torch +from collections import OrderedDict +from torch.nn import functional as F +from torchvision.ops import roi_align +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.losses.losses import r1_penalty +from basicsr.metrics import calculate_metric +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img + + +class GFPGANModel(BaseModel): + """GFPGAN model for """ + + def __init__(self, opt): + super(GFPGANModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # ----------- define net_d ----------- # + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + # ----------- define net_g with Exponential Moving Average (EMA) ----------- # + # net_g_ema only used for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # ----------- facial components networks ----------- # + if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): + self.use_facial_disc = True + else: + self.use_facial_disc = False + + if self.use_facial_disc: + # left eye + self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) + self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) + self.print_network(self.net_d_left_eye) + load_path = self.opt['path'].get('pretrain_network_d_left_eye') + if load_path is not None: + self.load_network(self.net_d_left_eye, load_path, True, 'params') + # right eye + self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) + self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) + self.print_network(self.net_d_right_eye) + load_path = self.opt['path'].get('pretrain_network_d_right_eye') + if load_path is not None: + self.load_network(self.net_d_right_eye, load_path, True, 'params') + # mouth + self.net_d_mouth = build_network(self.opt['network_d_mouth']) + self.net_d_mouth = self.model_to_device(self.net_d_mouth) + self.print_network(self.net_d_mouth) + load_path = self.opt['path'].get('pretrain_network_d_mouth') + if load_path is not None: + self.load_network(self.net_d_mouth, load_path, True, 'params') + + self.net_d_left_eye.train() + self.net_d_right_eye.train() + self.net_d_mouth.train() + + # ----------- define facial component gan loss ----------- # + self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device) + + # ----------- define losses ----------- # + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + # ----------- define identity loss ----------- # + if 'network_identity' in self.opt: + self.use_identity = True + else: + self.use_identity = False + + if self.use_identity: + # define identity network + self.network_identity = build_network(self.opt['network_identity']) + self.network_identity = self.model_to_device(self.network_identity) + self.print_network(self.network_identity) + load_path = self.opt['path'].get('pretrain_network_identity') + if load_path is not None: + self.load_network(self.network_identity, load_path, True, None) + self.network_identity.eval() + for param in self.network_identity.parameters(): + param.requires_grad = False + + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + self.net_d_reg_every = train_opt['net_d_reg_every'] + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + + # ----------- optimizer g ----------- # + net_g_reg_ratio = 1 + normal_params = [] + for _, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # ----------- optimizer d ----------- # + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + normal_params = [] + for _, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + if self.use_facial_disc: + # setup optimizers for facial component discriminators + optim_type = train_opt['optim_component'].pop('type') + lr = train_opt['optim_component']['lr'] + # left eye + self.optimizer_d_left_eye = self.get_optimizer( + optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_left_eye) + # right eye + self.optimizer_d_right_eye = self.get_optimizer( + optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_right_eye) + # mouth + self.optimizer_d_mouth = self.get_optimizer( + optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_mouth) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + if self.use_facial_disc: + # get facial component locations, shape (batch, 4) + self.loc_left_eyes = data['loc_left_eye'] + self.loc_right_eyes = data['loc_right_eye'] + self.loc_mouths = data['loc_mouth'] + + def construct_img_pyramid(self): + pyramid_gt = [self.gt] + down_img = self.gt + for _ in range(0, self.log_size - 3): + down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False) + pyramid_gt.insert(0, down_img) + return pyramid_gt + + def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): + # hard code + face_ratio = int(self.opt['network_g']['out_size'] / 512) + eye_out_size *= face_ratio + mouth_out_size *= face_ratio + + rois_eyes = [] + rois_mouths = [] + for b in range(self.loc_left_eyes.size(0)): # loop for batch size + # left eye and right eye + img_inds = self.loc_left_eyes.new_full((2, 1), b) + bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4) + rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) + rois_eyes.append(rois) + # mouse + img_inds = self.loc_left_eyes.new_full((1, 1), b) + rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_mouths.append(rois) + + rois_eyes = torch.cat(rois_eyes, 0).to(self.device) + rois_mouths = torch.cat(rois_mouths, 0).to(self.device) + + # real images + all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes_gt = all_eyes[0::2, :, :, :] + self.right_eyes_gt = all_eyes[1::2, :, :, :] + self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + # output + all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes = all_eyes[0::2, :, :, :] + self.right_eyes = all_eyes[1::2, :, :, :] + self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = False + for p in self.net_d_right_eye.parameters(): + p.requires_grad = False + for p in self.net_d_mouth.parameters(): + p.requires_grad = False + + # image pyramid loss weight + if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')): + pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1) + else: + pyramid_loss_weight = 1e-12 # very small loss + if pyramid_loss_weight > 0: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) + pyramid_gt = self.construct_img_pyramid() + else: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=False) + + # get roi-align regions + if self.use_facial_disc: + self.get_roi_regions(eye_out_size=80, mouth_out_size=120) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # image pyramid loss + if pyramid_loss_weight > 0: + for i in range(0, self.log_size - 2): + l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight + l_g_total += l_pyramid + loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # facial component loss + if self.use_facial_disc: + # left eye + fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_left_eye'] = l_g_gan + # right eye + fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_right_eye'] = l_g_gan + # mouth + fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True) + l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_mouth'] = l_g_gan + + if self.opt['train'].get('comp_style_weight', 0) > 0: + # get gt feat + _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True) + _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True) + _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True) + + def _comp_style(feat, feat_gt, criterion): + return criterion(self._gram_mat(feat[0]), self._gram_mat( + feat_gt[0].detach())) * 0.5 + criterion( + self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach())) + + # facial component style loss + comp_style_loss = 0 + comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1) + comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight'] + l_g_total += comp_style_loss + loss_dict['l_g_comp_style_loss'] = comp_style_loss + + # identity loss + if self.use_identity: + identity_weight = self.opt['train']['identity_weight'] + # get gray images and resize + out_gray = self.gray_resize_for_identity(self.output) + gt_gray = self.gray_resize_for_identity(self.gt) + + identity_gt = self.network_identity(gt_gray).detach() + identity_out = self.network_identity(out_gray) + l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight + l_g_total += l_identity + loss_dict['l_identity'] = l_identity + + l_g_total.backward() + self.optimizer_g.step() + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + # ----------- optimize net_d ----------- # + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = True + for p in self.net_d_right_eye.parameters(): + p.requires_grad = True + for p in self.net_d_mouth.parameters(): + p.requires_grad = True + self.optimizer_d_left_eye.zero_grad() + self.optimizer_d_right_eye.zero_grad() + self.optimizer_d_mouth.zero_grad() + + fake_d_pred = self.net_d(self.output.detach()) + real_d_pred = self.net_d(self.gt) + l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In wgan, real_score should be positive and fake_score should benegative + loss_dict['real_score'] = real_d_pred.detach().mean() + loss_dict['fake_score'] = fake_d_pred.detach().mean() + l_d.backward() + + if current_iter % self.net_d_reg_every == 0: + self.gt.requires_grad = True + real_pred = self.net_d(self.gt) + l_d_r1 = r1_penalty(real_pred, self.gt) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + if self.use_facial_disc: + # lefe eye + fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach()) + real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt) + l_d_left_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_left_eye'] = l_d_left_eye + l_d_left_eye.backward() + # right eye + fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach()) + real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt) + l_d_right_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_right_eye'] = l_d_right_eye + l_d_right_eye.backward() + # mouth + fake_d_pred, _ = self.net_d_mouth(self.mouths.detach()) + real_d_pred, _ = self.net_d_mouth(self.mouths_gt) + l_d_mouth = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_mouth'] = l_d_mouth + l_d_mouth.backward() + + self.optimizer_d_left_eye.step() + self.optimizer_d_right_eye.step() + self.optimizer_d_mouth.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema(self.lq) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _ = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['sr']], min_max=(-1, 1)) + gt_img = tensor2img([visuals['gt']], min_max=(-1, 1)) + + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']], min_max=(-1, 1)) + del self.gt + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['sr'] = self.output.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + # save component discriminators + if self.use_facial_disc: + self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter) + self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter) + self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/gfpganv1_arch.py b/gfpganv1_arch.py new file mode 100644 index 0000000..270da91 --- /dev/null +++ b/gfpganv1_arch.py @@ -0,0 +1,386 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2Generator) +from basicsr.ops.fused_act import FusedLeakyReLU + + +class StyleGAN2GeneratorSFTV1(StyleGAN2Generator): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel + magnitude. A cross production will be applied to extent 1D resample + kenrel to 2D resample kernel. Default: [1, 3, 3, 1]. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorSFTV1, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ConvUpLayer(nn.Module): + """Conv Up Layer. Bilinear upsample + Conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0, + activate=True): + super(ConvUpLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + if bias and not activate: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + # activation + if activate: + if bias: + self.activation = FusedLeakyReLU(out_channels) + else: + self.activation = ScaledLeakyReLU(0.2) + else: + self.activation = None + + def forward(self, x): + # bilinear upsample + out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + # conv + out = F.conv2d( + out, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + # activation + if self.activation is not None: + out = self.activation(out) + return out + + +class ResUpBlock(nn.Module): + """Residual block with upsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels): + super(ResUpBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True) + self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +class GFPGANv1(nn.Module): + """Unet + StyleGAN2 decoder with SFT.""" + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=[1, 3, 3, 1], + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + self.stylegan_decoder = StyleGAN2GeneratorSFTV1( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + if fix_decoder: + for name, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, + x, + return_latents=False, + save_feat_path=None, + load_feat_path=None, + return_rgb=True, + randomize_noise=True): + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layer + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + if save_feat_path is not None: + torch.save(conditions, save_feat_path) + if load_feat_path is not None: + conditions = torch.load(load_feat_path) + conditions = [v.cuda() for v in conditions] + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/inference_gfpgan_full.py b/inference_gfpgan_full.py new file mode 100644 index 0000000..9fef567 --- /dev/null +++ b/inference_gfpgan_full.py @@ -0,0 +1,102 @@ +import argparse +import cv2 +import glob +import os +import torch +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from basicsr.utils import img2tensor, imwrite, tensor2img +from gfpganv1_arch import GFPGANv1 + + +def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=None): + # read image + img_name = os.path.basename(img_path) + print(f'Processing {img_name} ...') + basename, _ = os.path.splitext(img_name) + input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + face_helper.clean_all() + + if has_aligned: + input_img = cv2.resize(input_img, (512, 512)) + face_helper.cropped_faces = [input_img] + else: + face_helper.read_image(input_img) + # get face landmarks for each face + face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False) + # align and warp each face + save_crop_path = os.path.join(save_root, 'cropped_faces', img_name) + face_helper.align_warp_face(save_crop_path) + + # face restoration + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda') + + try: + with torch.no_grad(): + output = gfpgan(cropped_face_t, return_rgb=False)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + if suffix is not None: + save_face_name = f'{basename}_{idx:02d}_{suffix}.png' + else: + save_face_name = f'{basename}_{idx:02d}.png' + save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name) + imwrite(restored_face, save_restore_path) + + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + parser = argparse.ArgumentParser() + + parser.add_argument('--upscale_factor', type=int, default=1) + parser.add_argument('--model_path', type=str, default='models/GFPGANv1.pth') + parser.add_argument('--test_path', type=str, default='inputs') + parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') + parser.add_argument('--only_center_face', action='store_true') + + args = parser.parse_args() + if args.test_path.endswith('/'): + args.test_path = args.test_path[:-1] + save_root = 'results/' + os.makedirs(save_root, exist_ok=True) + + # initialize the GFP-GAN + gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + + gfpgan.to(device) + checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) + gfpgan.load_state_dict(checkpoint['params_ema']) + gfpgan.eval() + + # initialize face helper + face_helper = FaceRestoreHelper( + upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png') + + # scan all the jpg and png images + img_list = sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))) + for img_path in img_list: + restoration( + gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=args.suffix) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..9788236 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[flake8] +ignore = + # line break before binary operator (W503) + W503, + # line break after binary operator (W504) + W504, +max-line-length=120 + +[yapf] +based_on_style = pep8 +column_limit = 120 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 120 +multi_line_output = 0 +known_standard_library = pkg_resources,setuptools +known_first_party = basicsr +known_third_party = cv2,facexlib,torch,torchvision,tqdm +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY