forked from weipu-zhang/STORM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
env_wrapper.py
118 lines (95 loc) · 3.59 KB
/
env_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from collections import deque
import cv2
import gymnasium
import numpy as np
class LifeLossInfo(gymnasium.Wrapper):
def __init__(self, env):
super().__init__(env)
self.lives_info = None
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
current_lives_info = info["lives"]
if current_lives_info < self.lives_info:
info["life_loss"] = True
self.lives_info = info["lives"]
else:
info["life_loss"] = False
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
observation, info = self.env.reset(**kwargs)
self.lives_info = info["lives"]
info["life_loss"] = False
return observation, info
class SeedEnvWrapper(gymnasium.Wrapper):
def __init__(self, env, seed):
super().__init__(env)
self.seed = seed
self.env.action_space.seed(seed)
def reset(self, **kwargs):
kwargs["seed"] = self.seed
obs, _ = self.env.reset(**kwargs)
return obs, _
def step(self, action):
return self.env.step(action)
class MaxLast2FrameSkipWrapper(gymnasium.Wrapper):
def __init__(self, env, skip=4):
super().__init__(env)
self.skip = skip
def reset(self, **kwargs):
obs, _ = self.env.reset(**kwargs)
return obs, _
def step(self, action):
total_reward = 0
self.obs_buffer = deque(maxlen=2)
for _ in range(self.skip):
obs, reward, done, truncated, info = self.env.step(action)
self.obs_buffer.append(obs)
total_reward += reward
if done or truncated:
break
if len(self.obs_buffer) == 1:
obs = self.obs_buffer[0]
else:
obs = np.max(np.stack(self.obs_buffer), axis=0)
return obs, total_reward, done, truncated, info
def build_single_env(env_name, image_size):
env = gymnasium.make(env_name, full_action_space=True, frameskip=1)
from gymnasium.wrappers import AtariPreprocessing
env = AtariPreprocessing(env, screen_size=image_size, grayscale_obs=False)
return env
def build_vec_env(env_list, image_size, num_envs):
# lambda pitfall refs to: https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7
assert num_envs % len(env_list) == 0
env_fns = []
vec_env_names = []
for env_name in env_list:
def lambda_generator(env_name, image_size):
return lambda: build_single_env(env_name, image_size)
env_fns += [
lambda_generator(env_name, image_size)
for i in range(num_envs // len(env_list))
]
vec_env_names += [env_name for i in range(num_envs // len(env_list))]
vec_env = gymnasium.vector.AsyncVectorEnv(env_fns=env_fns)
return vec_env, vec_env_names
if __name__ == "__main__":
vec_env, vec_env_names = build_vec_env(
["ALE/Pong-v5", "ALE/IceHockey-v5", "ALE/Breakout-v5", "ALE/Tennis-v5"],
64,
num_envs=8,
)
current_obs, _ = vec_env.reset()
while True:
action = vec_env.action_space.sample()
obs, reward, done, truncated, info = vec_env.step(action)
# done = done or truncated
if done.any():
print("---------")
print(reward)
print(info["episode_frame_number"])
cv2.imshow("Pong", current_obs[0])
cv2.imshow("IceHockey", current_obs[2])
cv2.imshow("Breakout", current_obs[4])
cv2.imshow("Tennis", current_obs[6])
cv2.waitKey(40)
current_obs = obs