Skip to content

Commit

Permalink
polish code under comment
Browse files Browse the repository at this point in the history
  • Loading branch information
ruoyuGao committed Mar 7, 2023
1 parent 4e7ab65 commit e09febd
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 16 deletions.
9 changes: 3 additions & 6 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def serial_pipeline(
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
is_dynamic_seed: Optional[bool] = None,
dynamic_seed: Optional[bool] = True,
) -> 'Policy': # noqa
"""
Overview:
Expand All @@ -37,7 +37,7 @@ def serial_pipeline(
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
- is_dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
- dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
Expand All @@ -55,10 +55,7 @@ def serial_pipeline(
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
if is_dynamic_seed is None:
collector_env.seed(cfg.seed)
else:
collector_env.seed(cfg.seed, dynamic_seed=is_dynamic_seed)
collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
Expand Down
2 changes: 1 addition & 1 deletion ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_mdqn():
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole_mdqn_unittest'
try:
serial_pipeline(config, seed=0, max_train_iter=1, is_dynamic_seed=False)
serial_pipeline(config, seed=0, max_train_iter=1, dynamic_seed=False)
except Exception:
assert False, "pipeline fail"
finally:
Expand Down
12 changes: 12 additions & 0 deletions ding/entry/tests/test_serial_entry_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config

with open("./algo_record.log", "w+") as f:
f.write("ALGO TEST STARTS\n")
Expand Down Expand Up @@ -405,6 +406,17 @@ def test_wqmix():
f.write("28. wqmix\n")


@pytest.mark.algotest
def test_mdqn():
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
try:
serial_pipeline(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
f.write("29. mdqn\n")


# @pytest.mark.algotest
def test_td3_bc():
# train expert
Expand Down
1 change: 1 addition & 0 deletions ding/policy/mdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _init_learn(self) -> None:
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
# set eps in order to consistent with the original paper implementation
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125)

self._gamma = self._cfg.discount_factor
Expand Down
6 changes: 3 additions & 3 deletions ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def m_q_1step_td_error(
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
q, target_q, next_q, act, reward, done, weight = data
lo = -1
lower_bound = -1
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
Expand All @@ -68,7 +68,7 @@ def m_q_1step_td_error(
# same to the last second tau_log_pi_a
munchausen_addon = log_pi.gather(1, act_get)

muchausen_term = alpha * torch.clamp(munchausen_addon, min=lo, max=1)
muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1)

# replay_next_log_policy
target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1)
Expand All @@ -86,7 +86,7 @@ def m_q_1step_td_error(
top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0]
action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean()

clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lo)
clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound)
clipfrac = torch.as_tensor(clipped).float()

return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac
Expand Down
3 changes: 1 addition & 2 deletions dizoo/atari/config/serial/asterix/asterix_mdqn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
env_id='Asterix-v0',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=True, ),
),
policy=dict(
cuda=True,
Expand Down Expand Up @@ -61,4 +60,4 @@
if __name__ == '__main__':
# or you can enter ding -m serial -c asterix_mdqn_config.py -s 0
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), is_dynamic_seed=False)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), dynamic_seed=False)
2 changes: 1 addition & 1 deletion dizoo/atari/config/serial/enduro/enduro_mdqn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@
if __name__ == '__main__':
# or you can enter ding -m serial -c enduro_mdqn_config.py -s 0
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), is_dynamic_seed=False)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), dynamic_seed=False)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
env_id='SpaceInvaders-v0',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=True, ),
),
policy=dict(
cuda=True,
Expand Down Expand Up @@ -61,4 +60,4 @@
if __name__ == '__main__':
# or you can enter ding -m serial -c spaceinvaders_mdqn_config.py -s 0
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e7), is_dynamic_seed=False)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e7), dynamic_seed=False)
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@
if __name__ == "__main__":
# or you can enter `ding -m serial -c cartpole_mdqn_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, is_dynamic_seed=False)
serial_pipeline((main_config, create_config), seed=0, dynamic_seed=False)

0 comments on commit e09febd

Please sign in to comment.