Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wyh): madqn algorithm #540

Merged
merged 16 commits into from
Nov 15, 2022
54 changes: 54 additions & 0 deletions ding/model/template/madqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch.nn as nn
from ding.utils import MODEL_REGISTRY
from .qmix import QMix


@MODEL_REGISTRY.register('madqn')
Weiyuhong-1998 marked this conversation as resolved.
Show resolved Hide resolved
class MADQN(nn.Module):

def __init__(
self,
agent_num: int,
obs_shape: int,
action_shape: int,
hidden_size_list: list,
global_obs_shape: int = None,
mixer: bool = False,
global_boost: bool = True,
lstm_type: str = 'gru',
dueling: bool = False
) -> None:
super(MADQN, self).__init__()
self.current = QMix(
agent_num=agent_num,
obs_shape=obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)
self.global_boost = global_boost
if self.global_boost:
boost_obs_shape = global_obs_shape
else:
boost_obs_shape = obs_shape
self.boost = QMix(
agent_num=agent_num,
obs_shape=boost_obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)

def forward(self, data: dict, boost: bool = False, single_step: bool = True) -> dict:
if boost:
if self.global_boost:
data['obs']['agent_state'] = data['obs']['global_state']
return self.boost(data, single_step=single_step)
Weiyuhong-1998 marked this conversation as resolved.
Show resolved Hide resolved
else:
return self.current(data, single_step=single_step)
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .decision_transformer import DTPolicy
from .pdqn import PDQNPolicy
from .sac import SQILSACPolicy
from .madqn import MADQNPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -208,6 +209,11 @@ class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy):
pass


@POLICY_REGISTRY.register('madqn_command')
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy):
pass


@POLICY_REGISTRY.register('ddpg_command')
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy):

Expand Down
Loading