Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add bcq algorithm #640

Merged
merged 11 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions ding/example/bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import gym
from ditk import logging
from ding.model import BCQ
from ding.policy import BCQPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.halfcheetah_medium_bcq_config import main_config, create_config


def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

dataset = create_dataset(cfg)
model = BCQ(**cfg.policy.model)
policy = BCQPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000000))
task.use(offline_logger())
task.run()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import QACEnsemble
132 changes: 132 additions & 0 deletions ding/model/template/bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Union, Dict, Optional, List
from easydict import EasyDict
import numpy as np
import torch
import torch.nn as nn

from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import RegressionHead, ReparameterizationHead
from .vae import VanillaVAE


@MODEL_REGISTRY.register('bcq')
class BCQ(nn.Module):

mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval']

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
actor_head_hidden_size: List = [400, 300],
critic_head_hidden_size: List = [400, 300],
activation: Optional[nn.Module] = nn.ReLU(),
vae_hidden_dims: List = [750, 750],
phi: float = 0.05
) -> None:
"""
Overview:
Initialize neural network, i.e. agent Q network and actor.
Arguments:
- obs_shape (:obj:`int`): the dimension of observation state
- action_shape (:obj:`int`): the dimension of action shape
- actor_hidden_size (:obj:`list`): the list of hidden size of actor
- critic_hidden_size (:obj:'list'): the list of hidden size of critic
- activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU().
- vae_hidden_dims (:obj:`list`): the list of hidden size of vae
"""
super(BCQ, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add note for arguments.

obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.input_size = obs_shape
self.phi = phi

critic_input_size = self.input_size + action_shape
self.critic = nn.ModuleList()
for _ in range(2):
net = []
d = critic_input_size
for dim in critic_head_hidden_size:
net.append(nn.Linear(d, dim))
net.append(activation)
d = dim
net.append(nn.Linear(d, 1))
self.critic.append(nn.Sequential(*net))

net = []
d = critic_input_size
for dim in actor_head_hidden_size:
net.append(nn.Linear(d, dim))
net.append(activation)
d = dim
net.append(nn.Linear(d, 1))
self.actor = nn.Sequential(*net)

self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims)

def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]:
"""
Overview:
The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \
different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ.
Mode compute_actor:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including action tensor.
Mode compute_critic:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including q_value tensor.
Mode compute_vae:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
Mode compute_eval:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including action tensor.


.. note::
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)

def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
x = [m(x).squeeze() for m in self.critic]
return {'q_value': x}

def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
input = torch.cat([inputs['obs'], inputs['action']], -1)
x = self.actor(input)
action = self.phi * 1 * torch.tanh(x)
action = (action + inputs['action']).clamp(-1, 1)
return {'action': action}

def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.vae.forward(inputs)

def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
obs = inputs['obs']
obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0)
z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5)
sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action']
action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action']
q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0]
idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1)
idx = idx.repeat_interleave(action.shape[-1], dim=-1)
action = action.gather(0, idx).squeeze()
return {'action': action}
2 changes: 2 additions & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@

from .pc import ProcedureCloningBFSPolicy

from .bcq import BCQPolicy

# new-type policy
from .ppof import PPOFPolicy
Loading