-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_dqn.py
133 lines (113 loc) · 3.87 KB
/
run_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Driver file for running DQN agent on Atari
"""
import os
import gym
import numpy as np
import random
import logging
import argparse
from models.linear_models import *
from models.deep_dqn import DQN
import utils
from utils.gym_atari_wrappers import get_env, get_wrapper_by_name
from utils.schedule import LinearSchedule
from configs.dqn_config import Config
from learn import OptimizerSpec, dqn_learn
from utils.tf_wrapper import PixelBonus
# do logging
# logger = logging.getLogger(__name__)
# logger.addHandler(logging.StreamHandler(sys.stderr))
def print_key_pairs(v, title="Parameters"):
"""
Print key-value pairs for user-specified args
---> borrowed from avast's benchmarks.utils
:param v:
:param title:c
:return:
"""
items = v.items() if type(v) is dict else v
logging.info("\n" + "-" * 40)
logging.info(title)
logging.info("-" * 40)
for key,value in items:
logging.info("{:<20}: {:<10}".format(key, str(value)))
logging.info("-" * 40)
def update_tf_wrapper_args(args, tf_flags):
"""
take input command line args to DQN agent and update tensorflow wrapper default
settings
:param args:
:param FLAGS:
:return:
"""
# doesn't support boolean arguments
to_parse = args.wrapper_args
if to_parse:
for kwarg in to_parse:
keyname, val = kwarg.split('=')
if keyname in ['ckpt_path', 'data_path', 'samples_path', 'summary_path']:
# if directories don't exist, make them
if not os.path.exists(val):
os.makedirs(val)
tf_flags.update(keyname, val)
elif keyname in ['data', 'model']:
tf_flags.update(keyname, val)
elif keyname in ['mmc_beta']:
tf_flags.update(keyname, float(val))
else:
tf_flags.update(keyname, int(val))
return tf_flags
def main(config, env):
"""
Run DQN on Atari
:param config:
:param env:
:return:
"""
FLAGS = update_tf_wrapper_args(args, utils.tf_wrapper.FLAGS)
def stopping_criterion(env, t):
# t := number of steps of wrapped env
# different from number of steps in underlying env
return get_wrapper_by_name(env, "Monitor").get_total_steps() >= \
config.max_timesteps
# optimizer_spec = OptimizerSpec(
# constructor=torch.optim.Adam,
# kwargs=dict(lr=config.learning_rate, eps=config.epsilon),
# )
optimizer_spec = OptimizerSpec(
constructor=torch.optim.RMSprop,
kwargs=dict(lr=config.learning_rate, momentum=config.momentum, eps=config.epsilon)
)
exploration_schedule = LinearSchedule(1000000, 0.1)
dqn_learn(
env=env, q_func=DQN, optimizer_spec=optimizer_spec,
density=PixelBonus, cnn_kwargs=FLAGS, config=config,
exploration=exploration_schedule, stopping_criterion=stopping_criterion,
)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument("-W", "--wrapper_args", nargs='+',
help='args to add onto tensorflow wrapper')
args = argparser.parse_args()
# get config file
config_file = Config()
# Run training; set seeds for reproducibility
seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# get environment
if config_file.deep:
# this sets up the Atari environment
env = get_env(config_file.env_name, seed)
else:
env = gym.make(config_file.env_name)
# if directories don't exist, make them
if not os.path.exists(config_file.output_path):
os.makedirs(config_file.output_path)
# Set up logger
logging.basicConfig(filename=config_file.log_path, level=logging.INFO)
# print all argument variables
# print_key_pairs(args.__dict__.items(), title='Command line args')
main(config_file, env)