-
Notifications
You must be signed in to change notification settings - Fork 3
/
confusion-report
executable file
·169 lines (142 loc) · 5.17 KB
/
confusion-report
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#! /usr/bin/python3
# usage: confusion-report true-positions.csv regions/ predictions/ > report.csv
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'lib')))
import argparse
import collections
import csv
import datetime
import glob
import multiprocessing
import time
import ageo
import numpy as np
_time_0 = time.monotonic()
def progress(message, *args):
global _time_0
sys.stderr.write(
("{}: " + message + "\n").format(
datetime.timedelta(seconds = time.monotonic() - _time_0),
*args))
def warning(message, *args):
sys.stderr.write(
("\t*** " + message + "\n").format(*args))
TruePosition = collections.namedtuple("TruePosition",
("lat", "lon", "ipv4", "asn", "cc"))
positions = {}
def load_true_positions(fname):
global positions
with open(fname) as fp:
rd = csv.DictReader(fp)
if 'id' in rd.fieldnames:
idf = 'id'
elif 'pid' in rd.fieldnames:
idf = 'pid'
else:
raise RuntimeError("{}: can't find id column among {!r}"
.format(fname, rd.fieldnames))
for row in rd:
pos = TruePosition(
float(row['latitude']), float(row['longitude']),
row['address_v4'], row['asn_v4'], row['country_code'].lower())
# sanity check
if not (-90 <= pos.lat < 90) or not (-180 < pos.lon < 180):
warning("{} ({}): position off globe: {}, {}",
row[idf], pos.ipv4, pos.lat, pos.lon)
elif (-1 < pos.lat < 1) and (-1 < pos.lon < 1):
warning("{} ({}): null island: {}, {}",
row[idf], pos.ipv4, pos.lat, pos.lon)
else:
positions[int(row[idf])] = pos
def decode_filename(fname):
fname = os.path.splitext(os.path.basename(fname))[0]
# FIXME: hardcoded tag set and naming convention matching the
# hardcoding in 'calibrate'.
sp = fname.rfind('-')
tag = fname[:sp]
tid = fname[sp+1:]
calg, cset = {
'cbg-m-a': ('CBG', 'Combined'),
'cbg-m-1': ('CBG', 'Separate'),
'oct-m-a': ('Octant', 'Combined'),
'oct-m-1': ('Octant', 'Separate'),
'spo-m-a': ('Spotter (uniform)', 'Combined'),
'spo-m-1': ('Spotter (uniform)', 'Separate'),
'spo-g-a': ('Spotter (gaussian)', 'Combined'),
'spo-g-1': ('Spotter (gaussian)', 'Separate'),
}[tag]
return int(tid), calg, cset
regions = []
def load_regions(rgndir):
global regions
for fname in glob.glob(os.path.join(rgndir, "*.h5")):
label = os.path.splitext(os.path.basename(fname))[0]
rgn = ageo.Location.load(fname)
rgn.compute_probability_matrix_now()
regions.append((label, rgn))
regions.sort()
basemap = None
def load_base_map(mapfile):
global basemap
basemap = ageo.Map(mapfile)
def area_proportion_each_region(fname):
global regions, region_containing, positions, basemap
tid, _, _ = decode_filename(fname)
try:
tpos = positions[tid]
rc = region_containing[tid]
except KeyError:
return fname, None
rv = np.zeros((len(regions),))
loc = ageo.Location.load(fname)
if basemap is not None:
if basemap.contains_point(tpos.lon, tpos.lat):
loc = loc.intersection(basemap)
for i, (_, rgn) in enumerate(regions):
rv[i] = loc.intersection(rgn).area
rv /= rv.sum()
return fname, rv
def get_region_containing(tpos):
global regions
for i, (_, rgn) in enumerate(regions):
if rgn.contains_point(tpos.lon, tpos.lat):
return i
return None
region_containing = {}
def compute_region_containing():
global region_containing
for id, tpos in sorted(positions.items()):
region_containing[id] = get_region_containing(tpos)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("positions")
ap.add_argument("regions")
ap.add_argument("prediction_dir")
ap.add_argument("-m", dest="basemap")
args = ap.parse_args()
progress("loading positions")
load_true_positions(args.positions)
progress("loading regions")
load_regions(args.regions)
if args.basemap:
progress("loading base map")
load_base_map(args.basemap)
progress("computing true containment")
compute_region_containing()
with multiprocessing.Pool() as pool:
predictions = sorted(glob.glob(os.path.join(args.prediction_dir, "*.h5")))
wr = csv.writer(sys.stdout)
wr.writerow(["algorithm", "cal_set", "id", "true_rgn"] +
[label for label, _ in regions])
for fname, prop in pool.imap_unordered(area_proportion_each_region, predictions):
tid, calg, cset = decode_filename(fname)
if prop is None:
progress("{}: {}/{} (no true position)", tid, calg, cset)
else:
wr.writerow(np.concatenate([
[calg, cset, tid, region_containing[tid]],
prop
]))
progress("{}: {}/{}", tid, calg, cset)
main()