-
Notifications
You must be signed in to change notification settings - Fork 10
/
max_ent_irl.py
79 lines (65 loc) · 2.38 KB
/
max_ent_irl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
from value_iteration import *
def expected_svf(trans_probs, trajs, policy):
n_states, n_actions, _ = trans_probs.shape
n_t = len(trajs[0])
mu = np.zeros((n_states, n_t))
for traj in trajs:
mu[traj[0][0], 0] += 1
mu[:, 0] = mu[:, 0] / len(trajs)
for t in range(1, n_t):
for s in range(n_states):
mu[s, t] = sum([mu[pre_s, t - 1] * trans_probs[pre_s, policy[pre_s], s] for pre_s in range(n_states)])
return np.sum(mu, 1)
def max_ent_irl(feature_matrix, trans_probs, trajs,
gamma=0.9, n_epoch=20, alpha=0.5):
n_states, d_states = feature_matrix.shape
_, n_actions, _ = trans_probs.shape
feature_exp = np.zeros((d_states))
for episode in trajs:
for step in episode:
feature_exp += feature_matrix[step[0], :]
feature_exp = feature_exp / len(trajs)
theta = np.random.uniform(size=(d_states,))
for _ in range(n_epoch):
r = feature_matrix.dot(theta)
v = value_iteration(trans_probs, r, gamma)
pi = best_policy(trans_probs, v)
exp_svf = expected_svf(trans_probs, trajs, pi)
grad = feature_exp - feature_matrix.T.dot(exp_svf)
theta += alpha * grad
return feature_matrix.dot(theta)
def feature_matrix(env):
return np.eye(env.nS)
def generate_demos(env, policy, n_trajs=100, len_traj=5):
trajs = []
for _ in range(n_trajs):
episode = []
env.reset()
for i in range(len_traj):
cur_s = env.s
state, reward, done, _ = env.step(policy[cur_s])
episode.append((cur_s, policy[cur_s], state))
if done:
for _ in range(i + 1, len_traj):
episode.append((state, 0, state))
break
trajs.append(episode)
return trajs
if __name__ == '__main__':
from envs import gridworld
grid = gridworld.GridworldEnv()
trans_probs, reward = trans_mat(grid)
U = value_iteration(trans_probs, reward)
pi = best_policy(trans_probs, U)
trajs = generate_demos(grid, pi)
res = max_ent_irl(feature_matrix(grid), trans_probs, trajs)
print res
import matplotlib.pyplot as plt
def to_mat(res, shape):
dst = np.zeros(shape)
for i, v in enumerate(res):
dst[i / shape[1], i % shape[1]] = v
return dst
plt.matshow(to_mat(res, grid.shape))
plt.show()