mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2025-05-21 01:30:13 -07:00
support GFPGAN clean arch
This commit is contained in:
parent
cc3c881f85
commit
7f67e12999
@ -19,7 +19,8 @@ def restoration(gfpgan,
|
||||
has_aligned=False,
|
||||
only_center_face=True,
|
||||
suffix=None,
|
||||
paste_back=False):
|
||||
paste_back=False,
|
||||
device='cuda'):
|
||||
# read image
|
||||
img_name = os.path.basename(img_path)
|
||||
print(f'Processing {img_name} ...')
|
||||
@ -43,7 +44,7 @@ def restoration(gfpgan,
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
@ -77,17 +78,18 @@ def restoration(gfpgan,
|
||||
|
||||
if __name__ == '__main__':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--upscale_factor', type=int, default=1)
|
||||
parser.add_argument('--upscale_factor', type=int, default=2)
|
||||
parser.add_argument('--arch', type=str, default='clean')
|
||||
parser.add_argument('--channel', type=int, default=2)
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_false')
|
||||
parser.add_argument('--save_root', type=str, default='results')
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -123,14 +125,17 @@ if __name__ == '__main__':
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
|
||||
gfpgan.to(device)
|
||||
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
|
||||
gfpgan.load_state_dict(checkpoint['params_ema'])
|
||||
gfpgan.eval()
|
||||
gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
gfpgan.to(device).eval()
|
||||
|
||||
# initialize face helper
|
||||
face_helper = FaceRestoreHelper(
|
||||
args.upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
|
||||
args.upscale_factor,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
device=device)
|
||||
|
||||
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
|
||||
for img_path in img_list:
|
||||
@ -142,6 +147,7 @@ if __name__ == '__main__':
|
||||
has_aligned=args.aligned,
|
||||
only_center_face=args.only_center_face,
|
||||
suffix=args.suffix,
|
||||
paste_back=args.paste_back)
|
||||
paste_back=args.paste_back,
|
||||
device=device)
|
||||
|
||||
print(f'Results are in the [{args.save_root}] folder.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user