Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine rl component bundle #549

Merged
merged 20 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/cim/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .rl_component_bundle import CIMBundle as rl_component_bundle_cls
from .rl_component_bundle import rl_component_bundle

__all__ = [
"rl_component_bundle_cls",
"rl_component_bundle",
]
2 changes: 1 addition & 1 deletion examples/cim/rl/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.0,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
Expand Down
6 changes: 3 additions & 3 deletions examples/cim/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.0,
replay_memory_capacity=10000,
batch_size=32,
params=DQNParams(
reward_discount=0.0,
update_target_every=5,
num_epochs=10,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
),
)
2 changes: 1 addition & 1 deletion examples/cim/rl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
return DiscreteMADDPGTrainer(
name=name,
reward_discount=0.0,
params=DiscreteMADDPGParams(
reward_discount=0.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
shared_critic=False,
Expand Down
3 changes: 1 addition & 2 deletions examples/cim/rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.0,
params=PPOParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
lam=0.0,
clip_ratio=0.1,
),
Expand Down
5 changes: 0 additions & 5 deletions examples/cim/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
"durations": 560,
}

if env_conf["topology"].startswith("toy"):
num_agents = int(env_conf["topology"].split(".")[1][0])
else:
num_agents = int(env_conf["topology"].split(".")[1][:2])

port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

Expand Down
103 changes: 37 additions & 66 deletions examples/cim/rl/rl_component_bundle.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from functools import partial
from typing import Any, Callable, Dict, Optional

from maro.rl.policy import AbsPolicy
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
from maro.rl.rollout import AbsEnvSampler
from maro.rl.training import AbsTrainer
from maro.simulator import Env

from .algorithms.ac import get_ac, get_ac_policy
from .algorithms.dqn import get_dqn, get_dqn_policy
from .algorithms.maddpg import get_maddpg, get_maddpg_policy
from .algorithms.ppo import get_ppo, get_ppo_policy
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim
from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim
from examples.cim.rl.env_sampler import CIMEnvSampler


class CIMBundle(RLComponentBundle):
def get_env_config(self) -> dict:
return env_conf

def get_test_env_config(self) -> Optional[dict]:
return None

def get_env_sampler(self) -> AbsEnvSampler:
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"])

def get_agent2policy(self) -> Dict[Any, str]:
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list}

def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]:
if algorithm == "ac":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "ppo":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "dqn":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
elif algorithm == "discrete_maddpg":
policy_creator = {
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy")
for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

return policy_creator

def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]:
if algorithm == "ac":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "ppo":
trainer_creator = {
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents)
}
elif algorithm == "dqn":
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)}
elif algorithm == "discrete_maddpg":
trainer_creator = {
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)
}
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

return trainer_creator
# Environments
learn_env = Env(**env_conf)
test_env = learn_env

# Agent, policy, and trainers
num_agents = len(learn_env.agent_idx_list)
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list}
if algorithm == "ac":
policies = [get_ac_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_ac(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "ppo":
policies = [get_ppo_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_ppo(state_dim, f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "dqn":
policies = [get_dqn_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_dqn(f"{algorithm}_{i}") for i in range(num_agents)]
elif algorithm == "discrete_maddpg":
policies = [get_maddpg_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)]
trainers = [get_maddpg(state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)]
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")

# Build RLComponentBundle
rl_component_bundle = RLComponentBundle(
env_sampler=CIMEnvSampler(
learn_env=learn_env,
test_env=test_env,
policies=policies,
agent2policy=agent2policy,
reward_eval_delay=reward_shaping_conf["time_window"],
),
agent2policy=agent2policy,
policies=policies,
trainers=trainers,
)
47 changes: 47 additions & 0 deletions examples/rl/cim_distributed.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# Example RL config file for CIM scenario.
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.

# Run this workflow by executing one of the following commands:
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml

job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
eval_schedule: 5
logging:
stdout: INFO
file: DEBUG
rollout:
parallelism:
sampling: 3
eval: null
min_env_samples: 3
grace_factor: 0.2
controller:
host: "127.0.0.1"
port: 20000
logging:
stdout: INFO
file: DEBUG
training:
mode: parallel
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
interval: 5
proxy:
host: "127.0.0.1"
frontend: 10000
backend: 10001
num_workers: 2
logging:
stdout: INFO
file: DEBUG
4 changes: 2 additions & 2 deletions examples/vm_scheduling/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .rl_component_bundle import VMBundle as rl_component_bundle_cls
from .rl_component_bundle import rl_component_bundle

__all__ = [
"rl_component_bundle_cls",
"rl_component_bundle",
]
2 changes: 1 addition & 1 deletion examples/vm_scheduling/rl/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str)
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.9,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
reward_discount=0.9,
grad_iters=100,
critic_loss_cls=torch.nn.MSELoss,
min_logp=-20,
Expand Down
8 changes: 4 additions & 4 deletions examples/vm_scheduling/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.9,
replay_memory_capacity=10000,
batch_size=32,
data_parallelism=2,
params=DQNParams(
reward_discount=0.9,
update_target_every=5,
num_epochs=100,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
data_parallelism=2,
),
)
26 changes: 22 additions & 4 deletions examples/vm_scheduling/rl/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from collections import defaultdict
from os import makedirs
from os.path import dirname, join, realpath
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np
from matplotlib import pyplot as plt

from maro.rl.rollout import AbsEnvSampler, CacheElement
from maro.rl.policy import AbsPolicy
from maro.rl.rollout import AbsAgentWrapper, AbsEnvSampler, CacheElement, SimpleAgentWrapper
from maro.simulator import Env
from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionEvent, PostponeAction

Expand All @@ -30,8 +31,25 @@


class VMEnvSampler(AbsEnvSampler):
def __init__(self, learn_env: Env, test_env: Env) -> None:
super(VMEnvSampler, self).__init__(learn_env, test_env)
def __init__(
self,
learn_env: Env,
test_env: Env,
policies: List[AbsPolicy],
agent2policy: Dict[Any, str],
trainable_policies: List[str] = None,
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
) -> None:
super(VMEnvSampler, self).__init__(
learn_env,
test_env,
policies,
agent2policy,
trainable_policies,
agent_wrapper_cls,
reward_eval_delay,
)

self._learn_env.set_seed(seed)
self._test_env.set_seed(test_seed)
Expand Down
Loading