Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(gry): add MDQN algorithm #590

Merged
merged 41 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ff7f56b
draft runable verison for mdqn and config file
ruoyuGao Feb 21, 2023
17821ab
fix style for mdqn
ruoyuGao Feb 21, 2023
8ef391c
fix style for mdqn
ruoyuGao Feb 22, 2023
d888b47
update action_gap part for mdqn
ruoyuGao Feb 23, 2023
9e17ae8
provide tau and alpha
ruoyuGao Feb 24, 2023
6c1164a
Merge remote-tracking branch 'origin' into ruoyugao
ruoyuGao Feb 24, 2023
58de257
draft runable verison for mdqn and config file
ruoyuGao Feb 21, 2023
93b1607
fix style for mdqn
ruoyuGao Feb 21, 2023
0c05a40
fix style for mdqn
ruoyuGao Feb 22, 2023
b8bf947
update action_gap part for mdqn
ruoyuGao Feb 23, 2023
feb534f
provide tau and alpha
ruoyuGao Feb 24, 2023
d89b953
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Feb 24, 2023
afeda48
add clipfrac to mdqn
ruoyuGao Feb 25, 2023
282ef80
add unit test for mdqn td error
ruoyuGao Feb 26, 2023
1509378
provide current exp parameter
ruoyuGao Feb 27, 2023
f0b0d3f
fix bug in mdqn td loss function and polish code
ruoyuGao Mar 2, 2023
5a56060
revert useless change in dqn
ruoyuGao Mar 2, 2023
b1929ce
update readme for mdqn
ruoyuGao Mar 2, 2023
b376319
delete wring named folder
ruoyuGao Mar 2, 2023
e43124c
rename asterix folder
ruoyuGao Mar 2, 2023
e2e7c3c
provide resonable config for asterix
ruoyuGao Mar 2, 2023
3731b50
Merge branch 'opendilab:main' into ruoyugao
ruoyuGao Mar 2, 2023
a8f99dc
fix style and unit test
ruoyuGao Mar 2, 2023
47169d4
polish code under comment
ruoyuGao Mar 3, 2023
68fc21a
fix typo in dizoo asterix config
ruoyuGao Mar 3, 2023
a98f000
fix style
ruoyuGao Mar 3, 2023
501517d
fix style
ruoyuGao Mar 3, 2023
9f76f03
provide is_dynamic_seed for collector env
ruoyuGao Mar 5, 2023
b03b982
add unit test for mdqn in test_serial_entry with asterix
ruoyuGao Mar 5, 2023
bdcd0ae
change test for mdqn from asterix to cartpole because of platform tes…
ruoyuGao Mar 5, 2023
5e41c44
Merge branch 'main' into ruoyugao
ruoyuGao Mar 5, 2023
c3ee31e
change is_dynamic structure because of unit test failed at test entry
ruoyuGao Mar 5, 2023
f7b51f7
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Mar 5, 2023
fa0b929
add comment for is_dynamic_seed
ruoyuGao Mar 6, 2023
40aad3c
Merge branch 'main' into ruoyugao
ruoyuGao Mar 6, 2023
6987ec0
add enduro and spaceinvaders mdqn config file && polish comments
ruoyuGao Mar 7, 2023
9b6e20a
Merge branch 'opendilab:main' into ruoyugao
ruoyuGao Mar 7, 2023
4e7ab65
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Mar 7, 2023
e09febd
polish code under comment
ruoyuGao Mar 7, 2023
cd7f178
Merge branch 'main' into ruoyugao
ruoyuGao Mar 7, 2023
1f72e24
Merge branch 'main' into ruoyugao
ruoyuGao Mar 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions README.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def serial_pipeline(
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
dynamic_seed: Optional[bool] = True,
) -> 'Policy': # noqa
"""
Overview:
Expand All @@ -36,6 +37,7 @@ def serial_pipeline(
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
- dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
Expand All @@ -53,7 +55,7 @@ def serial_pipeline(
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
Expand Down
15 changes: 15 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config


@pytest.mark.platformtest
Expand All @@ -68,6 +69,20 @@ def test_dqn():
os.popen('rm -rf cartpole_dqn_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_mdqn():
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole_mdqn_unittest'
try:
serial_pipeline(config, seed=0, max_train_iter=1, dynamic_seed=False)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf cartpole_mdqn_unittest')


PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.platformtest
@pytest.mark.unittest
def test_bdq():
Expand Down
12 changes: 12 additions & 0 deletions ding/entry/tests/test_serial_entry_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config

with open("./algo_record.log", "w+") as f:
f.write("ALGO TEST STARTS\n")
Expand Down Expand Up @@ -405,6 +406,17 @@ def test_wqmix():
f.write("28. wqmix\n")


@pytest.mark.algotest
def test_mdqn():
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
try:
serial_pipeline(config, seed=0)
except Exception:
assert False, "pipeline fail"
with open("./algo_record.log", "a+") as f:
f.write("29. mdqn\n")


# @pytest.mark.algotest
def test_td3_bc():
# train expert
Expand Down
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls
from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch
from .dqn import DQNSTDIMPolicy, DQNPolicy
from .mdqn import MDQNPolicy
from .iqn import IQNPolicy
from .fqf import FQFPolicy
from .qrdqn import QRDQNPolicy
Expand Down
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base_policy import CommandModePolicy

from .dqn import DQNPolicy, DQNSTDIMPolicy
from .mdqn import MDQNPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
Expand Down Expand Up @@ -101,6 +102,11 @@ class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('mdqn_command')
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('dqn_command')
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass
Expand Down
243 changes: 243 additions & 0 deletions ding/policy/mdqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import List, Dict, Any
import copy
import torch

from ding.torch_utils import Adam, to_device
from ding.rl_utils import m_q_1step_td_data, m_q_1step_td_error
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY

from .dqn import DQNPolicy
from .common_utils import default_preprocess_learn


@POLICY_REGISTRY.register('mdqn')
class MDQNPolicy(DQNPolicy):
"""
Overview:
Policy class of Munchausen DQN algorithm, extended by auxiliary objectives.
Paper link: https://arxiv.org/abs/2007.14430
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str mdqn | RL policy register name, refer to | This arg is optional,
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
| update priority
5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
| ``_weight`` | to correct biased update. If True,
| priority must be True.
6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
7 ``nstep`` int 1, | N-step reward discount sum for target
[3, 5] | q_value estimation
8 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
9 | ``learn.multi`` bool False | whether to use multi gpu during
| ``_gpu``
10 | ``learn.batch_`` int 32 | The number of samples of an iteration
| ``size``
11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
12 | ``learn.target_`` int 2000 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
14 ``collect.n_sample`` int 4 | The number of training samples of a | It varies from
| call of collector. | different envs
15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
| 'linear'].
17 | ``other.eps.`` float 0.01 | start value of exploration rate | [0,1]
| ``start``
18 | ``other.eps.`` float 0.001 | end value of exploration rate | [0,1]
| ``end``
19 | ``other.eps.`` int 250000 | decay length of exploration | greater than 0. set
| ``decay`` | decay=250000 means
| the exploration rate
| decay from start
| value to end value
| during decay length.
20 | ``entropy_tau`` float 0.003 | the ration of entropy in TD loss
21 | ``alpha`` float 0.9 | the ration of Munchausen term to the
| TD loss
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
type='mdqn',
# (bool) Whether use cuda in policy
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy=False,
# (bool) Whether enable priority experience sample
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (float) Discount factor(gamma) for returns
discount_factor=0.97,
# (float) Entropy factor (tau) for Munchausen DQN
entropy_tau=0.03,
# (float) Discount factor (alpha) for Munchausen term
m_alpha=0.9,
# (int) The number of step for calculating target q_value
nstep=1,
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
# (int) How many samples in a training batch
batch_size=64,
# (float) The step size of gradient descent
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_sample=4,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
# (float) Epsilon start value
start=0.95,
# (float) Epsilon end value
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
)

def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
# set eps in order to consistent with the original paper implementation
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._entropy_tau = self._cfg.entropy_tau
self._m_alpha = self._cfg.m_alpha

# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
if 'target_update_freq' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
elif 'target_theta' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
else:
raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta")
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``, ``action_gap``, ``clip_frac``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value_current = self._target_model.forward(data['obs'])['logit']
target_q_value = self._target_model.forward(data['next_obs'])['logit']

data_m = m_q_1step_td_data(
q_value, target_q_value_current, target_q_value, data['action'], data['reward'].squeeze(0), data['done'],
data['weight']
)

loss, td_error_per_sample, action_gap, clipfrac = m_q_1step_td_error(
data_m, self._gamma, self._entropy_tau, self._m_alpha
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()

# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'q_value': q_value.mean().item(),
'target_q_value': target_q_value.mean().item(),
'priority': td_error_per_sample.abs().tolist(),
'action_gap': action_gap.item(),
'clip_frac': clipfrac.mean().item(),
}

def _monitor_vars_learn(self) -> List[str]:
return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac']
3 changes: 2 additions & 1 deletion ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .gae import gae_data, gae
from .a2c import a2c_data, a2c_error
from .coma import coma_data, coma_error
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data, td_lambda_error,\
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, \
q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error,\
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
Expand Down
Loading