forked from adbrebs/taxi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ext_test.py
87 lines (69 loc) · 3.62 KB
/
ext_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
import logging
import os
import csv
from blocks.model import Model
from blocks.extensions import SimpleExtension
logger = logging.getLogger(__name__)
class RunOnTest(SimpleExtension):
def __init__(self, model_name, model, stream, **kwargs):
super(RunOnTest, self).__init__(**kwargs)
self.model_name = model_name
cg = Model(model.predict(**stream.inputs()))
self.inputs = cg.inputs
self.outputs = model.predict.outputs
req_vars_test = model.predict.inputs + ['trip_id']
self.test_stream = stream.test(req_vars_test)
self.function = cg.get_theano_function()
self.best_dvc = None
self.best_tvc = None
def do(self, which_callback, *args):
iter_no = self.main_loop.log.status['iterations_done']
if 'valid_destination_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_destination_cost']
elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_model_cost_cost']
elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
dvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
else:
raise RuntimeError("Unknown model type")
if 'valid_time_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_time_cost']
elif 'valid_model_cost_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_model_cost_cost']
elif 'valid_model_valid_cost_cost' in self.main_loop.log.current_row:
tvc = self.main_loop.log.current_row['valid_model_valid_cost_cost']
else:
raise RuntimeError("Unknown model type")
output_dvc = (self.best_dvc is None or dvc < self.best_dvc) and 'destination' in self.outputs
output_tvc = (self.best_tvc is None or tvc < self.best_tvc) and 'duration' in self.outputs
if not output_dvc and not output_tvc:
return
if output_dvc:
self.best_dvc = dvc
dest_outname = 'test-dest-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, dvc)
dest_outfile = open(os.path.join('output', dest_outname), 'w')
dest_outcsv = csv.writer(dest_outfile)
dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
logger.info("Generating output for test set: %s" % dest_outname)
if output_tvc:
self.best_tvc = tvc
time_outname = 'test-time-%s-it%09d-cost%.3f.csv' % (self.model_name, iter_no, tvc)
time_outfile = open(os.path.join('output', time_outname), 'w')
time_outcsv = csv.writer(time_outfile)
time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
logger.info("Generating output for test set: %s" % time_outname)
for d in self.test_stream.get_epoch_iterator(as_dict=True):
input_values = [d[k.name] for k in self.inputs]
output_values = self.function(*input_values)
for i in range(d['trip_id'].shape[0]):
if output_dvc:
destination = output_values[self.outputs.index('destination')]
dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]])
if output_tvc:
duration = output_values[self.outputs.index('duration')]
time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))])
if output_dvc:
dest_outfile.close()
if output_tvc:
time_outfile.close()