mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-21 01:30:13 -07:00
add GFPGAN clean arch
This commit is contained in:
parent
7023b5cbdd
commit
cc3c881f85
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,7 @@
|
||||
.vscode
|
||||
datasets/*
|
||||
experiments/*
|
||||
tb_logger/*
|
||||
|
||||
# ignored files
|
||||
version.py
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
[](https://github.com/TencentARC/GFPGAN/releases)
|
||||
[](https://github.com/TencentARC/GFPGAN/issues)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
|
||||
|
304
archs/gfpganv1_clean_arch.py
Normal file
304
archs/gfpganv1_clean_arch.py
Normal file
@ -0,0 +1,304 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
super(StyleGAN2GeneratorCSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||
if mode == 'down':
|
||||
self.scale_factor = 0.5
|
||||
elif mode == 'up':
|
||||
self.scale_factor = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
||||
# upsample/downsample
|
||||
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
||||
# skip
|
||||
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
skip = self.skip(x)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""GFPGANv1 Clean version."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANv1Clean, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
if fix_decoder:
|
||||
for name, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
378
archs/stylegan2_clean_arch.py
Normal file
378
archs/stylegan2_clean_arch.py
Normal file
@ -0,0 +1,378 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.archs.arch_util import default_init_weights
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
||||
# initialization
|
||||
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
||||
math.sqrt(in_channels * kernel_size**2))
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
elif self.sample_mode == 'downsample':
|
||||
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
||||
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# add bias
|
||||
out = out + self.bias
|
||||
# activation
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class StyleGAN2GeneratorClean(nn.Module):
|
||||
"""Clean version of StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
||||
super(StyleGAN2GeneratorClean, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.extend(
|
||||
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
# initialization
|
||||
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
'16': int(512 * narrow),
|
||||
'32': int(512 * narrow),
|
||||
'64': int(256 * channel_multiplier * narrow),
|
||||
'128': int(128 * channel_multiplier * narrow),
|
||||
'256': int(64 * channel_multiplier * narrow),
|
||||
'512': int(32 * channel_multiplier * narrow),
|
||||
'1024': int(16 * channel_multiplier * narrow)
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels['4'], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels['4'],
|
||||
channels['4'],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None)
|
||||
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels['4']
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2**((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode='upsample'))
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None))
|
||||
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
@ -8,6 +8,7 @@ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from archs.gfpganv1_arch import GFPGANv1
|
||||
from archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
from basicsr.utils import img2tensor, imwrite, tensor2img
|
||||
|
||||
|
||||
@ -32,7 +33,7 @@ def restoration(gfpgan,
|
||||
else:
|
||||
face_helper.read_image(input_img)
|
||||
# get face landmarks for each face
|
||||
face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
|
||||
face_helper.get_face_landmarks_5(only_center_face=only_center_face)
|
||||
# align and warp each face
|
||||
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
|
||||
face_helper.align_warp_face(save_crop_path)
|
||||
@ -79,24 +80,40 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--upscale_factor', type=int, default=1)
|
||||
parser.add_argument('--arch', type=str, default='clean')
|
||||
parser.add_argument('--channel', type=int, default=2)
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_true')
|
||||
parser.add_argument('--save_root', type=str, default='results')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.test_path.endswith('/'):
|
||||
args.test_path = args.test_path[:-1]
|
||||
save_root = 'results/'
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
os.makedirs(args.save_root, exist_ok=True)
|
||||
|
||||
# initialize the GFP-GAN
|
||||
if args.arch == 'clean':
|
||||
gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=args.channel,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
else:
|
||||
gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
channel_multiplier=args.channel,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
@ -121,10 +138,10 @@ if __name__ == '__main__':
|
||||
gfpgan,
|
||||
face_helper,
|
||||
img_path,
|
||||
save_root,
|
||||
args.save_root,
|
||||
has_aligned=args.aligned,
|
||||
only_center_face=args.only_center_face,
|
||||
suffix=args.suffix,
|
||||
paste_back=args.paste_back)
|
||||
|
||||
print('Results are in the <results> folder.')
|
||||
print(f'Results are in the [{args.save_root}] folder.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user