mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-15 14:50:11 -07:00
update utils and unittest
This commit is contained in:
parent
be73d6d9a4
commit
37237da798
@ -2,10 +2,9 @@ import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from torchvision.transforms.functional import normalize
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
@ -70,7 +69,8 @@ class GFPGANer():
|
||||
device=self.device)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
@ -128,25 +128,3 @@ class GFPGANer():
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
@ -17,7 +17,7 @@ line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = gfpgan
|
||||
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
|
||||
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
@ -25,3 +25,9 @@ default_section = THIRDPARTY
|
||||
skip = .git,./docs/build
|
||||
count =
|
||||
quiet-level = 3
|
||||
|
||||
[aliases]
|
||||
test=pytest
|
||||
|
||||
[tool:pytest]
|
||||
addopts=tests/
|
||||
|
BIN
tests/data/ffhq_gt.lmdb/data.mdb
Normal file
BIN
tests/data/ffhq_gt.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/ffhq_gt.lmdb/lock.mdb
Normal file
BIN
tests/data/ffhq_gt.lmdb/lock.mdb
Normal file
Binary file not shown.
1
tests/data/ffhq_gt.lmdb/meta_info.txt
Normal file
1
tests/data/ffhq_gt.lmdb/meta_info.txt
Normal file
@ -0,0 +1 @@
|
||||
00000000.png (512,512,3) 1
|
BIN
tests/data/gt/00000000.png
Normal file
BIN
tests/data/gt/00000000.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 429 KiB |
BIN
tests/data/test_eye_mouth_landmarks.pth
Normal file
BIN
tests/data/test_eye_mouth_landmarks.pth
Normal file
Binary file not shown.
24
tests/data/test_ffhq_degradation_dataset.yml
Normal file
24
tests/data/test_ffhq_degradation_dataset.yml
Normal file
@ -0,0 +1,24 @@
|
||||
name: UnitTest
|
||||
type: FFHQDegradationDataset
|
||||
dataroot_gt: tests/data/gt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
use_hflip: true
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
out_size: 512
|
||||
|
||||
blur_kernel_size: 41
|
||||
kernel_list: ['iso', 'aniso']
|
||||
kernel_prob: [0.5, 0.5]
|
||||
blur_sigma: [0.1, 10]
|
||||
downsample_range: [0.8, 8]
|
||||
noise_range: [0, 20]
|
||||
jpeg_range: [60, 100]
|
||||
|
||||
# color jitter and gray
|
||||
color_jitter_prob: 1
|
||||
color_jitter_shift: 20
|
||||
color_jitter_pt_prob: 1
|
||||
gray_prob: 1
|
140
tests/data/test_gfpgan_model.yml
Normal file
140
tests/data/test_gfpgan_model.yml
Normal file
@ -0,0 +1,140 @@
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: GFPGANv1
|
||||
out_size: 512
|
||||
num_style_feat: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
decoder_load_path: ~
|
||||
fix_decoder: true
|
||||
num_mlp: 8
|
||||
lr_mlp: 0.01
|
||||
input_is_latent: true
|
||||
different_w: true
|
||||
narrow: 0.5
|
||||
sft_half: true
|
||||
|
||||
network_d:
|
||||
type: StyleGAN2Discriminator
|
||||
out_size: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
|
||||
network_d_left_eye:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_d_right_eye:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_d_mouth:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_identity:
|
||||
type: ResNetArcFace
|
||||
block: IRBlock
|
||||
layers: [2, 2, 2, 2]
|
||||
use_se: False
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
pretrain_network_d_left_eye: ~
|
||||
pretrain_network_d_right_eye: ~
|
||||
pretrain_network_d_mouth: ~
|
||||
pretrain_network_identity: ~
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_component:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [600000, 700000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 800000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
# pixel loss
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: !!float 1e-1
|
||||
reduction: mean
|
||||
# L1 loss used in pyramid loss, component style loss and identity loss
|
||||
L1_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1
|
||||
reduction: mean
|
||||
|
||||
# image pyramid loss
|
||||
pyramid_loss_weight: 1
|
||||
remove_pyramid_loss: 50000
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1
|
||||
style_weight: 50
|
||||
range_norm: true
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: wgan_softplus
|
||||
loss_weight: !!float 1e-1
|
||||
# r1 regularization for discriminator
|
||||
r1_reg_weight: 10
|
||||
# facial component loss
|
||||
gan_component_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1
|
||||
comp_style_weight: 200
|
||||
# identity loss
|
||||
identity_weight: 10
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
net_d_reg_every: 1
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: True
|
||||
use_pbar: True
|
||||
|
||||
metrics:
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
49
tests/test_arcface_arch.py
Normal file
49
tests/test_arcface_arch.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
|
||||
|
||||
|
||||
def test_resnetarcface():
|
||||
"""Test arch: ResNetArcFace."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
|
||||
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output.shape == (1, 512)
|
||||
|
||||
# -------------------- without SE block ----------------------- #
|
||||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
|
||||
output = net(img)
|
||||
assert output.shape == (1, 512)
|
||||
|
||||
|
||||
def test_basicblock():
|
||||
"""Test the BasicBlock in arcface_arch"""
|
||||
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 3, 12, 12)
|
||||
|
||||
# ----------------- use the downsmaple module--------------- #
|
||||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
||||
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 3, 6, 6)
|
||||
|
||||
|
||||
def test_bottleneck():
|
||||
"""Test the Bottleneck in arcface_arch"""
|
||||
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 4, 12, 12)
|
||||
|
||||
# ----------------- use the downsmaple module--------------- #
|
||||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
||||
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 4, 6, 6)
|
96
tests/test_ffhq_degradation_dataset.py
Normal file
96
tests/test_ffhq_degradation_dataset.py
Normal file
@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
|
||||
|
||||
|
||||
def test_ffhq_degradation_dataset():
|
||||
|
||||
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 1
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
||||
|
||||
# ------------------ test with probability = 0 -------------------- #
|
||||
opt['color_jitter_prob'] = 0
|
||||
opt['color_jitter_pt_prob'] = 0
|
||||
opt['gray_prob'] = 0
|
||||
opt['io_backend'] = dict(type='disk')
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 0
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 0
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == '00000000'
|
||||
|
||||
# ------------------ test with crop_components -------------------- #
|
||||
opt['crop_components'] = True
|
||||
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
|
||||
opt['eye_enlarge_ratio'] = 1.4
|
||||
opt['gt_gray'] = True
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.crop_components is True
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == '00000000'
|
||||
assert result['loc_left_eye'].shape == (4, )
|
||||
assert result['loc_right_eye'].shape == (4, )
|
||||
assert result['loc_mouth'].shape == (4, )
|
||||
|
||||
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
||||
with pytest.raises(ValueError):
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
dataset = FFHQDegradationDataset(opt)
|
203
tests/test_gfpgan_arch.py
Normal file
203
tests/test_gfpgan_arch.py
Normal file
@ -0,0 +1,203 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
|
||||
|
||||
|
||||
def test_stylegan2generatorsft():
|
||||
"""Test arch: StyleGAN2GeneratorSFT."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorSFT(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
sft_half=False).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
||||
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
||||
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
||||
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
||||
output = net([style], conditions)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], conditions, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], conditions, randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
|
||||
def test_gfpganv1():
|
||||
"""Test arch: GFPGANv1."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = GFPGANv1(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
# -------------------- with different_w = True ----------------------- #
|
||||
net = GFPGANv1(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
|
||||
def test_facialcomponentdiscriminator():
|
||||
"""Test arch: FacialComponentDiscriminator."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = FacialComponentDiscriminator().cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert len(output) == 2
|
||||
assert output[0].shape == (1, 1, 8, 8)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- return intermediate features ----------------------- #
|
||||
output = net(img, return_feats=True)
|
||||
assert len(output) == 2
|
||||
assert output[0].shape == (1, 1, 8, 8)
|
||||
assert len(output[1]) == 2
|
||||
assert output[1][0].shape == (1, 128, 16, 16)
|
||||
assert output[1][1].shape == (1, 256, 8, 8)
|
||||
|
||||
|
||||
def test_stylegan2generatorcsft():
|
||||
"""Test arch: StyleGAN2GeneratorCSFT."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorCSFT(
|
||||
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
||||
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
||||
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
||||
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
||||
output = net([style], conditions)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], conditions, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], conditions, randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
|
||||
def test_gfpganv1clean():
|
||||
"""Test arch: GFPGANv1Clean."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = GFPGANv1Clean(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
# -------------------- with different_w = True ----------------------- #
|
||||
net = GFPGANv1Clean(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
132
tests/test_gfpgan_model.py
Normal file
132
tests/test_gfpgan_model.py
Normal file
@ -0,0 +1,132 @@
|
||||
import tempfile
|
||||
import torch
|
||||
import yaml
|
||||
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
|
||||
from basicsr.data.paired_image_dataset import PairedImageDataset
|
||||
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
||||
|
||||
from gfpgan.archs.arcface_arch import ResNetArcFace
|
||||
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
|
||||
from gfpgan.models.gfpgan_model import GFPGANModel
|
||||
|
||||
|
||||
def test_gfpgan_model():
|
||||
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = GFPGANModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'GFPGANModel'
|
||||
assert isinstance(model.net_g, GFPGANv1) # generator
|
||||
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
|
||||
# facial component discriminators
|
||||
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
|
||||
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
|
||||
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
|
||||
# identity network
|
||||
assert isinstance(model.network_identity, ResNetArcFace)
|
||||
# losses
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
||||
assert isinstance(model.cri_gan, GANLoss)
|
||||
assert isinstance(model.cri_l1, L1Loss)
|
||||
# optimizer
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
||||
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
||||
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
|
||||
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
|
||||
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
|
||||
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 512, 512)
|
||||
assert model.gt.shape == (1, 3, 512, 512)
|
||||
assert model.loc_left_eyes.shape == (1, 4)
|
||||
assert model.loc_right_eyes.shape == (1, 4)
|
||||
assert model.loc_mouths.shape == (1, 4)
|
||||
|
||||
# ----------------- test optimize_parameters -------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(1)
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = [
|
||||
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
||||
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
||||
'l_d_right_eye', 'l_d_mouth'
|
||||
]
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
|
||||
# ----------------- remove pyramid_loss_weight-------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = [
|
||||
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
||||
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
||||
'l_d_right_eye', 'l_d_mouth'
|
||||
]
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
|
||||
# ----------------- test save -------------------- #
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['path']['models'] = tmpdir
|
||||
model.opt['path']['training_states'] = tmpdir
|
||||
model.save(0, 1)
|
||||
|
||||
# ----------------- test the test function -------------------- #
|
||||
model.test()
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
# delete net_g_ema
|
||||
model.__delattr__('net_g_ema')
|
||||
model.test()
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert model.net_g.training is True # should back to training mode after testing
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/gt',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
assert model.is_train is True
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
||||
|
||||
# validation
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['is_train'] = False
|
||||
model.opt['val']['suffix'] = 'test'
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.opt['val']['pbar'] = True
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
||||
|
||||
# if opt['val']['suffix'] is None
|
||||
model.opt['val']['suffix'] = None
|
||||
model.opt['name'] = 'demo'
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
52
tests/test_stylegan2_clean_arch.py
Normal file
52
tests/test_stylegan2_clean_arch.py
Normal file
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
def test_stylegan2generatorclean():
|
||||
"""Test arch: StyleGAN2GeneratorClean."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorClean(
|
||||
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
output = net([style], input_is_latent=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], input_is_latent=True, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# ------------------ test make_noise ----------------------- #
|
||||
out = net.make_noise()
|
||||
assert len(out) == 7
|
||||
assert out[0].shape == (1, 1, 4, 4)
|
||||
assert out[1].shape == (1, 1, 8, 8)
|
||||
assert out[2].shape == (1, 1, 8, 8)
|
||||
assert out[3].shape == (1, 1, 16, 16)
|
||||
assert out[4].shape == (1, 1, 16, 16)
|
||||
assert out[5].shape == (1, 1, 32, 32)
|
||||
assert out[6].shape == (1, 1, 32, 32)
|
||||
|
||||
# ------------------ test get_latent ----------------------- #
|
||||
out = net.get_latent(style)
|
||||
assert out.shape == (1, 512)
|
||||
|
||||
# ------------------ test mean_latent ----------------------- #
|
||||
out = net.mean_latent(2)
|
||||
assert out.shape == (1, 512)
|
43
tests/test_utils.py
Normal file
43
tests/test_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import cv2
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
from gfpgan.utils import GFPGANer
|
||||
|
||||
|
||||
def test_gfpganer():
|
||||
# initialize with the clean model
|
||||
restorer = GFPGANer(
|
||||
model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None)
|
||||
# test attribute
|
||||
assert isinstance(restorer.gfpgan, GFPGANv1Clean)
|
||||
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
||||
|
||||
# initialize with the original model
|
||||
restorer = GFPGANer(
|
||||
model_path='experiments/pretrained_models/GFPGANv1.pth',
|
||||
upscale=2,
|
||||
arch='original',
|
||||
channel_multiplier=1,
|
||||
bg_upsampler=None)
|
||||
# test attribute
|
||||
assert isinstance(restorer.gfpgan, GFPGANv1)
|
||||
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
||||
|
||||
# ------------------ test enhance ---------------- #
|
||||
img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
|
||||
result = restorer.enhance(img, has_aligned=False, paste_back=True)
|
||||
assert result[0][0].shape == (512, 512, 3)
|
||||
assert result[1][0].shape == (512, 512, 3)
|
||||
assert result[2].shape == (1024, 1024, 3)
|
||||
|
||||
# with has_aligned=True
|
||||
result = restorer.enhance(img, has_aligned=True, paste_back=False)
|
||||
assert result[0][0].shape == (512, 512, 3)
|
||||
assert result[1][0].shape == (512, 512, 3)
|
||||
assert result[2] is None
|
Loading…
x
Reference in New Issue
Block a user