-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
32 lines (21 loc) · 933 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from utils import *
from model import Model
#loader = HandwritingDataset(data_dir="./data", split="val", scale_factor=1, seq_length=300)
#lstm_model = Model(seq_length=300)
#lstm_model.load_state_dict(torch.load("./saves/model_11.pth"))
#stroke = lstm_model.sample(100)
# idx = np.random.choice(len(loader))
# stroke, sentence = loader.strokes[idx], loader.labels[idx]
#
# print(sentence)
# draw_strokes(stroke)
use_cuda = torch.cuda.is_available()
lstm_model = Model(seq_length=300, bidirectional=False)
load_epoch = max([int(os.path.splitext(fname)[0].split("_")[1]) for fname in os.listdir("./saves") if "model" in fname])
if use_cuda:
lstm_model.cuda()
lstm_model.load_state_dict(torch.load("./saves/model_{}.pth".format(load_epoch)))
else:
lstm_model.load_state_dict(torch.load("./saves/model_{}.pth".format(load_epoch), map_location='cpu'))
stroke = lstm_model.sample(800)
draw_strokes(stroke, factor=5)