Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zp): add dreamerv3 algorithm #652

Merged
merged 17 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions ding/model/common/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
kernel_size: SequenceType = [8, 4, 3],
stride: SequenceType = [4, 2, 1],
padding: Optional[SequenceType] = None,
layer_norm: Optional[bool] = False,
norm_type: Optional[str] = None
) -> None:
"""
Expand All @@ -50,6 +51,7 @@ def __init__(
- stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
- padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
See ``nn.Conv2d`` for more details. Default is ``None``.
- layer_norm (:obj:`bool`): Whether to use ``DreamerLayerNorm``.
- norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResBlock`` \
for more details. Default is ``None``.
"""
Expand All @@ -63,17 +65,35 @@ def __init__(
layers = []
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
if layer_norm:
layers.append(
Conv2dSame(
in_channels=input_size,
out_channels=hidden_size_list[i],
kernel_size=(kernel_size[i], kernel_size[i]),
stride=(2, 2),
bias=False,
)
)
layers.append(DreamerLayerNorm(hidden_size_list[i]))
layers.append(self.act)
else:
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(3, len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type))
if len(self.hidden_size_list) >= len(kernel_size) + 2:
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
assert self.hidden_size_list[len(kernel_size) - 1] == self.hidden_size_list[
len(kernel_size)], "Please indicate the same hidden size between conv and res block"
assert len(
set(hidden_size_list[len(kernel_size):-1])
) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(len(kernel_size), len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i - 1], activation=self.act, norm_type=norm_type))
layers.append(Flatten())
self.main = nn.Sequential(*layers)

flatten_size = self._get_flatten_size()
self.output_size = hidden_size_list[-1]
self.output_size = hidden_size_list[-1] # outside to use
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
self.mid = nn.Linear(flatten_size, hidden_size_list[-1])

def _get_flatten_size(self) -> int:
Expand Down Expand Up @@ -306,3 +326,41 @@ def forward(self, x):
if self.final_relu:
x = torch.relu(x)
return x


class Conv2dSame(torch.nn.Conv2d):

def calc_same_pad(self, i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add data type for i,k,s,d
Add function notation


def forward(self, x):
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])

if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])

ret = F.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return ret


class DreamerLayerNorm(nn.Module):
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, ch, eps=1e-03):
super(DreamerLayerNorm, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add data type

self.norm = torch.nn.LayerNorm(ch, eps=eps)

def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x
18 changes: 18 additions & 0 deletions ding/model/common/tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def test_conv_encoder(self):
self.output_check(model, outputs)
assert outputs.shape == (B, 128)

def test_dreamer_conv_encoder(self):
inputs = torch.randn(B, C, H, W)
model = ConvEncoder(
(C, H, W),
hidden_size_list=[32, 64, 128, 256, 128],
activation=torch.nn.SiLU(),
kernel_size=[4, 4, 4, 4],
layer_norm=True
)
print(model)
outputs = model(inputs)
self.output_check(model, outputs)
assert outputs.shape == (B, 128)

def test_fc_encoder(self):
inputs = torch.randn(B, 32)
hidden_size_list = [128 for _ in range(3)]
Expand All @@ -47,3 +61,7 @@ def test_impalaconv_encoder(self):
outputs = model(inputs)
self.output_check(model, outputs)
assert outputs.shape == (B, 256)


a = TestEncoder()
a.test_dreamer_conv_encoder()
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
72 changes: 72 additions & 0 deletions ding/world_model/model/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import math
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not use torchd as short module name. It's not a standard coding format.
It's better to from torch.distributions import XXXXX.
Import only what you need.


from ding.world_model.utils import weight_init, uniform_weight_init, ContDist, Bernoulli, TwoHotDistSymlog, UnnormalizedHuber
from ding.torch_utils import MLP, fc_block


class DenseHead(nn.Module):
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
inp_dim, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
shape, # (255,)
layer_num,
units, # 512
act='SiLU',
norm='LN',
dist='normal',
std=1.0,
outscale=1.0,
):
super(DenseHead, self).__init__()
self._shape = (shape, ) if isinstance(shape, int) else shape
if len(self._shape) == 0:
self._shape = (1, )
self._layer_num = layer_num
self._units = units
self._act = getattr(torch.nn, act)()
self._norm = norm
self._dist = dist
self._std = std

self.mlp = MLP(
inp_dim,
self._units,
self._units,
self._layer_num,
layer_fn=nn.Linear,
activation=self._act,
norm_type=self._norm
)
self.mlp.apply(weight_init)

self.mean_layer = nn.Linear(self._units, np.prod(self._shape))
self.mean_layer.apply(uniform_weight_init(outscale))

if self._std == "learned":
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer.apply(uniform_weight_init(outscale))

Check warning on line 54 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L53-L54

Added lines #L53 - L54 were not covered by tests

def forward(self, features, dtype=None):
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
x = features
out = self.mlp(x) # (batch, time, _units=512)
mean = self.mean_layer(out) # (batch, time, 255)
if self._std == "learned":
std = self.std_layer(out)

Check warning on line 61 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L61

Added line #L61 was not covered by tests
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
else:
std = self._std
if self._dist == "normal":
return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape)))

Check warning on line 65 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L65

Added line #L65 was not covered by tests
if self._dist == "huber":
zhangpaipai marked this conversation as resolved.
Show resolved Hide resolved
return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape)))

Check warning on line 67 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L67

Added line #L67 was not covered by tests
if self._dist == "binary":
return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)))

Check warning on line 69 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L69

Added line #L69 was not covered by tests
if self._dist == "twohot_symlog":
return TwoHotDistSymlog(logits=mean)
raise NotImplementedError(self._dist)

Check warning on line 72 in ding/world_model/model/networks.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/model/networks.py#L72

Added line #L72 was not covered by tests
23 changes: 23 additions & 0 deletions ding/world_model/model/tests/test_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
import torch
from itertools import product
from ding.world_model.model.networks import DenseHead

# arguments
shape = [255, (255, ), ()]
# to do
# dist = ['normal', 'huber', 'binary', 'twohot_symlog']
dist = ['twohot_symlog']
args = list(product(*[shape, dist]))


@pytest.mark.unittest
@pytest.mark.parametrize('shape, dist', args)
def test_DenseHead(shape, dist):
in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64
head = DenseHead(in_dim, shape, layer_num, units, dist=dist)
x = torch.randn(B, time, in_dim)
a = torch.randn(B, time, 1)
y = head(x)
assert y.mode().shape == (B, time, 1)
assert y.log_prob(a).shape == (B, time)
52 changes: 51 additions & 1 deletion ding/world_model/tests/test_world_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
from easydict import EasyDict
from ding.world_model.utils import get_rollout_length_scheduler
import torch
from torch import distributions as torchd
from itertools import product
from ding.world_model.utils import get_rollout_length_scheduler, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init


@pytest.mark.unittest
Expand All @@ -17,3 +20,50 @@ def test_get_rollout_length_scheduler():
assert scheduler(19999) == 1
assert scheduler(150000) == 25
assert scheduler(1500000) == 25


B, time = 16, 64
mean = torch.randn(B, time, 255)
std = 1.0
a = torch.randn(B, time, 1) # or torch.randn(B, time, 255)
sample_shape = torch.Size([])


@pytest.mark.unittest
def test_ContDist():
dist_origin = torchd.normal.Normal(mean, std)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = ContDist(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time)
assert dist_origin.log_prob(a).shape == (B, time, 255)
assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_UnnormalizedHuber():
dist_origin = UnnormalizedHuber(mean, std)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = ContDist(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time)
assert dist_origin.log_prob(a).shape == (B, time, 255)
assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_Bernoulli():
dist_origin = torchd.bernoulli.Bernoulli(logits=mean)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = Bernoulli(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time, 255)
# to do
# assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_TwoHotDistSymlog():
dist = TwoHotDistSymlog(logits=mean)
assert dist.mode().shape == (B, time, 1)
assert dist.log_prob(a).shape == (B, time)
Loading