diff --git a/gfpgan/utils.py b/gfpgan/utils.py index 6223e73..f3e163e 100644 --- a/gfpgan/utils.py +++ b/gfpgan/utils.py @@ -2,10 +2,9 @@ import cv2 import os import torch from basicsr.utils import img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from torch.hub import download_url_to_file, get_dir from torchvision.transforms.functional import normalize -from urllib.parse import urlparse from gfpgan.archs.gfpganv1_arch import GFPGANv1 from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean @@ -70,7 +69,8 @@ class GFPGANer(): device=self.device) if model_path.startswith('https://'): - model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None) + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) loadnet = torch.load(model_path) if 'params_ema' in loadnet: keyname = 'params_ema' @@ -128,25 +128,3 @@ class GFPGANer(): return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img else: return self.face_helper.cropped_faces, self.face_helper.restored_faces, None - - -def load_file_from_url(url, model_dir=None, progress=True, file_name=None): - """Load file form http url, will download models if necessary. - - Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py - """ - if model_dir is None: - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - if file_name is not None: - filename = file_name - cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) - if not os.path.exists(cached_file): - print(f'Downloading: "{url}" to {cached_file}\n') - download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) - return cached_file diff --git a/setup.cfg b/setup.cfg index 826aafc..3d90d60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ line_length = 120 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = gfpgan -known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm +known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY @@ -25,3 +25,9 @@ default_section = THIRDPARTY skip = .git,./docs/build count = quiet-level = 3 + +[aliases] +test=pytest + +[tool:pytest] +addopts=tests/ diff --git a/tests/data/ffhq_gt.lmdb/data.mdb b/tests/data/ffhq_gt.lmdb/data.mdb new file mode 100644 index 0000000..823e0a9 Binary files /dev/null and b/tests/data/ffhq_gt.lmdb/data.mdb differ diff --git a/tests/data/ffhq_gt.lmdb/lock.mdb b/tests/data/ffhq_gt.lmdb/lock.mdb new file mode 100644 index 0000000..c53d2e5 Binary files /dev/null and b/tests/data/ffhq_gt.lmdb/lock.mdb differ diff --git a/tests/data/ffhq_gt.lmdb/meta_info.txt b/tests/data/ffhq_gt.lmdb/meta_info.txt new file mode 100644 index 0000000..8f18d95 --- /dev/null +++ b/tests/data/ffhq_gt.lmdb/meta_info.txt @@ -0,0 +1 @@ +00000000.png (512,512,3) 1 diff --git a/tests/data/gt/00000000.png b/tests/data/gt/00000000.png new file mode 100644 index 0000000..33425aa Binary files /dev/null and b/tests/data/gt/00000000.png differ diff --git a/tests/data/test_eye_mouth_landmarks.pth b/tests/data/test_eye_mouth_landmarks.pth new file mode 100644 index 0000000..35243df Binary files /dev/null and b/tests/data/test_eye_mouth_landmarks.pth differ diff --git a/tests/data/test_ffhq_degradation_dataset.yml b/tests/data/test_ffhq_degradation_dataset.yml new file mode 100644 index 0000000..df50c4b --- /dev/null +++ b/tests/data/test_ffhq_degradation_dataset.yml @@ -0,0 +1,24 @@ +name: UnitTest +type: FFHQDegradationDataset +dataroot_gt: tests/data/gt +io_backend: + type: disk + +use_hflip: true +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] +out_size: 512 + +blur_kernel_size: 41 +kernel_list: ['iso', 'aniso'] +kernel_prob: [0.5, 0.5] +blur_sigma: [0.1, 10] +downsample_range: [0.8, 8] +noise_range: [0, 20] +jpeg_range: [60, 100] + +# color jitter and gray +color_jitter_prob: 1 +color_jitter_shift: 20 +color_jitter_pt_prob: 1 +gray_prob: 1 diff --git a/tests/data/test_gfpgan_model.yml b/tests/data/test_gfpgan_model.yml new file mode 100644 index 0000000..bac650e --- /dev/null +++ b/tests/data/test_gfpgan_model.yml @@ -0,0 +1,140 @@ +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# network structures +network_g: + type: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: ~ + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 0.5 + sft_half: true + +network_d: + type: StyleGAN2Discriminator + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + +network_d_left_eye: + type: FacialComponentDiscriminator + +network_d_right_eye: + type: FacialComponentDiscriminator + +network_d_mouth: + type: FacialComponentDiscriminator + +network_identity: + type: ResNetArcFace + block: IRBlock + layers: [2, 2, 2, 2] + use_se: False + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + pretrain_network_d_left_eye: ~ + pretrain_network_d_right_eye: ~ + pretrain_network_d_mouth: ~ + pretrain_network_identity: ~ + # resume + resume_state: ~ + ignore_resume_networks: ['network_identity'] + +# training settings +train: + optim_g: + type: Adam + lr: !!float 2e-3 + optim_d: + type: Adam + lr: !!float 2e-3 + optim_component: + type: Adam + lr: !!float 2e-3 + + scheduler: + type: MultiStepLR + milestones: [600000, 700000] + gamma: 0.5 + + total_iter: 800000 + warmup_iter: -1 # no warm up + + # losses + # pixel loss + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-1 + reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + + # image pyramid loss + pyramid_loss_weight: 1 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + # facial component loss + gan_component_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1 + comp_style_weight: 200 + # identity loss + identity_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 1 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: True + use_pbar: True + + metrics: + psnr: # metric name + type: calculate_psnr + crop_border: 0 + test_y_channel: false diff --git a/tests/test_arcface_arch.py b/tests/test_arcface_arch.py new file mode 100644 index 0000000..b4b28d3 --- /dev/null +++ b/tests/test_arcface_arch.py @@ -0,0 +1,49 @@ +import torch + +from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace + + +def test_resnetarcface(): + """Test arch: ResNetArcFace.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval() + img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda() + output = net(img) + assert output.shape == (1, 512) + + # -------------------- without SE block ----------------------- # + net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval() + output = net(img) + assert output.shape == (1, 512) + + +def test_basicblock(): + """Test the BasicBlock in arcface_arch""" + block = BasicBlock(1, 3, stride=1, downsample=None).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 3, 12, 12) + + # ----------------- use the downsmaple module--------------- # + downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() + block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 3, 6, 6) + + +def test_bottleneck(): + """Test the Bottleneck in arcface_arch""" + block = Bottleneck(1, 1, stride=1, downsample=None).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 4, 12, 12) + + # ----------------- use the downsmaple module--------------- # + downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() + block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 4, 6, 6) diff --git a/tests/test_ffhq_degradation_dataset.py b/tests/test_ffhq_degradation_dataset.py new file mode 100644 index 0000000..fa56c03 --- /dev/null +++ b/tests/test_ffhq_degradation_dataset.py @@ -0,0 +1,96 @@ +import pytest +import yaml + +from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset + + +def test_ffhq_degradation_dataset(): + + with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 1 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == 'tests/data/gt/00000000.png' + + # ------------------ test with probability = 0 -------------------- # + opt['color_jitter_prob'] = 0 + opt['color_jitter_pt_prob'] = 0 + opt['gray_prob'] = 0 + opt['io_backend'] = dict(type='disk') + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 0 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == 'tests/data/gt/00000000.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb' + opt['io_backend'] = dict(type='lmdb') + + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 0 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == '00000000' + + # ------------------ test with crop_components -------------------- # + opt['crop_components'] = True + opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth' + opt['eye_enlarge_ratio'] = 1.4 + opt['gt_gray'] = True + opt['io_backend'] = dict(type='lmdb') + + dataset = FFHQDegradationDataset(opt) + assert dataset.crop_components is True + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == '00000000' + assert result['loc_left_eye'].shape == (4, ) + assert result['loc_right_eye'].shape == (4, ) + assert result['loc_mouth'].shape == (4, ) + + # ------------------ lmdb backend should have paths ends with lmdb -------------------- # + with pytest.raises(ValueError): + opt['dataroot_gt'] = 'tests/data/gt' + opt['io_backend'] = dict(type='lmdb') + dataset = FFHQDegradationDataset(opt) diff --git a/tests/test_gfpgan_arch.py b/tests/test_gfpgan_arch.py new file mode 100644 index 0000000..cef14a4 --- /dev/null +++ b/tests/test_gfpgan_arch.py @@ -0,0 +1,203 @@ +import torch + +from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT + + +def test_stylegan2generatorsft(): + """Test arch: StyleGAN2GeneratorSFT.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorSFT( + out_size=32, + num_style_feat=512, + num_mlp=8, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda() + condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda() + condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda() + conditions = [condition1, condition1, condition2, condition2, condition3, condition3] + output = net([style], conditions) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], conditions, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], conditions, randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], conditions, truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + +def test_gfpganv1(): + """Test arch: GFPGANv1.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = GFPGANv1( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + # -------------------- with different_w = True ----------------------- # + net = GFPGANv1( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=True, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + +def test_facialcomponentdiscriminator(): + """Test arch: FacialComponentDiscriminator.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = FacialComponentDiscriminator().cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert len(output) == 2 + assert output[0].shape == (1, 1, 8, 8) + assert output[1] is None + + # -------------------- return intermediate features ----------------------- # + output = net(img, return_feats=True) + assert len(output) == 2 + assert output[0].shape == (1, 1, 8, 8) + assert len(output[1]) == 2 + assert output[1][0].shape == (1, 128, 16, 16) + assert output[1][1].shape == (1, 256, 8, 8) + + +def test_stylegan2generatorcsft(): + """Test arch: StyleGAN2GeneratorCSFT.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorCSFT( + out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda() + condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda() + condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda() + conditions = [condition1, condition1, condition2, condition2, condition3, condition3] + output = net([style], conditions) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], conditions, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], conditions, randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], conditions, truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + +def test_gfpganv1clean(): + """Test arch: GFPGANv1Clean.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = GFPGANv1Clean( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=True).cuda().eval() + + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + # -------------------- with different_w = True ----------------------- # + net = GFPGANv1Clean( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=True, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) diff --git a/tests/test_gfpgan_model.py b/tests/test_gfpgan_model.py new file mode 100644 index 0000000..1408ddd --- /dev/null +++ b/tests/test_gfpgan_model.py @@ -0,0 +1,132 @@ +import tempfile +import torch +import yaml +from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator +from basicsr.data.paired_image_dataset import PairedImageDataset +from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss + +from gfpgan.archs.arcface_arch import ResNetArcFace +from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1 +from gfpgan.models.gfpgan_model import GFPGANModel + + +def test_gfpgan_model(): + with open('tests/data/test_gfpgan_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = GFPGANModel(opt) + # test attributes + assert model.__class__.__name__ == 'GFPGANModel' + assert isinstance(model.net_g, GFPGANv1) # generator + assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator + # facial component discriminators + assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator) + assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator) + assert isinstance(model.net_d_mouth, FacialComponentDiscriminator) + # identity network + assert isinstance(model.network_identity, ResNetArcFace) + # losses + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.cri_perceptual, PerceptualLoss) + assert isinstance(model.cri_gan, GANLoss) + assert isinstance(model.cri_l1, L1Loss) + # optimizer + assert isinstance(model.optimizers[0], torch.optim.Adam) + assert isinstance(model.optimizers[1], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 512, 512), dtype=torch.float32) + lq = torch.rand((1, 3, 512, 512), dtype=torch.float32) + loc_left_eye = torch.rand((1, 4), dtype=torch.float32) + loc_right_eye = torch.rand((1, 4), dtype=torch.float32) + loc_mouth = torch.rand((1, 4), dtype=torch.float32) + data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth) + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 512, 512) + assert model.gt.shape == (1, 3, 512, 512) + assert model.loc_left_eyes.shape == (1, 4) + assert model.loc_right_eyes.shape == (1, 4) + assert model.loc_mouths.shape == (1, 4) + + # ----------------- test optimize_parameters -------------------- # + model.feed_data(data) + model.optimize_parameters(1) + assert model.output.shape == (1, 3, 512, 512) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = [ + 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', + 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', + 'l_d_right_eye', 'l_d_mouth' + ] + assert set(expected_keys).issubset(set(model.log_dict.keys())) + + # ----------------- remove pyramid_loss_weight-------------------- # + model.feed_data(data) + model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000 + assert model.output.shape == (1, 3, 512, 512) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = [ + 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', + 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', + 'l_d_right_eye', 'l_d_mouth' + ] + assert set(expected_keys).issubset(set(model.log_dict.keys())) + + # ----------------- test save -------------------- # + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['path']['models'] = tmpdir + model.opt['path']['training_states'] = tmpdir + model.save(0, 1) + + # ----------------- test the test function -------------------- # + model.test() + assert model.output.shape == (1, 3, 512, 512) + # delete net_g_ema + model.__delattr__('net_g_ema') + model.test() + assert model.output.shape == (1, 3, 512, 512) + assert model.net_g.training is True # should back to training mode after testing + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/gt', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['path']['visualization'] = tmpdir + model.nondist_validation(dataloader, 1, None, save_img=True) + assert model.is_train is True + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) + + # validation + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['is_train'] = False + model.opt['val']['suffix'] = 'test' + model.opt['path']['visualization'] = tmpdir + model.opt['val']['pbar'] = True + model.nondist_validation(dataloader, 1, None, save_img=True) + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) + + # if opt['val']['suffix'] is None + model.opt['val']['suffix'] = None + model.opt['name'] = 'demo' + model.opt['path']['visualization'] = tmpdir + model.nondist_validation(dataloader, 1, None, save_img=True) + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) diff --git a/tests/test_stylegan2_clean_arch.py b/tests/test_stylegan2_clean_arch.py new file mode 100644 index 0000000..78bb920 --- /dev/null +++ b/tests/test_stylegan2_clean_arch.py @@ -0,0 +1,52 @@ +import torch + +from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean + + +def test_stylegan2generatorclean(): + """Test arch: StyleGAN2GeneratorClean.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorClean( + out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + output = net([style], input_is_latent=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], input_is_latent=True, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # ------------------ test make_noise ----------------------- # + out = net.make_noise() + assert len(out) == 7 + assert out[0].shape == (1, 1, 4, 4) + assert out[1].shape == (1, 1, 8, 8) + assert out[2].shape == (1, 1, 8, 8) + assert out[3].shape == (1, 1, 16, 16) + assert out[4].shape == (1, 1, 16, 16) + assert out[5].shape == (1, 1, 32, 32) + assert out[6].shape == (1, 1, 32, 32) + + # ------------------ test get_latent ----------------------- # + out = net.get_latent(style) + assert out.shape == (1, 512) + + # ------------------ test mean_latent ----------------------- # + out = net.mean_latent(2) + assert out.shape == (1, 512) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..a963b32 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,43 @@ +import cv2 +from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean +from gfpgan.utils import GFPGANer + + +def test_gfpganer(): + # initialize with the clean model + restorer = GFPGANer( + model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth', + upscale=2, + arch='clean', + channel_multiplier=2, + bg_upsampler=None) + # test attribute + assert isinstance(restorer.gfpgan, GFPGANv1Clean) + assert isinstance(restorer.face_helper, FaceRestoreHelper) + + # initialize with the original model + restorer = GFPGANer( + model_path='experiments/pretrained_models/GFPGANv1.pth', + upscale=2, + arch='original', + channel_multiplier=1, + bg_upsampler=None) + # test attribute + assert isinstance(restorer.gfpgan, GFPGANv1) + assert isinstance(restorer.face_helper, FaceRestoreHelper) + + # ------------------ test enhance ---------------- # + img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR) + result = restorer.enhance(img, has_aligned=False, paste_back=True) + assert result[0][0].shape == (512, 512, 3) + assert result[1][0].shape == (512, 512, 3) + assert result[2].shape == (1024, 1024, 3) + + # with has_aligned=True + result = restorer.enhance(img, has_aligned=True, paste_back=False) + assert result[0][0].shape == (512, 512, 3) + assert result[1][0].shape == (512, 512, 3) + assert result[2] is None