Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rl v3 example update #461

Merged
merged 8 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions examples/rl/cim/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from torch.optim import Adam, RMSprop

from maro.rl_v3.model import DiscretePolicyNet, FullyConnected, VNet
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.training.algorithms import DiscreteActorCritic, DiscreteActorCriticParams
from maro.rl.model import DiscretePolicyNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteActorCritic, DiscreteActorCriticParams


actor_net_conf = {
Expand All @@ -26,15 +26,15 @@
"batch_norm": True,
"head": True
}
actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})
actor_learning_rate = 0.001
critic_learning_rate = 0.001


class MyActorNet(DiscretePolicyNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
self._actor_optim = actor_optim_conf[0](self._actor.parameters(), **actor_optim_conf[1])
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)

def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._actor(states)
Expand All @@ -45,55 +45,65 @@ def freeze(self) -> None:
def unfreeze(self) -> None:
self.unfreeze_all_parameters()

def step(self, loss: torch.Tensor) -> None:
self._optim.zero_grad()
loss.backward()
self._optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
self._actor_optim.zero_grad()
self._optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}

def apply_gradients(self, grad: dict) -> None:
for name, param in self.named_parameters():
param.grad = grad[name]
self._actor_optim.step()
self._optim.step()

def get_net_state(self) -> dict:
return {
"network": self.state_dict(),
"actor_optim": self._actor_optim.state_dict()
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._actor_optim.load_state_dict(net_state["actor_optim"])
self._optim.load_state_dict(net_state["optim"])


class MyCriticNet(VNet):
def __init__(self, state_dim: int) -> None:
super(MyCriticNet, self).__init__(state_dim=state_dim)
self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf)
self._critic_optim = critic_optim_conf[0](self._critic.parameters(), **critic_optim_conf[1])
self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate)

def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
return self._critic(states).squeeze(-1)

def step(self, loss: torch.Tensor) -> None:
self._optim.zero_grad()
loss.backward()
self._optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
self._critic_optim.zero_grad()
self._optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}

def apply_gradients(self, grad: dict) -> None:
for name, param in self.named_parameters():
param.grad = grad[name]
self._critic_optim.step()
self._optim.step()

def get_net_state(self) -> dict:
return {
"network": self.state_dict(),
"critic_optim": self._critic_optim.state_dict()
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._critic_optim.load_state_dict(net_state["critic_optim"])
self._optim.load_state_dict(net_state["optim"])

def freeze(self) -> None:
self.freeze_all_parameters()
Expand All @@ -102,11 +112,11 @@ def unfreeze(self) -> None:
self.unfreeze_all_parameters()


def get_discrete_policy_gradient(name: str, *, state_dim: int, action_num: int) -> DiscretePolicyGradient:
def get_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_ac(name: str, *, state_dim: int) -> DiscreteActorCritic:
def get_ac(state_dim: int, name: str) -> DiscreteActorCritic:
return DiscreteActorCritic(
name=name,
params=DiscreteActorCriticParams(
Expand All @@ -116,8 +126,7 @@ def get_ac(name: str, *, state_dim: int) -> DiscreteActorCritic:
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=.0,
data_parallelism=2
lam=.0
)
)

Expand Down
23 changes: 14 additions & 9 deletions examples/rl/cim/algorithms/dqn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import time
from typing import Dict

import torch
from torch.optim import RMSprop

from maro.rl_v3.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl_v3.model import DiscreteQNet, FullyConnected
from maro.rl_v3.policy import ValueBasedPolicy
from maro.rl_v3.training.algorithms import DQN, DQNParams
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQN, DQNParams


q_net_conf = {
Expand All @@ -21,18 +22,23 @@
"head": True,
"dropout_p": 0.0
}
q_net_optim_conf = (RMSprop, {"lr": 0.05})
learning_rate = 0.05


class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = q_net_optim_conf[0](self._fc.parameters(), **q_net_optim_conf[1])
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)

def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)

def step(self, loss: torch.Tensor) -> None:
self._optim.zero_grad()
loss.backward()
self._optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
self._optim.zero_grad()
loss.backward()
Expand All @@ -58,7 +64,7 @@ def unfreeze(self) -> None:
self.unfreeze_all_parameters()


def get_value_based_policy(name: str, *, state_dim: int, action_num: int) -> ValueBasedPolicy:
def get_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
Expand Down Expand Up @@ -87,7 +93,6 @@ def get_dqn(name: str) -> DQN:
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
data_parallelism=2
batch_size=32
)
)
46 changes: 28 additions & 18 deletions examples/rl/cim/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
from torch.optim import Adam, RMSprop

from maro.rl_v3.model import DiscretePolicyNet, FullyConnected, MultiQNet
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.training.algorithms import DiscreteMADDPG, DiscreteMADDPGParams
from maro.rl.model import DiscretePolicyNet, FullyConnected, MultiQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteMADDPG, DiscreteMADDPGParams


actor_net_conf = {
Expand All @@ -27,16 +27,16 @@
"batch_norm": True,
"head": True
}
actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})
actor_learning_rate = 0.001
critic_learning_rate = 0.001


# #####################################################################################################################
class MyActorNet(DiscretePolicyNet):
def __init__(self, state_dim: int, action_num: int) -> None:
super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf)
self._actor_optim = actor_optim_conf[0](self._actor.parameters(), **actor_optim_conf[1])
self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate)

def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
return self._actor(states)
Expand All @@ -47,55 +47,65 @@ def freeze(self) -> None:
def unfreeze(self) -> None:
self.unfreeze_all_parameters()

def step(self, loss: torch.Tensor) -> None:
self._optim.zero_grad()
loss.backward()
self._optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
self._actor_optim.zero_grad()
self._optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}

def apply_gradients(self, grad: dict) -> None:
for name, param in self.named_parameters():
param.grad = grad[name]
self._actor_optim.step()
self._optim.step()

def get_net_state(self) -> dict:
return {
"network": self.state_dict(),
"actor_optim": self._actor_optim.state_dict()
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._actor_optim.load_state_dict(net_state["actor_optim"])
self._optim.load_state_dict(net_state["optim"])


class MyMultiCriticNet(MultiQNet):
def __init__(self, state_dim: int, action_dims: List[int]) -> None:
super(MyMultiCriticNet, self).__init__(state_dim=state_dim, action_dims=action_dims)
self._critic = FullyConnected(input_dim=state_dim + sum(action_dims), **critic_net_conf)
self._critic_optim = critic_optim_conf[0](self._critic.parameters(), **critic_optim_conf[1])
self._optim = RMSprop(self._critic.parameters(), critic_learning_rate)

def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor:
return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1)

def step(self, loss: torch.Tensor) -> None:
self._optim.zero_grad()
loss.backward()
self._optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
self._critic_optim.zero_grad()
self._optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}

def apply_gradients(self, grad: dict) -> None:
for name, param in self.named_parameters():
param.grad = grad[name]
self._critic_optim.step()
self._optim.step()

def get_net_state(self) -> dict:
return {
"network": self.state_dict(),
"critic_optim": self._critic_optim.state_dict()
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._critic_optim.load_state_dict(net_state["critic_optim"])
self._optim.load_state_dict(net_state["optim"])

def freeze(self) -> None:
self.freeze_all_parameters()
Expand All @@ -108,18 +118,18 @@ def get_multi_critic_net(state_dim: int, action_dims: List[int]) -> MyMultiCriti
return MyMultiCriticNet(state_dim, action_dims)


def get_discrete_policy_gradient(name: str, *, state_dim: int, action_num: int) -> DiscretePolicyGradient:
def get_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient:
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_maddpg(name: str, *, state_dim: int, action_dims: List[int]) -> DiscreteMADDPG:
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPG:
return DiscreteMADDPG(
name=name,
params=DiscreteMADDPGParams(
device="cpu",
reward_discount=.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
# shared_critic=True,
shared_critic=False
)
)
2 changes: 1 addition & 1 deletion examples/rl/cim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
+ len(vessel_attributes)
)

algorithm = "dqn" # ac, dqn or discrete_maddpg
algorithm = "ac" # ac, dqn or discrete_maddpg
14 changes: 6 additions & 8 deletions examples/rl/cim/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import numpy as np

from maro.rl_v3.policy import RLPolicy
from maro.rl_v3.rollout import AbsEnvSampler, CacheElement, SimpleAgentWrapper
from maro.rl.policy import RLPolicy
from maro.rl.rollout import AbsEnvSampler, CacheElement, SimpleAgentWrapper
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent

Expand Down Expand Up @@ -57,7 +57,7 @@ def _translate_to_env_action(self, action_dict: Dict[Any, np.ndarray], event: De

return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}

def _get_reward(self, env_action_dict: Dict[Any, object], tick: int) -> Dict[Any, float]:
def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]:
start_tick = tick + 1
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

Expand All @@ -78,14 +78,12 @@ def _post_step(self, cache_element: CacheElement, reward: Dict[Any, float]) -> N
self._tracker["env_metric"] = self._env.metrics


agent2policy = {agent: f"{algorithm}_{agent}.{agent}" for agent in Env(**env_conf).agent_idx_list}
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in Env(**env_conf).agent_idx_list}


def env_sampler_creator(policy_creator: Dict[str, Callable[[str], RLPolicy]]) -> CIMEnvSampler:
return CIMEnvSampler(
get_env_func=lambda: Env(**env_conf),
get_env=lambda: Env(**env_conf),
policy_creator=policy_creator,
agent2policy=agent2policy,
agent_wrapper_cls=SimpleAgentWrapper,
device="cpu"
agent2policy=agent2policy
)
Loading