Skip to content

Commit

Permalink
feature(zjow): polish ppof agent code for opendilab huggingface (#730)
Browse files Browse the repository at this point in the history
* polish ppof code
  • Loading branch information
zjowowen authored Sep 21, 2023
1 parent 6a26e98 commit 08a6c52
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 127 deletions.
136 changes: 65 additions & 71 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,38 @@
from ding.policy import PPOFPolicy


def get_instance_config(env: str, algorithm: str) -> EasyDict:
def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
if algorithm == 'PPOF':
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
if env_id == 'LunarLander-v2':
cfg.n_sample = 512
cfg.value_norm = 'popart'
cfg.entropy_weight = 1e-3
elif env == 'lunarlander_continuous':
elif env_id == 'LunarLanderContinuous-v2':
cfg.action_space = 'continuous'
cfg.n_sample = 400
elif env == 'bipedalwalker':
elif env_id == 'BipedalWalker-v3':
cfg.learning_rate = 1e-3
cfg.action_space = 'continuous'
cfg.n_sample = 1024
elif env == 'acrobot':
elif env_id == 'acrobot':
cfg.learning_rate = 1e-4
cfg.n_sample = 400
elif env == 'rocket_landing':
elif env_id == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'drone_fly':
elif env_id == 'drone_fly':
cfg.action_space = 'continuous'
cfg.adv_norm = False
cfg.epoch_per_collect = 5
cfg.learning_rate = 5e-5
cfg.n_sample = 640
elif env == 'hybrid_moving':
elif env_id == 'hybrid_moving':
cfg.action_space = 'hybrid'
cfg.n_sample = 3200
cfg.entropy_weight = 0.03
Expand All @@ -50,13 +50,13 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
fixed_sigma_value=0.3,
bound_type='tanh',
)
elif env == 'evogym_carrier':
elif env_id == 'evogym_carrier':
cfg.action_space = 'continuous'
cfg.n_sample = 2048
cfg.batch_size = 256
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-3
elif env == 'mario':
elif env_id == 'mario':
cfg.n_sample = 256
cfg.batch_size = 64
cfg.epoch_per_collect = 2
Expand All @@ -66,14 +66,14 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
actor_head_hidden_size=128,
)
elif env == 'di_sheep':
elif env_id == 'di_sheep':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
cfg.adv_norm = False
cfg.entropy_weight = 0.001
elif env == 'procgen_bigfish':
elif env_id == 'procgen_bigfish':
cfg.n_sample = 16384
cfg.batch_size = 16384
cfg.epoch_per_collect = 10
Expand All @@ -83,7 +83,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=256,
actor_head_hidden_size=256,
)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling']:
elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
Expand All @@ -94,7 +94,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env == 'PongNoFrameskip':
elif env_id == 'PongNoFrameskip-v4':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
Expand All @@ -104,7 +104,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'SpaceInvadersNoFrameskip':
elif env_id == 'SpaceInvadersNoFrameskip-v4':
cfg.n_sample = 320
cfg.batch_size = 320
cfg.epoch_per_collect = 1
Expand All @@ -116,7 +116,7 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'QbertNoFrameskip':
elif env_id == 'QbertNoFrameskip-v4':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
Expand All @@ -127,13 +127,13 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'minigrid_fourroom':
elif env_id == 'minigrid_fourroom':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.learning_rate = 3e-4
cfg.epoch_per_collect = 10
cfg.entropy_weight = 0.001
elif env == 'metadrive':
elif env_id == 'metadrive':
cfg.learning_rate = 3e-4
cfg.action_space = 'continuous'
cfg.entropy_weight = 0.001
Expand All @@ -146,49 +146,61 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
elif env in ['hopper']:
elif env_id == 'Hopper-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'HalfCheetah-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'Walker2d-v3':
cfg.action_space = "continuous"
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
else:
raise KeyError("not supported env type: {}".format(env))
raise KeyError("not supported env type: {}".format(env_id))
else:
raise KeyError("not supported algorithm type: {}".format(algorithm))

return cfg


def get_instance_env(env: str) -> BaseEnv:
if env == 'lunarlander_discrete':
def get_instance_env(env_id: str) -> BaseEnv:
if env_id == 'LunarLander-v2':
return DingEnvWrapper(gym.make('LunarLander-v2'))
elif env == 'lunarlander_continuous':
return DingEnvWrapper(gym.make('LunarLander-v2', continuous=True))
elif env == 'bipedalwalker':
elif env_id == 'LunarLanderContinuous-v2':
return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
elif env_id == 'BipedalWalker-v3':
return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
elif env == 'pendulum':
elif env_id == 'Pendulum-v1':
return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
elif env == 'acrobot':
elif env_id == 'acrobot':
return DingEnvWrapper(gym.make('Acrobot-v1'))
elif env == 'rocket_landing':
elif env_id == 'rocket_landing':
from dizoo.rocket.envs import RocketEnv
cfg = EasyDict({
'task': 'landing',
'max_steps': 800,
})
return RocketEnv(cfg)
elif env == 'drone_fly':
elif env_id == 'drone_fly':
from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
cfg = EasyDict({
'env_id': 'flythrugate-aviary-v0',
'action_type': 'VEL',
})
return GymPybulletDronesEnv(cfg)
elif env == 'hybrid_moving':
elif env_id == 'hybrid_moving':
import gym_hybrid
return DingEnvWrapper(gym.make('Moving-v0'))
elif env == 'evogym_carrier':
elif env_id == 'evogym_carrier':
import evogym.envs
from evogym import sample_robot, WorldObject
path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
Expand All @@ -203,7 +215,7 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'mario':
elif env_id == 'mario':
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
return DingEnvWrapper(
Expand All @@ -219,10 +231,10 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'di_sheep':
elif env_id == 'di_sheep':
from sheep_env import SheepEnv
return DingEnvWrapper(SheepEnv(level=9))
elif env == 'procgen_bigfish':
elif env_id == 'procgen_bigfish':
return DingEnvWrapper(
gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
cfg={
Expand All @@ -234,66 +246,48 @@ def get_instance_env(env: str) -> BaseEnv:
},
seed_api=False,
)
elif env == 'hopper':
elif env_id == 'Hopper-v3':
cfg = EasyDict(
env_id='Hopper-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
elif env == 'HalfCheetah':
elif env_id == 'HalfCheetah-v3':
cfg = EasyDict(
env_id='HalfCheetah-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
elif env == 'Walker2d':
elif env_id == 'Walker2d-v3':
cfg = EasyDict(
env_id='Walker2d-v3',
env_wrapper='mujoco_default',
act_scale=True,
rew_clip=True,
)
return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
elif env == "SpaceInvadersNoFrameskip":
cfg = EasyDict({
'env_id': "SpaceInvadersNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("SpaceInvadersNoFrameskip-v4"), cfg=cfg)
elif env == "PongNoFrameskip":
cfg = EasyDict({
'env_id': "PongNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("PongNoFrameskip-v4"), cfg=cfg)
elif env == "QbertNoFrameskip":
cfg = EasyDict({
'env_id': "QbertNoFrameskip-v4",
'env_wrapper': 'atari_default',
})
return DingEnvWrapper(gym.make("QbertNoFrameskip-v4"), cfg=cfg)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling', 'atari_breakout', 'atari_spaceinvader',
'atari_gopher']:
from dizoo.atari.envs.atari_env import AtariEnv
atari_env_list = {
'atari_qbert': 'QbertNoFrameskip-v4',
'atari_kangaroo': 'KangarooNoFrameskip-v4',
'atari_bowling': 'BowlingNoFrameskip-v4',
'atari_breakout': 'BreakoutNoFrameskip-v4',
'atari_spaceinvader': 'SpaceInvadersNoFrameskip-v4',
'atari_gopher': 'GopherNoFrameskip-v4'
}

elif env_id in [
'BowlingNoFrameskip-v4',
'BreakoutNoFrameskip-v4',
'GopherNoFrameskip-v4'
'KangarooNoFrameskip-v4',
'PongNoFrameskip-v4',
'QbertNoFrameskip-v4',
'SpaceInvadersNoFrameskip-v4',
]:

cfg = EasyDict({
'env_id': atari_env_list[env],
'env_id': env_id,
'env_wrapper': 'atari_default',
})
ding_env_atari = DingEnvWrapper(gym.make(atari_env_list[env]), cfg=cfg)
ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
return ding_env_atari
elif env == 'minigrid_fourroom':
elif env_id == 'minigrid_fourroom':
import gymnasium
return DingEnvWrapper(
gymnasium.make('MiniGrid-FourRooms-v0'),
Expand All @@ -306,7 +300,7 @@ def get_instance_env(env: str) -> BaseEnv:
]
}
)
elif env == 'metadrive':
elif env_id == 'metadrive':
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
cfg = dict(
Expand All @@ -319,7 +313,7 @@ def get_instance_env(env: str) -> BaseEnv:
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
else:
raise KeyError("not supported env type: {}".format(env))
raise KeyError("not supported env type: {}".format(env_id))


def get_hybrid_shape(action_space) -> EasyDict:
Expand Down
Loading

0 comments on commit 08a6c52

Please sign in to comment.