-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear_schedule.py
107 lines (88 loc) · 2.87 KB
/
linear_schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
from utils.test_env import EnvTest
import sys
sys.path.append('..')
class LinearSchedule(object):
def __init__(self, eps_begin, eps_end, nsteps):
"""
Args:
eps_begin: initial exploration
eps_end: end exploration
nsteps: number of steps between the two values of eps
"""
self.epsilon = eps_begin
self.eps_begin = eps_begin
self.eps_end = eps_end
self.nsteps = nsteps
def update(self, t):
"""
Updates epsilon
Args:
t: (int) nth frames
"""
##############################################################
total_decay = self.eps_end - self.eps_begin
decay_per_step = (1.0 * total_decay) / self.nsteps
self.epsilon = self.eps_begin + (decay_per_step * (t * 1.0))
if t >= self.nsteps:
self.epsilon = self.eps_end
##############################################################
class LinearExploration(LinearSchedule):
def __init__(self, env, eps_begin, eps_end, nsteps):
"""
Args:
env: gym environment
eps_begin: initial exploration
eps_end: end exploration
nsteps: number of steps between the two values of eps
"""
self.env = env
super(LinearExploration, self).__init__(eps_begin, eps_end, nsteps)
def get_action(self, best_action):
"""
Returns a random action with prob epsilon, otherwise return the best_action
Args:
best_action: (int) best action according some policy
Returns:
an action
"""
##############################################################
rand_num = np.random.random()
if (rand_num <= self.epsilon):
sample = np.random.randint(0, self.env.num_actions)
return sample
else:
return best_action
##############################################################
def test1():
env = EnvTest((5, 5, 1))
exp_strat = LinearExploration(env, 1, 0, 10)
found_diff = False
for i in range(10):
rnd_act = exp_strat.get_action(0)
if rnd_act != 0 and rnd_act is not None:
found_diff = True
assert found_diff, "Test 1 failed."
print("Test1: ok")
def test2():
env = EnvTest((5, 5, 1))
exp_strat = LinearExploration(env, 1, 0, 10)
exp_strat.update(5)
assert exp_strat.epsilon == 0.5, "Test 2 failed"
print("Test2: ok")
def test3():
env = EnvTest((5, 5, 1))
exp_strat = LinearExploration(env, 1, 0.5, 10)
exp_strat.update(20)
assert exp_strat.epsilon == 0.5, "Test 3 failed"
print("Test3: ok")
def your_test():
"""
Use this to implement your own tests
"""
pass
if __name__ == "__main__":
test1()
test2()
test3()
your_test()