From 2b72b8b3f8e805aab6db0a312414f703c3631947 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 16 Jul 2024 22:14:12 +0800 Subject: [PATCH 1/5] Add IQL algo --- ding/entry/serial_entry_offline.py | 1 + ding/model/template/qvac.py | 596 ++++++++++++ ding/policy/command_mode_policy_instance.py | 4 + ding/policy/iql.py | 848 ++++++++++++++++++ ding/utils/data/dataset.py | 32 + .../config/halfcheetah_medium_iql_config.py | 53 ++ dizoo/d4rl/entry/d4rl_iql_main.py | 21 + 7 files changed, 1555 insertions(+) create mode 100644 ding/model/template/qvac.py create mode 100644 ding/policy/iql.py create mode 100644 dizoo/d4rl/config/halfcheetah_medium_iql_config.py create mode 100644 dizoo/d4rl/entry/d4rl_iql_main.py diff --git a/ding/entry/serial_entry_offline.py b/ding/entry/serial_entry_offline.py index b92b5c7dda..f24e5404fb 100755 --- a/ding/entry/serial_entry_offline.py +++ b/ding/entry/serial_entry_offline.py @@ -62,6 +62,7 @@ def serial_pipeline_offline( sampler=sampler, collate_fn=lambda x: x, pin_memory=cfg.policy.cuda, + drop_last=True, ) # Env, Policy try: diff --git a/ding/model/template/qvac.py b/ding/model/template/qvac.py new file mode 100644 index 0000000000..9a84558813 --- /dev/null +++ b/ding/model/template/qvac.py @@ -0,0 +1,596 @@ +from typing import Union, Dict, Optional +from easydict import EasyDict +import numpy as np +import torch +import torch.nn as nn + +from ding.utils import SequenceType, squeeze, MODEL_REGISTRY +from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \ + FCEncoder, ConvEncoder + + +@MODEL_REGISTRY.register('continuous_qvac') +class ContinuousQVAC(nn.Module): + """ + Overview: + The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and V-value critic, such as \ + IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is composed of \ + four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ + extract the feature from various observation. Heads are used to predict corresponding Q-value and V-value or action logit. \ + In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ + and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + """ + mode = ['compute_actor', 'compute_critic'] + + def __init__( + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType, EasyDict], + action_space: str, + twin_critic: bool = False, + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 1, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + encoder_hidden_size_list: Optional[SequenceType] = None, + share_encoder: Optional[bool] = False, + ) -> None: + """ + Overview: + Initailize the ContinuousQVAC Model according to input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). + - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ + EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). + - action_space (:obj:`str`): The type of action space, including [``regression``, ``reparameterization``, \ + ``hybrid``], ``regression`` is used for DDPG/TD3, ``reparameterization`` is used for SAC and \ + ``hybrid`` for PADDPG. + - twin_critic (:obj:`bool`): Whether to use twin critic, one of tricks in TD3. + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. + - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. + - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value. + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ + after each FC layer, if ``None`` then default set to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ + see ``ding.torch_utils.network`` for more details. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ + the last element must match ``head_hidden_size``, this argument is only used in image observation. + - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic. + """ + super(ContinuousQVAC, self).__init__() + obs_shape: int = squeeze(obs_shape) + action_shape = squeeze(action_shape) + self.action_shape = action_shape + self.action_space = action_space + assert self.action_space in ['regression', 'reparameterization', 'hybrid'], self.action_space + + # encoder + self.share_encoder = share_encoder + if np.isscalar(obs_shape) or len(obs_shape) == 1: + assert not self.share_encoder, "Vector observation doesn't need share encoder." + assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear" + # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep + # compatible with the image observation and avoid adding an extra layer nn.Linear. + self.actor_encoder = nn.Identity() + self.critic_encoder = nn.Identity() + encoder_output_size = obs_shape + elif len(obs_shape) == 3: + + def setup_conv_encoder(): + kernel_size = [3 for _ in range(len(encoder_hidden_size_list))] + stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)] + return ConvEncoder( + obs_shape, + encoder_hidden_size_list, + activation=activation, + norm_type=norm_type, + kernel_size=kernel_size, + stride=stride + ) + + if self.share_encoder: + encoder = setup_conv_encoder() + self.actor_encoder = self.critic_encoder = encoder + else: + self.actor_encoder = setup_conv_encoder() + self.critic_encoder = setup_conv_encoder() + encoder_output_size = self.actor_encoder.output_size + else: + raise RuntimeError("not support observation shape: {}".format(obs_shape)) + # head + if self.action_space == 'regression': # DDPG, TD3 + self.actor_head = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + RegressionHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + final_tanh=True, + activation=activation, + norm_type=norm_type + ) + ) + elif self.action_space == 'reparameterization': # SAC + self.actor_head = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + ReparameterizationHead( + actor_head_hidden_size, + action_shape, + actor_head_layer_num, + sigma_type='conditioned', + activation=activation, + norm_type=norm_type + ) + ) + elif self.action_space == 'hybrid': # PADDPG + # hybrid action space: action_type(discrete) + action_args(continuous), + # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} + action_shape.action_args_shape = squeeze(action_shape.action_args_shape) + action_shape.action_type_shape = squeeze(action_shape.action_type_shape) + actor_action_args = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + RegressionHead( + actor_head_hidden_size, + action_shape.action_args_shape, + actor_head_layer_num, + final_tanh=True, + activation=activation, + norm_type=norm_type + ) + ) + actor_action_type = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + DiscreteHead( + actor_head_hidden_size, + action_shape.action_type_shape, + actor_head_layer_num, + activation=activation, + norm_type=norm_type, + ) + ) + self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) + + self.twin_critic = twin_critic + if self.action_space == 'hybrid': + critic_q_input_size = encoder_output_size + action_shape.action_type_shape + action_shape.action_args_shape + critic_v_input_size = encoder_output_size + else: + critic_q_input_size = encoder_output_size + action_shape + critic_v_input_size = encoder_output_size + if self.twin_critic: + self.critic_q_head = nn.ModuleList() + self.critic_v_head = nn.ModuleList() + for _ in range(2): + self.critic_q_head.append( + nn.Sequential( + nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + ) + self.critic_v_head.append( + nn.Sequential( + nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + ) + else: + self.critic_q_head = nn.Sequential( + nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + self.critic_v_head = nn.Sequential( + nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, + RegressionHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + final_tanh=False, + activation=activation, + norm_type=norm_type + ) + ) + + # Convenient for calling some apis (e.g. self.critic.parameters()), + # but may cause misunderstanding when `print(self)` + self.actor = nn.ModuleList([self.actor_encoder, self.actor_head]) + self.critic = nn.ModuleList([self.critic_encoder, self.critic_q_head, self.critic_v_head]) + + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \ + ``mode`` will forward with different network modules to get different outputs and save computation. + Arguments: + - inputs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The input data for forward computation \ + graph, for ``compute_actor``, it is the observation tensor, for ``compute_critic``, it is the \ + dict data including obs and action tensor. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. + Returns: + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph, whose \ + key-values vary in different forward modes. + Examples (Actor): + >>> # Regression mode + >>> model = ContinuousQVAC(64, 6, 'regression') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) + >>> # Reparameterization Mode + >>> model = ContinuousQVAC(64, 6, 'reparameterization') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu + >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma + + Examples (Critic): + >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} + >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') + >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value + """ + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(inputs) + + def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + Overview: + QVAC forward computation graph for actor part, input observation tensor to predict action or action logit. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output dict varying \ + from action_space: ``regression``, ``reparameterization``, ``hybrid``. + ReturnsKeys (regression): + - action (:obj:`torch.Tensor`): Continuous action with same size as ``action_shape``, usually in DDPG/TD3. + ReturnsKeys (reparameterization): + - logit (:obj:`Dict[str, torch.Tensor]`): The predictd reparameterization action logit, usually in SAC. \ + It is a list containing two tensors: ``mu`` and ``sigma``. The former is the mean of the gaussian \ + distribution, the latter is the standard deviation of the gaussian distribution. + ReturnsKeys (hybrid): + - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \ + as ``action_type_shape``, i.e., all the possible discrete action types. + - action_args (:obj:`torch.Tensor`): Continuous action arguments with same size as ``action_args_shape``. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. + - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. + - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. + - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ + ``action_shape.action_type_shape``. + - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ + ``action_shape.action_args_shape``. + Examples: + >>> # Regression mode + >>> model = ContinuousQVAC(64, 6, 'regression') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) + >>> # Reparameterization Mode + >>> model = ContinuousQVAC(64, 6, 'reparameterization') + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu + >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma + """ + obs = self.actor_encoder(obs) + if self.action_space == 'regression': + x = self.actor_head(obs) + return {'action': x['pred']} + elif self.action_space == 'reparameterization': + x = self.actor_head(obs) + return {'logit': [x['mu'], x['sigma']]} + elif self.action_space == 'hybrid': + logit = self.actor_head[0](obs) + action_args = self.actor_head[1](obs) + return {'logit': logit['logit'], 'action_args': action_args['pred']} + + def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph for critic part, input observation and action tensor to predict Q-value. + Arguments: + - inputs (:obj:`Dict[str, torch.Tensor]`): The dict of input data, including ``obs`` and ``action`` \ + tensor, also contains ``logit`` and ``action_args`` tensor in hybrid action_space. + ArgumentsKeys: + - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. + - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. + - logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space. + - action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space. + Returns: + - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC's forward computation graph for critic, \ + including ``q_value``. + ReturnKeys: + - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``. + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ + ``action_shape.action_type_shape``. + - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ + ``action_shape.action_args_shape``. + - action (:obj:`torch.Tensor`): :math:`(B, N4)`, where B is batch size and N4 is ``action_shape``. + - q_value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size. + + Examples: + >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} + >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') + >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value + """ + + obs, action = inputs['obs'], inputs['action'] + obs = self.critic_encoder(obs) + assert len(obs.shape) == 2 + if self.action_space == 'hybrid': + action_type_logit = inputs['logit'] + action_type_logit = torch.softmax(action_type_logit, dim=-1) + action_args = action['action_args'] + if len(action_args.shape) == 1: + action_args = action_args.unsqueeze(1) + x = torch.cat([obs, action_type_logit, action_args], dim=1) + else: + if len(action.shape) == 1: # (B, ) -> (B, 1) + action = action.unsqueeze(1) + x = torch.cat([obs, action], dim=1) + if self.twin_critic: + x = [m(x)['pred'] for m in self.critic_q_head] + y = [m(obs)['pred'] for m in self.critic_v_head] + else: + x = self.critic_q_head(x)['pred'] + y = self.critic_v_head(obs)['pred'] + return {'q_value': x, 'v_value': y} + + +@MODEL_REGISTRY.register('discrete_qvac') +class DiscreteQVAC(nn.Module): + """ + Overview: + The neural network and computation graph of algorithms related to discrete action Actor-Critic that have both Q-value and V-value critic, \ + such as Discrete action IQL. This model now supports only discrete action space. The DiscreteQVAC is composed of \ + four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ + extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. \ + In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ + and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. + Interfaces: + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + """ + mode = ['compute_actor', 'compute_critic'] + + def __init__( + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + twin_critic: bool = False, + actor_head_hidden_size: int = 64, + actor_head_layer_num: int = 1, + critic_head_hidden_size: int = 64, + critic_head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + encoder_hidden_size_list: SequenceType = None, + share_encoder: Optional[bool] = False, + ) -> None: + """ + Overview: + Initailize the DiscreteQVAC Model according to input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). + - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ). + - twin_critic (:obj:`bool`): Whether to use twin critic. + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. + - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. + - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value. + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ + after each FC layer, if ``None`` then default set to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ + see ``ding.torch_utils.network`` for more details. + - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ + the last element must match ``head_hidden_size``, this argument is only used in image observation. + - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic. + """ + super(DiscreteQVAC, self).__init__() + obs_shape: int = squeeze(obs_shape) + action_shape: int = squeeze(action_shape) + # encoder + self.share_encoder = share_encoder + if np.isscalar(obs_shape) or len(obs_shape) == 1: + assert not self.share_encoder, "Vector observation doesn't need share encoder." + assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear" + # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep + # compatible with the image observation and avoid adding an extra layer nn.Linear. + self.actor_encoder = nn.Identity() + self.critic_encoder = nn.Identity() + encoder_output_size = obs_shape + elif len(obs_shape) == 3: + + def setup_conv_encoder(): + kernel_size = [3 for _ in range(len(encoder_hidden_size_list))] + stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)] + return ConvEncoder( + obs_shape, + encoder_hidden_size_list, + activation=activation, + norm_type=norm_type, + kernel_size=kernel_size, + stride=stride + ) + + if self.share_encoder: + encoder = setup_conv_encoder() + self.actor_encoder = self.critic_encoder = encoder + else: + self.actor_encoder = setup_conv_encoder() + self.critic_encoder = setup_conv_encoder() + encoder_output_size = self.actor_encoder.output_size + else: + raise RuntimeError("not support observation shape: {}".format(obs_shape)) + + # head + self.actor_head = nn.Sequential( + nn.Linear(encoder_output_size, actor_head_hidden_size), activation, + DiscreteHead( + actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type + ) + ) + + self.twin_critic = twin_critic + if self.twin_critic: + self.critic_q_head = nn.ModuleList() + self.critic_v_head = nn.ModuleList() + for _ in range(2): + self.critic_q_head.append( + nn.Sequential( + nn.Linear(encoder_output_size, critic_head_hidden_size), activation, + DiscreteHead( + critic_head_hidden_size, + action_shape, + critic_head_layer_num, + activation=activation, + norm_type=norm_type + ) + ) + ) + self.critic_v_head.append( + nn.Sequential( + nn.Linear(encoder_output_size, critic_head_hidden_size), activation, + DiscreteHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + activation=activation, + norm_type=norm_type + ) + ) + ) + + else: + self.critic_q_head = nn.Sequential( + nn.Linear(encoder_output_size, critic_head_hidden_size), activation, + DiscreteHead( + critic_head_hidden_size, + action_shape, + critic_head_layer_num, + activation=activation, + norm_type=norm_type + ) + ) + self.critic_v_head = nn.Sequential( + nn.Linear(encoder_output_size, critic_head_hidden_size), activation, + DiscreteHead( + critic_head_hidden_size, + 1, + critic_head_layer_num, + activation=activation, + norm_type=norm_type + ) + ) + # Convenient for calling some apis (e.g. self.critic.parameters()), + # but may cause misunderstanding when `print(self)` + self.actor = nn.ModuleList([self.actor_encoder, self.actor_head]) + self.critic = nn.ModuleList([self.critic_encoder, self.critic_q_head, self.critic_v_head]) + + def forward(self, inputs: torch.Tensor, mode: str) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \ + ``mode`` will forward with different network modules to get different outputs and save computation. + Arguments: + - inputs (:obj:`torch.Tensor`): The input observation tensor data. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. + Returns: + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph, whose \ + key-values vary in different forward modes. + Examples (Actor): + >>> model = DiscreteQVAC(64, 6) + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'].shape == torch.Size([4, 6]) + + Examples(Critic): + >>> model = DiscreteQVAC(64, 6, twin_critic=False) + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_critic') + >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6]) + """ + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(inputs) + + def compute_actor(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph for actor part, input observation tensor to predict action or action logit. + Arguments: + - inputs (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph for actor, \ + including discrete action ``logit``. + ReturnsKeys: + - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \ + as ``action_shape``, i.e., all the possible discrete action choices. + Shapes: + - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. + - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ + ``action_shape``. + Examples: + >>> model = DiscreteQVAC(64, 6) + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_actor') + >>> assert actor_outputs['logit'].shape == torch.Size([4, 6]) + """ + x = self.actor_encoder(inputs) + x = self.actor_head(x) + return {'logit': x['logit']} + + def compute_critic(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Overview: + QVAC forward computation graph for critic part, input observation to predict Q-value for each possible \ + discrete action choices. + Arguments: + - inputs (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph for critic, \ + including ``q_value`` for each possible discrete action choices. + ReturnKeys: + - q_value (:obj:`torch.Tensor`): The predicted Q-value for each possible discrete action choices, it will \ + be the same dimension as ``action_shape`` and used to calculate the loss. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``. + - q_value (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``. + Examples: + >>> model = DiscreteQVAC(64, 6, twin_critic=False) + >>> obs = torch.randn(4, 64) + >>> actor_outputs = model(obs,'compute_critic') + >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6]) + """ + inputs = self.critic_encoder(inputs) + if self.twin_critic: + x = [m(inputs)['logit'] for m in self.critic_q_head] + y = [m(inputs)['logit'] for m in self.critic_v_head] + else: + x = self.critic_q_head(inputs)['logit'] + y = self.critic_v_head(inputs)['logit'] + return {'q_value': x, 'v_value': y} diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2e817ead4b..f779cdc316 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -43,6 +43,7 @@ from .d4pg import D4PGPolicy from .cql import CQLPolicy, DiscreteCQLPolicy +from .iql import IQLPolicy from .dt import DTPolicy from .pdqn import PDQNPolicy from .madqn import MADQNPolicy @@ -320,6 +321,9 @@ class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): pass +@POLICY_REGISTRY.register('iql_command') +class CQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): + pass @POLICY_REGISTRY.register('discrete_cql_command') class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): diff --git a/ding/policy/iql.py b/ding/policy/iql.py new file mode 100644 index 0000000000..2fd5fc4772 --- /dev/null +++ b/ding/policy/iql.py @@ -0,0 +1,848 @@ +from typing import List, Dict, Any, Tuple, Union +import copy +from collections import namedtuple +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Normal, Independent, TransformedDistribution +from torch.distributions.transforms import TanhTransform + +from ding.torch_utils import Adam, to_device +from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ + qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy +from .common_utils import default_preprocess_learn + + +def asymmetric_l2_loss(u, tau): + return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) + + +@POLICY_REGISTRY.register('iql') +class IQLPolicy(Policy): + """ + Overview: + Policy class of Implicit Q-Learning (IQL) algorithm for continuous control. Paper link: https://arxiv.org/abs/2110.06169. + + Config: + == ==================== ======== ============= ================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============= ================================= ======================= + 1 ``type`` str cql | RL policy register name, refer | this arg is optional, + | to registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool True | Whether to use cuda for network | + 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for + | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ + | | buffer when training starts. | TD3. + 4 | ``model.policy_`` int 256 | Linear layer size for policy | + | ``embedding_size`` | network. | + 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | + | ``embedding_size`` | network. | + 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when + | ``embedding_size`` | network. | model.value_network + | | | is False. + 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when + | ``_rate_q`` | network. | model.value_network + | | | is True. + 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when + | ``_rate_policy`` | network. | model.value_network + | | | is True. + 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when + | ``_rate_value`` | network. | model.value_network + | | | is False. + 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- + | | coefficient. | zation for auto + | | | `alpha`, when + | | | auto_alpha is True + 11 | ``learn.repara_`` bool True | Determine whether to use | + | ``meterization`` | reparameterization trick. | + 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter + | ``auto_alpha`` | auto temperature parameter | determines the + | | `alpha`. | relative importance + | | | of the entropy term + | | | against the reward. + 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only + | ``ignore_done`` | done flag. | in halfcheetah env. + 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation + | ``target_theta`` | target network. | factor in polyak aver + | | | aging for target + | | | networks. + == ==================== ======== ============= ================================= ======================= + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='iql', + # (bool) Whether to use cuda for policy. + cuda=False, + # (bool) on_policy: Determine whether on-policy or off-policy. + # on-policy setting influences the behaviour of buffer. + on_policy=False, + # (bool) priority: Determine whether to use priority in buffer sample. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of training samples(randomly collected) in replay buffer when training starts. + random_collect_size=10000, + model=dict( + # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. + # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . + # Default to True. + twin_critic=True, + # (str type) action_space: Use reparameterization trick for continous action + action_space='reparameterization', + # (int) Hidden size for actor network head. + actor_head_hidden_size=256, + # (int) Hidden size for critic network head. + critic_head_hidden_size=256, + ), + # learn_mode config + learn=dict( + # (int) How many updates (iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + update_per_collect=1, + # (int) Minibatch size for gradient descent. + batch_size=256, + # (float) learning_rate_q: Learning rate for soft q network. + learning_rate_q=3e-4, + # (float) learning_rate_policy: Learning rate for policy network. + learning_rate_policy=3e-4, + # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. + learning_rate_alpha=3e-4, + # (float) target_theta: Used for soft update of the target network, + # aka. Interpolation factor in polyak averaging for target networks. + target_theta=0.005, + # (float) discount factor for the discounted sum of rewards, aka. gamma. + discount_factor=0.99, + # (float) alpha: Entropy regularization coefficient. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. + # Default to 0.2. + alpha=0.2, + # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . + # Temperature parameter determines the relative importance of the entropy term against the reward. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # Default to False. + # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. + auto_alpha=True, + # (bool) log_space: Determine whether to use auto `\alpha` in log space. + log_space=True, + # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + # (float) Weight uniform initialization range in the last output layer. + init_w=3e-3, + # (int) The numbers of action sample each at every state s from a uniform-at-random. + num_actions=10, + # (bool) Whether use lagrange multiplier in q value loss. + with_lagrange=False, + # (float) The threshold for difference in Q-values. + lagrange_thresh=-1, + # (float) Loss weight for conservative item. + min_q_weight=1.0, + # (float) coefficient for the asymmetric loss, range from [0.5, 1.0], default to 0.70. + tau=0.7, + # (float) temperature coefficient for Advantage Weighted Regression loss, default to 1.0. + beta=1.0, + ), + eval=dict(), # for compatibility + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ + + return 'continuous_qvac', ['ding.model.template.qvac'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange, \ + main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ + target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self._twin_critic = self._cfg.model.twin_critic + self._num_actions = self._cfg.learn.num_actions + + self._min_q_version = 3 + self._min_q_weight = self._cfg.learn.min_q_weight + self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) + self._lagrange_thresh = self._cfg.learn.lagrange_thresh + if self._with_lagrange: + self.target_action_gap = self._lagrange_thresh + self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() + self.alpha_prime_optimizer = Adam( + [self.log_alpha_prime], + lr=self._cfg.learn.learning_rate_q, + ) + + # Weight Init + init_w = self._cfg.learn.init_w + self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) + self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) + self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) + self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) + if self._twin_critic: + self._model.critic_q_head[0][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[0][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_q_head[1][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[1][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[0][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[0][-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[1][-1].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[1][-1].last.bias.data.uniform_(-init_w, init_w) + else: + self._model.critic_q_head[2].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_q_head[-1].last.bias.data.uniform_(-init_w, init_w) + self._model.critic_v_head[2].last.weight.data.uniform_(-init_w, init_w) + self._model.critic_v_head[-1].last.bias.data.uniform_(-init_w, init_w) + + # Optimizers + self._optimizer_q = Adam( + self._model.critic.parameters(), + lr=self._cfg.learn.learning_rate_q, + ) + self._optimizer_policy = Adam( + self._model.actor.parameters(), + lr=self._cfg.learn.learning_rate_policy, + ) + + # Algorithm config + self._gamma = self._cfg.learn.discount_factor + + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + self._forward_learn_cnt = 0 + + self._tau = self._cfg.learn.tau + self._beta = self._cfg.learn.beta + self._policy_start_training_counter=300000 + + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + loss_dict = {} + data = default_preprocess_learn( + data, + use_priority=self._priority, + use_priority_IS_weight=self._cfg.priority_IS_weight, + ignore_done=self._cfg.learn.ignore_done, + use_nstep=False + ) + if len(data.get('action').shape) == 1: + data['action'] = data['action'].reshape(-1, 1) + + if self._cuda: + data = to_device(data, self._device) + + self._learn_model.train() + obs = data['obs'] + next_obs = data['next_obs'] + reward = data['reward'] + done = data['done'] + + # 1. predict q and v value + value = self._learn_model.forward(data, mode='compute_critic') + q_value, v_value= value['q_value'], value['v_value'] + + # 2. predict target value + with torch.no_grad(): + (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] + + # dist = Independent(Normal(mu, sigma), 1) + # pred = dist.rsample() + # next_action = torch.tanh(pred) + # y = 1 - next_action.pow(2) + 1e-6 + # next_log_prob = dist.log_prob(pred).unsqueeze(-1) + # next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) + + next_obs_dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1)]) + next_action = next_obs_dist.rsample() + next_log_prob = next_obs_dist.log_prob(next_action) + + next_data = {'obs': next_obs, 'action': next_action} + next_value = self._learn_model.forward(next_data, mode='compute_critic') + next_q_value, next_v_value = next_value['q_value'], next_value['v_value'] + + # the value of a policy according to the maximum entropy objective + if self._twin_critic: + next_q_value = torch.min(next_q_value[0], next_q_value[1]) + + # 3. compute v loss + if self._twin_critic: + q_value_min = torch.min(q_value[0], q_value[1]).detach() + v_loss_0 = asymmetric_l2_loss(q_value_min - v_value[0], self._tau) + v_loss_1 = asymmetric_l2_loss(q_value_min - v_value[1], self._tau) + v_loss = (v_loss_0 + v_loss_1) / 2 + else: + advantage = q_value.detach() - v_value + v_loss = asymmetric_l2_loss(advantage, self._tau) + + # 4. compute q loss + if self._twin_critic: + next_v_value = torch.min(next_v_value[0], next_v_value[1]) + q_data0 = v_1step_td_data(q_value[0], next_v_value, reward, done, data['weight']) + loss_dict['critic_q_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) + q_data1 = v_1step_td_data(q_value[1], next_v_value, reward, done, data['weight']) + loss_dict['twin_critic_q_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) + q_loss = (loss_dict['critic_q_loss'] + loss_dict['twin_critic_q_loss']) / 2 + else: + q_data = v_1step_td_data(q_value, next_v_value, reward, done, data['weight']) + loss_dict['critic_q_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) + q_loss = loss_dict['critic_q_loss'] + + # 5. update q and v network + self._optimizer_q.zero_grad() + v_loss.backward() + q_loss.backward() + self._optimizer_q.step() + + # 6. evaluate to get action distribution + (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] + # dist = Independent(Normal(mu, sigma), 1) + # pred = dist.rsample() + # action = torch.tanh(pred) + # y = 1 - action.pow(2) + 1e-6 + # log_prob = dist.log_prob(pred).unsqueeze(-1) + # log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) + + dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1)]) + action = dist.rsample() + log_prob = dist.log_prob(action) + + eval_data = {'obs': obs, 'action': action} + new_value = self._learn_model.forward(eval_data, mode='compute_critic') + new_q_value, new_v_value = new_value['q_value'], new_value['v_value'] + if self._twin_critic: + new_q_value = torch.min(new_q_value[0], new_q_value[1]) + new_v_value = torch.min(new_v_value[0], new_v_value[1]) + new_advantage = new_q_value - new_v_value + + # 8. compute policy loss + policy_loss = (- self._beta * log_prob * torch.exp(new_advantage.detach()).clamp(max=1000.0)).mean() + self._policy_start_training_counter -= 1 + if self._policy_start_training_counter > 0: + policy_loss = policy_loss * 0.0 + + loss_dict['policy_loss'] = policy_loss + + # 9. update policy network + self._optimizer_policy.zero_grad() + loss_dict['policy_loss'].backward() + self._optimizer_policy.step() + + loss_dict['total_loss'] = sum(loss_dict.values()) + + # ============= + # after update + # ============= + self._forward_learn_cnt += 1 + + return { + 'cur_lr_q': self._optimizer_q.defaults['lr'], + 'cur_lr_p': self._optimizer_policy.defaults['lr'], + 'priority': q_loss.abs().tolist(), + 'q_loss': q_loss.detach().mean().item(), + 'v_loss': v_loss.detach().mean().item(), + 'log_prob': log_prob.detach().mean().item(), + 'next_q_value': next_q_value.detach().mean().item(), + 'next_v_value': next_v_value.detach().mean().item(), + 'policy_loss': policy_loss.detach().mean().item(), + 'total_loss': loss_dict['total_loss'].detach().item(), + 'advantage_max': new_advantage.max().detach().item(), + 'new_q_value': new_q_value.detach().mean().item(), + 'new_v_value': new_v_value.detach().mean().item(), + } + + def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: + # evaluate to get action distribution + obs = data['obs'] + obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) + (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] + dist = Independent(Normal(mu, sigma), 1) + pred = dist.rsample() + action = torch.tanh(pred) + + # evaluate action log prob depending on Jacobi determinant. + y = 1 - action.pow(2) + epsilon + log_prob = dist.log_prob(pred).unsqueeze(-1) + log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) + + return action, log_prob.view(-1, num_actions, 1) + + def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: + new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] + if self._twin_critic: + new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] + else: + new_q_value = new_q_value.view(-1, self._num_actions, 1) + if self._twin_critic and not keep: + new_q_value = torch.min(new_q_value[0], new_q_value[1]) + return new_q_value + + def _get_v_value(self, data: Dict, keep: bool = True) -> torch.Tensor: + new_v_value = self._learn_model.forward(data, mode='compute_critic')['v_value'] + if self._twin_critic: + new_v_value = [value.view(-1, self._num_actions, 1) for value in new_v_value] + else: + new_v_value = new_v_value.view(-1, self._num_actions, 1) + if self._twin_critic and not keep: + new_v_value = torch.min(new_v_value[0], new_v_value[1]) + return new_v_value + + def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ + self._unroll_len = self._cfg.collect.unroll_len + self._collect_model = model_wrap(self._model, wrapper_name='base') + self._collect_model.reset() + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + with torch.no_grad(): + (mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit'] + dist = Independent(Normal(mu, sigma), 1) + action = torch.tanh(dist.rsample()) + output = {'logit': (mu, sigma), 'action': action} + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \ + will be also added when ``collector_logit`` is True. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ + if self._cfg.collect.collector_logit: + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'logit': policy_output['logit'], + 'action': policy_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + else: + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': policy_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In continuous SAC, a train sample is a processed transition \ + (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) + + def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \ + eval model, which is equipped with ``base`` model wrapper to ensure compability. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit'] + action = torch.tanh(mu) # deterministic_eval + output = {'action': action} + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + twin_critic = ['twin_critic_loss'] if self._twin_critic else [] + return [ + 'cur_lr_q', + 'cur_lr_p', + 'value_loss' + 'policy_loss', + 'q_loss', + 'v_loss', + 'policy_loss', + 'log_prob', + 'total_loss', + 'advantage_max', + 'next_q_value', + 'next_v_value', + 'new_q_value', + 'new_v_value', + ] + twin_critic + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'optimizer_q': self._optimizer_q.state_dict(), + 'optimizer_policy': self._optimizer_policy.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._optimizer_q.load_state_dict(state_dict['optimizer_q']) + self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) + +@POLICY_REGISTRY.register('discrete_iql') +class DiscreteIQLPolicy(Policy): + """ + Overview: + Policy class of discrete Implicit Q-Learning (IQL) algorithm in discrete action space environments. + Paper link: https://arxiv.org/abs/2110.06169. + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='discrete_cql', + # (bool) Whether to use cuda for policy. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. + on_policy=False, + # (bool) Whether use priority(priority sample, IS weight, update priority) + priority=False, + # (float) Reward's future discount factor, aka. gamma. + discount_factor=0.97, + # (int) N-step reward for target q_value estimation + nstep=1, + # learn_mode config + learn=dict( + # (int) How many updates (iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + update_per_collect=1, + # (int) Minibatch size for one gradient descent. + batch_size=64, + # (float) Learning rate for soft q network. + learning_rate=0.001, + # (int) Frequence of target network update. + target_update_freq=100, + # (bool) Whether ignore done(usually for max step termination env). + ignore_done=False, + # (float) Loss weight for conservative item. + min_q_weight=1.0, + ), + eval=dict(), # for compatibility + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ + return 'discrete_qvac', ['ding.model.template.qvac'] + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ + contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ + target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + self._min_q_weight = self._cfg.learn.min_q_weight + self._priority = self._cfg.priority + # Optimizer + self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) + + self._gamma = self._cfg.discount_factor + self._nstep = self._cfg.nstep + + # use wrapper instead of plugin + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._learn_model.reset() + self._target_model.reset() + + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ + and ``value_gamma`` for nstep return computation. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + data = default_preprocess_learn( + data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True + ) + if self._cuda: + data = to_device(data, self._device) + if data['action'].dim() == 2 and data['action'].shape[-1] == 1: + data['action'] = data['action'].squeeze(-1) + # ==================== + # Q-learning forward + # ==================== + self._learn_model.train() + self._target_model.train() + # Current q value (main model) + ret = self._learn_model.forward(data['obs']) + q_value, tau = ret['q'], ret['tau'] + # Target q value + with torch.no_grad(): + target_q_value = self._target_model.forward(data['next_obs'])['q'] + # Max q value action (main model) + target_q_action = self._learn_model.forward(data['next_obs'])['action'] + + # add CQL + # 1. chose action and compute q in dataset. + # 2. compute value loss(negative_sampling - dataset_expec) + replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape) + replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1) + + dataset_expec = replay_chosen_q.mean() + + negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean() + + min_q_loss = negative_sampling - dataset_expec + + data_n = qrdqn_nstep_td_data( + q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight'] + ) + value_gamma = data.get('value_gamma') + loss, td_error_per_sample = qrdqn_nstep_td_error( + data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma + ) + + loss += self._min_q_weight * min_q_loss + + # ==================== + # Q-learning update + # ==================== + self._optimizer.zero_grad() + loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + self._optimizer.step() + + # ============= + # after update + # ============= + self._target_model.update(self._learn_model.state_dict()) + return { + 'cur_lr': self._optimizer.defaults['lr'], + 'total_loss': loss.item(), + 'priority': td_error_per_sample.abs().tolist(), + 'q_target': target_q_value.mean().item(), + 'q_value': q_value.mean().item(), + # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. + # '[histogram]action_distribution': data['action'], + } + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + return ['cur_lr', 'total_loss', 'q_target', 'q_value'] diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index f29dd3335a..bd24e2eaf8 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -114,6 +114,38 @@ def __init__(self, cfg: dict) -> None: except (KeyError, AttributeError): # do not normalize pass + if hasattr(cfg.env, "reward_norm"): + if cfg.env.reward_norm == "normalize": + dataset['rewards'] = (dataset['rewards'] - dataset['rewards'].mean()) / dataset['rewards'].std() + elif cfg.env.reward_norm == "iql_antmaze": + dataset['rewards'] = dataset['rewards'] - 1.0 + elif cfg.env.reward_norm == "iql_locomotion": + + def return_range(dataset, max_episode_steps): + returns, lengths = [], [] + ep_ret, ep_len = 0.0, 0 + for r, d in zip(dataset["rewards"], dataset["terminals"]): + ep_ret += float(r) + ep_len += 1 + if d or ep_len == max_episode_steps: + returns.append(ep_ret) + lengths.append(ep_len) + ep_ret, ep_len = 0.0, 0 + # returns.append(ep_ret) # incomplete trajectory + lengths.append(ep_len) # but still keep track of number of steps + assert sum(lengths) == len(dataset["rewards"]) + return min(returns), max(returns) + + min_ret, max_ret = return_range(dataset, 1000) + dataset['rewards'] /= max_ret - min_ret + dataset['rewards'] *= 1000 + elif cfg.env.reward_norm == "cql_antmaze": + dataset['rewards'] = (dataset['rewards'] - 0.5) * 4.0 + elif cfg.env.reward_norm == "antmaze": + dataset['rewards'] = (dataset['rewards'] - 0.25) * 2.0 + else: + raise NotImplementedError + self._data = [] self._load_d4rl(dataset) diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py new file mode 100644 index 0000000000..cb6c36c541 --- /dev/null +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -0,0 +1,53 @@ +# You can conduct Experiments on D4RL with this config file through the following command: +# cd ../entry && python d4rl_iql_main.py +from easydict import EasyDict + +main_config = dict( + exp_name="halfcheetah_medium_iql_seed0", + env=dict( + env_id='halfcheetah-medium-v2', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + reward_norm="iql_locomotion", + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=17, + action_shape=6, + ), + learn=dict( + data_path=None, + train_epoch=30000, + batch_size=4096, + learning_rate_q=3e-4, + learning_rate_policy=1e-4, + beta=1.0, + tau=0.7, + ), + collect=dict(data_type='d4rl', ), + eval=dict(evaluator=dict(eval_freq=5000, )), + other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='d4rl', + import_names=['dizoo.d4rl.envs.d4rl_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='iql', + import_names=['ding.policy.iql'], + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config diff --git a/dizoo/d4rl/entry/d4rl_iql_main.py b/dizoo/d4rl/entry/d4rl_iql_main.py new file mode 100644 index 0000000000..ded097ee42 --- /dev/null +++ b/dizoo/d4rl/entry/d4rl_iql_main.py @@ -0,0 +1,21 @@ +from ding.entry import serial_pipeline_offline +from ding.config import read_config +from pathlib import Path + + +def train(args): + # launch from anywhere + config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = read_config(str(config)) + config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) + serial_pipeline_offline(config, seed=args.seed) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--seed', '-s', type=int, default=10) + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_iql_config.py') + args = parser.parse_args() + train(args) From e43074a0e9f777f93f1f02beeb463db991a3e240 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Mon, 29 Jul 2024 18:05:17 +0800 Subject: [PATCH 2/5] Polish IQL Algorithm --- ding/model/template/qvac.py | 235 +---------------- ding/policy/command_mode_policy_instance.py | 2 +- ding/policy/iql.py | 241 ++---------------- .../config/halfcheetah_medium_iql_config.py | 3 +- 4 files changed, 24 insertions(+), 457 deletions(-) diff --git a/ding/model/template/qvac.py b/ding/model/template/qvac.py index 9a84558813..4b8c470f83 100644 --- a/ding/model/template/qvac.py +++ b/ding/model/template/qvac.py @@ -34,7 +34,7 @@ def __init__( actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), + activation: Optional[nn.Module] = nn.SiLU(), #nn.ReLU(), norm_type: Optional[str] = None, encoder_hidden_size_list: Optional[SequenceType] = None, share_encoder: Optional[bool] = False, @@ -361,236 +361,3 @@ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten x = self.critic_q_head(x)['pred'] y = self.critic_v_head(obs)['pred'] return {'q_value': x, 'v_value': y} - - -@MODEL_REGISTRY.register('discrete_qvac') -class DiscreteQVAC(nn.Module): - """ - Overview: - The neural network and computation graph of algorithms related to discrete action Actor-Critic that have both Q-value and V-value critic, \ - such as Discrete action IQL. This model now supports only discrete action space. The DiscreteQVAC is composed of \ - four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ - extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. \ - In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ - and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. - Interfaces: - ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` - """ - mode = ['compute_actor', 'compute_critic'] - - def __init__( - self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType], - twin_critic: bool = False, - actor_head_hidden_size: int = 64, - actor_head_layer_num: int = 1, - critic_head_hidden_size: int = 64, - critic_head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None, - encoder_hidden_size_list: SequenceType = None, - share_encoder: Optional[bool] = False, - ) -> None: - """ - Overview: - Initailize the DiscreteQVAC Model according to input arguments. - Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). - - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ). - - twin_critic (:obj:`bool`): Whether to use twin critic. - - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. - - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action. - - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. - - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value. - - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ - after each FC layer, if ``None`` then default set to ``nn.ReLU()``. - - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ - see ``ding.torch_utils.network`` for more details. - - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ - the last element must match ``head_hidden_size``, this argument is only used in image observation. - - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic. - """ - super(DiscreteQVAC, self).__init__() - obs_shape: int = squeeze(obs_shape) - action_shape: int = squeeze(action_shape) - # encoder - self.share_encoder = share_encoder - if np.isscalar(obs_shape) or len(obs_shape) == 1: - assert not self.share_encoder, "Vector observation doesn't need share encoder." - assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear" - # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep - # compatible with the image observation and avoid adding an extra layer nn.Linear. - self.actor_encoder = nn.Identity() - self.critic_encoder = nn.Identity() - encoder_output_size = obs_shape - elif len(obs_shape) == 3: - - def setup_conv_encoder(): - kernel_size = [3 for _ in range(len(encoder_hidden_size_list))] - stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)] - return ConvEncoder( - obs_shape, - encoder_hidden_size_list, - activation=activation, - norm_type=norm_type, - kernel_size=kernel_size, - stride=stride - ) - - if self.share_encoder: - encoder = setup_conv_encoder() - self.actor_encoder = self.critic_encoder = encoder - else: - self.actor_encoder = setup_conv_encoder() - self.critic_encoder = setup_conv_encoder() - encoder_output_size = self.actor_encoder.output_size - else: - raise RuntimeError("not support observation shape: {}".format(obs_shape)) - - # head - self.actor_head = nn.Sequential( - nn.Linear(encoder_output_size, actor_head_hidden_size), activation, - DiscreteHead( - actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type - ) - ) - - self.twin_critic = twin_critic - if self.twin_critic: - self.critic_q_head = nn.ModuleList() - self.critic_v_head = nn.ModuleList() - for _ in range(2): - self.critic_q_head.append( - nn.Sequential( - nn.Linear(encoder_output_size, critic_head_hidden_size), activation, - DiscreteHead( - critic_head_hidden_size, - action_shape, - critic_head_layer_num, - activation=activation, - norm_type=norm_type - ) - ) - ) - self.critic_v_head.append( - nn.Sequential( - nn.Linear(encoder_output_size, critic_head_hidden_size), activation, - DiscreteHead( - critic_head_hidden_size, - 1, - critic_head_layer_num, - activation=activation, - norm_type=norm_type - ) - ) - ) - - else: - self.critic_q_head = nn.Sequential( - nn.Linear(encoder_output_size, critic_head_hidden_size), activation, - DiscreteHead( - critic_head_hidden_size, - action_shape, - critic_head_layer_num, - activation=activation, - norm_type=norm_type - ) - ) - self.critic_v_head = nn.Sequential( - nn.Linear(encoder_output_size, critic_head_hidden_size), activation, - DiscreteHead( - critic_head_hidden_size, - 1, - critic_head_layer_num, - activation=activation, - norm_type=norm_type - ) - ) - # Convenient for calling some apis (e.g. self.critic.parameters()), - # but may cause misunderstanding when `print(self)` - self.actor = nn.ModuleList([self.actor_encoder, self.actor_head]) - self.critic = nn.ModuleList([self.critic_encoder, self.critic_q_head, self.critic_v_head]) - - def forward(self, inputs: torch.Tensor, mode: str) -> Dict[str, torch.Tensor]: - """ - Overview: - QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \ - ``mode`` will forward with different network modules to get different outputs and save computation. - Arguments: - - inputs (:obj:`torch.Tensor`): The input observation tensor data. - - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. - Returns: - - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph, whose \ - key-values vary in different forward modes. - Examples (Actor): - >>> model = DiscreteQVAC(64, 6) - >>> obs = torch.randn(4, 64) - >>> actor_outputs = model(obs,'compute_actor') - >>> assert actor_outputs['logit'].shape == torch.Size([4, 6]) - - Examples(Critic): - >>> model = DiscreteQVAC(64, 6, twin_critic=False) - >>> obs = torch.randn(4, 64) - >>> actor_outputs = model(obs,'compute_critic') - >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6]) - """ - assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) - return getattr(self, mode)(inputs) - - def compute_actor(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: - """ - Overview: - QVAC forward computation graph for actor part, input observation tensor to predict action or action logit. - Arguments: - - inputs (:obj:`torch.Tensor`): The input observation tensor data. - Returns: - - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph for actor, \ - including discrete action ``logit``. - ReturnsKeys: - - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \ - as ``action_shape``, i.e., all the possible discrete action choices. - Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. - - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ - ``action_shape``. - Examples: - >>> model = DiscreteQVAC(64, 6) - >>> obs = torch.randn(4, 64) - >>> actor_outputs = model(obs,'compute_actor') - >>> assert actor_outputs['logit'].shape == torch.Size([4, 6]) - """ - x = self.actor_encoder(inputs) - x = self.actor_head(x) - return {'logit': x['logit']} - - def compute_critic(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: - """ - Overview: - QVAC forward computation graph for critic part, input observation to predict Q-value for each possible \ - discrete action choices. - Arguments: - - inputs (:obj:`torch.Tensor`): The input observation tensor data. - Returns: - - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph for critic, \ - including ``q_value`` for each possible discrete action choices. - ReturnKeys: - - q_value (:obj:`torch.Tensor`): The predicted Q-value for each possible discrete action choices, it will \ - be the same dimension as ``action_shape`` and used to calculate the loss. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``. - - q_value (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``. - Examples: - >>> model = DiscreteQVAC(64, 6, twin_critic=False) - >>> obs = torch.randn(4, 64) - >>> actor_outputs = model(obs,'compute_critic') - >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6]) - """ - inputs = self.critic_encoder(inputs) - if self.twin_critic: - x = [m(inputs)['logit'] for m in self.critic_q_head] - y = [m(inputs)['logit'] for m in self.critic_v_head] - else: - x = self.critic_q_head(inputs)['logit'] - y = self.critic_v_head(inputs)['logit'] - return {'q_value': x, 'v_value': y} diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index f779cdc316..0a41b111c1 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -322,7 +322,7 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): pass @POLICY_REGISTRY.register('iql_command') -class CQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): +class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): pass @POLICY_REGISTRY.register('discrete_cql_command') diff --git a/ding/policy/iql.py b/ding/policy/iql.py index 2fd5fc4772..62a66a3564 100644 --- a/ding/policy/iql.py +++ b/ding/policy/iql.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from torch.distributions import Normal, Independent, TransformedDistribution -from torch.distributions.transforms import TanhTransform +from torch.distributions.transforms import TanhTransform, AffineTransform from ding.torch_utils import Adam, to_device from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ @@ -31,7 +31,7 @@ class IQLPolicy(Policy): == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= - 1 ``type`` str cql | RL policy register name, refer | this arg is optional, + 1 ``type`` str iql | RL policy register name, refer | this arg is optional, | to registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool True | Whether to use cuda for network | 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for @@ -95,9 +95,11 @@ class IQLPolicy(Policy): # (str type) action_space: Use reparameterization trick for continous action action_space='reparameterization', # (int) Hidden size for actor network head. - actor_head_hidden_size=256, + actor_head_hidden_size=512, + actor_head_layer_num=3, # (int) Hidden size for critic network head. - critic_head_hidden_size=256, + critic_head_hidden_size=512, + critic_head_layer_num=2, ), # learn_mode config learn=dict( @@ -208,8 +210,8 @@ def _init_learn(self) -> None: init_w = self._cfg.learn.init_w self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) - self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) - self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) + # self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) + # self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) if self._twin_critic: self._model.critic_q_head[0][-1].last.weight.data.uniform_(-init_w, init_w) self._model.critic_q_head[0][-1].last.bias.data.uniform_(-init_w, init_w) @@ -245,7 +247,7 @@ def _init_learn(self) -> None: self._tau = self._cfg.learn.tau self._beta = self._cfg.learn.beta - self._policy_start_training_counter=300000 + self._policy_start_training_counter=10000 #300000 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ @@ -259,7 +261,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ dimension by some utility functions such as ``default_preprocess_learn``. \ - For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + For IQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ @@ -300,14 +302,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: with torch.no_grad(): (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] - # dist = Independent(Normal(mu, sigma), 1) - # pred = dist.rsample() - # next_action = torch.tanh(pred) - # y = 1 - next_action.pow(2) + 1e-6 - # next_log_prob = dist.log_prob(pred).unsqueeze(-1) - # next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) - - next_obs_dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1)]) + next_obs_dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)]) next_action = next_obs_dist.rsample() next_log_prob = next_obs_dist.log_prob(next_action) @@ -350,15 +345,9 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: # 6. evaluate to get action distribution (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] - # dist = Independent(Normal(mu, sigma), 1) - # pred = dist.rsample() - # action = torch.tanh(pred) - # y = 1 - action.pow(2) + 1e-6 - # log_prob = dist.log_prob(pred).unsqueeze(-1) - # log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - - dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1)]) - action = dist.rsample() + + dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)]) + action = data['action'] log_prob = dist.log_prob(action) eval_data = {'obs': obs, 'action': action} @@ -370,16 +359,15 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: new_advantage = new_q_value - new_v_value # 8. compute policy loss - policy_loss = (- self._beta * log_prob * torch.exp(new_advantage.detach()).clamp(max=1000.0)).mean() + policy_loss = (- log_prob * torch.exp(new_advantage.detach()/self._beta).clamp(max=20.0)).mean() self._policy_start_training_counter -= 1 - if self._policy_start_training_counter > 0: - policy_loss = policy_loss * 0.0 loss_dict['policy_loss'] = policy_loss # 9. update policy network self._optimizer_policy.zero_grad() - loss_dict['policy_loss'].backward() + policy_loss.backward() + policy_grad_norm = torch.nn.utils.clip_grad_norm_(self._model.actor.parameters(), 1) self._optimizer_policy.step() loss_dict['total_loss'] = sum(loss_dict.values()) @@ -403,6 +391,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 'advantage_max': new_advantage.max().detach().item(), 'new_q_value': new_q_value.detach().mean().item(), 'new_v_value': new_v_value.detach().mean().item(), + 'policy_grad_norm': policy_grad_norm, } def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: @@ -594,7 +583,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() with torch.no_grad(): (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit'] - action = torch.tanh(mu) # deterministic_eval + action = torch.tanh(mu)/1.05 # deterministic_eval output = {'action': action} if self._cuda: output = to_device(output, 'cpu') @@ -625,6 +614,7 @@ def _monitor_vars_learn(self) -> List[str]: 'next_v_value', 'new_q_value', 'new_v_value', + 'policy_grad_norm', ] + twin_critic def _state_dict_learn(self) -> Dict[str, Any]: @@ -655,194 +645,3 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._learn_model.load_state_dict(state_dict['model']) self._optimizer_q.load_state_dict(state_dict['optimizer_q']) self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) - -@POLICY_REGISTRY.register('discrete_iql') -class DiscreteIQLPolicy(Policy): - """ - Overview: - Policy class of discrete Implicit Q-Learning (IQL) algorithm in discrete action space environments. - Paper link: https://arxiv.org/abs/2110.06169. - """ - - config = dict( - # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='discrete_cql', - # (bool) Whether to use cuda for policy. - cuda=False, - # (bool) Whether the RL algorithm is on-policy or off-policy. - on_policy=False, - # (bool) Whether use priority(priority sample, IS weight, update priority) - priority=False, - # (float) Reward's future discount factor, aka. gamma. - discount_factor=0.97, - # (int) N-step reward for target q_value estimation - nstep=1, - # learn_mode config - learn=dict( - # (int) How many updates (iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - update_per_collect=1, - # (int) Minibatch size for one gradient descent. - batch_size=64, - # (float) Learning rate for soft q network. - learning_rate=0.001, - # (int) Frequence of target network update. - target_update_freq=100, - # (bool) Whether ignore done(usually for max step termination env). - ignore_done=False, - # (float) Loss weight for conservative item. - min_q_weight=1.0, - ), - eval=dict(), # for compatibility - ) - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ - automatically call this method to get the default model setting and create model. - - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. - """ - return 'discrete_qvac', ['ding.model.template.qvac'] - - def _init_learn(self) -> None: - """ - Overview: - Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ - contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ - target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. - - .. note:: - For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ - and ``_load_state_dict_learn`` methods. - - .. note:: - For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. - - .. note:: - If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ - with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. - """ - self._min_q_weight = self._cfg.learn.min_q_weight - self._priority = self._cfg.priority - # Optimizer - self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) - - self._gamma = self._cfg.discount_factor - self._nstep = self._cfg.nstep - - # use wrapper instead of plugin - self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.learn.target_update_freq} - ) - self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') - self._learn_model.reset() - self._target_model.reset() - - def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: - """ - Overview: - Policy forward function of learn mode (training policy and updating parameters). Forward means \ - that the policy inputs some training batch data from the offline dataset and then returns the output \ - result, including various training information such as loss, action, priority. - Arguments: - - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ - training samples. For each element in list, the key of the dict is the name of data items and the \ - value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ - combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ - dimension by some utility functions such as ``default_preprocess_learn``. \ - For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ - ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ - and ``value_gamma`` for nstep return computation. - Returns: - - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ - recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ - detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. - - .. note:: - The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ - For the data type that not supported, the main reason is that the corresponding model does not support it. \ - You can implement you own model rather than use the default model. For more information, please raise an \ - issue in GitHub repo and we will continue to follow up. - """ - data = default_preprocess_learn( - data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True - ) - if self._cuda: - data = to_device(data, self._device) - if data['action'].dim() == 2 and data['action'].shape[-1] == 1: - data['action'] = data['action'].squeeze(-1) - # ==================== - # Q-learning forward - # ==================== - self._learn_model.train() - self._target_model.train() - # Current q value (main model) - ret = self._learn_model.forward(data['obs']) - q_value, tau = ret['q'], ret['tau'] - # Target q value - with torch.no_grad(): - target_q_value = self._target_model.forward(data['next_obs'])['q'] - # Max q value action (main model) - target_q_action = self._learn_model.forward(data['next_obs'])['action'] - - # add CQL - # 1. chose action and compute q in dataset. - # 2. compute value loss(negative_sampling - dataset_expec) - replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape) - replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1) - - dataset_expec = replay_chosen_q.mean() - - negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean() - - min_q_loss = negative_sampling - dataset_expec - - data_n = qrdqn_nstep_td_data( - q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight'] - ) - value_gamma = data.get('value_gamma') - loss, td_error_per_sample = qrdqn_nstep_td_error( - data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma - ) - - loss += self._min_q_weight * min_q_loss - - # ==================== - # Q-learning update - # ==================== - self._optimizer.zero_grad() - loss.backward() - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) - self._optimizer.step() - - # ============= - # after update - # ============= - self._target_model.update(self._learn_model.state_dict()) - return { - 'cur_lr': self._optimizer.defaults['lr'], - 'total_loss': loss.item(), - 'priority': td_error_per_sample.abs().tolist(), - 'q_target': target_q_value.mean().item(), - 'q_value': q_value.mean().item(), - # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. - # '[histogram]action_distribution': data['action'], - } - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ - as text logger, tensorboard logger, will use these keys to save the corresponding data. - Returns: - - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. - """ - return ['cur_lr', 'total_loss', 'q_target', 'q_value'] diff --git a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py index cb6c36c541..545ecf970b 100644 --- a/dizoo/d4rl/config/halfcheetah_medium_iql_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_iql_config.py @@ -18,6 +18,7 @@ model=dict( obs_shape=17, action_shape=6, + ), learn=dict( data_path=None, @@ -25,7 +26,7 @@ batch_size=4096, learning_rate_q=3e-4, learning_rate_policy=1e-4, - beta=1.0, + beta=0.05, tau=0.7, ), collect=dict(data_type='d4rl', ), From 56ec93efde5fbeab448064eeb222db3bee35f02a Mon Sep 17 00:00:00 2001 From: zjowowen Date: Mon, 29 Jul 2024 18:11:02 +0800 Subject: [PATCH 3/5] Polish IQL Algorithm --- ding/entry/serial_entry_offline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ding/entry/serial_entry_offline.py b/ding/entry/serial_entry_offline.py index f24e5404fb..b92b5c7dda 100755 --- a/ding/entry/serial_entry_offline.py +++ b/ding/entry/serial_entry_offline.py @@ -62,7 +62,6 @@ def serial_pipeline_offline( sampler=sampler, collate_fn=lambda x: x, pin_memory=cfg.policy.cuda, - drop_last=True, ) # Env, Policy try: From 61d372261bb9f1851f38f5239e996a8ec2927c87 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Mon, 29 Jul 2024 18:14:56 +0800 Subject: [PATCH 4/5] Polish IQL Algorithm --- ding/model/template/qgpo.py_backup | 365 ++++++++++++++++++++ ding/model/template/qvac.py | 2 +- ding/policy/command_mode_policy_instance.py | 2 + ding/policy/iql.py | 21 +- ding/policy/qgpo.py_backup | 172 +++++++++ 5 files changed, 554 insertions(+), 8 deletions(-) create mode 100644 ding/model/template/qgpo.py_backup create mode 100644 ding/policy/qgpo.py_backup diff --git a/ding/model/template/qgpo.py_backup b/ding/model/template/qgpo.py_backup new file mode 100644 index 0000000000..3e0136af5c --- /dev/null +++ b/ding/model/template/qgpo.py_backup @@ -0,0 +1,365 @@ +############################################################# +# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion +############################################################# + +from easydict import EasyDict +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from ding.torch_utils import MLP +from ding.torch_utils.diffusion_SDE import dpm_solver_pytorch +from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder +from ding.torch_utils.network.res_block import TemporalSpatialResBlock + + +def marginal_prob_std(t, device): + """ + Overview: + Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. + """ + t = torch.tensor(t, device=device) + beta_1 = 20.0 + beta_0 = 0.1 + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + alpha_t = torch.exp(log_mean_coeff) + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return alpha_t, std + + +class TwinQ(nn.Module): + + def __init__(self, action_dim, state_dim): + super().__init__() + self.q1 = MLP( + in_channels=state_dim + action_dim, + hidden_channels=256, + out_channels=1, + activation=nn.ReLU(), + layer_num=4, + output_activation=False + ) + self.q2 = MLP( + in_channels=state_dim + action_dim, + hidden_channels=256, + out_channels=1, + activation=nn.ReLU(), + layer_num=4, + output_activation=False + ) + + def both(self, action, condition=None): + as_ = torch.cat([action, condition], -1) if condition is not None else action + return self.q1(as_), self.q2(as_) + + def forward(self, action, condition=None): + return torch.min(*self.both(action, condition)) + + +class GuidanceQt(nn.Module): + + def __init__(self, action_dim, state_dim, time_embed_dim=32): + super().__init__() + self.qt = MLP( + in_channels=action_dim + time_embed_dim + state_dim, + hidden_channels=256, + out_channels=1, + activation=torch.nn.SiLU(), + layer_num=4, + output_activation=False + ) + self.embed = nn.Sequential( + GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim) + ) + + def forward(self, action, t, condition=None): + embed = self.embed(t) + ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1) + return self.qt(ats) + + +class Critic_Guide(nn.Module): + + def __init__(self, adim, sdim) -> None: + super().__init__() + # is sdim is 0 means unconditional guidance + self.conditional_sampling = False if sdim == 0 else True + self.q0 = None + self.qt = None + + def forward(self, a, condition=None): + return self.q0(a, condition) + + def calculate_guidance(self, a, t, condition=None): + raise NotImplementedError + + def calculateQ(self, a, condition=None): + return self(a, condition) + + def update_q0(self, data): + raise NotImplementedError + + def update_qt(self, data): + raise NotImplementedError + + +class QGPOCritic(Critic_Guide): + + def __init__(self, device, cfg, adim, sdim) -> None: + super().__init__(adim, sdim) + # is sdim is 0 means unconditional guidance + assert sdim > 0 + # only apply to conditional sampling here + self.device = device + self.cfg = cfg + self.q0 = TwinQ(adim, sdim).to(self.device) + self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device) + self.qt = GuidanceQt(adim, sdim).to(self.device) + self.qt_update_momentum = 0.005 + self.q_optimizer = torch.optim.Adam(self.q0.parameters(), lr=3e-4) + self.qt_optimizer = torch.optim.Adam(self.qt.parameters(), lr=3e-4) + self.discount = 0.99 + + self.alpha = self.cfg.alpha + self.guidance_scale = 1.0 + + def calculate_guidance(self, a, t, condition=None): + with torch.enable_grad(): + a.requires_grad_(True) + Q_t = self.qt(a, t, condition) + guidance = self.guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0] + return guidance.detach() + + def update_q0(self, data): + s = data["s"] + a = data["a"] + r = data["r"] + s_ = data["s_"] + d = data["d"] + + fake_a = data['fake_a'] + fake_a_ = data['fake_a_'] + with torch.no_grad(): + softmax = nn.Softmax(dim=1) + next_energy = self.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], + axis=1)).detach().squeeze() # + next_v = torch.sum(softmax(self.cfg.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True) + + # Update Q function + targets = r + (1. - d.float()) * self.discount * next_v.detach() + qs = self.q0.both(a, s) + q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) + self.q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + self.q_optimizer.step() + + # Update target + for param, target_param in zip(self.q0.parameters(), self.q0_target.parameters()): + target_param.data.copy_( + self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data + ) + + return q_loss.detach().cpu().numpy() + + def update_qt(self, data): + # input many s anction , + s = data['s'] + a = data['a'] + fake_a = data['fake_a'] + energy = self.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze() + + self.all_mean = torch.mean(energy, dim=-1).detach().cpu().squeeze().numpy() + self.all_std = torch.std(energy, dim=-1).detach().cpu().squeeze().numpy() + + # CEP guidance method, as proposed in the paper + logsoftmax = nn.LogSoftmax(dim=1) + softmax = nn.Softmax(dim=1) + + x0_data_energy = energy * self.alpha + # random_t = torch.rand((fake_a.shape[0], fake_a.shape[1]), device=s.device) * (1. - 1e-3) + 1e-3 + random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3 + random_t = torch.stack([random_t] * fake_a.shape[1], dim=1) + z = torch.randn_like(fake_a) + alpha_t, std = marginal_prob_std(random_t, device=self.device) + perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None] + xt_model_energy = self.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze() + p_label = softmax(x0_data_energy) + self.debug_used = torch.flatten(p_label).detach().cpu().numpy() + # + loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1)) + + self.qt_optimizer.zero_grad(set_to_none=True) + loss.backward() + self.qt_optimizer.step() + + return loss.detach().cpu().numpy() + + +class ScoreBase(nn.Module): + + def __init__(self, device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim=32): + super().__init__() + self.cfg = cfg + self.output_dim = output_dim + self.embed = nn.Sequential( + GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) + ) + self.device = device + self.noise_schedule = dpm_solver_pytorch.NoiseScheduleVP(schedule='linear') + self.dpm_solver = dpm_solver_pytorch.DPM_Solver( + self.forward_dmp_wrapper_fn, self.noise_schedule, predict_x0=True + ) + # self.dpm_solver = dpm_solver_pytorch.DPM_Solver(self.forward_dmp_wrapper_fn, self.noise_schedule) + self.marginal_prob_std = marginal_prob_std + self.q = [] + self.q.append(QGPOCritic(device, cfg.qgpo_critic, adim=output_dim, sdim=input_dim - output_dim)) + + def forward_dmp_wrapper_fn(self, x, t): + score = self(x, t) + result = -(score + self.q[0].calculate_guidance(x, t, self.condition)) * self.marginal_prob_std( + t, device=self.device + )[1][..., None] + return result + + def dpm_wrapper_sample(self, dim, batch_size, **kwargs): + with torch.no_grad(): + init_x = torch.randn(batch_size, dim, device=self.device) + return self.dpm_solver.sample(init_x, **kwargs).cpu().numpy() + + def calculateQ(self, s, a, t=None): + if s is None: + if self.condition.shape[0] == a.shape[0]: + s = self.condition + elif self.condition.shape[0] == 1: + s = torch.cat([self.condition] * a.shape[0]) + else: + assert False + return self.q[0](a, s) + + def forward(self, x, t, condition=None): + raise NotImplementedError + + def select_actions(self, states, diffusion_steps=15): + self.eval() + multiple_input = True + with torch.no_grad(): + states = torch.FloatTensor(states).to(self.device) + if states.dim == 1: + states = states.unsqueeze(0) + multiple_input = False + num_states = states.shape[0] + self.condition = states + results = self.dpm_wrapper_sample( + self.output_dim, batch_size=states.shape[0], steps=diffusion_steps, order=2 + ) + actions = results.reshape(num_states, self.output_dim).copy() # + self.condition = None + out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0] + self.train() + return out_actions + + def sample(self, states, sample_per_state=16, diffusion_steps=15): + self.eval() + num_states = states.shape[0] + with torch.no_grad(): + states = torch.FloatTensor(states).to(self.device) + states = torch.repeat_interleave(states, sample_per_state, dim=0) + self.condition = states + results = self.dpm_wrapper_sample( + self.output_dim, batch_size=states.shape[0], steps=diffusion_steps, order=2 + ) + actions = results[:, :].reshape(num_states, sample_per_state, self.output_dim).copy() + self.condition = None + self.train() + return actions + + +class ScoreNet(ScoreBase): + + def __init__(self, device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim=32): + super().__init__(device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim) + # The swish activation function + self.device = device + self.cfg = cfg + self.act = lambda x: x * torch.sigmoid(x) + self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU()) + self.sort_t = nn.Sequential( + nn.Linear(64, 128), + torch.nn.SiLU(), + nn.Linear(128, 128), + ) + self.down_block1 = TemporalSpatialResBlock(output_dim, 512) + self.down_block2 = TemporalSpatialResBlock(512, 256) + self.down_block3 = TemporalSpatialResBlock(256, 128) + self.middle1 = TemporalSpatialResBlock(128, 128) + self.up_block3 = TemporalSpatialResBlock(256, 256) + self.up_block2 = TemporalSpatialResBlock(512, 512) + self.last = nn.Linear(1024, output_dim) + + def forward(self, x, t, condition=None): + embed = self.embed(t) + + if condition is not None: + embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) + else: + if self.condition.shape[0] == x.shape[0]: + condition = self.condition + elif self.condition.shape[0] == 1: + condition = torch.cat([self.condition] * x.shape[0]) + else: + assert False + embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) + embed = self.sort_t(embed) + d1 = self.down_block1(x, embed) + d2 = self.down_block2(d1, embed) + d3 = self.down_block3(d2, embed) + u3 = self.middle1(d3, embed) + u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed) + u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed) + u0 = torch.cat([d1, u1], dim=-1) + h = self.last(u0) + self.h = h + # Normalize output + return h / self.marginal_prob_std(t, device=self.device)[1][..., None] + + +class QGPO(nn.Module): + + def __init__(self, cfg: EasyDict) -> None: + super(QGPO, self).__init__() + self.cfg = cfg + self.device = cfg.device + self.obs_dim = cfg.obs_dim + self.action_dim = cfg.action_dim + + #marginal_prob_std_fn = functools.partial(marginal_prob_std, device=self.device) + + self.score_model = ScoreNet( + device=self.device, + cfg=cfg.score_net, + input_dim=self.obs_dim + self.action_dim, + output_dim=self.action_dim, + marginal_prob_std=marginal_prob_std, + ) + + def loss_fn(self, x, marginal_prob_std, eps=1e-3): + """ + Overview: + The loss function for training score-based generative models. + Arguments: + model: A PyTorch model instance that represents a \ + time-dependent score-based model. + x: A mini-batch of training data. + marginal_prob_std: A function that gives the standard deviation of \ + the perturbation kernel. + eps: A tolerance value for numerical stability. + """ + random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps + z = torch.randn_like(x) + alpha_t, std = marginal_prob_std(random_t, device=x.device) + perturbed_x = x * alpha_t[:, None] + z * std[:, None] + score = self.score_model(perturbed_x, random_t) + loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, ))) + return loss diff --git a/ding/model/template/qvac.py b/ding/model/template/qvac.py index 4b8c470f83..50cdca8e7d 100644 --- a/ding/model/template/qvac.py +++ b/ding/model/template/qvac.py @@ -34,7 +34,7 @@ def __init__( actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.SiLU(), #nn.ReLU(), + activation: Optional[nn.Module] = nn.SiLU(), #nn.ReLU(), norm_type: Optional[str] = None, encoder_hidden_size_list: Optional[SequenceType] = None, share_encoder: Optional[bool] = False, diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 0a41b111c1..1ed500dbbd 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -321,10 +321,12 @@ class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): pass + @POLICY_REGISTRY.register('iql_command') class IQLCommandModePolicy(IQLPolicy, DummyCommandModePolicy): pass + @POLICY_REGISTRY.register('discrete_cql_command') class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): pass diff --git a/ding/policy/iql.py b/ding/policy/iql.py index 62a66a3564..adc8483891 100644 --- a/ding/policy/iql.py +++ b/ding/policy/iql.py @@ -18,7 +18,7 @@ def asymmetric_l2_loss(u, tau): - return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) + return torch.mean(torch.abs(tau - (u < 0).float()) * u ** 2) @POLICY_REGISTRY.register('iql') @@ -247,7 +247,7 @@ def _init_learn(self) -> None: self._tau = self._cfg.learn.tau self._beta = self._cfg.learn.beta - self._policy_start_training_counter=10000 #300000 + self._policy_start_training_counter = 10000 #300000 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ @@ -296,13 +296,17 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: # 1. predict q and v value value = self._learn_model.forward(data, mode='compute_critic') - q_value, v_value= value['q_value'], value['v_value'] + q_value, v_value = value['q_value'], value['v_value'] # 2. predict target value with torch.no_grad(): (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] - next_obs_dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)]) + next_obs_dist = TransformedDistribution( + Independent(Normal(mu, sigma), 1), + transforms=[TanhTransform(cache_size=1), + AffineTransform(loc=0.0, scale=1.05)] + ) next_action = next_obs_dist.rsample() next_log_prob = next_obs_dist.log_prob(next_action) @@ -346,7 +350,10 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: # 6. evaluate to get action distribution (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] - dist= TransformedDistribution(Independent(Normal(mu, sigma), 1), transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)]) + dist = TransformedDistribution( + Independent(Normal(mu, sigma), 1), + transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)] + ) action = data['action'] log_prob = dist.log_prob(action) @@ -359,7 +366,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: new_advantage = new_q_value - new_v_value # 8. compute policy loss - policy_loss = (- log_prob * torch.exp(new_advantage.detach()/self._beta).clamp(max=20.0)).mean() + policy_loss = (-log_prob * torch.exp(new_advantage.detach() / self._beta).clamp(max=20.0)).mean() self._policy_start_training_counter -= 1 loss_dict['policy_loss'] = policy_loss @@ -583,7 +590,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() with torch.no_grad(): (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit'] - action = torch.tanh(mu)/1.05 # deterministic_eval + action = torch.tanh(mu) / 1.05 # deterministic_eval output = {'action': action} if self._cuda: output = to_device(output, 'cpu') diff --git a/ding/policy/qgpo.py_backup b/ding/policy/qgpo.py_backup new file mode 100644 index 0000000000..d8ea1dc3ad --- /dev/null +++ b/ding/policy/qgpo.py_backup @@ -0,0 +1,172 @@ +############################################################# +# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion +############################################################# + +from typing import List, Dict, Any +import functools +import torch +import numpy as np +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy + +from ding.model.template.qgpo import marginal_prob_std + + +@POLICY_REGISTRY.register('qgpo') +class QGPOPolicy(Policy): + """ + Overview: + Policy class of QGPO algorithm + Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning + https://arxiv.org/abs/2304.12824 + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='qgpo', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool type) on_policy: Determine whether on-policy or off-policy. + # on-policy setting influences the behaviour of buffer. + # Default False in QGPO. + on_policy=False, + multi_agent=False, + model=dict( + score_net=dict( + qgpo_critic=dict( + # (float) The scale of the energy guidance when training qt. + # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a)) + alpha=3, + # (float) The scale of the energy guidance when training q0. + # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a') + # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a)) + q_alpha=1, + ), + ), + device='cuda', + # obs_dim + # action_dim + ), + learn=dict( + # learning rate for behavior model training + learning_rate=1e-4, + # batch size during the training of behavior model + batch_size=4096, + # batch size during the training of q value + batch_size_q=256, + # number of fake action support + M=16, + # number of diffusion time steps + diffusion_steps=15, + # training iterations when behavior model is fixed + behavior_policy_stop_training_iter=600000, + # training iterations when energy-guided policy begin training + energy_guided_policy_begin_training_iter=600000, + # training iterations when q value stop training, default None means no limit + q_value_stop_training_iter=1100000, + ), + eval=dict( + # energy guidance scale for policy in evaluation + # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a)) + guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], + ), + ) + + def _init_learn(self) -> None: + self.cuda = self._cfg.cuda + if self.cuda: + self.margin_prob_std_fn = functools.partial(marginal_prob_std, device=self._device) + self.behavior_model_optimizer = torch.optim.Adam( + self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate + ) + self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter if hasattr( + self._cfg.learn, 'behavior_policy_stop_training_iter' + ) else np.inf + self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter if \ + hasattr(self._cfg.learn, 'energy_guided_policy_begin_training_iter') else 0 + self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter if hasattr( + self._cfg.learn, 'q_value_stop_training_iter' + ) and self._cfg.learn.q_value_stop_training_iter >= 0 else np.inf + + def _forward_learn(self, data: dict) -> Dict[str, Any]: + if self.cuda: + data = {k: d.to(self._device) for k, d in data.items()} + else: + data = {k: d for k, d in data.items()} + s = data['s'] + a = data['a'] + + # training behavior model + if self.behavior_policy_stop_training_iter > 0: + self._model.score_model.condition = s + behavior_model_training_loss = self._model.loss_fn(a, self.margin_prob_std_fn) + self.behavior_model_optimizer.zero_grad() + behavior_model_training_loss.backward() + self.behavior_model_optimizer.step() + self._model.score_model.condition = None + self.behavior_policy_stop_training_iter -= 1 + behavior_model_training_loss = behavior_model_training_loss.detach().cpu().numpy() + else: + behavior_model_training_loss = 0 + + # training Q function + self.energy_guided_policy_begin_training_iter -= 1 + self.q_value_stop_training_iter -= 1 + if self.energy_guided_policy_begin_training_iter < 0: + if self.q_value_stop_training_iter > 0: + q0_loss = self._model.score_model.q[0].update_q0(data) + else: + q0_loss = 0 + qt_loss = self._model.score_model.q[0].update_qt(data) + else: + q0_loss = 0 + qt_loss = 0 + + total_loss = behavior_model_training_loss + q0_loss + qt_loss + + return dict( + total_loss=total_loss, + behavior_model_training_loss=behavior_model_training_loss, + q0_loss=q0_loss, + qt_loss=qt_loss, + ) + + def _init_collect(self) -> None: + pass + + def _forward_collect(self) -> None: + pass + + def _init_eval(self) -> None: + self.guidance_scale = self._cfg.eval.guidance_scale + self.diffusion_steps = self._cfg.eval.diffusion_steps + + def _forward_eval(self, data: dict) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + states = data + actions = self._model.score_model.select_actions(states, diffusion_steps=self.diffusion_steps) + output = actions + + return {i: {"action": d} for i, d in zip(data_id, output)} + + def _get_train_sample(self) -> None: + pass + + def _process_transition(self) -> None: + pass + + def _state_dict_learn(self) -> Dict[str, Any]: + return { + 'model': self._model.state_dict(), + 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + self._model.load_state_dict(state_dict['model']) + self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer']) + + def _monitor_vars_learn(self) -> List[str]: + return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss'] From 63363f481b652a66556b2df23d4533a6438beffc Mon Sep 17 00:00:00 2001 From: zjowowen Date: Mon, 29 Jul 2024 18:16:58 +0800 Subject: [PATCH 5/5] Polish IQL Algorithm --- ding/model/template/qgpo.py_backup | 365 ----------------------------- ding/policy/qgpo.py_backup | 172 -------------- 2 files changed, 537 deletions(-) delete mode 100644 ding/model/template/qgpo.py_backup delete mode 100644 ding/policy/qgpo.py_backup diff --git a/ding/model/template/qgpo.py_backup b/ding/model/template/qgpo.py_backup deleted file mode 100644 index 3e0136af5c..0000000000 --- a/ding/model/template/qgpo.py_backup +++ /dev/null @@ -1,365 +0,0 @@ -############################################################# -# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion -############################################################# - -from easydict import EasyDict -import functools - -import torch -import torch.nn as nn -import torch.nn.functional as F -import copy -from ding.torch_utils import MLP -from ding.torch_utils.diffusion_SDE import dpm_solver_pytorch -from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder -from ding.torch_utils.network.res_block import TemporalSpatialResBlock - - -def marginal_prob_std(t, device): - """ - Overview: - Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. - """ - t = torch.tensor(t, device=device) - beta_1 = 20.0 - beta_0 = 0.1 - log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 - alpha_t = torch.exp(log_mean_coeff) - std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) - return alpha_t, std - - -class TwinQ(nn.Module): - - def __init__(self, action_dim, state_dim): - super().__init__() - self.q1 = MLP( - in_channels=state_dim + action_dim, - hidden_channels=256, - out_channels=1, - activation=nn.ReLU(), - layer_num=4, - output_activation=False - ) - self.q2 = MLP( - in_channels=state_dim + action_dim, - hidden_channels=256, - out_channels=1, - activation=nn.ReLU(), - layer_num=4, - output_activation=False - ) - - def both(self, action, condition=None): - as_ = torch.cat([action, condition], -1) if condition is not None else action - return self.q1(as_), self.q2(as_) - - def forward(self, action, condition=None): - return torch.min(*self.both(action, condition)) - - -class GuidanceQt(nn.Module): - - def __init__(self, action_dim, state_dim, time_embed_dim=32): - super().__init__() - self.qt = MLP( - in_channels=action_dim + time_embed_dim + state_dim, - hidden_channels=256, - out_channels=1, - activation=torch.nn.SiLU(), - layer_num=4, - output_activation=False - ) - self.embed = nn.Sequential( - GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim) - ) - - def forward(self, action, t, condition=None): - embed = self.embed(t) - ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1) - return self.qt(ats) - - -class Critic_Guide(nn.Module): - - def __init__(self, adim, sdim) -> None: - super().__init__() - # is sdim is 0 means unconditional guidance - self.conditional_sampling = False if sdim == 0 else True - self.q0 = None - self.qt = None - - def forward(self, a, condition=None): - return self.q0(a, condition) - - def calculate_guidance(self, a, t, condition=None): - raise NotImplementedError - - def calculateQ(self, a, condition=None): - return self(a, condition) - - def update_q0(self, data): - raise NotImplementedError - - def update_qt(self, data): - raise NotImplementedError - - -class QGPOCritic(Critic_Guide): - - def __init__(self, device, cfg, adim, sdim) -> None: - super().__init__(adim, sdim) - # is sdim is 0 means unconditional guidance - assert sdim > 0 - # only apply to conditional sampling here - self.device = device - self.cfg = cfg - self.q0 = TwinQ(adim, sdim).to(self.device) - self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device) - self.qt = GuidanceQt(adim, sdim).to(self.device) - self.qt_update_momentum = 0.005 - self.q_optimizer = torch.optim.Adam(self.q0.parameters(), lr=3e-4) - self.qt_optimizer = torch.optim.Adam(self.qt.parameters(), lr=3e-4) - self.discount = 0.99 - - self.alpha = self.cfg.alpha - self.guidance_scale = 1.0 - - def calculate_guidance(self, a, t, condition=None): - with torch.enable_grad(): - a.requires_grad_(True) - Q_t = self.qt(a, t, condition) - guidance = self.guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0] - return guidance.detach() - - def update_q0(self, data): - s = data["s"] - a = data["a"] - r = data["r"] - s_ = data["s_"] - d = data["d"] - - fake_a = data['fake_a'] - fake_a_ = data['fake_a_'] - with torch.no_grad(): - softmax = nn.Softmax(dim=1) - next_energy = self.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], - axis=1)).detach().squeeze() # - next_v = torch.sum(softmax(self.cfg.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True) - - # Update Q function - targets = r + (1. - d.float()) * self.discount * next_v.detach() - qs = self.q0.both(a, s) - q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) - self.q_optimizer.zero_grad(set_to_none=True) - q_loss.backward() - self.q_optimizer.step() - - # Update target - for param, target_param in zip(self.q0.parameters(), self.q0_target.parameters()): - target_param.data.copy_( - self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data - ) - - return q_loss.detach().cpu().numpy() - - def update_qt(self, data): - # input many s anction , - s = data['s'] - a = data['a'] - fake_a = data['fake_a'] - energy = self.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze() - - self.all_mean = torch.mean(energy, dim=-1).detach().cpu().squeeze().numpy() - self.all_std = torch.std(energy, dim=-1).detach().cpu().squeeze().numpy() - - # CEP guidance method, as proposed in the paper - logsoftmax = nn.LogSoftmax(dim=1) - softmax = nn.Softmax(dim=1) - - x0_data_energy = energy * self.alpha - # random_t = torch.rand((fake_a.shape[0], fake_a.shape[1]), device=s.device) * (1. - 1e-3) + 1e-3 - random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3 - random_t = torch.stack([random_t] * fake_a.shape[1], dim=1) - z = torch.randn_like(fake_a) - alpha_t, std = marginal_prob_std(random_t, device=self.device) - perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None] - xt_model_energy = self.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze() - p_label = softmax(x0_data_energy) - self.debug_used = torch.flatten(p_label).detach().cpu().numpy() - # - loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1)) - - self.qt_optimizer.zero_grad(set_to_none=True) - loss.backward() - self.qt_optimizer.step() - - return loss.detach().cpu().numpy() - - -class ScoreBase(nn.Module): - - def __init__(self, device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim=32): - super().__init__() - self.cfg = cfg - self.output_dim = output_dim - self.embed = nn.Sequential( - GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) - ) - self.device = device - self.noise_schedule = dpm_solver_pytorch.NoiseScheduleVP(schedule='linear') - self.dpm_solver = dpm_solver_pytorch.DPM_Solver( - self.forward_dmp_wrapper_fn, self.noise_schedule, predict_x0=True - ) - # self.dpm_solver = dpm_solver_pytorch.DPM_Solver(self.forward_dmp_wrapper_fn, self.noise_schedule) - self.marginal_prob_std = marginal_prob_std - self.q = [] - self.q.append(QGPOCritic(device, cfg.qgpo_critic, adim=output_dim, sdim=input_dim - output_dim)) - - def forward_dmp_wrapper_fn(self, x, t): - score = self(x, t) - result = -(score + self.q[0].calculate_guidance(x, t, self.condition)) * self.marginal_prob_std( - t, device=self.device - )[1][..., None] - return result - - def dpm_wrapper_sample(self, dim, batch_size, **kwargs): - with torch.no_grad(): - init_x = torch.randn(batch_size, dim, device=self.device) - return self.dpm_solver.sample(init_x, **kwargs).cpu().numpy() - - def calculateQ(self, s, a, t=None): - if s is None: - if self.condition.shape[0] == a.shape[0]: - s = self.condition - elif self.condition.shape[0] == 1: - s = torch.cat([self.condition] * a.shape[0]) - else: - assert False - return self.q[0](a, s) - - def forward(self, x, t, condition=None): - raise NotImplementedError - - def select_actions(self, states, diffusion_steps=15): - self.eval() - multiple_input = True - with torch.no_grad(): - states = torch.FloatTensor(states).to(self.device) - if states.dim == 1: - states = states.unsqueeze(0) - multiple_input = False - num_states = states.shape[0] - self.condition = states - results = self.dpm_wrapper_sample( - self.output_dim, batch_size=states.shape[0], steps=diffusion_steps, order=2 - ) - actions = results.reshape(num_states, self.output_dim).copy() # - self.condition = None - out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0] - self.train() - return out_actions - - def sample(self, states, sample_per_state=16, diffusion_steps=15): - self.eval() - num_states = states.shape[0] - with torch.no_grad(): - states = torch.FloatTensor(states).to(self.device) - states = torch.repeat_interleave(states, sample_per_state, dim=0) - self.condition = states - results = self.dpm_wrapper_sample( - self.output_dim, batch_size=states.shape[0], steps=diffusion_steps, order=2 - ) - actions = results[:, :].reshape(num_states, sample_per_state, self.output_dim).copy() - self.condition = None - self.train() - return actions - - -class ScoreNet(ScoreBase): - - def __init__(self, device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim=32): - super().__init__(device, cfg, input_dim, output_dim, marginal_prob_std, embed_dim) - # The swish activation function - self.device = device - self.cfg = cfg - self.act = lambda x: x * torch.sigmoid(x) - self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU()) - self.sort_t = nn.Sequential( - nn.Linear(64, 128), - torch.nn.SiLU(), - nn.Linear(128, 128), - ) - self.down_block1 = TemporalSpatialResBlock(output_dim, 512) - self.down_block2 = TemporalSpatialResBlock(512, 256) - self.down_block3 = TemporalSpatialResBlock(256, 128) - self.middle1 = TemporalSpatialResBlock(128, 128) - self.up_block3 = TemporalSpatialResBlock(256, 256) - self.up_block2 = TemporalSpatialResBlock(512, 512) - self.last = nn.Linear(1024, output_dim) - - def forward(self, x, t, condition=None): - embed = self.embed(t) - - if condition is not None: - embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) - else: - if self.condition.shape[0] == x.shape[0]: - condition = self.condition - elif self.condition.shape[0] == 1: - condition = torch.cat([self.condition] * x.shape[0]) - else: - assert False - embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) - embed = self.sort_t(embed) - d1 = self.down_block1(x, embed) - d2 = self.down_block2(d1, embed) - d3 = self.down_block3(d2, embed) - u3 = self.middle1(d3, embed) - u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed) - u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed) - u0 = torch.cat([d1, u1], dim=-1) - h = self.last(u0) - self.h = h - # Normalize output - return h / self.marginal_prob_std(t, device=self.device)[1][..., None] - - -class QGPO(nn.Module): - - def __init__(self, cfg: EasyDict) -> None: - super(QGPO, self).__init__() - self.cfg = cfg - self.device = cfg.device - self.obs_dim = cfg.obs_dim - self.action_dim = cfg.action_dim - - #marginal_prob_std_fn = functools.partial(marginal_prob_std, device=self.device) - - self.score_model = ScoreNet( - device=self.device, - cfg=cfg.score_net, - input_dim=self.obs_dim + self.action_dim, - output_dim=self.action_dim, - marginal_prob_std=marginal_prob_std, - ) - - def loss_fn(self, x, marginal_prob_std, eps=1e-3): - """ - Overview: - The loss function for training score-based generative models. - Arguments: - model: A PyTorch model instance that represents a \ - time-dependent score-based model. - x: A mini-batch of training data. - marginal_prob_std: A function that gives the standard deviation of \ - the perturbation kernel. - eps: A tolerance value for numerical stability. - """ - random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps - z = torch.randn_like(x) - alpha_t, std = marginal_prob_std(random_t, device=x.device) - perturbed_x = x * alpha_t[:, None] + z * std[:, None] - score = self.score_model(perturbed_x, random_t) - loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, ))) - return loss diff --git a/ding/policy/qgpo.py_backup b/ding/policy/qgpo.py_backup deleted file mode 100644 index d8ea1dc3ad..0000000000 --- a/ding/policy/qgpo.py_backup +++ /dev/null @@ -1,172 +0,0 @@ -############################################################# -# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion -############################################################# - -from typing import List, Dict, Any -import functools -import torch -import numpy as np -from ding.torch_utils import to_device -from ding.utils import POLICY_REGISTRY -from ding.utils.data import default_collate, default_decollate -from .base_policy import Policy - -from ding.model.template.qgpo import marginal_prob_std - - -@POLICY_REGISTRY.register('qgpo') -class QGPOPolicy(Policy): - """ - Overview: - Policy class of QGPO algorithm - Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning - https://arxiv.org/abs/2304.12824 - """ - - config = dict( - # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='qgpo', - # (bool) Whether to use cuda for network. - cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. - # on-policy setting influences the behaviour of buffer. - # Default False in QGPO. - on_policy=False, - multi_agent=False, - model=dict( - score_net=dict( - qgpo_critic=dict( - # (float) The scale of the energy guidance when training qt. - # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a)) - alpha=3, - # (float) The scale of the energy guidance when training q0. - # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a') - # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a)) - q_alpha=1, - ), - ), - device='cuda', - # obs_dim - # action_dim - ), - learn=dict( - # learning rate for behavior model training - learning_rate=1e-4, - # batch size during the training of behavior model - batch_size=4096, - # batch size during the training of q value - batch_size_q=256, - # number of fake action support - M=16, - # number of diffusion time steps - diffusion_steps=15, - # training iterations when behavior model is fixed - behavior_policy_stop_training_iter=600000, - # training iterations when energy-guided policy begin training - energy_guided_policy_begin_training_iter=600000, - # training iterations when q value stop training, default None means no limit - q_value_stop_training_iter=1100000, - ), - eval=dict( - # energy guidance scale for policy in evaluation - # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a)) - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - ) - - def _init_learn(self) -> None: - self.cuda = self._cfg.cuda - if self.cuda: - self.margin_prob_std_fn = functools.partial(marginal_prob_std, device=self._device) - self.behavior_model_optimizer = torch.optim.Adam( - self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate - ) - self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter if hasattr( - self._cfg.learn, 'behavior_policy_stop_training_iter' - ) else np.inf - self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter if \ - hasattr(self._cfg.learn, 'energy_guided_policy_begin_training_iter') else 0 - self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter if hasattr( - self._cfg.learn, 'q_value_stop_training_iter' - ) and self._cfg.learn.q_value_stop_training_iter >= 0 else np.inf - - def _forward_learn(self, data: dict) -> Dict[str, Any]: - if self.cuda: - data = {k: d.to(self._device) for k, d in data.items()} - else: - data = {k: d for k, d in data.items()} - s = data['s'] - a = data['a'] - - # training behavior model - if self.behavior_policy_stop_training_iter > 0: - self._model.score_model.condition = s - behavior_model_training_loss = self._model.loss_fn(a, self.margin_prob_std_fn) - self.behavior_model_optimizer.zero_grad() - behavior_model_training_loss.backward() - self.behavior_model_optimizer.step() - self._model.score_model.condition = None - self.behavior_policy_stop_training_iter -= 1 - behavior_model_training_loss = behavior_model_training_loss.detach().cpu().numpy() - else: - behavior_model_training_loss = 0 - - # training Q function - self.energy_guided_policy_begin_training_iter -= 1 - self.q_value_stop_training_iter -= 1 - if self.energy_guided_policy_begin_training_iter < 0: - if self.q_value_stop_training_iter > 0: - q0_loss = self._model.score_model.q[0].update_q0(data) - else: - q0_loss = 0 - qt_loss = self._model.score_model.q[0].update_qt(data) - else: - q0_loss = 0 - qt_loss = 0 - - total_loss = behavior_model_training_loss + q0_loss + qt_loss - - return dict( - total_loss=total_loss, - behavior_model_training_loss=behavior_model_training_loss, - q0_loss=q0_loss, - qt_loss=qt_loss, - ) - - def _init_collect(self) -> None: - pass - - def _forward_collect(self) -> None: - pass - - def _init_eval(self) -> None: - self.guidance_scale = self._cfg.eval.guidance_scale - self.diffusion_steps = self._cfg.eval.diffusion_steps - - def _forward_eval(self, data: dict) -> dict: - data_id = list(data.keys()) - data = default_collate(list(data.values())) - states = data - actions = self._model.score_model.select_actions(states, diffusion_steps=self.diffusion_steps) - output = actions - - return {i: {"action": d} for i, d in zip(data_id, output)} - - def _get_train_sample(self) -> None: - pass - - def _process_transition(self) -> None: - pass - - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._model.state_dict(), - 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._model.load_state_dict(state_dict['model']) - self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer']) - - def _monitor_vars_learn(self) -> List[str]: - return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']