Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/opendilab/DI-engine into po…
Browse files Browse the repository at this point in the history
…lish-network-comments
  • Loading branch information
puyuan1996 committed Dec 28, 2023
2 parents adf3f18 + ac9fa76 commit df2fe3e
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 17 deletions.
4 changes: 2 additions & 2 deletions ding/entry/application_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def eval(
load_path: Optional[str] = None,
replay_path: Optional[str] = None,
) -> float:
r"""
"""
Overview:
Pure evaluation entry.
Pure policy evaluation entry. Evaluate mean episode return and save replay videos.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
Expand Down
4 changes: 2 additions & 2 deletions ding/envs/env_manager/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):

config = dict(
episode_num=float("inf"),
max_retry=5,
max_retry=1,
step_timeout=None,
auto_reset=True,
retry_type='reset',
Expand Down Expand Up @@ -678,7 +678,7 @@ def wait(rest_conn: list, wait_num: int, timeout: Optional[float] = None) -> Tup
class SyncSubprocessEnvManager(AsyncSubprocessEnvManager):
config = dict(
episode_num=float("inf"),
max_retry=5,
max_retry=1,
step_timeout=None,
auto_reset=True,
reset_timeout=None,
Expand Down
12 changes: 12 additions & 0 deletions dizoo/petting_zoo/entry/ptz_simple_spread_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
from ding.entry import eval


def main():
ckpt_path = './ckpt_best.pth.tar'
replay_path = './replay_videos'
eval((main_config, create_config), seed=0, load_path=ckpt_path, replay_path=replay_path)


if __name__ == "__main__":
main()
88 changes: 75 additions & 13 deletions dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,52 @@
from pettingzoo.mpe.simple_spread.simple_spread import Scenario


class PTZRecordVideo(gym.wrappers.RecordVideo):
def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
# gymnasium==0.27.1
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = self.env.step(action)

# Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True
if not (self.terminated is True or self.truncated is True): # the first location for modifications
# increment steps and episodes
self.step_id += 1
if not self.is_vector_env:
if terminateds or truncateds:
self.episode_id += 1
self.terminated = terminateds
self.truncated = truncateds
elif terminateds[0] or truncateds[0]:
self.episode_id += 1
self.terminated = terminateds[0]
self.truncated = truncateds[0]

if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
else:
if not self.is_vector_env:
if terminateds is True or truncateds is True: # the second location for modifications
self.close_video_recorder()
elif terminateds[0] or truncateds[0]:
self.close_video_recorder()

elif self._video_enabled():
self.start_video_recorder()

return observations, rewards, terminateds, truncateds, infos


@ENV_REGISTRY.register('petting_zoo')
class PettingZooEnv(BaseEnv):
# Now only supports simple_spread_v2.
Expand Down Expand Up @@ -51,19 +97,7 @@ def reset(self) -> np.ndarray:
self._env = parallel_env(
N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles
)
# dynamic seed reduces training speed greatly
# if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
# np_seed = 100 * np.random.randint(1, 1000)
# self._env.seed(self._seed + np_seed)
if self._replay_path is not None:
self._env = gym.wrappers.Monitor(
self._env, self._replay_path, video_callable=lambda episode_id: True, force=True
)
if hasattr(self, '_seed'):
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
if not self._init_flag:
self._env.reset()
self._agents = self._env.agents

self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents})
Expand Down Expand Up @@ -144,7 +178,14 @@ def reset(self) -> np.ndarray:
for agent in self._agents
}
)
if self._replay_path is not None:
self._env.render_mode = 'rgb_array'
self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True)
self._init_flag = True
if hasattr(self, '_seed'):
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
# self._eval_episode_return = {agent: 0. for agent in self._agents}
self._eval_episode_return = 0.
self._step_count = 0
Expand Down Expand Up @@ -320,6 +361,7 @@ def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False
scenario = Scenario()
world = scenario.make_world(N)
super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio)
self.render_mode = 'rgb_array'
self.metadata['name'] = "simple_spread_v2"

def _execute_world_step(self):
Expand Down Expand Up @@ -355,3 +397,23 @@ def _execute_world_step(self):
reward = agent_reward

self.rewards[agent.name] = reward

def render(self):
if self.render_mode is None:
gym.logger.warn(
"You are calling render method without specifying any render mode."
)
return
import pygame

self.enable_render(self.render_mode)

self.draw()
observation = np.array(pygame.surfarray.pixels3d(self.screen))
if self.render_mode == "human":
pygame.display.flip()
return (
np.transpose(observation, axes=(1, 0, 2))
if self.render_mode == "rgb_array"
else None
)

0 comments on commit df2fe3e

Please sign in to comment.