forked from asmith26/h-DQN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_mdp.py
91 lines (79 loc) · 2.14 KB
/
test_mdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import matplotlib.pyplot as plt
from envs.mdp import StochasticMDPEnv
plt.style.use('ggplot')
class Agent:
def __init__(self):
self.seen_6 = False
def select_move(self, state):
if state == 6:
self.seen_6 = True
if state < 6 and not self.seen_6:
return 1
else:
return 0
def update(self, state, action, reward):
pass
def main():
np.set_printoptions(precision=2)
env = StochasticMDPEnv()
agent = Agent()
visits = np.zeros((12, 6))
for episode_thousand in range(12):
for episode in range(1000):
done = False
state = env.reset()
agent.seen_6 = False
visits[episode_thousand][state-1] += 1
while not done:
action = agent.select_move(state)
next_state, reward, done = env.step(action)
visits[episode_thousand][next_state-1] += 1
state = next_state
print(visits/1000)
eps = list(range(1,13))
plt.subplot(2, 3, 1)
plt.plot(eps, visits[:,0]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S1")
plt.grid(True)
plt.subplot(2, 3, 2)
plt.plot(eps, visits[:,1]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S2")
plt.grid(True)
plt.subplot(2, 3, 3)
plt.plot(eps, visits[:,2]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S3")
plt.grid(True)
plt.subplot(2, 3, 4)
plt.plot(eps, visits[:,3]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S4")
plt.grid(True)
plt.subplot(2, 3, 5)
plt.plot(eps, visits[:,4]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S5")
plt.grid(True)
plt.subplot(2, 3, 6)
plt.plot(eps, visits[:,5]/1000)
plt.xlabel("Episodes (*1000)")
plt.ylim(-0.01, 2.0)
plt.xlim(1, 12)
plt.title("S6")
plt.grid(True)
plt.show()
if __name__ == "__main__":
main()