-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph_handler.py
83 lines (68 loc) · 3.12 KB
/
graph_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gzip
import json
from json import encoder
import os
import tensorflow as tf
from evaluator import Evaluation, AccuracyEvaluation
from my.utils import short_floats
import pickle
class GraphHandler(object):
def __init__(self, config, model):
self.config = config
self.model = model
self.saver = tf.train.Saver(max_to_keep=config.max_to_keep)
self.writer = None
self.save_path = os.path.join(config.save_dir, config.model_name)
def initialize(self, sess):
sess.run(tf.global_variables_initializer())
if self.config.load:
self._load(sess)
if self.config.mode == 'train':
self.writer = tf.summary.FileWriter(self.config.log_dir, graph=tf.get_default_graph())
def save(self, sess, global_step=None):
saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
saver.save(sess, self.save_path, global_step=global_step)
def _load(self, sess):
config = self.config
vars_ = {var.name.split(":")[0]: var for var in tf.global_variables()}
if config.load_ema:
ema = self.model.var_ema
# for var in tf.trainable_variables():
# del vars_[var.name.split(":")[0]]
# vars_[ema.average_name(var)] = var
saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)
if config.load_path:
save_path = config.load_path
elif config.load_step > 0:
save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
else:
save_dir = config.save_dir
checkpoint = tf.train.get_checkpoint_state(save_dir)
assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
save_path = checkpoint.model_checkpoint_path
print("Loading saved model from {}".format(save_path))
saver.restore(sess, save_path)
def add_summary(self, summary, global_step):
self.writer.add_summary(summary, global_step)
def add_summaries(self, summaries, global_step):
for summary in summaries:
self.add_summary(summary, global_step)
def dump_eval(self, e, precision=2, path=None):
assert isinstance(e, Evaluation)
if self.config.dump_pickle:
path = path or os.path.join(self.config.eval_dir, "{}-{}.pklz".format(e.data_type, str(e.global_step).zfill(6)))
with gzip.open(path, 'wb', compresslevel=3) as fh:
pickle.dump(e.dict, fh)
else:
path = path or os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
with open(path, 'w') as fh:
json.dump(short_floats(e.dict, precision), fh)
def dump_answer(self, e, path=None):
assert isinstance(e, Evaluation)
import csv
path = path or os.path.join(self.config.answer_dir, "{}-{}.csv".format(e.data_type, str(e.global_step).zfill(6)))
with open(path, 'w', newline='') as fh:
writer=csv.writer(fh)
writer.writerows(e.ans)
fh.close()
# json.dump(e.ans, fh)