Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(cy): add BDQ algorithm #558

Merged
merged 12 commits into from
Jan 3, 2023
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 48 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 49 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 50 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 51 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
</details>


Expand Down
15 changes: 15 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config
from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa


@pytest.mark.platformtest
Expand All @@ -67,6 +68,20 @@ def test_dqn():
os.popen('rm -rf cartpole_dqn_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_bdq():
config = [deepcopy(pendulum_bdq_config), deepcopy(pendulum_bdq_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'pendulum_bdq_unittest'
try:
serial_pipeline(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf pendulum_bdq_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_ddpg():
Expand Down
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, head_cls_map, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
110 changes: 110 additions & 0 deletions ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,116 @@ def forward(self, x: torch.Tensor) -> Dict:
return {'logit': q, 'distribution': dist}


class BranchingHead(nn.Module):

def __init__(
self,
hidden_size: int,
num_branches: int = 0,
action_bins_per_branch: int = 2,
layer_num: int = 1,
a_layer_num: Optional[int] = None,
v_layer_num: Optional[int] = None,
norm_type: Optional[str] = None,
activation: Optional[nn.Module] = nn.ReLU(),
noise: Optional[bool] = False,
) -> None:
"""
Overview:
Init the ``BranchingHead`` layers according to the provided arguments. \
This head achieves a linear increase of the number of network outputs \
with the number of degrees of freedom by allowing a level of independence \
for each individual action dimension.
Therefore, this head is suitable for high dimensional action Spaces.
Arguments:
- hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``BranchingHead``.
- num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension.
- action_bins_per_branch (:obj:int): The number of action bins in each dimension.
- layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
- a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
- v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
- output_size (:obj:`int`): The number of outputs.
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
for more details. Default ``None``.
- activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
Default ``False``.
"""
super(BranchingHead, self).__init__()
if a_layer_num is None:
a_layer_num = layer_num
if v_layer_num is None:
v_layer_num = layer_num
self.num_branches = num_branches
self.action_bins_per_branch = action_bins_per_branch

layer = NoiseLinearLayer if noise else nn.Linear
block = noise_block if noise else fc_block
# value network

self.V = nn.Sequential(
MLP(
hidden_size,
hidden_size,
hidden_size,
v_layer_num,
layer_fn=layer,
activation=activation,
norm_type=norm_type
), block(hidden_size, 1)
)
# action branching network
action_output_dim = action_bins_per_branch
self.branches = nn.ModuleList(
[
nn.Sequential(
MLP(
hidden_size,
hidden_size,
hidden_size,
a_layer_num,
layer_fn=layer,
activation=activation,
norm_type=norm_type
), block(hidden_size, action_output_dim)
) for _ in range(self.num_branches)
]
)

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to run MLP with ``BranchingHead`` and return the prediction dictionary.
Arguments:
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
Returns:
- outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`).
Shapes:
- x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
- logit: :math:`(B, M)`, where ``M = output_size``.

Examples:
>>> head = BranchingHead(64, 5, 2)
>>> inputs = torch.randn(4, 64)
>>> outputs = head(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
"""
value_out = self.V(x)
value_out = torch.unsqueeze(value_out, 1)
action_out = []
for b in self.branches:
action_out.append(b(x))
action_scores = torch.stack(action_out, 1)
'''
From the paper, this implementation performs better than both the naive alternative (Q = V + A) \
and the local maximum reduction method (Q = V + max(A)).
'''
action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
Cloud-Pku marked this conversation as resolved.
Show resolved Hide resolved
logits = value_out + action_scores
return {'logit': logits}


Cloud-Pku marked this conversation as resolved.
Show resolved Hide resolved
class RainbowHead(nn.Module):
"""
Overview:
Expand Down
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
from .qac import QAC, DiscreteQAC
from .pdqn import PDQN
from .vac import VAC
Expand Down
97 changes: 96 additions & 1 deletion ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
QuantileHead, FQFHead, QRDQNHead, DistributionHead
QuantileHead, FQFHead, QRDQNHead, DistributionHead, BranchingHead
from ding.torch_utils.network.gtrxl import GTrXL


Expand Down Expand Up @@ -98,6 +98,101 @@ def forward(self, x: torch.Tensor) -> Dict:
return x


@MODEL_REGISTRY.register('bdq')
class BDQ(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
num_branches: int = 0,
action_bins_per_branch: int = 2,
layer_num: int = 3,
a_layer_num: Optional[int] = None,
v_layer_num: Optional[int] = None,
encoder_hidden_size_list: SequenceType = [128, 128, 64],
head_hidden_size: Optional[int] = None,
norm_type: Optional[nn.Module] = None,
activation: Optional[nn.Module] = nn.ReLU(),
) -> None:
"""
Overview:
Init the BDQ (encoder + head) Model according to input arguments. \
referenced paper Action Branching Architectures for Deep Reinforcement Learning \
<https://arxiv.org/pdf/1711.08946>
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension, \
such as 6 in mujoco's halfcheetah environment.
- action_bins_per_branch (:obj:`int`): The number of actions in each dimension.
- layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
- a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
- v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details.
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``
"""
super(BDQ, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, num_branches = squeeze(obs_shape), squeeze(num_branches)
if head_hidden_size is None:
head_hidden_size = encoder_hidden_size_list[-1]

# backbone
# FC Encoder
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
)

self.num_branches = num_branches
self.action_bins_per_branch = action_bins_per_branch

# head
self.head = BranchingHead(
head_hidden_size,
num_branches=self.num_branches,
action_bins_per_branch=self.action_bins_per_branch,
layer_num=layer_num,
a_layer_num=a_layer_num,
v_layer_num=v_layer_num,
activation=activation,
norm_type=norm_type
)

def forward(self, x: torch.Tensor) -> Dict:
r"""
Overview:
BDQ forward computation graph, input observation tensor to predict q_value.
Arguments:
- x (:obj:`torch.Tensor`): Observation inputs
Returns:
- outputs (:obj:`Dict`): BDQ forward outputs, such as q_value.
ReturnsKeys:
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
Shapes:
- x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is
``num_branches * action_bins_per_branch``
Examples:
>>> model = BDQ(8, 5, 2) # arguments: 'obs_shape', 'num_branches' and 'action_bins_per_branch'.
>>> inputs = torch.randn(4, 8)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
"""
x = self.encoder(x) / (self.num_branches + 1) # corresponds to the "Gradient Rescaling" in the paper
x = self.head(x)
return x


@MODEL_REGISTRY.register('c51dqn')
class C51DQN(nn.Module):

Expand Down
21 changes: 20 additions & 1 deletion ding/model/template/tests/test_q_learning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from itertools import product
import torch
from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN
from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
from ding.torch_utils import is_differentiable

T, B = 3, 4
Expand Down Expand Up @@ -40,6 +40,25 @@ def test_dqn(self, obs_shape, act_shape):
assert outputs['logit'][i].shape == (B, s)
self.output_check(model, outputs['logit'])

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_bdq(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
else:
inputs = torch.randn(B, *obs_shape)
if not isinstance(act_shape, int) and len(act_shape) > 1:
return
num_branches = act_shape
for action_bins_per_branch in range(1, 10):
model = BDQ(obs_shape, num_branches, action_bins_per_branch)
outputs = model(inputs)
assert isinstance(outputs, dict)
if isinstance(act_shape, int):
assert outputs['logit'].shape == (B, act_shape, action_bins_per_branch)
else:
assert outputs['logit'].shape == (B, *act_shape, action_bins_per_branch)
self.output_check(model, outputs['logit'])

@pytest.mark.parametrize('obs_shape, act_shape', args)
def test_rainbowdqn(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
Expand Down
2 changes: 2 additions & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@

from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy

from .bdq import BDQPolicy
Loading