Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(crb): update multi discrete policy(dqn, ppo, rainbow) #51

Merged
merged 6 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 70 additions & 19 deletions dizoo/common/policy/md_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,94 @@
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error
from ding.policy import DQNPolicy
from ding.utils import POLICY_REGISTRY
from ding.policy.common_utils import default_preprocess_learn
from ding.torch_utils import to_device


@POLICY_REGISTRY.register('md_dqn')
class MultiDiscreteDQNPolicy(DQNPolicy):
r"""
Overview:
Policy class of Multi-discrete action space DQN algorithm.
"""

def _forward_learn(self, data: dict) -> Dict[str, Any]:
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
assert reward.shape == (self._cfg.learn.batch_size, self._nstep), reward.shape
reward = reward.permute(1, 0).contiguous()
q_value = self._armor.forward(data['obs'])['logit']
# target_q_value = self._armor.target_forward(data['next_obs'])['logit']
target = self._armor.forward(data['next_obs'])
target_q_value = target['logit']
next_act = target['action']
if isinstance(q_value, torch.Tensor):
td_data = q_nstep_td_data( # 'q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'
q_value, target_q_value, data['action'][0], next_act, reward, data['done'], data['weight']
)
loss, td_error_per_sample = q_nstep_td_error(td_data, self._gamma, nstep=self._nstep)
else:
"""
Overview:
Forward computation of learn mode(updating policy). It supports both single and multi-discrete action \
space. It depends on whether the ``q_value`` is a list.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``, ``IS``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``, ``priority``
- optional: ``action_distribution``
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.train()
self._target_model.train()
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
with torch.no_grad():
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']

value_gamma = data.get('value_gamma')
if isinstance(q_value, list):
tl_num = len(q_value)
loss, td_error_per_sample = [], []
for i in range(tl_num):
td_data = q_nstep_td_data(
q_value[i], target_q_value[i], data['action'][i], next_act[i], reward, data['done'], data['weight']
q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
)
loss_, td_error_per_sample_ = q_nstep_td_error(
td_data, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
loss_, td_error_per_sample_ = q_nstep_td_error(td_data, self._gamma, nstep=self._nstep)
loss.append(loss_)
td_error_per_sample.append(td_error_per_sample_.abs())
loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
else:
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)

# ====================
# Q-learning update
# ====================
self._optimizer.zero_grad()
loss.backward()
if self._cfg.learn.multi_gpu:
self.sync_gradients(self._learn_model)
self._optimizer.step()
self._armor.target_update(self._armor.state_dict()['model'])

# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
Expand Down
57 changes: 0 additions & 57 deletions dizoo/common/policy/md_ppo.py

This file was deleted.

65 changes: 46 additions & 19 deletions dizoo/common/policy/md_rainbow_dqn.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,88 @@
from typing import Dict, Any
import torch
from ding.torch_utils import to_device
from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, dist_1step_td_data, dist_1step_td_error
from ding.policy import RainbowDQNPolicy
from ding.utils import POLICY_REGISTRY
from ding.policy.common_utils import default_preprocess_learn


@POLICY_REGISTRY.register('md_rainbow_dqn')
class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
r"""
Overview:
Multi-discrete action space Rainbow DQN algorithms.
"""

def _forward_learn(self, data: dict) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode, acquire the data and calculate the loss and\
Forward and backward function of learn mode, acquire the data and calculate the loss and \
optimize learner model

Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action']

Returns:
- info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
- info_dict (:obj:`Dict[str, Any]`): Including cur_lr, total_loss and priority
- cur_lr (:obj:`float`): current learning rate
- total_loss (:obj:`float`): the calculated loss
- priority (:obj:`list`): the priority of samples
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Rainbow forward
# ====================
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
assert reward.shape == (self._cfg.learn.batch_size, self._nstep), reward.shape
reward = reward.permute(1, 0).contiguous()
# reset noise of noisenet for both main armor and target armor
self._reset_noise(self._armor.model)
self._reset_noise(self._armor.target_model)
q_dist = self._armor.forward(data['obs'])['distribution']
self._learn_model.train()
self._target_model.train()
# reset noise of noisenet for both main model and target model
self._reset_noise(self._learn_model)
self._reset_noise(self._target_model)
q_dist = self._learn_model.forward(data['obs'])['distribution']
with torch.no_grad():
target_q_dist = self._armor.target_forward(data['next_obs'])['distribution']
self._reset_noise(self._armor.model)
target_q_action = self._armor.forward(data['next_obs'])['action']
target_q_dist = self._target_model.forward(data['next_obs'])['distribution']
self._reset_noise(self._learn_model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']

value_gamma = data.get('value_gamma', None)
if isinstance(q_dist, torch.Tensor):
td_data = dist_nstep_td_data(
q_dist, target_q_dist, data['action'], target_q_action, reward, data['done'], data['weight']
q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
loss, td_error_per_sample = dist_nstep_td_error(
td_data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep
td_data,
self._gamma,
self._v_min,
self._v_max,
self._n_atom,
nstep=self._nstep,
value_gamma=value_gamma
)
else:
tl_num = len(q_dist)
losses = []
td_error_per_samples = []
for i in range(tl_num):
td_data = dist_nstep_td_data(
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], reward, data['done'],
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
)
td_loss, td_error_per_sample = dist_nstep_td_error(
td_data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep
td_data,
self._gamma,
self._v_min,
self._v_max,
self._n_atom,
nstep=self._nstep,
value_gamma=value_gamma
)
losses.append(td_loss)
td_error_per_samples.append(td_error_per_sample)
Expand All @@ -70,7 +97,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
# =============
# after update
# =============
self._armor.target_update(self._armor.state_dict()['model'])
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
Expand Down