-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgridworld.py
55 lines (42 loc) · 1.67 KB
/
gridworld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
from os.path import dirname, join, realpath
dir_path = dirname(dirname(realpath(__file__)))
sys.path.insert(1, join(dir_path, 'utils'))
import numpy as np
from env import GridWorld
def iterative_policy_evaluation(env: GridWorld, gamma: float, theta: float) -> np.ndarray:
'''
In-place (asynchronous) iterative policy evaluation for equiproable policy
'''
value_function = np.zeros((env.height, env.width))
while True:
delta = 0
for state in env.state_space:
if env.terminated(state):
continue
x, y = state[0], state[1]
old_value = value_function[x, y]
next_history = []
for action in env.action_space:
env.state = np.copy(state)
next_state, reward, _ = env.step(action)
next_history.append((next_state, action, reward))
value = 0
for next_state, action, reward in next_history:
value += env.transition_probs[action] * 1 * (reward +
gamma * value_function[next_state[0], next_state[1]])
value_function[x, y] = value
delta = max(delta, abs(old_value - value_function[x, y]))
if delta < theta:
break
return value_function
if __name__ == '__main__':
height = width = 4
terminal_states = [(0, 0), (height - 1, width - 1)]
env = GridWorld(height, width, terminal_states=terminal_states)
gamma = 1
theta = 1e-5
value_function = iterative_policy_evaluation(env, gamma, theta)
value_function = np.around(np.reshape(
value_function, (height, width)), decimals=1)
print(value_function)