Skip to content

Commit

Permalink
add different Dist unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangpaipai committed May 19, 2023
1 parent 0aad131 commit 1b3f02a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
10 changes: 5 additions & 5 deletions ding/world_model/model/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
@pytest.mark.unittest
@pytest.mark.parametrize('shape, dist', args)
def test_DenseHead(shape, dist):
in_dim, layer_num, units, time, B = 1536, 2, 512, 16, 64
in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64
head = DenseHead(in_dim, shape, layer_num, units, dist=dist)
x = torch.randn(time, B, in_dim)
a = torch.randn(time, B, 1)
x = torch.randn(B, time, in_dim)
a = torch.randn(B, time, 1)
y = head(x)
assert y.mode().shape == (time, B, 1)
assert y.log_prob(a).shape == (time, B)
assert y.mode().shape == (B, time, 1)
assert y.log_prob(a).shape == (B, time)
52 changes: 51 additions & 1 deletion ding/world_model/tests/test_world_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
from easydict import EasyDict
from ding.world_model.utils import get_rollout_length_scheduler
import torch
from torch import distributions as torchd
from itertools import product
from ding.world_model.utils import get_rollout_length_scheduler, SampleDist, OneHotDist, TwoHotDistSymlog, SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init


@pytest.mark.unittest
Expand All @@ -17,3 +20,50 @@ def test_get_rollout_length_scheduler():
assert scheduler(19999) == 1
assert scheduler(150000) == 25
assert scheduler(1500000) == 25


B, time = 16, 64
mean = torch.randn(B, time, 255)
std = 1.0
a = torch.randn(B, time, 1) # or torch.randn(B, time, 255)
sample_shape = torch.Size([])


@pytest.mark.unittest
def test_ContDist():
dist_origin = torchd.normal.Normal(mean, std)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = ContDist(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time)
assert dist_origin.log_prob(a).shape == (B, time, 255)
assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_UnnormalizedHuber():
dist_origin = UnnormalizedHuber(mean, std)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = ContDist(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time)
assert dist_origin.log_prob(a).shape == (B, time, 255)
assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_Bernoulli():
dist_origin = torchd.bernoulli.Bernoulli(logits=mean)
dist = torchd.independent.Independent(dist_origin, 1)
dist_new = Bernoulli(dist)
assert dist_new.mode().shape == (B, time, 255)
assert dist_new.log_prob(a).shape == (B, time, 255)
# to do
# assert dist_new.sample().shape == (B, time, 255)


@pytest.mark.unittest
def test_TwoHotDistSymlog():
dist = TwoHotDistSymlog(logits=mean)
assert dist.mode().shape == (B, time, 1)
assert dist.log_prob(a).shape == (B, time)
4 changes: 2 additions & 2 deletions ding/world_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'):

def mean(self):
print("mean called")
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
_mean = self.probs * self.buckets
return symexp(torch.sum(_mean, dim=-1, keepdim=True))

Check warning on line 106 in ding/world_model/utils.py

View check run for this annotation

Codecov / codecov/patch

ding/world_model/utils.py#L104-L106

Added lines #L104 - L106 were not covered by tests

def mode(self):
_mode = self.probs * self.buckets
Expand Down

0 comments on commit 1b3f02a

Please sign in to comment.