forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvs.py
executable file
·84 lines (65 loc) · 2.36 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import gym
import numpy as np
from gym.spaces.box import Box
from baselines import bench
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
try:
import dm_control2gym
except ImportError:
pass
try:
import roboschool
except ImportError:
pass
try:
import pybullet_envs
except ImportError:
pass
def make_env(env_id, seed, rank, log_dir, add_timestep):
def _thunk():
if env_id.startswith("dm"):
_, domain, task = env_id.split('.')
env = dm_control2gym.make(domain_name=domain, task_name=task)
else:
env = gym.make(env_id)
is_atari = hasattr(gym.envs, 'atari') and isinstance(
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(env_id)
env.seed(seed + rank)
obs_shape = env.observation_space.shape
if add_timestep and len(
obs_shape) == 1 and str(env).find('TimeLimit') > -1:
env = AddTimestep(env)
if log_dir is not None:
env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
if is_atari:
env = wrap_deepmind(env)
# If the input has shape (W,H,3), wrap for PyTorch convolutions
obs_shape = env.observation_space.shape
if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
env = WrapPyTorch(env)
return env
return _thunk
class AddTimestep(gym.ObservationWrapper):
def __init__(self, env=None):
super(AddTimestep, self).__init__(env)
self.observation_space = Box(
self.observation_space.low[0],
self.observation_space.high[0],
[self.observation_space.shape[0] + 1],
dtype=self.observation_space.dtype)
def observation(self, observation):
return np.concatenate((observation, [self.env._elapsed_steps]))
class WrapPyTorch(gym.ObservationWrapper):
def __init__(self, env=None):
super(WrapPyTorch, self).__init__(env)
obs_shape = self.observation_space.shape
self.observation_space = Box(
self.observation_space.low[0, 0, 0],
self.observation_space.high[0, 0, 0],
[obs_shape[2], obs_shape[1], obs_shape[0]],
dtype=self.observation_space.dtype)
def observation(self, observation):
return observation.transpose(2, 0, 1)