-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
env(lisong): add beergame supply chain optimization env (#512)
* feat(lisong): add beergame env * fix(lisong): reset state and modify reward * feat(lisong): modify episode_collect and reward_shaping * fix(lisong): fix codestyle * fix(lisong): remove useless code * fix(lisong): fix bugs * feature(lisong): add link in env table * feat(lisong): add plotting figure * fix(lisong): fix figure_path bug * fix(lisong):fix log * polish(lisong): polish code
- Loading branch information
Showing
14 changed files
with
1,356 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from easydict import EasyDict | ||
|
||
beergame_ppo_config = dict( | ||
exp_name='beergame_ppo_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
stop_value=200, | ||
role=0, # 0-3 : retailer, warehouse, distributor, manufacturer | ||
agent_type='bs', | ||
# type of co-player, 'bs'- base stock, 'Strm'- use Sterman formula to model typical human behavior | ||
demandDistribution=0 | ||
# distribution of demand, default=0, '0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data' | ||
), | ||
policy=dict( | ||
cuda=True, | ||
recompute_adv=True, | ||
action_space='discrete', | ||
model=dict( | ||
obs_shape=50, # statedim * multPerdInpt= 5 * 10 | ||
action_shape=5, # the quantity relative to the arriving order | ||
action_space='discrete', | ||
encoder_hidden_size_list=[64, 64, 128], | ||
actor_head_hidden_size=128, | ||
critic_head_hidden_size=128, | ||
), | ||
learn=dict( | ||
epoch_per_collect=10, | ||
batch_size=320, | ||
learning_rate=3e-4, | ||
entropy_weight=0.001, | ||
adv_norm=True, | ||
value_norm=True, | ||
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must | ||
# use ignore_done=False here, | ||
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True | ||
# for halfcheetah, the length=1000 | ||
ignore_done=True, | ||
), | ||
collect=dict( | ||
n_episode=8, | ||
discount_factor=0.99, | ||
gae_lambda=0.95, | ||
collector=dict( | ||
get_train_sample=True, | ||
reward_shaping=True, # whether use total reward to reshape reward | ||
), | ||
), | ||
eval=dict(evaluator=dict(eval_freq=500, )), | ||
), | ||
) | ||
beergame_ppo_config = EasyDict(beergame_ppo_config) | ||
main_config = beergame_ppo_config | ||
beergame_ppo_create_config = dict( | ||
env=dict( | ||
type='beergame', | ||
import_names=['dizoo.beergame.envs.beergame_env'], | ||
), | ||
env_manager=dict(type='base'), | ||
policy=dict(type='ppo'), | ||
collector=dict(type='episode', ), | ||
) | ||
beergame_ppo_create_config = EasyDict(beergame_ppo_create_config) | ||
create_config = beergame_ppo_create_config | ||
|
||
if __name__ == "__main__": | ||
# or you can enter `ding -m serial_onpolicy -c beergame_onppo_config.py -s 0` | ||
from ding.entry import serial_pipeline_onpolicy | ||
serial_pipeline_onpolicy([main_config, create_config], seed=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import torch | ||
from tensorboardX import SummaryWriter | ||
|
||
from ding.config import compile_config | ||
from ding.worker import InteractionSerialEvaluator | ||
from ding.envs import BaseEnvManager | ||
from ding.policy import PPOPolicy | ||
from ding.model import VAC | ||
from ding.utils import set_pkg_seed | ||
from dizoo.beergame.config.beergame_onppo_config import beergame_ppo_config, beergame_ppo_create_config | ||
from ding.envs import get_vec_env_setting | ||
from functools import partial | ||
|
||
|
||
def main(cfg, seed=0): | ||
env_fn = None | ||
cfg, create_cfg = beergame_ppo_config, beergame_ppo_create_config | ||
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) | ||
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | ||
|
||
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) | ||
cfg.env.manager.auto_reset = False | ||
evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager) | ||
evaluator_env.seed(seed, dynamic_seed=False) | ||
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | ||
model = VAC(**cfg.policy.model) | ||
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | ||
policy = PPOPolicy(cfg.policy, model=model) | ||
# set the path to save figure | ||
cfg.policy.eval.evaluator.figure_path = './' | ||
evaluator = InteractionSerialEvaluator( | ||
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | ||
) | ||
# load model | ||
model.load_state_dict(torch.load('model path', map_location='cpu')["model"]) | ||
evaluator.eval(None, -1, -1) | ||
|
||
|
||
if __name__ == "__main__": | ||
beergame_ppo_config.exp_name = 'beergame_evaluate' | ||
main(beergame_ppo_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL. | ||
import argparse | ||
import numpy as np | ||
|
||
|
||
# Here we want to define the agent class for the BeerGame | ||
class Agent(object): | ||
# initializes the agents with initial values for IL, OO and saves self.agentNum for recognizing the agents. | ||
def __init__( | ||
self, agentNum: int, IL: int, AO: int, AS: int, c_h: float, c_p: float, eta: int, compuType: str, | ||
config: argparse.Namespace | ||
) -> None: | ||
self.agentNum = agentNum | ||
self.IL = IL # Inventory level of each agent - changes during the game | ||
self.OO = 0 # Open order of each agent - changes during the game | ||
self.ASInitial = AS # the initial arriving shipment. | ||
self.ILInitial = IL # IL at which we start each game with this number | ||
self.AOInitial = AO # OO at which we start each game with this number | ||
self.config = config # an instance of config is stored inside the class | ||
self.curState = [] # this function gets the current state of the game | ||
self.nextState = [] | ||
self.curReward = 0 # the reward observed at the current step | ||
self.cumReward = 0 # cumulative reward; reset at the begining of each episode | ||
self.totRew = 0 # it is reward of all players obtained for the current player. | ||
self.c_h = c_h # holding cost | ||
self.c_p = c_p # backorder cost | ||
self.eta = eta # the total cost regulazer | ||
self.AS = np.zeros((1, 1)) # arriced shipment | ||
self.AO = np.zeros((1, 1)) # arrived order | ||
self.action = 0 # the action at time t | ||
self.compType = compuType | ||
# self.compTypeTrain = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists | ||
# self.compTypeTest = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists | ||
self.alpha_b = self.config.alpha_b[self.agentNum] # parameters for the formula | ||
self.betta_b = self.config.betta_b[self.agentNum] # parameters for the formula | ||
if self.config.demandDistribution == 0: | ||
self.a_b = np.mean((self.config.demandUp, self.config.demandLow)) # parameters for the formula | ||
self.b_b = np.mean((self.config.demandUp, self.config.demandLow)) * ( | ||
np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + | ||
np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) | ||
) # parameters for the formula | ||
elif self.config.demandDistribution == 1 or self.config.demandDistribution == 3 or self.config.demandDistribution == 4: | ||
self.a_b = self.config.demandMu # parameters for the formula | ||
self.b_b = self.config.demandMu * ( | ||
np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + | ||
np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) | ||
) # parameters for the formula | ||
elif self.config.demandDistribution == 2: | ||
self.a_b = 8 # parameters for the formula | ||
self.b_b = (3 / 4.) * 8 * ( | ||
np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + | ||
np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) | ||
) # parameters for the formula | ||
elif self.config.demandDistribution == 3: | ||
self.a_b = 10 # parameters for the formula | ||
self.b_b = 7 * ( | ||
np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + | ||
np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) | ||
) # parameters for the formula | ||
else: | ||
raise Exception('The demand distribution is not defined or it is not a valid type.!') | ||
|
||
self.hist = [] # this is used for plotting - keeps the history for only one game | ||
self.hist2 = [] # this is used for animation usage | ||
self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017 | ||
self.T = 0 | ||
self.bsBaseStock = 0 | ||
self.init_bsBaseStock = 0 | ||
self.nextObservation = [] | ||
|
||
if self.compType == 'srdqn': | ||
# sets the initial input of the network | ||
self.currentState = np.stack( | ||
[self.curState for _ in range(self.config.multPerdInpt)], axis=0 | ||
) # multPerdInpt observations stacked. each row is an observation | ||
|
||
# reset player information | ||
def resetPlayer(self, T: int): | ||
self.IL = self.ILInitial | ||
self.OO = 0 | ||
self.AS = np.squeeze( | ||
np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10)) | ||
) # arriced shipment | ||
self.AO = np.squeeze( | ||
np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10)) | ||
) # arrived order | ||
if self.agentNum != 0: | ||
for i in range(self.config.leadRecOrderUp_aux[self.agentNum - 1]): | ||
self.AO[i] = self.AOInitial[self.agentNum - 1] | ||
for i in range(self.config.leadRecItemUp[self.agentNum]): | ||
self.AS[i] = self.ASInitial | ||
self.curReward = 0 # the reward observed at the current step | ||
self.cumReward = 0 # cumulative reward; reset at the begining of each episode | ||
self.action = [] | ||
self.hist = [] | ||
self.hist2 = [] | ||
self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017 | ||
self.T = T | ||
self.curObservation = self.getCurState(1) # this function gets the current state of the game | ||
self.nextObservation = [] | ||
if self.compType == 'srdqn': | ||
self.currentState = np.stack([self.curObservation for _ in range(self.config.multPerdInpt)], axis=0) | ||
|
||
# updates the IL and OO at time t, after recieving "rec" number of items | ||
def recieveItems(self, time: int) -> None: | ||
self.IL = self.IL + self.AS[time] # inverntory level update | ||
self.OO = self.OO - self.AS[time] # invertory in transient update | ||
|
||
# find action Value associated with the action list | ||
def actionValue(self, curTime: int) -> int: | ||
if self.config.fixedAction: | ||
a = self.config.actionList[np.argmax(self.action)] | ||
else: | ||
# "d + x" rule | ||
if self.compType == 'srdqn': | ||
a = max(0, self.config.actionList[np.argmax(self.action)] * self.config.action_step + self.AO[curTime]) | ||
elif self.compType == 'rnd': | ||
a = max(0, self.config.actionList[np.argmax(self.action)] + self.AO[curTime]) | ||
else: | ||
a = max(0, self.config.actionListOpt[np.argmax(self.action)]) | ||
|
||
return a | ||
|
||
# getReward returns the reward at the current state | ||
def getReward(self) -> None: | ||
# cost (holding + backorder) for one time unit | ||
self.curReward = (self.c_p * max(0, -self.IL) + self.c_h * max(0, self.IL)) / 200. # self.config.Ttest # | ||
self.curReward = -self.curReward | ||
# make reward negative, because it is the cost | ||
|
||
# sum total reward of each agent | ||
self.cumReward = self.config.gamma * self.cumReward + self.curReward | ||
|
||
# This function returns a np.array of the current state of the agent | ||
def getCurState(self, t: int) -> np.ndarray: | ||
if self.config.ifUseASAO: | ||
if self.config.if_use_AS_t_plus_1: | ||
curState = np.array( | ||
[-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t], self.AO[t]] | ||
) | ||
else: | ||
curState = np.array( | ||
[-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t - 1], self.AO[t]] | ||
) | ||
else: | ||
curState = np.array([-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO]) | ||
|
||
if self.config.ifUseActionInD: | ||
a = self.config.actionList[np.argmax(self.action)] | ||
curState = np.concatenate((curState, np.array([a]))) | ||
|
||
return curState |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .clBeergame import clBeerGame | ||
from .beergame_core import BeerGame |
Oops, something went wrong.