Skip to content

Commit

Permalink
fix(nyz): update ptz to latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Mar 1, 2023
1 parent 203be4b commit da2e590
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
3 changes: 2 additions & 1 deletion ding/envs/env_manager/subprocess_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import time
import copy
import gymnasium
import gym
import traceback
import torch
Expand Down Expand Up @@ -104,7 +105,7 @@ def _create_state(self) -> None:
self._reset_param = {i: {} for i in range(self.env_num)}
if self._shared_memory:
obs_space = self._observation_space
if isinstance(obs_space, gym.spaces.Dict):
if isinstance(obs_space, (gym.spaces.Dict, gymnasium.spaces.Dict)):
# For multi_agent case, such as multiagent_mujoco and petting_zoo mpe.
# Now only for the case that each agent in the team have the same obs structure
# and corresponding shape.
Expand Down
23 changes: 11 additions & 12 deletions dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, List, Union, Optional, Dict
import gym
import gymnasium as gym
import numpy as np
import pettingzoo
from functools import reduce
Expand All @@ -8,6 +8,9 @@
from ding.torch_utils import to_ndarray, to_list
from ding.envs.common.common_function import affine_transform
from ding.utils import ENV_REGISTRY, import_module
from pettingzoo.utils.conversions import parallel_wrapper_fn
from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env
from pettingzoo.mpe.simple_spread.simple_spread import Scenario


@ENV_REGISTRY.register('petting_zoo')
Expand Down Expand Up @@ -52,13 +55,14 @@ def reset(self) -> np.ndarray:
# if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
# np_seed = 100 * np.random.randint(1, 1000)
# self._env.seed(self._seed + np_seed)
if hasattr(self, '_seed'):
self._env.seed(self._seed)
if self._replay_path is not None:
self._env = gym.wrappers.Monitor(
self._env, self._replay_path, video_callable=lambda episode_id: True, force=True
)
obs = self._env.reset()
if hasattr(self, '_seed'):
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
if not self._init_flag:
self._agents = self._env.agents

Expand All @@ -69,7 +73,7 @@ def reset(self) -> np.ndarray:
elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
self._action_dim = (single_agent_obs_space.n, )
else:
raise Exception('Only support `Box` or `Discrte` obs space for single agent.')
raise Exception('Only support `Box` or `Discrete` obs space for single agent.')

# only for env 'simple_spread_v2', n_agent = 5
# now only for the case that each agent in the team have the same obs structure and corresponding shape.
Expand Down Expand Up @@ -173,7 +177,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high
)

obs, rew, done, info = self._env.step(action)
obs, rew, done, trunc, info = self._env.step(action)
obs_n = self._process_obs(obs)
rew_n = np.array([sum([rew[agent] for agent in self._agents])])
# collide_sum = 0
Expand Down Expand Up @@ -308,18 +312,13 @@ def reward_space(self) -> gym.spaces.Space:
return self._reward_space


from pettingzoo.utils.conversions import parallel_wrapper_fn
from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env
from pettingzoo.mpe.scenarios.simple_spread import Scenario


class simple_spread_raw_env(SimpleEnv):

def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False):
assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1."
scenario = Scenario()
world = scenario.make_world(N)
super().__init__(scenario, world, max_cycles, continuous_actions, local_ratio)
super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio)
self.metadata['name'] = "simple_spread_v2"

def _execute_world_step(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
'mpire>=2.3.5',
'pynng',
'redis',
'pettingzoo==1.12.0',
'pettingzoo',
'DI-treetensor>=0.3.0',
'DI-toolkit>=0.0.2',
'hbutils>=0.5.0',
Expand Down

0 comments on commit da2e590

Please sign in to comment.