Skip to content

Commit

Permalink
feature(cy): add dreamerV3 + MiniGrid code (#725)
Browse files Browse the repository at this point in the history
* support flat obs, discrete action;
add minigrid config

* fix one bug

* fix eval bug

* modify minigrid wrapper

* modify something

* fix onething

* polish & style check

* polish the code

* code polish

* polish code
  • Loading branch information
Cloud-Pku authored Feb 1, 2024
1 parent 6a20ae3 commit 2405639
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 46 deletions.
4 changes: 1 addition & 3 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ class DREAMERVAC(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
dyn_stoch=32,
dyn_deter=512,
Expand All @@ -391,9 +390,8 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(DREAMERVAC, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.obs_shape, self.action_shape = obs_shape, action_shape
self.action_shape = action_shape

if dyn_discrete:
feat_size = dyn_stoch * dyn_discrete + dyn_deter
Expand Down
25 changes: 21 additions & 4 deletions ding/policy/mbpolicy/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N
latent[key][i] *= mask[i]
for i in range(len(action)):
action[i] *= mask[i]

data = data - 0.5
assert world_model.obs_type == 'vector' or world_model.obs_type == 'RGB', \
"action type must be vector or RGB"
# normalize RGB image input
if world_model.obs_type == 'RGB':
data = data - 0.5
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
Expand All @@ -247,11 +250,18 @@ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=N
action = action.detach()

state = (latent, action)
assert world_model.action_type == 'discrete' or world_model.action_type == 'continuous', \
"action type must be continuous or discrete"
if world_model.action_type == 'discrete':
action = torch.where(action == 1)[1]
output = {"action": action, "logprob": logprob, "state": state}

if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
if world_model.action_type == 'discrete':
for l in range(len(output)):
output[l]['action'] = output[l]['action'].squeeze(0)
return {i: d for i, d in zip(data_id, output)}

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
Expand All @@ -272,7 +282,7 @@ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple
# TODO(zp) random_collect just have action
#'logprob': model_output['logprob'],
'reward': timestep.reward,
'discount': timestep.info['discount'],
'discount': 1. - timestep.done, # timestep.info['discount'],
'done': timestep.done,
}
return transition
Expand Down Expand Up @@ -309,7 +319,9 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict
for i in range(len(action)):
action[i] *= mask[i]

data = data - 0.5
# normalize RGB image input
if world_model.obs_type == 'RGB':
data = data - 0.5
embed = world_model.encoder(data)
latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
feat = world_model.dynamics.get_feat(latent)
Expand All @@ -321,11 +333,16 @@ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict
action = action.detach()

state = (latent, action)
if world_model.action_type == 'discrete':
action = torch.where(action == 1)[1]
output = {"action": action, "logprob": logprob, "state": state}

if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
if world_model.action_type == 'discrete':
for l in range(len(output)):
output[l]['action'] = output[l]['action'].squeeze(0)
return {i: d for i, d in zip(data_id, output)}

def _monitor_vars_learn(self) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions ding/torch_utils/network/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def forward(self, features):
elif self._dist == "binary":
return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)))
elif self._dist == "twohot_symlog":
return TwoHotDistSymlog(logits=mean, device=self._device)
return TwoHotDistSymlog(logits=mean, low=-1., high=1., device=self._device)
raise NotImplementedError(self._dist)


Expand Down Expand Up @@ -475,8 +475,8 @@ def log_prob(self, x):
above = torch.clip(above, 0, len(self.buckets) - 1)
equal = (below == above)

dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
dist_to_below = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
Expand Down
78 changes: 53 additions & 25 deletions ding/world_model/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts
from ding.utils.data import default_collate
from ding.model import ConvEncoder
from ding.model import ConvEncoder, FCEncoder
from ding.world_model.base_world_model import WorldModel
from ding.world_model.model.networks import RSSM, ConvDecoder
from ding.torch_utils import to_device
from ding.torch_utils import to_device, one_hot
from ding.torch_utils.network.dreamer import DenseHead


Expand Down Expand Up @@ -37,6 +37,7 @@ class DREAMERWorldModel(WorldModel, nn.Module):
norm='LayerNorm',
grad_heads=['image', 'reward', 'discount'],
units=512,
image_dec_layers=2,
reward_layers=2,
discount_layers=2,
value_layers=2,
Expand Down Expand Up @@ -71,26 +72,33 @@ def __init__(self, cfg, env, tb_logger):
self._cfg.act = nn.modules.activation.SiLU # nn.SiLU
self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm
self.state_size = self._cfg.state_size
self.obs_type = self._cfg.obs_type
self.action_size = self._cfg.action_size
self.action_type = self._cfg.action_type
self.reward_size = self._cfg.reward_size
self.hidden_size = self._cfg.hidden_size
self.batch_size = self._cfg.batch_size
if self.obs_type == 'vector':
self.encoder = FCEncoder(self.state_size, self._cfg.encoder_hidden_size_list, activation=torch.nn.SiLU())
self.embed_size = self._cfg.encoder_hidden_size_list[-1]
elif self.obs_type == 'RGB':
self.encoder = ConvEncoder(
self.state_size,
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
activation=torch.nn.SiLU(),
kernel_size=self._cfg.encoder_kernels,
layer_norm=True
)
self.embed_size = (
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
2 ** (len(self._cfg.encoder_kernels) - 1)
)

self.encoder = ConvEncoder(
self.state_size,
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
activation=torch.nn.SiLU(),
kernel_size=self._cfg.encoder_kernels,
layer_norm=True
)
self.embed_size = (
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
2 ** (len(self._cfg.encoder_kernels) - 1)
)
self.dynamics = RSSM(
self._cfg.dyn_stoch,
self._cfg.dyn_deter,
self._cfg.dyn_hidden,
self._cfg.action_type,
self._cfg.dyn_input_layers,
self._cfg.dyn_output_layers,
self._cfg.dyn_rec_depth,
Expand All @@ -113,14 +121,28 @@ def __init__(self, cfg, env, tb_logger):
feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter
else:
feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter
self.heads["image"] = ConvDecoder(
feat_size, # pytorch version
self._cfg.cnn_depth,
self._cfg.act,
self._cfg.norm,
self.state_size,
self._cfg.decoder_kernels,
)

if isinstance(self.state_size, int):
self.heads['image'] = DenseHead(
feat_size,
(self.state_size, ),
self._cfg.image_dec_layers,
self._cfg.units,
'SiLU', # self._cfg.act
'LN', # self._cfg.norm
dist='binary',
outscale=0.0,
device=self._cfg.device,
)
elif len(self.state_size) == 3:
self.heads["image"] = ConvDecoder(
feat_size, # pytorch version
self._cfg.cnn_depth,
self._cfg.act,
self._cfg.norm,
self.state_size,
self._cfg.decoder_kernels,
)
self.heads["reward"] = DenseHead(
feat_size, # dyn_stoch * dyn_discrete + dyn_deter
(255, ),
Expand Down Expand Up @@ -172,9 +194,15 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])}

data['discount'] = data.get('discount', 1.0 - data['done'].float())
data['discount'] *= 0.997
data['weight'] = data.get('weight', None)
data['image'] = data['obs'] - 0.5
if self.obs_type == 'RGB':
data['image'] = data['obs'] - 0.5
else:
data['image'] = data['obs']
if self.action_type == 'continuous':
data['action'] *= (1.0 / torch.clip(torch.abs(data['action']), min=1.0))
else:
data['action'] = one_hot(data['action'], self.action_size)
data = to_device(data, self._cfg.device)
if len(data['reward'].shape) == 2:
data['reward'] = data['reward'].unsqueeze(-1)
Expand All @@ -185,9 +213,9 @@ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):

self.requires_grad_(requires_grad=True)

image = data['image'].reshape([-1] + list(data['image'].shape[-3:]))
image = data['image'].reshape([-1] + list(data['image'].shape[2:]))
embed = self.encoder(image)
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]])
embed = embed.reshape(list(data['image'].shape[:2]) + [embed.shape[-1]])

post, prior = self.dynamics.observe(embed, data["action"])
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
Expand Down
18 changes: 12 additions & 6 deletions ding/world_model/model/networks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import numpy as np
from typing import Optional, Dict, Union, List

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd

from ding.utils import SequenceType
from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \
OneHotDist, ContDist, SymlogDist, DreamerLayerNorm

Expand All @@ -17,6 +18,7 @@ def __init__(
stoch=30,
deter=200,
hidden=200,
action_type=None,
layers_input=1,
layers_output=1,
rec_depth=1,
Expand All @@ -38,6 +40,7 @@ def __init__(
self._stoch = stoch
self._deter = deter
self._hidden = hidden
self._action_type = action_type
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
Expand Down Expand Up @@ -179,7 +182,8 @@ def get_dist(self, state, dtype=None):
def obs_step(self, prev_state, prev_action, embed, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if self._action_type == 'continuous':
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
Expand All @@ -202,7 +206,8 @@ def obs_step(self, prev_state, prev_action, embed, sample=True):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if self._action_type == 'continuous':
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
Expand Down Expand Up @@ -282,8 +287,9 @@ def kl_loss(self, post, prior, forward, free, lscale, rscale):
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
dist(rhs) if self._discrete else dist(rhs)._dist,
)
loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
# free bits
loss_lhs = torch.mean(torch.clip(value_lhs, min=free))
loss_rhs = torch.mean(torch.clip(value_rhs, min=free))
loss = lscale * loss_lhs + rscale * loss_rhs

return loss, value, loss_lhs, loss_rhs
Expand Down Expand Up @@ -357,7 +363,7 @@ def calc_same_pad(self, k, s, d):
outpad = pad * 2 - val
return pad, outpad

def __call__(self, features, dtype=None):
def __call__(self, features):
x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter]
x = x.reshape([-1, 4, 4, self._embed_size // 16])
x = x.permute(0, 3, 1, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
cuda=cuda,
model=dict(
state_size=(3, 64, 64), # has to be specified
obs_type='RGB',
action_size=1, # has to be specified
action_type='continuous',
reward_size=1,
batch_size=16,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
cuda=cuda,
model=dict(
state_size=(3, 64, 64), # has to be specified
obs_type='RGB',
action_size=6, # has to be specified
action_type='continuous',
reward_size=1,
batch_size=16,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# it is better to put random_collect_size in policy.other
random_collect_size=2500,
model=dict(
obs_shape=(3, 64, 64),
action_shape=6,
actor_dist='normal',
),
Expand Down Expand Up @@ -60,7 +59,9 @@
cuda=cuda,
model=dict(
state_size=(3, 64, 64), # has to be specified
obs_type='RGB',
action_size=6, # has to be specified
action_type='continuous',
reward_size=1,
batch_size=16,
),
Expand Down
Loading

0 comments on commit 2405639

Please sign in to comment.