forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
103 lines (81 loc) · 3.31 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import types
import numpy as np
import torch
from torch.autograd import Variable
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize
from envs import make_env
parser = argparse.ArgumentParser(description='RL')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--num-stack', type=int, default=4,
help='number of frames to stack (default: 4)')
parser.add_argument('--log-interval', type=int, default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument('--env-name', default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument('--load-dir', default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
args = parser.parse_args()
env = make_env(args.env_name, args.seed, 0, None)
env = DummyVecEnv([env])
actor_critic, ob_rms = \
torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
if len(env.observation_space.shape) == 1:
env = VecNormalize(env, ret=False)
env.ob_rms = ob_rms
# An ugly hack to remove updates
def _obfilt(self, obs):
if self.ob_rms:
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs
env._obfilt = types.MethodType(_obfilt, env)
render_func = env.venv.envs[0].render
else:
render_func = env.envs[0].render
obs_shape = env.observation_space.shape
obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
current_obs = torch.zeros(1, *obs_shape)
states = torch.zeros(1, actor_critic.state_size)
masks = torch.zeros(1, 1)
def update_current_obs(obs):
shape_dim0 = env.observation_space.shape[0]
obs = torch.from_numpy(obs).float()
if args.num_stack > 1:
current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
current_obs[:, -shape_dim0:] = obs
render_func('human')
obs = env.reset()
update_current_obs(obs)
if args.env_name.find('Bullet') > -1:
import pybullet as p
torsoId = -1
for i in range(p.getNumBodies()):
if (p.getBodyInfo(i)[0].decode() == "torso"):
torsoId = i
while True:
value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
Variable(states, volatile=True),
Variable(masks, volatile=True),
deterministic=True)
states = states.data
cpu_actions = action.data.squeeze(1).cpu().numpy()
# Obser reward and next obs
obs, reward, done, _ = env.step(cpu_actions)
masks.fill_(0.0 if done else 1.0)
if current_obs.dim() == 4:
current_obs *= masks.unsqueeze(2).unsqueeze(2)
else:
current_obs *= masks
update_current_obs(obs)
if args.env_name.find('Bullet') > -1:
if torsoId > -1:
distance = 5
yaw = 0
humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
render_func('human')