-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Drafi v2.0 for V2 * Polish models with more comments * Polish policies with more comments * Lint * Lint * Add developer doc for models. * Add developer doc for policies. * Remove policy manager V2 since it is not used and out-of-date * Lint * Lint
- Loading branch information
Showing
27 changed files
with
3,797 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Container Inventory Management | ||
|
||
This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find: | ||
* ``config.py``, which contains environment and policy configurations for the scenario; | ||
* ``env_sampler.py``, which defines state, action and reward shaping in the ``CIMEnvSampler`` class; | ||
* ``policies.py``, which defines the Q-net for DQN and the network components for Actor-Critic; | ||
* ``callbacks.py``, which defines routines to be invoked at the end of training or evaluation episodes. | ||
|
||
The scripts for running the learning workflows can be found under ``examples/rl/workflows``. See ``README`` under ``examples/rl`` for details about the general applicability of these scripts. We recommend that you follow this example to write your own scenarios. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .callbacks import post_collect, post_evaluate | ||
from .env_sampler import agent2policy, get_env_sampler | ||
from .policies import policy_func_dict | ||
|
||
__all__ = ["agent2policy", "post_collect", "post_evaluate", "get_env_sampler", "policy_func_dict"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import time | ||
from os import makedirs | ||
from os.path import dirname, join, realpath | ||
|
||
log_dir = join(dirname(realpath(__file__)), "log", str(time.time())) | ||
makedirs(log_dir, exist_ok=True) | ||
|
||
|
||
def post_collect(trackers, ep, segment): | ||
# print the env metric from each rollout worker | ||
for tracker in trackers: | ||
print(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}") | ||
|
||
# print the average env metric | ||
if len(trackers) > 1: | ||
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers) | ||
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys} | ||
print(f"average env summary (episode {ep}, segment {segment}): {avg_metric}") | ||
|
||
|
||
def post_evaluate(trackers, ep): | ||
# print the env metric from each rollout worker | ||
for tracker in trackers: | ||
print(f"env summary (episode {ep}): {tracker['env_metric']}") | ||
|
||
# print the average env metric | ||
if len(trackers) > 1: | ||
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers) | ||
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys} | ||
print(f"average env summary (episode {ep}): {avg_metric}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import torch | ||
from torch.optim import Adam, RMSprop | ||
|
||
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy | ||
|
||
|
||
env_conf = { | ||
"scenario": "cim", | ||
"topology": "toy.4p_ssdd_l0.0", | ||
"durations": 560 | ||
} | ||
|
||
port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"] | ||
vessel_attributes = ["empty", "full", "remaining_space"] | ||
|
||
state_shaping_conf = { | ||
"look_back": 7, | ||
"max_ports_downstream": 2 | ||
} | ||
|
||
action_shaping_conf = { | ||
"action_space": [(i - 10) / 10 for i in range(21)], | ||
"finite_vessel_space": True, | ||
"has_early_discharge": True | ||
} | ||
|
||
reward_shaping_conf = { | ||
"time_window": 99, | ||
"fulfillment_factor": 1.0, | ||
"shortage_factor": 1.0, | ||
"time_decay": 0.97 | ||
} | ||
|
||
# obtain state dimension from a temporary env_wrapper instance | ||
state_dim = ( | ||
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes) | ||
+ len(vessel_attributes) | ||
) | ||
|
||
############################################## POLICIES ############################################### | ||
|
||
algorithm = "ac" | ||
|
||
# DQN settings | ||
q_net_conf = { | ||
"input_dim": state_dim, | ||
"hidden_dims": [256, 128, 64, 32], | ||
"output_dim": len(action_shaping_conf["action_space"]), | ||
"activation": torch.nn.LeakyReLU, | ||
"softmax": False, | ||
"batch_norm": True, | ||
"skip_connection": False, | ||
"head": True, | ||
"dropout_p": 0.0 | ||
} | ||
|
||
q_net_optim_conf = (RMSprop, {"lr": 0.05}) | ||
|
||
dqn_conf = { | ||
"reward_discount": .0, | ||
"update_target_every": 5, | ||
"num_epochs": 10, | ||
"soft_update_coef": 0.1, | ||
"double": False, | ||
"exploration_strategy": (epsilon_greedy, {"epsilon": 0.4}), | ||
"exploration_scheduling_options": [( | ||
"epsilon", MultiLinearExplorationScheduler, { | ||
"splits": [(2, 0.32)], | ||
"initial_value": 0.4, | ||
"last_ep": 5, | ||
"final_value": 0.0, | ||
} | ||
)], | ||
"replay_memory_capacity": 10000, | ||
"random_overwrite": False, | ||
"warmup": 100, | ||
"rollout_batch_size": 128, | ||
"train_batch_size": 32, | ||
# "prioritized_replay_kwargs": { | ||
# "alpha": 0.6, | ||
# "beta": 0.4, | ||
# "beta_step": 0.001, | ||
# "max_priority": 1e8 | ||
# } | ||
} | ||
|
||
|
||
# AC settings | ||
actor_net_conf = { | ||
"input_dim": state_dim, | ||
"hidden_dims": [256, 128, 64], | ||
"output_dim": len(action_shaping_conf["action_space"]), | ||
"activation": torch.nn.Tanh, | ||
"softmax": True, | ||
"batch_norm": False, | ||
"head": True | ||
} | ||
|
||
critic_net_conf = { | ||
"input_dim": state_dim, | ||
"hidden_dims": [256, 128, 64], | ||
"output_dim": 1, | ||
"activation": torch.nn.LeakyReLU, | ||
"softmax": False, | ||
"batch_norm": True, | ||
"head": True | ||
} | ||
|
||
actor_optim_conf = (Adam, {"lr": 0.001}) | ||
critic_optim_conf = (RMSprop, {"lr": 0.001}) | ||
|
||
ac_conf = { | ||
"reward_discount": .0, | ||
"grad_iters": 10, | ||
"critic_loss_cls": torch.nn.SmoothL1Loss, | ||
"min_logp": None, | ||
"critic_loss_coef": 0.1, | ||
"entropy_coef": 0.01, | ||
# "clip_ratio": 0.8 # for PPO | ||
"lam": .0, | ||
"get_loss_on_rollout": False | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import os | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from maro.rl.learning.env_sampler_v2 import AbsEnvSampler | ||
from maro.simulator import Env | ||
from maro.simulator.scenarios.cim.common import Action, ActionType | ||
|
||
cim_path = os.path.dirname(os.path.realpath(__file__)) | ||
if cim_path not in sys.path: | ||
sys.path.insert(0, cim_path) | ||
|
||
from config import ( | ||
action_shaping_conf, algorithm, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf, | ||
vessel_attributes | ||
) | ||
from policies import policy_func_dict | ||
|
||
|
||
class CIMEnvSampler(AbsEnvSampler): | ||
def get_state(self, tick=None): | ||
""" | ||
The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back" | ||
value in ``state_shaping_conf``), as well as all downstream port features. | ||
""" | ||
if tick is None: | ||
tick = self._env.tick | ||
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"] | ||
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx | ||
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)] | ||
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') | ||
state = np.concatenate([ | ||
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes], | ||
vessel_snapshots[tick : vessel_idx : vessel_attributes] | ||
]) | ||
return {port_idx: state} | ||
|
||
def get_env_actions(self, action_by_agent): | ||
""" | ||
The policy output is an integer from [0, 20] which is to be interpreted as the index of ``action_space`` in | ||
``action_shaping_conf``. For example, action 5 corresponds to -0.5, which means loading 50% of the containers | ||
available at the current port to the vessel, while action 18 corresponds to 0.8, which means loading 80% of the | ||
containers on the vessel to the port. Note that action 10 corresponds 0.0, which means doing nothing. | ||
""" | ||
action_space = action_shaping_conf["action_space"] | ||
finite_vsl_space = action_shaping_conf["finite_vessel_space"] | ||
has_early_discharge = action_shaping_conf["has_early_discharge"] | ||
|
||
port_idx, action = list(action_by_agent.items()).pop() | ||
vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope | ||
vsl_snapshots = self._env.snapshot_list["vessels"] | ||
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf") | ||
|
||
model_action = action["action"] if isinstance(action, dict) else action | ||
percent = abs(action_space[model_action]) | ||
zero_action_idx = len(action_space) / 2 # index corresponding to value zero. | ||
if model_action < zero_action_idx: | ||
action_type = ActionType.LOAD | ||
actual_action = min(round(percent * action_scope.load), vsl_space) | ||
elif model_action > zero_action_idx: | ||
action_type = ActionType.DISCHARGE | ||
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0 | ||
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge | ||
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge) | ||
else: | ||
actual_action, action_type = 0, None | ||
|
||
return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)] | ||
|
||
def get_reward(self, actions, tick): | ||
""" | ||
The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and | ||
shortage measures are the sums of fulfillment and shortage values over the next k days, respectively, each | ||
adjusted with exponential decay factors (using the "time_decay" value in ``reward_shaping_conf``) to put more | ||
emphasis on the near future. Here k is the "time_window" value in ``reward_shaping_conf``. The linear | ||
combination coefficients are given by "fulfillment_factor" and "shortage_factor" in ``reward_shaping_conf``. | ||
""" | ||
start_tick = tick + 1 | ||
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"])) | ||
|
||
# Get the ports that took actions at the given tick | ||
ports = [action.port_idx for action in actions] | ||
port_snapshots = self._env.snapshot_list["ports"] | ||
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1) | ||
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1) | ||
|
||
decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])] | ||
rewards = np.float32( | ||
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list) | ||
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list) | ||
) | ||
return {agent_id: reward for agent_id, reward in zip(ports, rewards)} | ||
|
||
def post_step(self, state, action, env_action, reward, tick): | ||
""" | ||
The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can | ||
be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record | ||
the latest env metric without keeping the history for logging purposes. | ||
""" | ||
self._tracker["env_metric"] = self._env.metrics | ||
|
||
|
||
agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list} | ||
|
||
def get_env_sampler(): | ||
return CIMEnvSampler( | ||
get_env=lambda: Env(**env_conf), | ||
get_policy_func_dict=policy_func_dict, | ||
agent2policy=agent2policy, | ||
reward_eval_delay=reward_shaping_conf["time_window"], | ||
parallel_inference=False | ||
) |
Oops, something went wrong.