-
Notifications
You must be signed in to change notification settings - Fork 8
/
meta_test.py
167 lines (139 loc) · 6.36 KB
/
meta_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from metalearner import MetaLearner
from sampler import BatchSampler
import config
import time
import copy
from multiprocessing import Process
import pickle
import random
import numpy as np
import tensorflow as tf
from utils import parse, config_all, parse_roadnet
import os
from copy import deepcopy as dp
from traffic import *
def main(args):
'''
Perform meta-testing for MAML, Metalight, Random, and Pretrained
Arguments:
args: generated in utils.py:parse()
'''
# configuration: experiment, agent, traffic_env, path
_dic_exp_conf, _dic_agent_conf, _dic_traffic_env_conf, _dic_path = config_all(args)
traffic_file_list = _dic_traffic_env_conf["TRAFFIC_CATEGORY"][args.traffic_group]
process_list = []
for traffic_file in traffic_file_list:
dic_exp_conf = dp(_dic_exp_conf)
dic_agent_conf = dp(_dic_agent_conf)
dic_traffic_env_conf = dp(_dic_traffic_env_conf)
dic_path = dp(_dic_path)
traffic_of_tasks = [traffic_file]
dic_traffic_env_conf['ROADNET_FILE'] = dic_traffic_env_conf["TRAFFIC_CATEGORY"]["traffic_info"][traffic_file][2]
dic_traffic_env_conf['FLOW_FILE'] = dic_traffic_env_conf["TRAFFIC_CATEGORY"]["traffic_info"][traffic_file][3]
# path
_time = time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))
postfix = ""
dic_path.update({
"PATH_TO_MODEL": os.path.join(dic_path["PATH_TO_MODEL"], traffic_file + "_" + _time + postfix),
"PATH_TO_WORK_DIRECTORY": os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"],
traffic_file + "_" + _time + postfix),
"PATH_TO_GRADIENT": os.path.join(dic_path["PATH_TO_GRADIENT"], traffic_file + "_" + _time + postfix,
"gradient"),
"PATH_TO_DATA": os.path.join(dic_path["PATH_TO_DATA"], traffic_file.split(".")[0])
})
# traffic env
dic_traffic_env_conf["TRAFFIC_FILE"] = traffic_file
dic_traffic_env_conf["TRAFFIC_IN_TASKS"] = [traffic_file]
# parse roadnet
roadnet_path = os.path.join(dic_path['PATH_TO_DATA'], dic_traffic_env_conf['ROADNET_FILE'])
lane_phase_info = parse_roadnet(roadnet_path)
dic_traffic_env_conf["LANE_PHASE_INFO"] = lane_phase_info["intersection_1_1"]
dic_traffic_env_conf["num_lanes"] = int(len(lane_phase_info["intersection_1_1"]["start_lane"]) / 4) # num_lanes per direction
dic_traffic_env_conf["num_phases"] = len(lane_phase_info["intersection_1_1"]["phase"])
dic_exp_conf.update({
"TRAFFIC_FILE": traffic_file, # Todo
"TRAFFIC_IN_TASKS": traffic_of_tasks})
single_process = args.single_process
if single_process:
_train(copy.deepcopy(dic_exp_conf),
copy.deepcopy(dic_agent_conf),
copy.deepcopy(dic_traffic_env_conf),
copy.deepcopy(deploy_dic_path))
else:
p = Process(target=_train, args=(copy.deepcopy(dic_exp_conf),
copy.deepcopy(dic_agent_conf),
copy.deepcopy(dic_traffic_env_conf),
copy.deepcopy(dic_path)))
process_list.append(p)
num_process = args.num_process
if not single_process:
i = 0
list_cur_p = []
for p in process_list:
if len(list_cur_p) < num_process:
print(i)
p.start()
list_cur_p.append(p)
i += 1
if len(list_cur_p) < num_process:
continue
idle = check_all_workers_working(list_cur_p)
while idle == -1:
time.sleep(1)
idle = check_all_workers_working(
list_cur_p)
del list_cur_p[idle]
for i in range(len(list_cur_p)):
p = list_cur_p[i]
p.join()
def check_all_workers_working(list_cur_p):
for i in range(len(list_cur_p)):
if not list_cur_p[i].is_alive():
return i
return -1
def _train(dic_exp_conf, dic_agent_conf, dic_traffic_env_conf, dic_path):
'''
Perform meta-testing for MAML, Metalight, Random, and Pretrained
Arguments:
dic_exp_conf: dict, configuration of this experiment
dic_agent_conf: dict, configuration of agent
dic_traffic_env_conf: dict, configuration of traffic environment
dic_path: dict, path of source files and output files
'''
random.seed(dic_agent_conf['SEED'])
np.random.seed(dic_agent_conf['SEED'])
tf.set_random_seed(dic_agent_conf['SEED'])
sampler = BatchSampler(dic_exp_conf=dic_exp_conf,
dic_agent_conf=dic_agent_conf,
dic_traffic_env_conf=dic_traffic_env_conf,
dic_path=dic_path,
batch_size=args.fast_batch_size,
num_workers=args.num_workers)
policy = config.DIC_AGENTS[args.algorithm](
dic_agent_conf=dic_agent_conf,
dic_traffic_env_conf=dic_traffic_env_conf,
dic_path=dic_path
)
metalearner = MetaLearner(sampler, policy,
dic_agent_conf=dic_agent_conf,
dic_traffic_env_conf=dic_traffic_env_conf,
dic_path=dic_path
)
if dic_agent_conf['PRE_TRAIN']:
if not dic_agent_conf['PRE_TRAIN_MODEL_NAME'] == 'random':
params = pickle.load(open(os.path.join('model', 'initial', "common",
dic_agent_conf['PRE_TRAIN_MODEL_NAME'] + '.pkl'), 'rb'))
metalearner.meta_params = params
metalearner.meta_target_params = params
tasks = dic_exp_conf['TRAFFIC_IN_TASKS']
episodes = None
for batch_id in range(dic_exp_conf['NUM_ROUNDS']):
tasks = [dic_exp_conf['TRAFFIC_FILE']]
if dic_agent_conf['MULTI_EPISODES']:
episodes = metalearner.sample_meta_test(tasks[0], batch_id, episodes)
else:
episodes = metalearner.sample_meta_test(tasks[0], batch_id)
if __name__ == '__main__':
args = parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu
main(args)