1
0
mirror of https://github.com/TencentARC/GFPGAN.git synced 2025-05-15 14:50:11 -07:00

update utils and unittest

This commit is contained in:
Xintao 2021-11-28 23:09:38 +08:00
parent be73d6d9a4
commit 37237da798
15 changed files with 750 additions and 26 deletions

View File

@ -2,10 +2,9 @@ import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
@ -70,7 +69,8 @@ class GFPGANer():
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
@ -128,25 +128,3 @@ class GFPGANer():
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

View File

@ -17,7 +17,7 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = gfpgan
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
@ -25,3 +25,9 @@ default_section = THIRDPARTY
skip = .git,./docs/build
count =
quiet-level = 3
[aliases]
test=pytest
[tool:pytest]
addopts=tests/

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1 @@
00000000.png (512,512,3) 1

BIN
tests/data/gt/00000000.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 429 KiB

Binary file not shown.

View File

@ -0,0 +1,24 @@
name: UnitTest
type: FFHQDegradationDataset
dataroot_gt: tests/data/gt
io_backend:
type: disk
use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
color_jitter_prob: 1
color_jitter_shift: 20
color_jitter_pt_prob: 1
gray_prob: 1

View File

@ -0,0 +1,140 @@
num_gpu: 1
manual_seed: 0
is_train: True
dist: False
# network structures
network_g:
type: GFPGANv1
out_size: 512
num_style_feat: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
decoder_load_path: ~
fix_decoder: true
num_mlp: 8
lr_mlp: 0.01
input_is_latent: true
different_w: true
narrow: 0.5
sft_half: true
network_d:
type: StyleGAN2Discriminator
out_size: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
network_d_left_eye:
type: FacialComponentDiscriminator
network_d_right_eye:
type: FacialComponentDiscriminator
network_d_mouth:
type: FacialComponentDiscriminator
network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
pretrain_network_d_left_eye: ~
pretrain_network_d_right_eye: ~
pretrain_network_d_mouth: ~
pretrain_network_identity: ~
# resume
resume_state: ~
ignore_resume_networks: ['network_identity']
# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-3
optim_d:
type: Adam
lr: !!float 2e-3
optim_component:
type: Adam
lr: !!float 2e-3
scheduler:
type: MultiStepLR
milestones: [600000, 700000]
gamma: 0.5
total_iter: 800000
warmup_iter: -1 # no warm up
# losses
# pixel loss
pixel_opt:
type: L1Loss
loss_weight: !!float 1e-1
reduction: mean
# L1 loss used in pyramid loss, component style loss and identity loss
L1_opt:
type: L1Loss
loss_weight: 1
reduction: mean
# image pyramid loss
pyramid_loss_weight: 1
remove_pyramid_loss: 50000
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1
style_weight: 50
range_norm: true
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: wgan_softplus
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
gan_component_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1
comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
net_d_reg_every: 1
# validation settings
val:
val_freq: !!float 5e3
save_img: True
use_pbar: True
metrics:
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@ -0,0 +1,49 @@
import torch
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
def test_resnetarcface():
"""Test arch: ResNetArcFace."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 512)
# -------------------- without SE block ----------------------- #
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
output = net(img)
assert output.shape == (1, 512)
def test_basicblock():
"""Test the BasicBlock in arcface_arch"""
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 6, 6)
def test_bottleneck():
"""Test the Bottleneck in arcface_arch"""
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 6, 6)

View File

@ -0,0 +1,96 @@
import pytest
import yaml
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
def test_ffhq_degradation_dataset():
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 1
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'
# ------------------ test with probability = 0 -------------------- #
opt['color_jitter_prob'] = 0
opt['color_jitter_pt_prob'] = 0
opt['gray_prob'] = 0
opt['io_backend'] = dict(type='disk')
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'
# ------------------ test lmdb backend -------------------- #
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'
# ------------------ test with crop_components -------------------- #
opt['crop_components'] = True
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
opt['eye_enlarge_ratio'] = 1.4
opt['gt_gray'] = True
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)
assert dataset.crop_components is True
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'
assert result['loc_left_eye'].shape == (4, )
assert result['loc_right_eye'].shape == (4, )
assert result['loc_mouth'].shape == (4, )
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
with pytest.raises(ValueError):
opt['dataroot_gt'] = 'tests/data/gt'
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)

203
tests/test_gfpgan_arch.py Normal file
View File

@ -0,0 +1,203 @@
import torch
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
def test_stylegan2generatorsft():
"""Test arch: StyleGAN2GeneratorSFT."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorSFT(
out_size=32,
num_style_feat=512,
num_mlp=8,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
lr_mlp=0.01,
narrow=1,
sft_half=False).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
output = net([style], conditions)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], conditions, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], conditions, randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
def test_gfpganv1():
"""Test arch: GFPGANv1."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = GFPGANv1(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
# -------------------- with different_w = True ----------------------- #
net = GFPGANv1(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=True,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
def test_facialcomponentdiscriminator():
"""Test arch: FacialComponentDiscriminator."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = FacialComponentDiscriminator().cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert len(output) == 2
assert output[0].shape == (1, 1, 8, 8)
assert output[1] is None
# -------------------- return intermediate features ----------------------- #
output = net(img, return_feats=True)
assert len(output) == 2
assert output[0].shape == (1, 1, 8, 8)
assert len(output[1]) == 2
assert output[1][0].shape == (1, 128, 16, 16)
assert output[1][1].shape == (1, 256, 8, 8)
def test_stylegan2generatorcsft():
"""Test arch: StyleGAN2GeneratorCSFT."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorCSFT(
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
output = net([style], conditions)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], conditions, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], conditions, randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
def test_gfpganv1clean():
"""Test arch: GFPGANv1Clean."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = GFPGANv1Clean(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
# -------------------- with different_w = True ----------------------- #
net = GFPGANv1Clean(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=True,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)

132
tests/test_gfpgan_model.py Normal file
View File

@ -0,0 +1,132 @@
import tempfile
import torch
import yaml
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
from basicsr.data.paired_image_dataset import PairedImageDataset
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
from gfpgan.archs.arcface_arch import ResNetArcFace
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
from gfpgan.models.gfpgan_model import GFPGANModel
def test_gfpgan_model():
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = GFPGANModel(opt)
# test attributes
assert model.__class__.__name__ == 'GFPGANModel'
assert isinstance(model.net_g, GFPGANv1) # generator
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
# facial component discriminators
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
# identity network
assert isinstance(model.network_identity, ResNetArcFace)
# losses
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.cri_perceptual, PerceptualLoss)
assert isinstance(model.cri_gan, GANLoss)
assert isinstance(model.cri_l1, L1Loss)
# optimizer
assert isinstance(model.optimizers[0], torch.optim.Adam)
assert isinstance(model.optimizers[1], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 512, 512)
assert model.gt.shape == (1, 3, 512, 512)
assert model.loc_left_eyes.shape == (1, 4)
assert model.loc_right_eyes.shape == (1, 4)
assert model.loc_mouths.shape == (1, 4)
# ----------------- test optimize_parameters -------------------- #
model.feed_data(data)
model.optimize_parameters(1)
assert model.output.shape == (1, 3, 512, 512)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = [
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
'l_d_right_eye', 'l_d_mouth'
]
assert set(expected_keys).issubset(set(model.log_dict.keys()))
# ----------------- remove pyramid_loss_weight-------------------- #
model.feed_data(data)
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
assert model.output.shape == (1, 3, 512, 512)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = [
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
'l_d_right_eye', 'l_d_mouth'
]
assert set(expected_keys).issubset(set(model.log_dict.keys()))
# ----------------- test save -------------------- #
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['path']['models'] = tmpdir
model.opt['path']['training_states'] = tmpdir
model.save(0, 1)
# ----------------- test the test function -------------------- #
model.test()
assert model.output.shape == (1, 3, 512, 512)
# delete net_g_ema
model.__delattr__('net_g_ema')
model.test()
assert model.output.shape == (1, 3, 512, 512)
assert model.net_g.training is True # should back to training mode after testing
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/gt',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['path']['visualization'] = tmpdir
model.nondist_validation(dataloader, 1, None, save_img=True)
assert model.is_train is True
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)
# validation
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['is_train'] = False
model.opt['val']['suffix'] = 'test'
model.opt['path']['visualization'] = tmpdir
model.opt['val']['pbar'] = True
model.nondist_validation(dataloader, 1, None, save_img=True)
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)
# if opt['val']['suffix'] is None
model.opt['val']['suffix'] = None
model.opt['name'] = 'demo'
model.opt['path']['visualization'] = tmpdir
model.nondist_validation(dataloader, 1, None, save_img=True)
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)

View File

@ -0,0 +1,52 @@
import torch
from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
def test_stylegan2generatorclean():
"""Test arch: StyleGAN2GeneratorClean."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorClean(
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
output = net([style], input_is_latent=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], input_is_latent=True, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# ------------------ test make_noise ----------------------- #
out = net.make_noise()
assert len(out) == 7
assert out[0].shape == (1, 1, 4, 4)
assert out[1].shape == (1, 1, 8, 8)
assert out[2].shape == (1, 1, 8, 8)
assert out[3].shape == (1, 1, 16, 16)
assert out[4].shape == (1, 1, 16, 16)
assert out[5].shape == (1, 1, 32, 32)
assert out[6].shape == (1, 1, 32, 32)
# ------------------ test get_latent ----------------------- #
out = net.get_latent(style)
assert out.shape == (1, 512)
# ------------------ test mean_latent ----------------------- #
out = net.mean_latent(2)
assert out.shape == (1, 512)

43
tests/test_utils.py Normal file
View File

@ -0,0 +1,43 @@
import cv2
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
from gfpgan.utils import GFPGANer
def test_gfpganer():
# initialize with the clean model
restorer = GFPGANer(
model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=None)
# test attribute
assert isinstance(restorer.gfpgan, GFPGANv1Clean)
assert isinstance(restorer.face_helper, FaceRestoreHelper)
# initialize with the original model
restorer = GFPGANer(
model_path='experiments/pretrained_models/GFPGANv1.pth',
upscale=2,
arch='original',
channel_multiplier=1,
bg_upsampler=None)
# test attribute
assert isinstance(restorer.gfpgan, GFPGANv1)
assert isinstance(restorer.face_helper, FaceRestoreHelper)
# ------------------ test enhance ---------------- #
img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
result = restorer.enhance(img, has_aligned=False, paste_back=True)
assert result[0][0].shape == (512, 512, 3)
assert result[1][0].shape == (512, 512, 3)
assert result[2].shape == (1024, 1024, 3)
# with has_aligned=True
result = restorer.enhance(img, has_aligned=True, paste_back=False)
assert result[0][0].shape == (512, 512, 3)
assert result[1][0].shape == (512, 512, 3)
assert result[2] is None