Skip to content

Commit

Permalink
fix(wyh): add test for rl_utils ppo and td (opendilab#89)
Browse files Browse the repository at this point in the history
* fix(wyh):test rl_utils code

* fix(wyh):modify rl utils bug ppo adv batch B,A

* fix(wyh):style

* fix(wyh):fix bug
  • Loading branch information
Weiyuhong-1998 authored Oct 15, 2021
1 parent b1c7300 commit 6d09f79
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 3 deletions.
48 changes: 48 additions & 0 deletions ding/rl_utils/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,51 @@ def test_ppo(use_value_clip, dual_clip, weight):
total_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)
assert isinstance(value_new.grad, torch.Tensor)


@pytest.mark.unittest
def test_mappo():
B, A, N = 4, 8, 32
logit_new = torch.randn(B, A, N).requires_grad_(True)
logit_old = logit_new + torch.rand_like(logit_new) * 0.1
action = torch.randint(0, N, size=(B, A))
value_new = torch.randn(B, A).requires_grad_(True)
value_old = value_new + torch.rand_like(value_new) * 0.1
adv = torch.rand(B, A)
return_ = torch.randn(B, A) * 2
data = ppo_data(logit_new, logit_old, action, value_new, value_old, adv, return_, None)
loss, info = ppo_error(data)
assert all([l.shape == tuple() for l in loss])
assert all([np.isscalar(i) for i in info])
assert logit_new.grad is None
assert value_new.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)
assert isinstance(value_new.grad, torch.Tensor)


@pytest.mark.unittest
@pytest.mark.parametrize('use_value_clip, dual_clip, weight', args)
def test_ppo_error_continous(use_value_clip, dual_clip, weight):
B, N = 4, 6
mu_sigma_new = [torch.randn(B, N).requires_grad_(True), torch.randn(B, N).requires_grad_(True)]
mu_sigma_old = [
mu_sigma_new[0] + torch.rand_like(mu_sigma_new[0]) * 0.1,
mu_sigma_new[1] + torch.rand_like(mu_sigma_new[1]) * 0.1
]
action = torch.rand(B, N)
value_new = torch.randn(B).requires_grad_(True)
value_old = value_new + torch.rand_like(value_new) * 0.1
adv = torch.rand(B)
return_ = torch.randn(B) * 2
data = ppo_data(mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight)
loss, info = ppo_error_continuous(data, use_value_clip=use_value_clip, dual_clip=dual_clip)
assert all([l.shape == tuple() for l in loss])
assert all([np.isscalar(i) for i in info])
assert mu_sigma_new[0].grad is None
assert value_new.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(mu_sigma_new[0].grad, torch.Tensor)
assert isinstance(value_new.grad, torch.Tensor)
184 changes: 181 additions & 3 deletions ding/rl_utils/tests/test_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import torch
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data, \
dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error
dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, q_nstep_sql_td_error, \
iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error
from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale


@pytest.mark.unittest
Expand All @@ -23,7 +25,13 @@ def test_q_nstep_td():
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
print(loss)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True)
value_gamma = torch.tensor(0.9)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma)
loss.backward()
assert isinstance(q.grad, torch.Tensor)


@pytest.mark.unittest
Expand Down Expand Up @@ -84,6 +92,13 @@ def test_dist_nstep_td():
assert dist.grad is None
loss.backward()
assert isinstance(dist.grad, torch.Tensor)
weight = torch.tensor([0.9])
value_gamma = torch.tensor(0.9)
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight)
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma)
assert loss.shape == ()
loss.backward()
assert isinstance(dist.grad, torch.Tensor)


@pytest.mark.unittest
Expand All @@ -106,6 +121,29 @@ def test_q_nstep_td_with_rescale():
print(loss)


@pytest.mark.unittest
def test_qrdqn_nstep_td():
batch_size = 4
action_dim = 3
tau = 3
next_q = torch.randn(batch_size, action_dim, tau)
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim, tau).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, tau, None)
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9))
assert td_error_per_sample.shape == (batch_size, )


@pytest.mark.unittest
def test_dist_1step_compatible():
batch_size = 4
Expand Down Expand Up @@ -153,6 +191,8 @@ def test_v_1step_td():
assert isinstance(v.grad, torch.Tensor)
data = v_1step_td_data(v, next_v, reward, None, None)
loss, td_error_per_sample = v_1step_td_error(data, 0.99)
loss.backward()
assert isinstance(v.grad, torch.Tensor)


@pytest.mark.unittest
Expand All @@ -168,5 +208,143 @@ def test_v_nstep_td():
assert v.grad is None
loss.backward()
assert isinstance(v.grad, torch.Tensor)
data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99)
data = v_nstep_td_data(v, next_v, reward, done, None, 0.99)
loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
loss.backward()
assert isinstance(v.grad, torch.Tensor)


@pytest.mark.unittest
def test_q_nstep_sql_td():
batch_size = 4
action_dim = 3
next_q = torch.randn(batch_size, action_dim)
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 0.5, nstep=nstep, cum_reward=True)
value_gamma = torch.tensor(0.9)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(
data, 0.95, 0.5, nstep=nstep, cum_reward=True, value_gamma=value_gamma
)
loss.backward()
assert isinstance(q.grad, torch.Tensor)


@pytest.mark.unittest
def test_iqn_nstep_td():
batch_size = 4
action_dim = 3
tau = 3
next_q = torch.randn(tau, batch_size, action_dim)
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
for nstep in range(1, 10):
q = torch.randn(tau, batch_size, action_dim).requires_grad_(True)
replay_quantile = torch.randn([tau, batch_size, 1])
reward = torch.rand(nstep, batch_size)
data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None)
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)
assert td_error_per_sample.shape == (batch_size, )
assert loss.shape == ()
assert q.grad is None
loss.backward()
assert isinstance(q.grad, torch.Tensor)
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9))
assert td_error_per_sample.shape == (batch_size, )


@pytest.mark.unittest
def test_shape_fn_qntd():
batch_size = 4
action_dim = 3
next_q = torch.randn(batch_size, action_dim)
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
tmp = shape_fn_qntd([data, 0.95, 1], {})
assert tmp[0] == reward.shape[0]
assert tmp[1] == q.shape[0]
assert tmp[2] == q.shape[1]
tmp = shape_fn_qntd([], {'gamma': 0.95, 'nstep': 1, 'data': data})
assert tmp[0] == reward.shape[0]
assert tmp[1] == q.shape[0]
assert tmp[2] == q.shape[1]


@pytest.mark.unittest
def test_shape_fn_dntd():
batch_size = 4
action_dim = 3
n_atom = 51
v_min = -10.0
v_max = 10.0
nstep = 5
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True)
next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs()
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
reward = torch.randn(nstep, batch_size)
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
tmp = shape_fn_dntd([data, 0.9, v_min, v_max, n_atom, nstep], {})
assert tmp[0] == reward.shape[0]
assert tmp[1] == dist.shape[0]
assert tmp[2] == dist.shape[1]
assert tmp[3] == n_atom
tmp = shape_fn_dntd([], {'data': data, 'gamma': 0.9, 'v_min': v_min, 'v_max': v_max, 'n_atom': n_atom, 'nstep': 5})
assert tmp[0] == reward.shape[0]
assert tmp[1] == dist.shape[0]
assert tmp[2] == dist.shape[1]
assert tmp[3] == n_atom


@pytest.mark.unittest
def test_shape_fn_qntd_rescale():
batch_size = 4
action_dim = 3
next_q = torch.randn(batch_size, action_dim)
done = torch.randn(batch_size)
action = torch.randint(0, action_dim, size=(batch_size, ))
next_action = torch.randint(0, action_dim, size=(batch_size, ))
for nstep in range(1, 10):
q = torch.randn(batch_size, action_dim).requires_grad_(True)
reward = torch.rand(nstep, batch_size)
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
tmp = shape_fn_qntd_rescale([data, 0.95, 1], {})
assert tmp[0] == reward.shape[0]
assert tmp[1] == q.shape[0]
assert tmp[2] == q.shape[1]
tmp = shape_fn_qntd_rescale([], {'gamma': 0.95, 'nstep': 1, 'data': data})
assert tmp[0] == reward.shape[0]
assert tmp[1] == q.shape[0]
assert tmp[2] == q.shape[1]


@pytest.mark.unittest
def test_fn_td_lambda():
T, B = 8, 4
value = torch.randn(T + 1, B).requires_grad_(True)
reward = torch.rand(T, B)
data = td_lambda_data(value, reward, None)
tmp = shape_fn_td_lambda([], {'data': data})
assert tmp == reward.shape[0]
tmp = shape_fn_td_lambda([data], {})
assert tmp == reward.shape

0 comments on commit 6d09f79

Please sign in to comment.