Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test #590

Merged
merged 9 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions examples/cim/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple

import torch
from torch.optim import RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand All @@ -23,32 +24,62 @@


class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
def __init__(
self,
state_dim: int,
action_num: int,
dueling_param: Optional[Tuple[dict, dict]] = None,
) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)

self._use_dueling = dueling_param is not None
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
if self._use_dueling:
q_kwargs, v_kwargs = dueling_param
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)

self._optim = RMSprop(self.parameters(), lr=learning_rate)

def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)
logits = self._fc(states)
if self._use_dueling:
q = self._q(logits)
v = self._v(logits)
logits = q - q.mean(dim=1, keepdim=True) + v
return logits


def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
q_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
v_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": None,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}

return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
},
),
],
q_net=MyQNet(
state_dim,
action_num,
dueling_param=(q_kwargs, v_kwargs),
),
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand All @@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
num_epochs=10,
soft_update_coef=0.1,
double=False,
random_overwrite=False,
alpha=1.0,
beta=1.0,
),
)
2 changes: 1 addition & 1 deletion examples/cim/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@

action_num = len(action_shaping_conf["action_space"])

algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg
16 changes: 2 additions & 14 deletions examples/vm_scheduling/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand Down Expand Up @@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
},
),
],
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand Down
12 changes: 4 additions & 8 deletions maro/rl/exploration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration

__all__ = [
"AbsExplorationScheduler",
"LinearExplorationScheduler",
"MultiLinearExplorationScheduler",
"epsilon_greedy",
"gaussian_noise",
"uniform_noise",
"ExploreStrategy",
"EpsilonGreedy",
"LinearExploration",
]
127 changes: 0 additions & 127 deletions maro/rl/exploration/scheduling.py

This file was deleted.

Loading