-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
66 lines (49 loc) · 1.43 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import numpy as np
import torch
import torchvision
from torch import optim
import os
from PIL import Image
import matplotlib.pyplot as plt
from models import Generator
parser = argparse.ArgumentParser(description="Arguments parser")
parser.add_argument(
"--img_path",
default="",
type=str,
help="img path",
)
parser.add_argument(
"--model_path",
default="",
type=str,
help="path for the generator model",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
IMG_PATH = args.img_path
to_tensor = torchvision.transforms.ToTensor()
img = np.array(Image.open(IMG_PATH))
img_tensor = to_tensor(img).unsqueeze(0).to(device)
GENERATOR_CHECKPOINT = args.model_path
gen = Generator().to(device)
if GENERATOR_CHECKPOINT:
if os.path.isfile(GENERATOR_CHECKPOINT):
print(
"Loading checkpoint {} of the generator...".format(GENERATOR_CHECKPOINT)
)
checkpoint = torch.load(
GENERATOR_CHECKPOINT, map_location=lambda storage, loc: storage
)
gen.load_state_dict(checkpoint["model_state_dict"])
print("Generator correctly loaded.")
else:
print("Generator checkpoint filepath incorrect.")
exit()
generated = gen(img_tensor).reshape(3, 256, 256)
print(generated.shape)
generated = generated.permute(1, 2, 0).detach().cpu().numpy()
plt.figure()
plt.imshow(generated)
plt.show()