-
Notifications
You must be signed in to change notification settings - Fork 0
/
gridworld_policy.py
223 lines (181 loc) · 9.25 KB
/
gridworld_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from gridworld_env import GridWorldEnv, Action, CellType
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import math
import random
class SingleStatePolicy:
"""
The SingleStatePolicy class represents a policy for a single state in a Markov Decision Process (MDP).
Attributes:
- up (float): The probability of taking the "up" action.
- down (float): The probability of taking the "down" action.
- left (float): The probability of taking the "left" action.
- right (float): The probability of taking the "right" action.
Methods:
- to_list(): Converts the policy to a list of action probabilities.
- update_from_list(values): Updates the policy from a list of action probabilities.
- __str__(): Returns a string representation of the policy.
- kl_divergence(policy_at_state_p, policy_at_state_q): Computes the Kullback-Leibler divergence between two policies.
"""
def __init__(self, up=0.0, down=0.0, left=0.0, right=0.0) -> None:
self.up = up
self.down = down
self.left = left
self.right = right
def to_list(self):
"""Convert the policy to a list of action probabilities."""
return [self.up, self.down, self.left, self.right]
def update_from_list(self, values):
"""Update the policy from a list of action probabilities."""
self.up, self.down, self.left, self.right = values
def __str__(self):
return f"Policy(up={self.up}, down={self.down}, left={self.left}, right={self.right})"
@staticmethod
def kl_divergence(policy_at_state_p: 'SingleStatePolicy', policy_at_state_q: 'SingleStatePolicy') -> float:
sum_v = 0.0
for action_p, action_q in zip(policy_at_state_p.to_list(), policy_at_state_q.to_list()):
sum_v += action_p * math.log2(action_p/(action_q + 1e-100) + 1e-100)
return sum_v
class GridWorldPolicy:
"""
GridWorldPolicy class represents a policy for a grid world environment.
Attributes:
- policy_grid (dict): A dictionary that maps states to individual state policies.
- grid_world_env (GridWorldEnv): The grid world environment associated with the policy.
Methods:
- __init__(grid_world_env: GridWorldEnv, default_policy=(0.25, 0.25, 0.25, 0.25)):
Initializes a new GridWorldPolicy instance.
- get_policy(state: tuple) -> SingleStatePolicy:
Returns the policy for a specific state.
- action_probabilities(state: tuple) -> dict[Action, float]:
Returns the probabilities of taking each action given a state.
- interpolate(gamma):
Interpolates policies based on gamma, creating a new PolicyGrid.
- visualize():
Visualizes the policy grid on the GridWorld layout.
- _draw_arrows(ax, state, policy):
Helper method to draw policy arrows on the grid.
"""
def __init__(self, grid_world_env: GridWorldEnv, default_policy=(0.25, 0.25, 0.25, 0.25)) -> None:
self.policy_grid = {}
self.grid_world_env = grid_world_env
# Initialize policy grid with default policy for each state
for x in range(self.grid_world_env.grid_size_x):
for y in range(self.grid_world_env.grid_size_y):
self.policy_grid[(x, y)] = SingleStatePolicy(*default_policy)
def get_policy(self, state: tuple):
"""Get the policy for a specific state."""
return self.policy_grid[state]
def action_probabilities(self, state: tuple) -> dict[Action, float]:
"""Return the probabilities of taking each action given a state."""
policy_at_state = self.get_policy(state)
return {
Action.UP: policy_at_state.up,
Action.DOWN: policy_at_state.down,
Action.LEFT: policy_at_state.left,
Action.RIGHT: policy_at_state.right
}
def interpolate(self, gamma: float) -> 'GridWorldPolicy':
"""Interpolate policies based on gamma, creating a new PolicyGrid."""
new_policy_grid = GridWorldPolicy(self.grid_world_env)
num_actions = len(Action) # Number of possible actions
for state, policy in self.policy_grid.items():
original_policy_list = policy.to_list()
interpolated_policy_list = [gamma * p + (1 - gamma) * (1 / num_actions) for p in original_policy_list]
new_policy_grid.policy_grid[state].update_from_list(interpolated_policy_list)
return new_policy_grid
def add_noise(self, noise_level: float) -> 'GridWorldPolicy':
"""
Generates a new GridWorldPolicy instance with random noise added to the policy of each state.
Parameters:
- noise_level (float): The level of noise, between 0 and 1, representing the maximum percentage change.
Returns:
- GridWorldPolicy: A new GridWorldPolicy instance with noisy policies.
"""
# Create a new GridWorldPolicy instance
new_policy = GridWorldPolicy(self.grid_world_env)
# Iterate over each state in the policy grid
for state, policy in self.policy_grid.items():
# Get the current policy as a list
current_policy_list = policy.to_list()
noisy_policy_list = []
# Add noise to each action's probability
for prob in current_policy_list:
# Add a small constant factor to avoid zero probability
constant_noise = random.uniform(0, noise_level / len(current_policy_list))
noise = random.uniform(-noise_level, noise_level) * prob
noisy_prob = min(max(prob + noise + constant_noise, 0), 1)
noisy_policy_list.append(noisy_prob)
# Normalize to ensure the sum of probabilities is 1
total = sum(noisy_policy_list)
normalized_policy_list = [p / total for p in noisy_policy_list]
# Update the new policy grid with the noisy policy
new_policy.policy_grid[state] = SingleStatePolicy(*normalized_policy_list)
return new_policy
def visualize(self):
"""Visualize the policy grid on the GridWorld layout."""
fig, ax = plt.subplots()
# Define colors for each cell type
cell_colors = {
CellType.EMPTY: (192 / 255, 192 / 255, 192 / 255), # Grey
CellType.TARGET: (255 / 255, 215 / 255, 0 / 255), # Yellow
CellType.START: (0 / 255, 0 / 255, 255 / 255), # Blue
CellType.OBSTACLE: (0 / 255, 0 / 255, 0 / 255), # Black
CellType.TREASURE: (0 / 255, 255 / 255, 0 / 255), # Green
CellType.TRAP: (255 / 255, 0 / 255, 0 / 255) # Red
}
# Define a mapping from cell type to integer
cell_type_to_int = {
CellType.EMPTY: 0,
CellType.TARGET: 1,
CellType.START: 2,
CellType.OBSTACLE: 3,
CellType.TREASURE: 4,
CellType.TRAP: 5
}
# Create the colormap using the specified colors
cmap = mcolors.ListedColormap([cell_colors[key] for key in cell_type_to_int.keys()])
# Convert layout to numerical values for coloring
data = np.zeros(self.grid_world_env.layout.shape, dtype=int)
for i in range(self.grid_world_env.grid_size_x):
for j in range(self.grid_world_env.grid_size_y):
cell_type = self.grid_world_env.layout[i, j]
data[i, j] = cell_type_to_int[cell_type]
ax.imshow(data, cmap=cmap)
# Display policy as arrows on the grid
for state, policy in self.policy_grid.items():
self._draw_arrows(ax, state, policy)
plt.show()
def _draw_arrows(self, ax, state, policy):
"""Helper method to draw policy arrows on the grid, skipping obstacles and targets."""
i, j = state
cell_type = self.grid_world_env.layout[i, j]
# Skip drawing arrows for obstacles and targets
if cell_type in [CellType.OBSTACLE, CellType.TARGET]:
return
arrow_scale = 0.3 # Scale factor for the arrow size
head_width = 0.1 # Width of the arrow head
head_length = 0.1 # Length of the arrow head
arrow_color = 'white' # Color of the arrow
# Draw arrows based on the policy probabilities
if policy.up > 0:
ax.arrow(j, i, 0, -arrow_scale * policy.up, head_width=head_width, head_length=head_length, fc=arrow_color,
ec=arrow_color)
if policy.down > 0:
ax.arrow(j, i, 0, arrow_scale * policy.down, head_width=head_width, head_length=head_length, fc=arrow_color,
ec=arrow_color)
if policy.left > 0:
ax.arrow(j, i, -arrow_scale * policy.left, 0, head_width=head_width, head_length=head_length,
fc=arrow_color, ec=arrow_color)
if policy.right > 0:
ax.arrow(j, i, arrow_scale * policy.right, 0, head_width=head_width, head_length=head_length,
fc=arrow_color, ec=arrow_color)
# Example usage
if __name__ == "__main__":
env = GridWorldEnv('layout.txt')
policy_grid = GridWorldPolicy(env)
# Example of setting a specific policy for a state
policy_grid.policy_grid[(1, 1)].update_from_list([0.5, 0.1, 0.2, 0.2])
# Visualize or use the policy grid as needed
policy_grid.visualize()