-
Notifications
You must be signed in to change notification settings - Fork 2
/
VecMonitor.py
36 lines (29 loc) · 1.12 KB
/
VecMonitor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvWrapper
import numpy as np
from copy import deepcopy
class VecMonitor(VecEnvWrapper):
def __init__(self, env):
super().__init__(env)
self.agent_count = self.venv.get_attr("num_players")[0]
def reset(self):
obs = self.venv.reset()
self.episode_rewards = np.zeros((self.venv.num_envs,self.agent_count))
return obs
def step_async(self, act):
self.venv.step_async(act)
def step_wait(self):
obs, rew, done, info = self.venv.step_wait()
self.episode_rewards += rew
index = 0
result_info = []
for d in done:
new_info = dict()
if d:
new_info["real_rewards"] = deepcopy(np.array(info[index]["real_rewards"]))
new_info["episode_rewards"] = deepcopy(self.episode_rewards[index])
self.episode_rewards[index,:] = 0
result_info.append(new_info)
else:
result_info.append([])
index += 1
return obs, rew, done, result_info