Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add EDAC and modify config of td3bc #639

Merged
merged 27 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions ding/example/edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import gym
from ditk import logging
from ding.model import Q_ensemble
from ding.policy import EDACPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.halfcheetah_medium_edac_config import main_config,create_config


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

dataset = create_dataset(cfg)
model = Q_ensemble(**cfg.policy.model)
policy = EDACPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
task.use(offline_logger())
task.run()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, \
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
74 changes: 73 additions & 1 deletion ding/model/common/head.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch.distributions import Normal, Independent

from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt
from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt, conv1d_block
from ding.rl_utils import beta_function_map
from ding.utils import lists_to_dicts, SequenceType

Expand Down Expand Up @@ -1316,6 +1316,77 @@ def forward(self, x: torch.Tensor) -> Dict:
return lists_to_dicts([m(x) for m in self.pred])


class EnsembleHead(nn.Module):
"""
Overview:
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
The ``EnsembleHead`` used to output action Q-value for Q-ensemble. \
Input is a (:obj:`torch.Tensor`) of shape ''(B, N * Ensemble_num, 1)'' and returns a (:obj:`Dict`) containing \
output ``pred``.
Interfaces:
``__init__``, ``forward``.
"""

def __init__(
self,
input_size: int,
output_size: int,
hidden_size: int,
layer_nun: int,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
ensemble_num: int,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
super(EnsembleHead, self).__init__()
d = input_size
layers = []
for _ in range(layer_nun):
layers.append(
conv1d_block(
d * ensemble_num,
hidden_size * ensemble_num,
kernel_size=1,
stride=1,
groups=ensemble_num,
activation=activation,
norm_type=norm_type
)
)
d = hidden_size
layers.append(
conv1d_block(
hidden_size * ensemble_num,
output_size * ensemble_num,
kernel_size=1,
stride=1,
groups=ensemble_num,
activation=None,
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
norm_type=None
)
)
self.pred = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to run MLP with ``EnsembleHead`` and return the prediction dictionary.
Arguments:
- x (:obj:`torch.Tensor`): Tensor containing input embedding.
Returns:
- outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`).
Shapes:
- x: :math:`(B, N * Ensemble_num, 1)`, where ``B = batch_size`` and ``N = hidden_size``.
- pred: :math:`(B, M * Ensemble_num, 1)`, where ``M = output_size``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensemeble_num

Examples:
>>> head = EnsembleHead(64 * 10, 64 * 10)
>>> inputs = torch.randn(4, 64 * 10, 1) `
>>> outputs = head(inputs)
>>> assert isinstance(outputs, dict)
>>> assert outputs['pred'].shape == torch.Size([10, 64 * 10])
"""
x = self.pred(x).squeeze()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use squeeze here, we can get the output with the shape (B, M * ensemble_num, 1)

return {'pred': x}


def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution:
if isinstance(logits, (list, tuple)):
return Independent(Normal(*logits), 1)
Expand All @@ -1341,4 +1412,5 @@ def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Di
'popart': PopArtVHead,
# multi
'multi': MultiHead,
'ensemble': EnsembleHead,
}
9 changes: 8 additions & 1 deletion ding/model/common/tests/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest

from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead
from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
from ding.torch_utils import is_differentiable

B = 4
Expand Down Expand Up @@ -84,3 +84,10 @@ def test_stochastic_dueling(self):
assert isinstance(sigma.grad, torch.Tensor)
assert outputs['q_value'].shape == (B, 1)
assert outputs['v_value'].shape == (B, 1)

def test_ensemble(self):
inputs = torch.randn(B, embedding_dim * 3, 1)
model = EnsembleHead(embedding_dim, action_shape, 3, 3,3)
outputs = model(inputs)['pred']
self.output_check(model, outputs)
assert outputs.shape == (B, action_shape * 3, 1)
4 changes: 2 additions & 2 deletions ding/model/template/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# general
from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ
from .qac import QAC, DiscreteQAC
from .qac import QAC, DiscreteQAC, Q_ensemble
from .pdqn import PDQN
from .vac import VAC
from .bc import DiscreteBC, ContinuousBC
Expand All @@ -22,4 +22,4 @@
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
180 changes: 180 additions & 0 deletions ding/model/template/edac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Union, Optional, Dict
from easydict import EasyDict

import torch
import torch.nn as nn
from ding.model.common import ReparameterizationHead, EnsembleHead
from ding.utils import SequenceType, squeeze

from ding.utils import MODEL_REGISTRY


@MODEL_REGISTRY.register('edac')
class Q_ensemble(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QACEnsemble

r"""
Overview:
The QAC network with ensemble, which is used in EDAC.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
ensemble_num: int = 2,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
**kwargs
) -> None:
"""
Overview:
Initailize the EDAC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- ensemble_num (:obj:`bool`): Q-net numble.
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for actor head.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for critic head.
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
see ``ding.torch_utils.network`` for more details.
"""
super(Q_ensemble, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.ensemble_num = ensemble_num
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)

critic_input_size = obs_shape + action_shape
self.critic = nn.Sequential(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why sequential here

EnsembleHead(
critic_input_size,
1,
critic_head_hidden_size,
critic_head_layer_num,
self.ensemble_num,
activation=activation,
norm_type=norm_type
)
)

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
"""
Overview:
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
Mode compute_actor:
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
Mode compute_critic:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including q_value tensor.
.. note::
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)

def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Overview:
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
such as ``action``, ``logit`` and so on.
Arguments:
- obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
i.e. ``(B, obs_shape)``.
Returns:
- outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
from action_space: ``reparameterization``.
ReturnsKeys (either):
- logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
- mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
- sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
- logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
``action_shape.action_type_shape``.
- action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
``action_shape.action_args_shape``.
Examples:
>>> model = EDAC(64, 64,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong model name

>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
"""
x = self.actor(obs)
return {'logit': [x['mu'], x['sigma']]}

def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
output, such as ``q_value``.
Arguments:
- inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and ``action`` tensor
Returns:
- outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
ArgumentsKeys:
- obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
- action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is ``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 is ``action_shape``.
- q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = EDAC(obs_shape=(8, ),action_shape=1)
>>> model(inputs, mode='compute_critic')['q_value'] # q value
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""

obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
if len(obs.shape) < 3:
# [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
else:
# [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
x = x.transpose(0, 1)
batch_size = obs.shape[1]
x = x.reshape(batch_size, -1, 1)
# [Ensemble_num,batch_size,1]
x = self.critic(x)['pred']
# [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
x = x.permute(1, 0)
return {'q_value': x}
Loading