-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathTest_SHT_A.py
40 lines (35 loc) · 1.23 KB
/
Test_SHT_A.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
"""
Created on Mon May 27 15:47:05 2019
@author: liuliang
"""
# =============================================================================
# original package
# =============================================================================
import torch
# =============================================================================
# creatived package
# =============================================================================
from model import LibraNet
from train_test import test_model
parameters = {'TRAIN_SKIP':100,
'BUFFER_LENGTH':10000,
'ERROR_RANGE':0.5,
'GAMMA':0.9,
'batch_size':128,
'Interval_N':57,
'step_log':0.1,
'start_log':-2,
'HV_NUMBER':8,
'ACTION_NUMBER':9,
'ERROR_SYSTEM':0,
'means':[[108.25673428], [ 97.02240046], [ 93.37483706]]}
test_path ='data/Test/'
epoch=0
net = LibraNet(parameters)
net.load_state_dict(torch.load('trained_model/LibraNet_SHT_A.pth.tar')['state_dict'])
net.cuda()
print('Test SHT PART A Start!')
mae,mse = test_model(net, epoch, test_path, parameters)
print('mae=%.3f,mse=%.3f\n'%(mae, mse))
print('Test SHT PART A Finish!')