-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
57 lines (42 loc) · 1.86 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
import gymnasium as gym
import os
from stable_baselines3 import DQN
import sumo_rl
import csv
import argparse
from generator import TrafficGenerator
def main(models_path, csv_file_path, sound_file):
NUM_SECONDS = 86400
ITERATIONS = int(NUM_SECONDS / 5)
if not os.path.exists(csv_file_path):
os.makedirs(csv_file_path)
trafficGenerator = TrafficGenerator(NUM_SECONDS, 20000, "test", "")
trafficGenerator.generate_routefile(random.randint(0, 100000), 0, "bimodal", False)
env = gym.make('sumo-rl-v0',
net_file='environment.net.xml',
route_file='test.rou.xml',
use_gui=False,
num_seconds=NUM_SECONDS)
for model_path in models_path:
# Load model
model = DQN.load(model_path, env=env)
obs, info = env.reset()
# Remove the '.zip' extension from the model_path to use as the CSV file name
model_name = os.path.splitext(model_path)[0]
with open(os.path.join(csv_file_path, f"{model_name}.csv"), 'a', newline='') as csv_file:
fieldnames = info.keys()
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
if csv_file.tell() == 0:
writer.writeheader()
for _ in range(ITERATIONS):
env.render()
action, _ = model.predict(obs)
obs, reward, done, _, info = env.step(action)
writer.writerow(info)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run traffic signal control with RL models.")
parser.add_argument('--models', nargs='+', required=True, help="List of model paths")
parser.add_argument('--output_dir', type=str, required=True, help="Directory for output CSV files")
args = parser.parse_args()
main(args.models, args.output_dir, args.sound)