-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Config files * Done * Minor bugfix * Add autoflake * Update isort exclude; add pre-commit to requirements * Check only isort * Minor * Format * Test passed * Run pre-commit * Minor bugfix in rl_component_bundle * Pass mypy * Fix a bug in RL notebook * A minor bug fix * Add upper bound for numpy version in test
- Loading branch information
Showing
50 changed files
with
1,495 additions
and
1,000 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .rl_component_bundle import CIMBundle as rl_component_bundle_cls | ||
from .rl_component_bundle import rl_component_bundle | ||
|
||
__all__ = [ | ||
"rl_component_bundle_cls", | ||
"rl_component_bundle", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,48 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from functools import partial | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from maro.rl.policy import AbsPolicy | ||
from maro.rl.rl_component.rl_component_bundle import RLComponentBundle | ||
from maro.rl.rollout import AbsEnvSampler | ||
from maro.rl.training import AbsTrainer | ||
from maro.simulator import Env | ||
|
||
from .algorithms.ac import get_ac, get_ac_policy | ||
from .algorithms.dqn import get_dqn, get_dqn_policy | ||
from .algorithms.maddpg import get_maddpg, get_maddpg_policy | ||
from .algorithms.ppo import get_ppo, get_ppo_policy | ||
from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim | ||
from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim | ||
from examples.cim.rl.env_sampler import CIMEnvSampler | ||
|
||
|
||
class CIMBundle(RLComponentBundle): | ||
def get_env_config(self) -> dict: | ||
return env_conf | ||
|
||
def get_test_env_config(self) -> Optional[dict]: | ||
return None | ||
|
||
def get_env_sampler(self) -> AbsEnvSampler: | ||
return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"]) | ||
|
||
def get_agent2policy(self) -> Dict[Any, str]: | ||
return {agent: f"{algorithm}_{agent}.policy" for agent in self.env.agent_idx_list} | ||
|
||
def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: | ||
if algorithm == "ac": | ||
policy_creator = { | ||
f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy") | ||
for i in range(num_agents) | ||
} | ||
elif algorithm == "ppo": | ||
policy_creator = { | ||
f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy") | ||
for i in range(num_agents) | ||
} | ||
elif algorithm == "dqn": | ||
policy_creator = { | ||
f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy") | ||
for i in range(num_agents) | ||
} | ||
elif algorithm == "discrete_maddpg": | ||
policy_creator = { | ||
f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy") | ||
for i in range(num_agents) | ||
} | ||
else: | ||
raise ValueError(f"Unsupported algorithm: {algorithm}") | ||
|
||
return policy_creator | ||
|
||
def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: | ||
if algorithm == "ac": | ||
trainer_creator = { | ||
f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") for i in range(num_agents) | ||
} | ||
elif algorithm == "ppo": | ||
trainer_creator = { | ||
f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") for i in range(num_agents) | ||
} | ||
elif algorithm == "dqn": | ||
trainer_creator = {f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") for i in range(num_agents)} | ||
elif algorithm == "discrete_maddpg": | ||
trainer_creator = { | ||
f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents) | ||
} | ||
else: | ||
raise ValueError(f"Unsupported algorithm: {algorithm}") | ||
|
||
return trainer_creator | ||
# Environments | ||
learn_env = Env(**env_conf) | ||
test_env = learn_env | ||
|
||
# Agent, policy, and trainers | ||
num_agents = len(learn_env.agent_idx_list) | ||
agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} | ||
if algorithm == "ac": | ||
policies = [get_ac_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] | ||
trainers = [get_ac(state_dim, f"{algorithm}_{i}") for i in range(num_agents)] | ||
elif algorithm == "ppo": | ||
policies = [get_ppo_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] | ||
trainers = [get_ppo(state_dim, f"{algorithm}_{i}") for i in range(num_agents)] | ||
elif algorithm == "dqn": | ||
policies = [get_dqn_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] | ||
trainers = [get_dqn(f"{algorithm}_{i}") for i in range(num_agents)] | ||
elif algorithm == "discrete_maddpg": | ||
policies = [get_maddpg_policy(state_dim, action_num, f"{algorithm}_{i}.policy") for i in range(num_agents)] | ||
trainers = [get_maddpg(state_dim, [1], f"{algorithm}_{i}") for i in range(num_agents)] | ||
else: | ||
raise ValueError(f"Unsupported algorithm: {algorithm}") | ||
|
||
# Build RLComponentBundle | ||
rl_component_bundle = RLComponentBundle( | ||
env_sampler=CIMEnvSampler( | ||
learn_env=learn_env, | ||
test_env=test_env, | ||
policies=policies, | ||
agent2policy=agent2policy, | ||
reward_eval_delay=reward_shaping_conf["time_window"], | ||
), | ||
agent2policy=agent2policy, | ||
policies=policies, | ||
trainers=trainers, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
# Example RL config file for CIM scenario. | ||
# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. | ||
|
||
# Run this workflow by executing one of the following commands: | ||
# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml | ||
# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml | ||
|
||
job: cim_rl_workflow | ||
scenario_path: "examples/cim/rl" | ||
log_path: "log/rl_job/cim.txt" | ||
main: | ||
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training. | ||
num_steps: null | ||
eval_schedule: 5 | ||
logging: | ||
stdout: INFO | ||
file: DEBUG | ||
rollout: | ||
parallelism: | ||
sampling: 3 | ||
eval: null | ||
min_env_samples: 3 | ||
grace_factor: 0.2 | ||
controller: | ||
host: "127.0.0.1" | ||
port: 20000 | ||
logging: | ||
stdout: INFO | ||
file: DEBUG | ||
training: | ||
mode: parallel | ||
load_path: null | ||
load_episode: null | ||
checkpointing: | ||
path: "checkpoint/rl_job/cim" | ||
interval: 5 | ||
proxy: | ||
host: "127.0.0.1" | ||
frontend: 10000 | ||
backend: 10001 | ||
num_workers: 2 | ||
logging: | ||
stdout: INFO | ||
file: DEBUG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .rl_component_bundle import VMBundle as rl_component_bundle_cls | ||
from .rl_component_bundle import rl_component_bundle | ||
|
||
__all__ = [ | ||
"rl_component_bundle_cls", | ||
"rl_component_bundle", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.