Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(gry): add MDQN algorithm #590

Merged
merged 41 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ff7f56b
draft runable verison for mdqn and config file
ruoyuGao Feb 21, 2023
17821ab
fix style for mdqn
ruoyuGao Feb 21, 2023
8ef391c
fix style for mdqn
ruoyuGao Feb 22, 2023
d888b47
update action_gap part for mdqn
ruoyuGao Feb 23, 2023
9e17ae8
provide tau and alpha
ruoyuGao Feb 24, 2023
6c1164a
Merge remote-tracking branch 'origin' into ruoyugao
ruoyuGao Feb 24, 2023
58de257
draft runable verison for mdqn and config file
ruoyuGao Feb 21, 2023
93b1607
fix style for mdqn
ruoyuGao Feb 21, 2023
0c05a40
fix style for mdqn
ruoyuGao Feb 22, 2023
b8bf947
update action_gap part for mdqn
ruoyuGao Feb 23, 2023
feb534f
provide tau and alpha
ruoyuGao Feb 24, 2023
d89b953
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Feb 24, 2023
afeda48
add clipfrac to mdqn
ruoyuGao Feb 25, 2023
282ef80
add unit test for mdqn td error
ruoyuGao Feb 26, 2023
1509378
provide current exp parameter
ruoyuGao Feb 27, 2023
f0b0d3f
fix bug in mdqn td loss function and polish code
ruoyuGao Mar 2, 2023
5a56060
revert useless change in dqn
ruoyuGao Mar 2, 2023
b1929ce
update readme for mdqn
ruoyuGao Mar 2, 2023
b376319
delete wring named folder
ruoyuGao Mar 2, 2023
e43124c
rename asterix folder
ruoyuGao Mar 2, 2023
e2e7c3c
provide resonable config for asterix
ruoyuGao Mar 2, 2023
3731b50
Merge branch 'opendilab:main' into ruoyugao
ruoyuGao Mar 2, 2023
a8f99dc
fix style and unit test
ruoyuGao Mar 2, 2023
47169d4
polish code under comment
ruoyuGao Mar 3, 2023
68fc21a
fix typo in dizoo asterix config
ruoyuGao Mar 3, 2023
a98f000
fix style
ruoyuGao Mar 3, 2023
501517d
fix style
ruoyuGao Mar 3, 2023
9f76f03
provide is_dynamic_seed for collector env
ruoyuGao Mar 5, 2023
b03b982
add unit test for mdqn in test_serial_entry with asterix
ruoyuGao Mar 5, 2023
bdcd0ae
change test for mdqn from asterix to cartpole because of platform tes…
ruoyuGao Mar 5, 2023
5e41c44
Merge branch 'main' into ruoyugao
ruoyuGao Mar 5, 2023
c3ee31e
change is_dynamic structure because of unit test failed at test entry
ruoyuGao Mar 5, 2023
f7b51f7
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Mar 5, 2023
fa0b929
add comment for is_dynamic_seed
ruoyuGao Mar 6, 2023
40aad3c
Merge branch 'main' into ruoyugao
ruoyuGao Mar 6, 2023
6987ec0
add enduro and spaceinvaders mdqn config file && polish comments
ruoyuGao Mar 7, 2023
9b6e20a
Merge branch 'opendilab:main' into ruoyugao
ruoyuGao Mar 7, 2023
4e7ab65
Merge branch 'ruoyugao' of https://github.com/ruoyuGao/DI-engine into…
ruoyuGao Mar 7, 2023
e09febd
polish code under comment
ruoyuGao Mar 7, 2023
cd7f178
Merge branch 'main' into ruoyugao
ruoyuGao Mar 7, 2023
1f72e24
Merge branch 'main' into ruoyugao
ruoyuGao Mar 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 49 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 50 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 51 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
| 52 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/dqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u asterix_mdqn_config.py |
</details>


Expand Down
2 changes: 1 addition & 1 deletion ding/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls
from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch
from .dqn import DQNSTDIMPolicy, DQNPolicy
from .dqn import DQNSTDIMPolicy, DQNPolicy, MDQNPolicy
from .iqn import IQNPolicy
from .fqf import FQFPolicy
from .qrdqn import QRDQNPolicy
Expand Down
7 changes: 6 additions & 1 deletion ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ding.rl_utils import get_epsilon_greedy_fn
from .base_policy import CommandModePolicy

from .dqn import DQNPolicy, DQNSTDIMPolicy
from .dqn import DQNPolicy, DQNSTDIMPolicy, MDQNPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
Expand Down Expand Up @@ -101,6 +101,11 @@ class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('mdqn_command')
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('dqn_command')
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass
Expand Down
237 changes: 236 additions & 1 deletion ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch

from ding.torch_utils import Adam, to_device, ContrastiveLoss
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.rl_utils import q_nstep_td_data, m_q_1step_td_data,\
m_q_1step_td_error, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
Expand Down Expand Up @@ -694,3 +695,237 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
self._aux_optimizer.load_state_dict(state_dict['aux_optimizer'])


@POLICY_REGISTRY.register('mdqn')
class MDQNPolicy(DQNPolicy):
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
"""
Overview:
Policy class of Munchausen DQN algorithm, extended by auxiliary objectives.
Paper link: https://arxiv.org/abs/2007.14430
Config:
== ==================== ======== ============== ======================================== =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============== ======================================== =======================
1 ``type`` str dqn_stdim | RL policy register name, refer to | This arg is optional,
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
| registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
| erent from modes
3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
| or off-policy
4 ``priority`` bool True | Whether use priority(PER) | Priority sample,
| update priority
5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
| ``_weight`` | to correct biased update. If True,
| priority must be True.
6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
| ``factor`` [0.95, 0.999] | gamma | reward env
7 ``nstep`` int 1, | N-step reward discount sum for target
[3, 5] | q_value estimation
8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
| ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
| valid in serial training | means more off-policy
9 | ``learn.multi`` bool False | whether to use multi gpu during
| ``_gpu``
10 | ``learn.batch_`` int 64 | The number of samples of an iteration
| ``size``
11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
| ``_rate``
12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
| ``update_freq``
13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
| ``done`` | calculation. | fake termination env
14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
| call of collector. | different envs
15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
| ``_len``
16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
| 'linear'].
17 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
| ``start``
18 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
| ``end``
19 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
| ``decay`` | decay=10000 means
| the exploration rate
| decay from start
| value to end value
| during decay length.
20 | ``aux_loss`` float 0.001 | the ratio of the auxiliary loss to | any real value,
| ``_weight`` | the TD loss | typically in
| [-0.1, 0.1]

21 | ``entropy_tau`` float 0.003 | the ration of entropy in TD loss
22 | ``alpha`` float 0.9 | the ration of Munchausen term to the
| TD loss .
== ==================== ======== ============== ======================================== =======================
"""
config = dict(
type='mdqn',
# (bool) Whether use cuda in policy
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy)
on_policy=False,
# (bool) Whether enable priority experience sample
priority=True,
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (float) Discount factor(gamma) for returns
discount_factor=0.97,
# (float) Entropy factor (tau) for Munchausen DQN
entropy_tau=0.03,
# (float) Discount factor (alpha) for Munchausen term
m_alpha=0.9,
nstep=1,
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=3,
# (int) How many samples in a training batch
batch_size=64,
# (float) The step size of gradient descent
learning_rate=0.001,
# ==============================================================
# The following configs are algorithm-specific
# ==============================================================
# (int) Frequence of target network update.
target_update_freq=100,
# (bool) Whether ignore done(usually for max step termination env)
ignore_done=False,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_sample=4,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
# (float) Epsilon start value
start=0.95,
# (float) Epsilon end value
end=0.1,
# (int) Decay length(env step)
decay=10000,
),
replay_buffer=dict(replay_buffer_size=10000, ),
),
)

def _init_learn(self) -> None:
"""
Overview:
Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
and target model.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125)

self._gamma = self._cfg.discount_factor
self._nstep = self._cfg.nstep
self._entropy_tau = self._cfg.entropy_tau
self._m_alpha = self._cfg.m_alpha

# use model_wrapper for specialized demands of different modes
self._target_model = copy.deepcopy(self._model)
if 'target_update_freq' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.learn.target_update_freq}
)
elif 'target_theta' in self._cfg.learn:
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
else:
raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta")
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Forward computation graph of learn mode(updating policy).
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``, ``action_gap``, ``clip_frac``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value_current = self._target_model.forward(data['obs'])['logit']
target_q_value = self._target_model.forward(data['next_obs'])['logit']

data_m = m_q_1step_td_data(
q_value, target_q_value_current, target_q_value, data['action'], data['reward'].squeeze(0), data['done'],
data['weight']
)

loss, td_error_per_sample, action_gap, clipfrac = m_q_1step_td_error(
data_m, self._gamma, self._entropy_tau, self._m_alpha
)
# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()

# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'q_value': q_value.mean().item(),
'target_q_value': target_q_value.mean().item(),
'priority': td_error_per_sample.abs().tolist(),
'action_gap': action_gap.item(),
'clip_frac': clipfrac.mean().item(),
}

def _monitor_vars_learn(self) -> List[str]:
return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac']
3 changes: 2 additions & 1 deletion ding/rl_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .gae import gae_data, gae
from .a2c import a2c_data, a2c_error
from .coma import coma_data, coma_error
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data, td_lambda_error,\
from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, \
q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error,\
q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error,\
Expand Down
45 changes: 45 additions & 0 deletions ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,51 @@ def q_1step_td_error(
return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean()


m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'])


def m_q_1step_td_error(
data: namedtuple,
gamma: float,
tau: float,
alpha: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
q, target_q, next_q, act, reward, done, weight = data
lo = -1
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
# calculate muchausen addon
# replay_log_policy
target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1)
top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0]
action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean()
ruoyuGao marked this conversation as resolved.
Show resolved Hide resolved
logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1)
log_pi = target_q - target_v_s - tau * logsum
act_get = act.unsqueeze(-1)
# same to the last second tau_log_pi_a
munchausen_addon = log_pi.gather(1, act_get)
clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lo)
clipfrac = torch.as_tensor(clipped).float()
muchausen_term = alpha * torch.clamp(munchausen_addon, min=lo, max=1)

# replay_next_log_policy
target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1)
logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1)
tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next
# do stable softmax == replay_next_policy
pi_target = F.softmax((next_q - target_v_s_next) / tau)
target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1)

target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a
td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1)
return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac


q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight'])


Expand Down
Loading