Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wyh): madqn algorithm #540

Merged
merged 16 commits into from
Nov 15, 2022
15 changes: 15 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
from dizoo.petting_zoo.config import ptz_simple_spread_madqn_config, ptz_simple_spread_madqn_create_config # noqa
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main
from dizoo.league_demo.league_demo_ppo_main import main as league_main
Expand Down Expand Up @@ -378,6 +379,20 @@ def test_wqmix():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_madqn():
config = [deepcopy(ptz_simple_spread_madqn_config), deepcopy(ptz_simple_spread_madqn_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_qtran():
Expand Down
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
from .ngu import NGU
from .qac_dist import QACDIST
from .maqac import MAQAC, ContinuousMAQAC
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
54 changes: 54 additions & 0 deletions ding/model/template/madqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch.nn as nn
from ding.utils import MODEL_REGISTRY
from .qmix import QMix


@MODEL_REGISTRY.register('madqn')
Weiyuhong-1998 marked this conversation as resolved.
Show resolved Hide resolved
class MADQN(nn.Module):

def __init__(
self,
agent_num: int,
obs_shape: int,
action_shape: int,
hidden_size_list: list,
global_obs_shape: int = None,
mixer: bool = False,
global_cooperation: bool = True,
lstm_type: str = 'gru',
dueling: bool = False
) -> None:
super(MADQN, self).__init__()
self.current = QMix(
agent_num=agent_num,
obs_shape=obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)
self.global_cooperation = global_cooperation
if self.global_cooperation:
cooperation_obs_shape = global_obs_shape
else:
cooperation_obs_shape = obs_shape
self.cooperation = QMix(
agent_num=agent_num,
obs_shape=cooperation_obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)

def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict:
if cooperation:
if self.global_cooperation:
data['obs']['agent_state'] = data['obs']['global_state']
return self.cooperation(data, single_step=single_step)
else:
return self.current(data, single_step=single_step)
30 changes: 30 additions & 0 deletions ding/model/template/tests/test_madqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch
from ding.torch_utils import is_differentiable
from ding.model.template import MADQN


@pytest.mark.unittest
def test_madqn():
agent_num, bs, T = 4, 3, 8
obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
embedding_dim = 64
madqn_model = MADQN(
agent_num=agent_num,
obs_shape=obs_dim,
action_shape=action_dim,
hidden_size_list=[embedding_dim, embedding_dim],
global_obs_shape=global_obs_dim
)
data = {
'obs': {
'agent_state': torch.randn(T, bs, agent_num, obs_dim),
'global_state': torch.randn(T, bs, agent_num, global_obs_dim),
'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
},
'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
}
output = madqn_model(data, cooperation=True, single_step=False)
assert output['total_q'].shape == (T, bs)
assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .decision_transformer import DTPolicy
from .pdqn import PDQNPolicy
from .sac import SQILSACPolicy
from .madqn import MADQNPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -208,6 +209,11 @@ class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('madqn_command')
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('ddpg_command')
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy):

Expand Down
Loading