Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add EDAC and modify config of td3bc #639

Merged
merged 27 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 50 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 51 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 52 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 53 | [edac](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
</details>


Expand Down
42 changes: 42 additions & 0 deletions ding/example/edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import gym
from ditk import logging
from ding.model import Q_ensemble
from ding.policy import EDACPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.halfcheetah_medium_edac_config import main_config,create_config


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

dataset = create_dataset(cfg)
model = Q_ensemble(**cfg.policy.model)
policy = EDACPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1e4))
task.use(offline_logger())
task.run()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
76 changes: 75 additions & 1 deletion ding/model/common/head.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch.distributions import Normal, Independent

from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt
from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt, conv1d_block
from ding.rl_utils import beta_function_map
from ding.utils import lists_to_dicts, SequenceType

Expand Down Expand Up @@ -1316,6 +1316,79 @@ def forward(self, x: torch.Tensor) -> Dict:
return lists_to_dicts([m(x) for m in self.pred])


class EnsembleHead(nn.Module):
"""
Overview:
The ``EnsembleHead`` used to output action Q-value for Q-ensemble. \
Input is a (:obj:`torch.Tensor`) of shape ''(B, N * ensemble_num, 1)'' and returns a (:obj:`Dict`) containing \
output ``pred``.
Interfaces:
``__init__``, ``forward``.
"""

def __init__(
self,
input_size: int,
output_size: int,
hidden_size: int,
layer_num: int,
ensemble_num: int,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
super(EnsembleHead, self).__init__()
d = input_size
layers = []
for _ in range(layer_num):
layers.append(
conv1d_block(
d * ensemble_num,
hidden_size * ensemble_num,
kernel_size=1,
stride=1,
groups=ensemble_num,
activation=activation,
norm_type=norm_type
)
)
d = hidden_size

# Adding activation for last layer will lead to train fail
layers.append(
conv1d_block(
hidden_size * ensemble_num,
output_size * ensemble_num,
kernel_size=1,
stride=1,
groups=ensemble_num,
activation=None,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
norm_type=None
)
)
self.pred = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to run MLP with ``EnsembleHead`` and return the prediction dictionary.
Arguments:
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
Returns:
- outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`).
Shapes:
- x: :math:`(B, N * ensemble_num, 1)`, where ``B = batch_size`` and ``N = hidden_size``.
- pred: :math:`(B, M * ensemble_num, 1)`, where ``M = output_size``.
Examples:
>>> head = EnsembleHead(64 * 10, 64 * 10)
>>> inputs = torch.randn(4, 64 * 10, 1) `
>>> outputs = head(inputs)
>>> assert isinstance(outputs, dict)
>>> assert outputs['pred'].shape == torch.Size([10, 64 * 10])
"""
x = self.pred(x).squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use squeeze here, we can get the output with the shape (B, M * ensemble_num, 1)

return {'pred': x}


def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution:
if isinstance(logits, (list, tuple)):
return Independent(Normal(*logits), 1)
Expand All @@ -1341,4 +1414,5 @@ def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Di
'popart': PopArtVHead,
# multi
'multi': MultiHead,
'ensemble': EnsembleHead,
}
9 changes: 8 additions & 1 deletion ding/model/common/tests/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest

from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead
from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
from ding.torch_utils import is_differentiable

B = 4
Expand Down Expand Up @@ -84,3 +84,10 @@ def test_stochastic_dueling(self):
assert isinstance(sigma.grad, torch.Tensor)
assert outputs['q_value'].shape == (B, 1)
assert outputs['v_value'].shape == (B, 1)

def test_ensemble(self):
inputs = torch.randn(B, embedding_dim * 3, 1)
model = EnsembleHead(embedding_dim, action_shape, 3, 3,3)
outputs = model(inputs)['pred']
self.output_check(model, outputs)
assert outputs.shape == (B, action_shape * 3, 1)
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .edac import QACEnsemble
179 changes: 179 additions & 0 deletions ding/model/template/edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from typing import Union, Optional, Dict
from easydict import EasyDict

import torch
import torch.nn as nn
from ding.model.common import ReparameterizationHead, EnsembleHead
from ding.utils import SequenceType, squeeze

from ding.utils import MODEL_REGISTRY


@MODEL_REGISTRY.register('edac')
class QACEnsemble(nn.Module):
r"""
Overview:
The QAC network with ensemble, which is used in EDAC.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
ensemble_num: int = 2,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
**kwargs
) -> None:
"""
Overview:
Initailize the EDAC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- ensemble_num (:obj:`int`): Q-net number.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for actor head.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for critic head.
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
see ``ding.torch_utils.network`` for more details.
"""
super(QACEnsemble, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.ensemble_num = ensemble_num
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)

critic_input_size = obs_shape + action_shape
self.critic = EnsembleHead(
critic_input_size,
1,
critic_head_hidden_size,
critic_head_layer_num,
self.ensemble_num,
activation=activation,
norm_type=norm_type
)


def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
"""
Overview:
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
Mode compute_actor:
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
Mode compute_critic:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including q_value tensor.
.. note::
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)

def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Overview:
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
such as ``action``, ``logit`` and so on.
Arguments:
- obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
i.e. ``(B, obs_shape)``.
Returns:
- outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
from action_space: ``reparameterization``.
ReturnsKeys (either):
- logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
- mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
- sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
- logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
``action_shape.action_type_shape``.
- action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
``action_shape.action_args_shape``.
Examples:
>>> model = QACEnsemble(64, 64,)
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
"""
x = self.actor(obs)
return {'logit': [x['mu'], x['sigma']]}

def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
output, such as ``q_value``.
Arguments:
- inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and ``action`` tensor
Returns:
- outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
ArgumentsKeys:
- obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
- action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is ``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 is ``action_shape``.
- q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = EDAC(obs_shape=(8, ),action_shape=1)
>>> model(inputs, mode='compute_critic')['q_value'] # q value
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""

obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
if len(obs.shape) < 3:
# [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
else:
# [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
x = x.transpose(0, 1)
batch_size = obs.shape[1]
x = x.reshape(batch_size, -1, 1)
# [Ensemble_num,batch_size,1]
x = self.critic(x)['pred']
# [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
x = x.permute(1, 0)
return {'q_value': x}
2 changes: 1 addition & 1 deletion ding/model/template/qac.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -559,4 +559,4 @@ def compute_critic(self, inputs: Dict) -> Dict:
x = [m(inputs['obs'])['logit'] for m in self.critic]
else:
x = self.critic(inputs['obs'])['logit']
return {'q_value': x}
return {'q_value': x}
3 changes: 3 additions & 0 deletions ding/policy/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
from .sac import SACPolicy, SACDiscretePolicy, SQILSACPolicy
from .cql import CQLPolicy, CQLDiscretePolicy
from .edac import EDACPolicy
from .impala import IMPALAPolicy
from .ngu import NGUPolicy
from .r2d2 import R2D2Policy
Expand Down Expand Up @@ -48,5 +49,7 @@

from .pc import ProcedureCloningBFSPolicy

from .edac import EDACPolicy

# new-type policy
from .ppof import PPOFPolicy
Loading