Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add HAPPO algorithm #717

Merged
merged 28 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
af96dc8
model(rjy): add vac model for HAPPO
Aug 30, 2023
a596c27
test(rjy): polish havac and add test
Sep 15, 2023
d282845
polish(rjy): fix conflict
Sep 15, 2023
43bb9e8
polish(rjy): add hidden_state for ac
Sep 21, 2023
b04ea13
feature(rjy): change the havac to multiagent model
Oct 10, 2023
f5648d0
feature(rjy): add happo forward_learn
Oct 10, 2023
42f4027
Merge branch 'main' into rjy-happo-model
Oct 11, 2023
42faae6
feature(rjy): modify the happo_data
Oct 20, 2023
3319a55
test(rjy): add happo data test
Oct 20, 2023
e3fdb80
feature(rjy): add HAPPO policy
Oct 26, 2023
8d4791d
feature(rjy): try to fit mu-mujoco
Oct 30, 2023
850f831
polish(rjy): Change code to adapt to mujoco
Oct 31, 2023
8e281dc
fix(rjy): fix the distribution in ppo update
Oct 31, 2023
f828553
fix(rjy): fix the happo+mujoco
Nov 3, 2023
70da407
config(rjy): add walker+happo config
Nov 9, 2023
23d1ddb
polish(rjy): separate actors and critics
Dec 27, 2023
ca3daff
polish(rjy): polish according to comments
Dec 27, 2023
e7277b8
polish(rjy): fix the pipeline
Dec 28, 2023
910a8f4
Merge branch 'main' into rjy-happo-model
Dec 29, 2023
b03390b
polish(rjy): fix the style
Dec 29, 2023
d5ace8e
polish(rjy): polish according to comments
Dec 29, 2023
78bffa7
polish(rjy): fix style
Dec 29, 2023
84028d8
polish(rjy): fix style
Dec 29, 2023
b6e7239
polish(rjy): fix style
Dec 29, 2023
8fc9517
polish(rjy): seperate the happo model
Jan 5, 2024
48dcd94
fix(rjy): fix happo model style
Jan 5, 2024
a1bf76f
polish(rjy): polish happo policy comments
Jan 10, 2024
e7e9662
polish(rjy): polish happo comments
Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ class ReparameterizationHead(nn.Module):
``__init__``, ``forward``.
"""

default_sigma_type = ['fixed', 'independent', 'conditioned']
default_sigma_type = ['fixed', 'independent', 'conditioned', 'happo']
nighood marked this conversation as resolved.
Show resolved Hide resolved
default_bound_type = ['tanh', None]

def __init__(
Expand Down Expand Up @@ -1155,6 +1155,11 @@ def __init__(
self.log_sigma_param = nn.Parameter(torch.zeros(1, output_size))
elif self.sigma_type == 'conditioned':
self.log_sigma_layer = nn.Linear(hidden_size, output_size)
elif self.sigma_type == 'happo':
self.sigma_x_coef = 1.
self.sigma_y_coef = 0.5
# This parameter (x_coef, y_coef) refers to the HAPPO paper
self.log_sigma_param = nn.Parameter(torch.ones(1, output_size) * self.sigma_x_coef)

def forward(self, x: torch.Tensor) -> Dict:
"""
Expand Down Expand Up @@ -1190,6 +1195,9 @@ def forward(self, x: torch.Tensor) -> Dict:
elif self.sigma_type == 'conditioned':
log_sigma = self.log_sigma_layer(x)
sigma = torch.exp(torch.clamp(log_sigma, -20, 2))
elif self.sigma_type == 'happo':
log_sigma = self.log_sigma_param + torch.zeros_like(mu)
sigma = torch.sigmoid(log_sigma / self.sigma_x_coef) * self.sigma_y_coef
return {'mu': mu, 'sigma': sigma}


Expand Down
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .bcq import BCQ
from .edac import EDAC
from .ebm import EBM, AutoregressiveEBM
from .havac import HAVAC
500 changes: 500 additions & 0 deletions ding/model/template/havac.py

Large diffs are not rendered by default.

102 changes: 102 additions & 0 deletions ding/model/template/tests/test_havac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest
import torch
import random
from ding.torch_utils import is_differentiable
from ding.model.template import HAVAC


@pytest.mark.unittest
class TestHAVAC:

def test_havac_rnn_actor(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'actor_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_actor')
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state'])
assert output['logit'].shape == (T, bs, action_dim)
assert len(output['actor_next_state']) == bs
print(output['actor_next_state'][0]['h'].shape)
loss = output['logit'].sum()
is_differentiable(loss, model.agent_models[agent_idx].actor)

def test_havac_rnn_critic(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'critic_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_critic')
assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state'])
assert output['value'].shape == (T, bs)
assert len(output['critic_next_state']) == bs
print(output['critic_next_state'][0]['h'].shape)
loss = output['value'].sum()
is_differentiable(loss, model.agent_models[agent_idx].critic)

def test_havac_rnn_actor_critic(self):
# discrete+rnn
bs, T = 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
agent_num = 5
data = {
'obs': {
'agent_state': torch.randn(T, bs, obs_dim),
'global_state': torch.randn(T, bs, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
},
'actor_prev_state': [None for _ in range(bs)],
'critic_prev_state': [None for _ in range(bs)],
}
model = HAVAC(
agent_obs_shape=obs_dim,
global_obs_shape=global_obs_dim,
action_shape=action_dim,
agent_num=agent_num,
use_lstm=True,
)
agent_idx = random.randint(0, agent_num - 1)
output = model(agent_idx, data, mode='compute_actor_critic')
assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state',
'value', 'critic_next_state', 'critic_hidden_state'])
assert output['logit'].shape == (T, bs, action_dim)
assert output['value'].shape == (T, bs)
loss = output['logit'].sum() + output['value'].sum()
is_differentiable(loss, model.agent_models[agent_idx])


# test_havac_rnn_actor()
# test_havac_rnn_critic()
# test_havac_rnn_actor_critic()
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@
# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
from .happo import HAPPOPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -186,6 +187,11 @@ class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('happo_command')
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('ppo_stdim_command')
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy):
pass
Expand Down
Loading
Loading