-
Notifications
You must be signed in to change notification settings - Fork 14
/
plot_results_dist.py
115 lines (79 loc) · 3.18 KB
/
plot_results_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import experiment_descriptor as ed
import misc
import util.io
root = misc.get_root()
def get_dist(exp_desc, average):
"""
Get the average distance from observed data in every round.
"""
if average == 'mean':
fname = 'dist_obs'
avg_f = np.mean
elif average == 'median':
fname = 'dist_obs_median'
avg_f = np.median
else:
raise ValueError('unknown average: {0}'.format(average))
res_file = os.path.join(root, 'results', exp_desc.get_dir(), fname)
if os.path.exists(res_file + '.pkl'):
avg_dist = util.io.load(res_file)
else:
exp_dir = os.path.join(root, 'experiments', exp_desc.get_dir(), '0')
_, obs_xs = util.io.load(os.path.join(exp_dir, 'gt'))
results = util.io.load(os.path.join(exp_dir, 'results'))
if isinstance(exp_desc.inf, ed.PostProp_Descriptor):
_, _, _, all_xs = results
elif isinstance(exp_desc.inf, ed.SNPE_MDN_Descriptor):
_, _, all_xs, _ = results
elif isinstance(exp_desc.inf, ed.SNL_Descriptor):
_, all_xs, _ = results
else:
raise TypeError('unsupported experiment descriptor')