-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_policy.py
58 lines (46 loc) · 2.45 KB
/
load_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pickle, tensorflow as tf, tf_util, numpy as np
def load_policy(filename):
with open(filename, 'rb') as f:
data = pickle.loads(f.read())
# assert len(data.keys()) == 2
nonlin_type = data['nonlin_type']
policy_type = [k for k in data.keys() if k != 'nonlin_type'][0]
assert policy_type == 'GaussianPolicy', 'Policy type {} not supported'.format(policy_type)
policy_params = data[policy_type]
assert set(policy_params.keys()) == {'logstdevs_1_Da', 'hidden', 'obsnorm', 'out'}
# Keep track of input and output dims (i.e. observation and action dims) for the user
def build_policy(obs_bo):
def read_layer(l):
assert list(l.keys()) == ['AffineLayer']
assert sorted(l['AffineLayer'].keys()) == ['W', 'b']
return l['AffineLayer']['W'].astype(np.float32), l['AffineLayer']['b'].astype(np.float32)
def apply_nonlin(x):
if nonlin_type == 'lrelu':
return tf_util.lrelu(x, leak=.01) # openai/imitation nn.py:233
elif nonlin_type == 'tanh':
return tf.tanh(x)
else:
raise NotImplementedError(nonlin_type)
# Build the policy. First, observation normalization.
assert list(policy_params['obsnorm'].keys()) == ['Standardizer']
obsnorm_mean = policy_params['obsnorm']['Standardizer']['mean_1_D']
obsnorm_meansq = policy_params['obsnorm']['Standardizer']['meansq_1_D']
obsnorm_stdev = np.sqrt(np.maximum(0, obsnorm_meansq - np.square(obsnorm_mean)))
print('obs', obsnorm_mean.shape, obsnorm_stdev.shape)
normedobs_bo = (obs_bo - obsnorm_mean) / (obsnorm_stdev + 1e-6) # 1e-6 constant from Standardizer class in nn.py:409 in openai/imitation
curr_activations_bd = normedobs_bo
# Hidden layers next
assert list(policy_params['hidden'].keys()) == ['FeedforwardNet']
layer_params = policy_params['hidden']['FeedforwardNet']
for layer_name in sorted(layer_params.keys()):
l = layer_params[layer_name]
W, b = read_layer(l)
curr_activations_bd = apply_nonlin(tf.matmul(curr_activations_bd, W) + b)
# Output layer
W, b = read_layer(policy_params['out'])
output_bo = tf.matmul(curr_activations_bd, W) + b
return output_bo
obs_bo = tf.placeholder(tf.float32, [None, None])
a_ba = build_policy(obs_bo)
policy_fn = tf_util.function([obs_bo], a_ba)
return policy_fn