-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.py
77 lines (67 loc) · 2.7 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import scipy.io as sio
import os
from model.metric import *
import torch
import argparse
from utils import Params
def parse_args():
"""
Args:
config: json file with hyperparams and exp settings
seed: random seed value
stage: 1 for traing VAE, 2 for optimization, and 12 for both
logging:
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='13', help='config filename')
parser.add_argument('--tag', type=str, default='test', help='data tag')
args = parser.parse_args()
return args
args = parse_args()
path_root = './experiments/base/{}'.format(args.config)
all_recons = []
all_inputs = []
for i in [0, 1, 2, 3, 5, 6, 7, 8, 11, 13, 14, 15]:
data = sio.loadmat('{}/data/qry_{}.mat'.format(path_root, i))
recons = data['recons']
inputs = data['inputs']
# recons = data['recons'][:, 8:, :]
# inputs = data['inputs'][:, 8:, :]
all_recons.append(recons)
all_inputs.append(inputs)
for i in [4, 9, 10, 12]:
data = sio.loadmat('{}/data/unknown_qry_{}.mat'.format(path_root, i))
recons = data['recons']
inputs = data['inputs']
all_recons.append(recons)
all_inputs.append(inputs)
all_recons = np.concatenate(all_recons, axis=0)
all_inputs = np.concatenate(all_inputs, axis=0)
recons_torch = torch.Tensor(all_recons)
inputs_torch = torch.Tensor(all_inputs)
mse_total = mse(recons_torch, inputs_torch)
mse_total = mse_total.mean([1, 2, 3])
mse_total = mse_total.cpu().detach().numpy()
vpt_total = vpt(recons_torch, inputs_torch)
vpt_total = vpt_total.cpu().detach().numpy()
dst_total = dst(all_recons, all_inputs)
dst_total = dst_total.mean(1)
vpd_total = vpd(all_recons, all_inputs)
print('mse for seq avg = {}'.format(mse_total.mean()))
print('mse for seq std = {}'.format(mse_total.std()))
print('vpt for seq avg = {}'.format(vpt_total.mean()))
print('vpt for seq std = {}'.format(vpt_total.std()))
print('dst for seq avg = {}'.format(dst_total.mean()))
print('dst for seq std = {}'.format(dst_total.std()))
print('vpd for seq avg = {}'.format(vpd_total.mean()))
print('vpd for seq std = {}'.format(vpd_total.std()))
# with open('{}/metric.txt'.format(path_root), 'a+') as f:
# f.write('mse for seq avg = {}\n'.format(mse_total.mean()))
# f.write('mse for seq std = {}\n'.format(mse_total.std()))
# f.write('vpt for seq avg = {}\n'.format(vpt_total.mean()))
# f.write('vpt for seq std = {}\n'.format(vpt_total.std()))
# f.write('dst for seq avg = {}\n'.format(dst_total.mean()))
# f.write('dst for seq std = {}\n'.format(dst_total.std()))
# f.write('vpd for seq avg = {}\n'.format(vpd_total.mean()))
# f.write('vpd for seq std = {}\n'.format(vpd_total.std()))