Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add wandb logger features; fix relative bugs for wandb online logger #579

Merged
merged 63 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
e571f50
td3 fix
zjowowen Nov 4, 2022
a614e3f
Merge branch 'opendilab:main' into benchmark-2
zjowowen Dec 19, 2022
9060c53
Add benchmark config file.
zjowowen Dec 19, 2022
731a2ad
Merge branch 'opendilab:main' into benchmark-2
zjowowen Jan 11, 2023
82a4944
add main
zjowowen Jan 15, 2023
ad616ff
fix
zjowowen Jan 15, 2023
f1aba9c
fix
zjowowen Jan 15, 2023
448daa1
add feature to wandb;fix bugs
zjowowen Feb 10, 2023
1e18f25
merge main
zjowowen Feb 10, 2023
8de9b9e
format code
zjowowen Feb 10, 2023
f36bec8
remove files.
zjowowen Feb 10, 2023
e5ec188
polish code
zjowowen Feb 10, 2023
46f64e6
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Feb 22, 2023
e520359
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Feb 24, 2023
6a9a565
fix td3 policy
zjowowen Feb 24, 2023
0222c04
Add td3
zjowowen Feb 28, 2023
929776b
Add td3 env
zjowowen Feb 28, 2023
4fba3b9
Add td3 env
zjowowen Feb 28, 2023
0257ae9
polish code
zjowowen Feb 28, 2023
cccd585
polish code
zjowowen Feb 28, 2023
d7f272e
polish code
zjowowen Feb 28, 2023
902f9b0
polish code
zjowowen Feb 28, 2023
17ba3a6
polish code
zjowowen Feb 28, 2023
21dcc8b
polish code
zjowowen Feb 28, 2023
bb0df37
polish code
zjowowen Feb 28, 2023
d01558d
polish code
zjowowen Feb 28, 2023
60f47b6
polish code
zjowowen Feb 28, 2023
511d71e
polish code
zjowowen Feb 28, 2023
6a9fd45
polish code
zjowowen Feb 28, 2023
d5573e9
polish code
zjowowen Feb 28, 2023
b7c2011
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 1, 2023
0a167f1
Merge branch 'opendilab:main' into benchmark-3
zjowowen Mar 2, 2023
3906543
fix data type error for mujoco
zjowowen Mar 2, 2023
e665493
polish code
zjowowen Mar 2, 2023
88f5181
polish code
zjowowen Mar 2, 2023
693a4cb
Add features
zjowowen Mar 2, 2023
e6bd0c5
fix base env manager readyimage
zjowowen Mar 3, 2023
cdb9928
polish code
zjowowen Mar 3, 2023
3015a92
remove NoReturn
zjowowen Mar 3, 2023
6e7041b
remove NoReturn
zjowowen Mar 3, 2023
c97a8d4
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 6, 2023
fe415b2
format code
zjowowen Mar 7, 2023
8f808b2
merge from main
zjowowen Mar 7, 2023
3432754
format code
zjowowen Mar 7, 2023
3f6ef3d
polish code
zjowowen Mar 7, 2023
535fd77
polish code
zjowowen Mar 7, 2023
4271610
fix logger
zjowowen Mar 7, 2023
ba0979b
format code
zjowowen Mar 7, 2023
3c19c2c
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 7, 2023
82826e2
format code
zjowowen Mar 7, 2023
da0dd12
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 7, 2023
bb35f90
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 10, 2023
5340658
change api for ckpt; polish code
zjowowen Mar 10, 2023
2d3f6c8
polish code
zjowowen Mar 13, 2023
2e8292c
merge from main
zjowowen Mar 13, 2023
2f883d7
format code
zjowowen Mar 13, 2023
3c15c84
polish code
zjowowen Mar 13, 2023
6ce1421
fix load bug
zjowowen Mar 13, 2023
eac9434
fix bug
zjowowen Mar 13, 2023
6fda31b
fix dtype error
zjowowen Mar 14, 2023
6b9def4
polish code
zjowowen Mar 15, 2023
6f49d0a
polish code
zjowowen Mar 15, 2023
4c69cb0
Polish code
zjowowen Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/bonus/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .ppof import PPOF
from .td3 import TD3
238 changes: 155 additions & 83 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,92 +4,157 @@
from ding.envs import BaseEnv, DingEnvWrapper
from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
EvalEpisodeReturnEnv, TransposeWrapper, TimeLimitWrapper
from ding.policy import PPOFPolicy
from ding.policy import PPOFPolicy, TD3Policy


def get_instance_config(env: str) -> EasyDict:
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
cfg.n_sample = 400
elif env == 'lunarlander_continuous':
cfg.action_space = 'continuous'
cfg.n_sample = 400
elif env == 'bipedalwalker':
cfg.learning_rate = 1e-3
cfg.action_space = 'continuous'
cfg.n_sample = 1024
elif env == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'drone_fly':
cfg.action_space = 'continuous'
cfg.adv_norm = False
cfg.epoch_per_collect = 5
cfg.learning_rate = 5e-5
cfg.n_sample = 640
elif env == 'hybrid_moving':
cfg.action_space = 'hybrid'
cfg.n_sample = 3200
cfg.entropy_weight = 0.03
cfg.batch_size = 320
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[256, 128, 64, 64],
sigma_type='fixed',
fixed_sigma_value=0.3,
bound_type='tanh',
)
elif env == 'evogym_carrier':
cfg.action_space = 'continuous'
cfg.n_sample = 2048
cfg.batch_size = 256
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-3
elif env == 'mario':
cfg.n_sample = 256
cfg.batch_size = 64
cfg.epoch_per_collect = 2
cfg.learning_rate = 1e-3
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
)
elif env == 'di_sheep':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
cfg.adv_norm = False
cfg.entropy_weight = 0.001
elif env == 'procgen_bigfish':
cfg.n_sample = 16384
cfg.batch_size = 16384
cfg.epoch_per_collect = 10
cfg.learning_rate = 5e-4
cfg.model = dict(
encoder_hidden_size_list=[64, 128, 256],
critic_head_hidden_size=256,
actor_head_hidden_size=256,
)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling']:
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
def get_instance_config(env: str, algorithm: str) -> EasyDict:
if algorithm == 'PPO':
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
cfg.n_sample = 400
elif env == 'lunarlander_continuous':
cfg.action_space = 'continuous'
cfg.n_sample = 400
elif env == 'bipedalwalker':
cfg.learning_rate = 1e-3
cfg.action_space = 'continuous'
cfg.n_sample = 1024
elif env == 'rocket_landing':
cfg.n_sample = 2048
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
)
elif env == 'drone_fly':
cfg.action_space = 'continuous'
cfg.adv_norm = False
cfg.epoch_per_collect = 5
cfg.learning_rate = 5e-5
cfg.n_sample = 640
elif env == 'hybrid_moving':
cfg.action_space = 'hybrid'
cfg.n_sample = 3200
cfg.entropy_weight = 0.03
cfg.batch_size = 320
cfg.adv_norm = False
cfg.model = dict(
encoder_hidden_size_list=[256, 128, 64, 64],
sigma_type='fixed',
fixed_sigma_value=0.3,
bound_type='tanh',
)
elif env == 'evogym_carrier':
cfg.action_space = 'continuous'
cfg.n_sample = 2048
cfg.batch_size = 256
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-3
elif env == 'mario':
cfg.n_sample = 256
cfg.batch_size = 64
cfg.epoch_per_collect = 2
cfg.learning_rate = 1e-3
cfg.model = dict(
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
)
elif env == 'di_sheep':
cfg.n_sample = 3200
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
cfg.adv_norm = False
cfg.entropy_weight = 0.001
elif env == 'procgen_bigfish':
cfg.n_sample = 16384
cfg.batch_size = 16384
cfg.epoch_per_collect = 10
cfg.learning_rate = 5e-4
cfg.model = dict(
encoder_hidden_size_list=[64, 128, 256],
critic_head_hidden_size=256,
actor_head_hidden_size=256,
)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling']:
cfg.n_sample = 1024
cfg.batch_size = 128
cfg.epoch_per_collect = 10
cfg.learning_rate = 0.0001
cfg.model = dict(
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
)
else:
raise KeyError("not supported env type: {}".format(env))
elif algorithm == 'TD3':
cfg = TD3Policy.default_config()
if env == 'hopper':
cfg.update(
dict(
exp_name='hopper_td3',
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
seed=0,
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=6000,
),
policy=dict(
cuda=True,
random_collect_size=25000,
model=dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
action_space='regression',
),
logger=dict(
gradient_logger=True,
video_logger=True,
plot_logger=True,
action_logger=True,
return_logger=False
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_actor=1e-3,
learning_rate_critic=1e-3,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
actor_update_freq=2,
noise=True,
noise_sigma=0.2,
noise_range=dict(
min=-0.5,
max=0.5,
),
),
collect=dict(
n_sample=1,
unroll_len=1,
noise_sigma=0.1,
),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
)
)
)
else:
raise KeyError("not supported env type: {}".format(env))
else:
raise KeyError("not supported env type: {}".format(env))
raise KeyError("not supported algorithm type: {}".format(algorithm))

return cfg


Expand Down Expand Up @@ -163,6 +228,13 @@ def get_instance_env(env: str) -> BaseEnv:
},
seed_api=False,
)
elif env == 'hopper':
from dizoo.mujoco.envs import MujocoEnv
cfg = EasyDict(
env_id='Hopper-v3',
env_wrapper='mujoco_default',
)
return DingEnvWrapper(cfg=cfg)
elif env in ['atari_qbert', 'atari_kangaroo', 'atari_bowling']:
from dizoo.atari.envs.atari_env import AtariEnv
atari_env_list = {
Expand Down
55 changes: 42 additions & 13 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Optional, Union
from ditk import logging
from easydict import EasyDict
Expand All @@ -16,6 +17,11 @@
from .config import get_instance_config, get_instance_env, get_hybrid_shape


@dataclass
class TrainingReturn:
wandb_url: str
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

class PPOF:
supported_env_list = [
# common
Expand Down Expand Up @@ -48,8 +54,11 @@ def __init__(
if isinstance(env, str):
assert env in PPOF.supported_env_list, "Please use supported envs: {}".format(PPOF.supported_env_list)
self.env = get_instance_env(env)
assert cfg is None, 'It should be default env tuned config'
self.cfg = get_instance_config(env)
if cfg is None:
# 'It should be default env tuned config'
self.cfg = get_instance_config(env, algorithm="PPO")
else:
self.cfg = cfg
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(env, BaseEnv):
self.cfg = cfg
raise NotImplementedError
Expand All @@ -76,6 +85,10 @@ def __init__(
)
self.policy = PPOFPolicy(self.cfg, model=model)

def load_policy(self, policy_state_dict, config):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
self.policy.load_state_dict(policy_state_dict)
self.policy._cfg = config

def train(
self,
step: int = int(1e7),
Expand All @@ -85,24 +98,34 @@ def train(
n_iter_save_ckpt: int = 1000,
context: Optional[str] = None,
debug: bool = False
) -> None:
) -> TrainingReturn:
if debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.debug(self.policy._model)
# define env and policy
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')

wandb_url_return = []
with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(ppof_adv_estimator(self.policy))
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(CkptSaver(self.policy, save_dir=self.exp_name, train_freq=n_iter_save_ckpt))
task.use(wandb_online_logger(self.exp_name, metric_list=self.policy.monitor_vars(), anonymous=True))
task.use(
wandb_online_logger(
self.exp_name,
metric_list=self.policy.monitor_vars(),
anonymous=True,
project_name=self.exp_name,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
wandb_url_return=wandb_url_return
)
)
task.use(termination_checker(max_env_step=step))
task.run()

return TrainingReturn(wandb_url=wandb_url_return[0])

def deploy(self, ckpt_path: str = None, enable_save_replay: bool = False, debug: bool = False) -> None:
if debug:
logging.getLogger().setLevel(logging.DEBUG)
Expand Down Expand Up @@ -145,7 +168,7 @@ def collect_data(
if n_episode is not None:
raise NotImplementedError
# define env and policy
env = self._setup_env_manager(env_num, context, debug)
env = self._setup_env_manager(env_num, context, debug, 'collector')
if ckpt_path is None:
ckpt_path = os.path.join(self.exp_name, 'ckpt/eval.pth.tar')
if save_data_path is None:
Expand All @@ -165,23 +188,29 @@ def collect_data(
def batch_evaluate(
self,
env_num: int = 4,
ckpt_path: Optional[str] = None,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
n_evaluator_episode: int = 4,
context: Optional[str] = None,
debug: bool = False
debug: bool = False,
render: bool = False,
replay_video_path: str = None,
) -> None:
if debug:
logging.getLogger().setLevel(logging.DEBUG)
# define env and policy
env = self._setup_env_manager(env_num, context, debug, 'evaluator')
if ckpt_path is None:
ckpt_path = os.path.join(self.exp_name, 'ckpt/eval.pth.tar')
state_dict = torch.load(ckpt_path, map_location='cpu')
self.policy.load_state_dict(state_dict)

# main execution task
with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, env, n_evaluator_episode))
task.use(
interaction_evaluator_ttorch(
self.seed,
self.policy,
env,
n_evaluator_episode,
render=render,
replay_video_path=replay_video_path
)
)
task.run(max_step=1)

def _setup_env_manager(
Expand Down
Loading