Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(lwq): polish VAE #404

Merged
merged 2 commits into from
Jun 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 73 additions & 49 deletions ding/model/template/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,14 @@
from ding.utils.type_helper import Tensor


class BaseVAE(nn.Module):

def __init__(self) -> None:
super(BaseVAE, self).__init__()

def encode(self, input: Tensor) -> List[Tensor]:
raise NotImplementedError

def decode(self, input: Tensor, obs_encoding: Optional[Tensor]) -> Any:
raise NotImplementedError

def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
raise RuntimeWarning()

def generate(self, x: Tensor, **kwargs) -> Tensor:
raise NotImplementedError

@abstractmethod
def forward(self, *inputs: Tensor) -> Tensor:
pass

@abstractmethod
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
pass


class VanillaVAE(BaseVAE):
class VanillaVAE(nn.Module):
"""
Overview:
Implementation of Vanilla variational autoencoder for action reconstruction.
Interfaces:
``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \
``forward``, ``loss_function`` .
"""

def __init__(
self,
Expand All @@ -59,32 +40,39 @@ def __init__(
self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size)

# Build Decoder
# self.condition_obs = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[-1]), nn.ReLU())
self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU())
self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU())
# TODO(pu): tanh
self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh())
# self.decode_reconst_action_head = nn.Linear(hidden_dims[0], self.action_shape)

# residual prediction
self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU())
self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape)

self.obs_encoding = None

def encode(self, input) -> Dict[str, Any]:
def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder
:return: (Tensor) List of latent codes
Overview:
Encodes the input by passing through the encoder network and returns the latent codes.
Arguments:
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \
`action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \
representing latent codes.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
- action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``.
"""
action_encoding = self.encode_action_head(input['action'])
obs_encoding = self.encode_obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
# input = torch.cat([obs_encoding, action_encoding], dim=-1)
# input = obs_encoding + action_encoding # TODO(pu): what about add, cat?
input = obs_encoding * action_encoding
input = obs_encoding * action_encoding # TODO(pu): what about add, cat?
result = self.encode_common(input)

# Split the result into mu and var components
Expand All @@ -111,8 +99,6 @@ def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]:
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``
"""
action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded
# action_decoding = self.decode_action_head(z) # NOTE: tanh, here z is not bounded
# action_obs_decoding = action_decoding + obs_encoding # TODO(pu): what about add, cat?
action_obs_decoding = action_decoding * obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)

Expand All @@ -132,16 +118,15 @@ def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]:
Returns:
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
ReturnsKeys:
- reconstruction_action (:obj:`torch.Tensor`): reconstruction_action.
- predition_residual (:obj:`torch.Tensor`): predition_residual.
- reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE .
- predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE.
Shapes:
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape``
"""
obs_encoding = self.encode_obs_head(obs)
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
action_decoding = self.decode_action_head(z)
# action_obs_decoding = action_decoding + obs_encoding # TODO(pu): what about add, cat?
action_obs_decoding = action_decoding * obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
Expand All @@ -165,7 +150,28 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> dict:
def forward(self, input: Dict[str, Tensor], **kwargs) -> dict:
"""
Overview:
Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`.
Argumens:
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \
and `action` (:obj:`torch.Tensor`), representing the observation \
and agent's action respectively.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
Shapes:
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \
where B is batch size and O is ``observation dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
"""

encode_output = self.encode(input)
z = self.reparameterize(encode_output['mu'], encode_output['log_var'])
decode_output = self.decode(z, encode_output['obs_encoding'])
Expand All @@ -178,13 +184,31 @@ def forward(self, input: Tensor, **kwargs) -> dict:
'z': z
}

def loss_function(self, args, **kwargs) -> dict:
def loss_function(self, args: Dict[str, Tensor], **kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
Overview:
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
Arguments:
- args (:obj:`Dict`): Dict containing keywords `recons_action` (:obj:`torch.Tensor`) \
and `prediction_residual` (:obj:`torch.Tensor`), `original_action` (:obj:`torch.Tensor`), \
`mu` (:obj:`torch.Tensor`), `log_var` (:obj:`torch.Tensor`) and \
`true_residual` (:obj:`torch.Tensor`).
- kwargs (:obj:`Dict`): Dict containing keywords `kld_weight` (:obj:`torch.Tensor`) \
and `predict_weight` (:obj:`torch.Tensor`).
Returns:
- outputs (:obj: `Dict`): Dict containing keywords `loss` \
(`obj`:`torch.Tensor`), `reconstruction_loss` (:obj: `torch.Tensor`), \
`kld_loss` (:obj: `torch.Tensor`) and `predict_loss` (:obj: `torch.Tensor`).
Shapes:
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \
and A is ``action dim``.
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \
and O is ``observation dim``.
- original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
"""
recons_action = args['recons_action']
prediction_residual = args['prediction_residual']
Expand Down