From 156d17fd6b180197e5f2c7d34a30227f77dac2ad Mon Sep 17 00:00:00 2001 From: Jinyu Wang Date: Fri, 13 Aug 2021 18:27:29 +0800 Subject: [PATCH 1/8] update the reset interface of Env and BE --- docs/source/apidoc/maro.utils.rst | 4 ++-- maro/data_lib/cim/cim_data_container.py | 4 ++-- maro/data_lib/cim/cim_data_container_helpers.py | 4 ++-- maro/data_lib/cim/cim_data_generator.py | 2 +- maro/data_lib/item_meta.py | 2 +- maro/simulator/core.py | 10 +++++++--- maro/simulator/scenarios/abs_business_engine.py | 2 +- maro/simulator/scenarios/cim/business_engine.py | 4 ++-- maro/simulator/scenarios/citi_bike/business_engine.py | 2 +- .../scenarios/vm_scheduling/business_engine.py | 2 +- .../{data_lib_exeption.py => data_lib_exception.py} | 0 11 files changed, 20 insertions(+), 16 deletions(-) rename maro/utils/exception/{data_lib_exeption.py => data_lib_exception.py} (100%) diff --git a/docs/source/apidoc/maro.utils.rst b/docs/source/apidoc/maro.utils.rst index 6e413a39a..edefd6fa4 100644 --- a/docs/source/apidoc/maro.utils.rst +++ b/docs/source/apidoc/maro.utils.rst @@ -45,10 +45,10 @@ maro.utils.exception.communication\_exception :undoc-members: :show-inheritance: -maro.utils.exception.data\_lib\_exeption +maro.utils.exception.data\_lib\_exception -------------------------------------------------------------------------------- -.. automodule:: maro.utils.exception.data_lib_exeption +.. automodule:: maro.utils.exception.data_lib_exception :members: :undoc-members: :show-inheritance: diff --git a/maro/data_lib/cim/cim_data_container.py b/maro/data_lib/cim/cim_data_container.py index f5399238d..ef5ef38e9 100644 --- a/maro/data_lib/cim/cim_data_container.py +++ b/maro/data_lib/cim/cim_data_container.py @@ -239,9 +239,9 @@ def port_mapping(self) -> Dict[str, int]: """Dict[str, int]: Name to index mapping for ports.""" return self._data_collection.port_mapping - def reset(self): + def reset(self, keep_seed): """Reset data container internal state.""" - self._is_need_reset_seed = True + self._is_need_reset_seed = keep_seed def _reset_seed(self): """Reset internal seed for generate reproduceable data""" diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index a980a683e..e83fde110 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -41,9 +41,9 @@ def _init_data_container(self): # Real Data Mode: read data from input data files, no need for any config.yml. self._data_cntr = data_from_files(data_folder=self._config_path) - def reset(self): + def reset(self, keep_seed): """Reset data container internal state""" - self._data_cntr.reset() + self._data_cntr.reset(keep_seed) def __getattr__(self, name): return getattr(self._data_cntr, name) diff --git a/maro/data_lib/cim/cim_data_generator.py b/maro/data_lib/cim/cim_data_generator.py index 2d817866f..0a2b121dc 100644 --- a/maro/data_lib/cim/cim_data_generator.py +++ b/maro/data_lib/cim/cim_data_generator.py @@ -7,7 +7,7 @@ from yaml import safe_load from maro.simulator.utils import seed -from maro.utils.exception.data_lib_exeption import CimGeneratorInvalidParkingDuration +from maro.utils.exception.data_lib_exception import CimGeneratorInvalidParkingDuration from .entities import CimSyntheticDataCollection, OrderGenerateMode, Stop from .global_order_proportion import GlobalOrderProportion diff --git a/maro/data_lib/item_meta.py b/maro/data_lib/item_meta.py index 25120b2c6..02a6a8f82 100644 --- a/maro/data_lib/item_meta.py +++ b/maro/data_lib/item_meta.py @@ -12,7 +12,7 @@ from yaml import SafeDumper, SafeLoader, YAMLObject, safe_dump, safe_load from maro.data_lib.common import dtype_pack_map -from maro.utils.exception.data_lib_exeption import MetaTimestampNotExist +from maro.utils.exception.data_lib_exception import MetaTimestampNotExist class EntityAttr(YAMLObject): diff --git a/maro/simulator/core.py b/maro/simulator/core.py index 554a9570e..accbe13b8 100644 --- a/maro/simulator/core.py +++ b/maro/simulator/core.py @@ -101,8 +101,12 @@ def dump(self): """ return - def reset(self): - """Reset environment.""" + def reset(self, keep_seed: bool = False): + """Reset environment. + + Args: + keep_seed (bool): Reset the random seed to the generate the same data sequence or not. Defaults to False. + """ self._tick = self._start_tick self._simulate_generator.close() @@ -120,7 +124,7 @@ def reset(self): self._decision_events.clear() - self._business_engine.reset() + self._business_engine.reset(keep_seed) @property def configs(self) -> dict: diff --git a/maro/simulator/scenarios/abs_business_engine.py b/maro/simulator/scenarios/abs_business_engine.py index 5aeb77aa8..51ead0d65 100644 --- a/maro/simulator/scenarios/abs_business_engine.py +++ b/maro/simulator/scenarios/abs_business_engine.py @@ -139,7 +139,7 @@ def configs(self) -> dict: pass @abstractmethod - def reset(self): + def reset(self, keep_seed): """Reset states business engine.""" pass diff --git a/maro/simulator/scenarios/cim/business_engine.py b/maro/simulator/scenarios/cim/business_engine.py index bb10986bc..56d63e6ee 100644 --- a/maro/simulator/scenarios/cim/business_engine.py +++ b/maro/simulator/scenarios/cim/business_engine.py @@ -196,7 +196,7 @@ def post_step(self, tick: int): return tick + 1 == self._max_tick - def reset(self): + def reset(self, keep_seed): """Reset the business engine, it will reset frame value.""" self._snapshots.reset() @@ -205,7 +205,7 @@ def reset(self): self._reset_nodes() - self._data_cntr.reset() + self._data_cntr.reset(keep_seed) # Insert departure event again. self._load_departure_events() diff --git a/maro/simulator/scenarios/citi_bike/business_engine.py b/maro/simulator/scenarios/citi_bike/business_engine.py index cc3da8051..b301340a0 100644 --- a/maro/simulator/scenarios/citi_bike/business_engine.py +++ b/maro/simulator/scenarios/citi_bike/business_engine.py @@ -146,7 +146,7 @@ def get_event_payload_detail(self) -> dict: CitiBikeEvents.DeliverBike.name: BikeTransferPayload.summary_key } - def reset(self): + def reset(self, keep_seed): """Reset internal states for episode.""" self._total_trips = 0 self._total_operate_num = 0 diff --git a/maro/simulator/scenarios/vm_scheduling/business_engine.py b/maro/simulator/scenarios/vm_scheduling/business_engine.py index 0b64be1cd..1507185b1 100644 --- a/maro/simulator/scenarios/vm_scheduling/business_engine.py +++ b/maro/simulator/scenarios/vm_scheduling/business_engine.py @@ -405,7 +405,7 @@ def _init_pms(self, pm_dict: dict): return start_pm_id - def reset(self): + def reset(self, keep_seed): """Reset internal states for episode.""" self._init_metrics() diff --git a/maro/utils/exception/data_lib_exeption.py b/maro/utils/exception/data_lib_exception.py similarity index 100% rename from maro/utils/exception/data_lib_exeption.py rename to maro/utils/exception/data_lib_exception.py From 1e3835d8bc2dfb2228e3d52c4b8efe2bba207490 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 18 Aug 2021 16:22:52 +0800 Subject: [PATCH 2/8] Try to fix reset routes generation seed issue --- maro/data_lib/cim/cim_data_container.py | 4 ++-- maro/data_lib/cim/cim_data_container_helpers.py | 15 ++++++++++----- maro/data_lib/cim/cim_data_dump.py | 2 +- maro/data_lib/cim/cim_data_generator.py | 6 ++++-- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/maro/data_lib/cim/cim_data_container.py b/maro/data_lib/cim/cim_data_container.py index ef5ef38e9..f5399238d 100644 --- a/maro/data_lib/cim/cim_data_container.py +++ b/maro/data_lib/cim/cim_data_container.py @@ -239,9 +239,9 @@ def port_mapping(self) -> Dict[str, int]: """Dict[str, int]: Name to index mapping for ports.""" return self._data_collection.port_mapping - def reset(self, keep_seed): + def reset(self): """Reset data container internal state.""" - self._is_need_reset_seed = keep_seed + self._is_need_reset_seed = True def _reset_seed(self): """Reset internal seed for generate reproduceable data""" diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index e83fde110..c9af3a296 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -10,6 +10,7 @@ from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer from .cim_data_generator import CimDataGenerator from .cim_data_loader import load_from_folder, load_real_data_from_folder +from .utils import route_init_rand class CimDataContainerWrapper: @@ -28,14 +29,14 @@ def __init__(self, config_path: str, max_tick: int, topology: str): self._init_data_container() - def _init_data_container(self): + def _init_data_container(self, topology_seed: int = None): if not os.path.exists(self._config_path): raise FileNotFoundError # Synthetic Data Mode: config.yml must exist. config_path = os.path.join(self._config_path, "config.yml") if os.path.exists(config_path): self._data_cntr = data_from_generator( - config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick + config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick, topology_seed=topology_seed ) else: # Real Data Mode: read data from input data files, no need for any config.yml. @@ -43,7 +44,10 @@ def _init_data_container(self): def reset(self, keep_seed): """Reset data container internal state""" - self._data_cntr.reset(keep_seed) + if not keep_seed: + self._init_data_container(route_init_rand.randint(0, 4096 - 1)) + else: + self._data_cntr.reset() def __getattr__(self, name): return getattr(self._data_cntr, name) @@ -68,20 +72,21 @@ def data_from_dumps(dumps_folder: str) -> CimSyntheticDataContainer: return CimSyntheticDataContainer(data_collection) -def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataContainer: +def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0, topology_seed: int = None) -> CimSyntheticDataContainer: """Collect data from data generator with configurations. Args: config_path(str): Path of configuration file (yaml). max_tick (int): Max tick to generate data. start_tick(int): Start tick to generate data. + topology_seed(int): Random seed for generating routes. 'None' means using the seed in the configuration file. Returns: CimSyntheticDataContainer: Data container used to provide cim data related interfaces. """ edg = CimDataGenerator() - data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick) + data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed) return CimSyntheticDataContainer(data_collection) diff --git a/maro/data_lib/cim/cim_data_dump.py b/maro/data_lib/cim/cim_data_dump.py index 9c8c06275..f99a0a335 100644 --- a/maro/data_lib/cim/cim_data_dump.py +++ b/maro/data_lib/cim/cim_data_dump.py @@ -247,7 +247,7 @@ def dump_from_config(config_file: str, output_folder: str, max_tick: int): generator = CimDataGenerator() - data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0) + data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None) dump_util = CimDataDumpUtil(data_collection) diff --git a/maro/data_lib/cim/cim_data_generator.py b/maro/data_lib/cim/cim_data_generator.py index 0a2b121dc..468fb1f11 100644 --- a/maro/data_lib/cim/cim_data_generator.py +++ b/maro/data_lib/cim/cim_data_generator.py @@ -29,13 +29,14 @@ def __init__(self): self._routes_parser = RoutesParser() self._global_order_proportion = GlobalOrderProportion() - def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataCollection: + def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0, topology_seed: int = None) -> CimSyntheticDataCollection: """Generate data with specified configurations. Args: config_file(str): File of configuration (yaml). max_tick(int): Max tick to generate. start_tick(int): Start tick to generate. + topology_seed(int): Random seed for generating routes. 'None' means using the seed in the configuration file. Returns: CimSyntheticDataCollection: Data collection contains all cim data. @@ -45,7 +46,8 @@ def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimS with open(config_file, "r") as fp: conf: dict = safe_load(fp) - topology_seed = conf["seed"] + if topology_seed is None: + topology_seed = conf["seed"] # set seed to generate data seed(topology_seed) From caa4a6d7cec9b17dc3f8f6a0f49315652279e31a Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 12:26:06 +0800 Subject: [PATCH 3/8] Refine random related logics. --- maro/data_lib/cim/cim_data_container.py | 17 +++----- .../cim/cim_data_container_helpers.py | 7 ++-- maro/data_lib/cim/cim_data_generator.py | 9 ++-- maro/data_lib/cim/global_order_proportion.py | 5 ++- maro/data_lib/cim/port_buffer_tick_wrapper.py | 5 ++- maro/data_lib/cim/utils.py | 21 ++++------ maro/simulator/utils/sim_random.py | 41 ++++++++++++++----- 7 files changed, 58 insertions(+), 47 deletions(-) diff --git a/maro/data_lib/cim/cim_data_container.py b/maro/data_lib/cim/cim_data_container.py index f5399238d..face972e1 100644 --- a/maro/data_lib/cim/cim_data_container.py +++ b/maro/data_lib/cim/cim_data_container.py @@ -12,13 +12,14 @@ ) from .port_buffer_tick_wrapper import PortBufferTickWrapper from .utils import ( - apply_noise, buffer_tick_rand, get_buffer_tick_seed, get_order_num_seed, list_sum_normalize, order_num_rand + apply_noise, list_sum_normalize, BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY ) from .vessel_future_stops_prediction import VesselFutureStopsPrediction from .vessel_past_stops_wrapper import VesselPastStopsWrapper from .vessel_reachable_stops_wrapper import VesselReachableStopsWrapper from .vessel_sailing_plan_wrapper import VesselSailingPlanWrapper from .vessel_stop_wrapper import VesselStopsWrapper +from ...simulator.utils import random class CimBaseDataContainer(ABC): @@ -60,9 +61,6 @@ def __init__(self, data_collection: CimBaseDataCollection): self._vessel_plan_wrapper = VesselSailingPlanWrapper(self._data_collection) self._reachable_stops_wrapper = VesselReachableStopsWrapper(self._data_collection) - # keep the seed so we can reproduce the sequence after reset - self._buffer_tick_seed: int = get_buffer_tick_seed() - # flag to tell if we need to reset seed, we need this flag as outside may set the seed after env.reset self._is_need_reset_seed = False @@ -245,7 +243,7 @@ def reset(self): def _reset_seed(self): """Reset internal seed for generate reproduceable data""" - buffer_tick_rand.seed(self._buffer_tick_seed) + random.reset_seed(BUFFER_TICK_RAND_KEY) @abstractmethod def get_orders(self, tick: int, total_empty_container: int) -> List[Order]: @@ -272,9 +270,6 @@ class CimSyntheticDataContainer(CimBaseDataContainer): def __init__(self, data_collection: CimSyntheticDataCollection): super().__init__(data_collection) - # keep the seed so we can reproduce the sequence after reset - self._order_num_seed: int = get_order_num_seed() - # TODO: get_events which composed with arrive, departure and order def get_orders(self, tick: int, total_empty_container: int) -> List[Order]: @@ -303,7 +298,7 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]: def _reset_seed(self): """Reset internal seed for generate reproduceable data""" super()._reset_seed() - order_num_rand.seed(self._order_num_seed) + random.reset_seed(ORDER_NUM_RAND_KEY) def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]: """Generate order for specified tick. @@ -339,7 +334,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]: for port_idx in range(self.port_number): source_dist: NoisedItem = self.ports[port_idx].source_proportion - noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, order_num_rand) + noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, random[ORDER_NUM_RAND_KEY]) noised_source_order_dist.append(noised_source_order_number) @@ -356,7 +351,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]: # apply noise and normalize noised_targets_dist = list_sum_normalize( - [apply_noise(target.base, target.noise, order_num_rand) for target in targets_dist]) + [apply_noise(target.base, target.noise, random[ORDER_NUM_RAND_KEY]) for target in targets_dist]) # order for current ports cur_port_order_num = ceil(orders_to_gen * noised_source_order_dist[port_idx]) diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index c9af3a296..803e9ebf9 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -5,12 +5,11 @@ import urllib.parse from maro.cli.data_pipeline.utils import StaticParameter -from maro.simulator.utils import seed - +from maro.simulator.utils import seed, random from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer from .cim_data_generator import CimDataGenerator from .cim_data_loader import load_from_folder, load_real_data_from_folder -from .utils import route_init_rand +from .utils import ROUTE_INIT_RAND_KEY, DATA_CONTAINER_INIT_SEED_LIMIT class CimDataContainerWrapper: @@ -45,7 +44,7 @@ def _init_data_container(self, topology_seed: int = None): def reset(self, keep_seed): """Reset data container internal state""" if not keep_seed: - self._init_data_container(route_init_rand.randint(0, 4096 - 1)) + self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1)) else: self._data_cntr.reset() diff --git a/maro/data_lib/cim/cim_data_generator.py b/maro/data_lib/cim/cim_data_generator.py index 468fb1f11..a0e331c67 100644 --- a/maro/data_lib/cim/cim_data_generator.py +++ b/maro/data_lib/cim/cim_data_generator.py @@ -6,14 +6,13 @@ from yaml import safe_load -from maro.simulator.utils import seed +from maro.simulator.utils import seed, random from maro.utils.exception.data_lib_exception import CimGeneratorInvalidParkingDuration - from .entities import CimSyntheticDataCollection, OrderGenerateMode, Stop from .global_order_proportion import GlobalOrderProportion from .port_parser import PortsParser from .route_parser import RoutesParser -from .utils import apply_noise, route_init_rand +from .utils import apply_noise, ROUTE_INIT_RAND_KEY from .vessel_parser import VesselsParser CIM_GENERATOR_VERSION = 0x000001 @@ -148,7 +147,7 @@ def _extend_route( port_idx = port_mapping[cur_route_point.port_name] # apply noise to parking duration - parking_duration = ceil(apply_noise(duration, duration_noise, route_init_rand)) + parking_duration = ceil(apply_noise(duration, duration_noise, random[ROUTE_INIT_RAND_KEY])) if parking_duration <= 0: raise CimGeneratorInvalidParkingDuration() @@ -167,7 +166,7 @@ def _extend_route( distance_to_next_port = cur_route_point.distance_to_next_port # apply noise to speed - noised_speed = apply_noise(speed, speed_noise, route_init_rand) + noised_speed = apply_noise(speed, speed_noise, random[ROUTE_INIT_RAND_KEY]) sailing_duration = ceil(distance_to_next_port / noised_speed) # next tick diff --git a/maro/data_lib/cim/global_order_proportion.py b/maro/data_lib/cim/global_order_proportion.py index cb4351061..9839fa4c3 100644 --- a/maro/data_lib/cim/global_order_proportion.py +++ b/maro/data_lib/cim/global_order_proportion.py @@ -6,7 +6,8 @@ import numpy as np -from .utils import apply_noise, clip, order_init_rand +from .utils import apply_noise, clip, ORDER_INIT_RAND_KEY +from ...simulator.utils import random class GlobalOrderProportion: @@ -59,7 +60,7 @@ def parse(self, conf: dict, total_container: int, max_tick: int, start_tick: int # apply noise if the distribution not zero if orders != 0: if noise != 0: - orders = apply_noise(orders, noise, order_init_rand) + orders = apply_noise(orders, noise, random[ORDER_INIT_RAND_KEY]) # clip and gen order orders = floor(clip(0, 1, orders) * total_container) diff --git a/maro/data_lib/cim/port_buffer_tick_wrapper.py b/maro/data_lib/cim/port_buffer_tick_wrapper.py index 98ec23093..c57a1c2d7 100644 --- a/maro/data_lib/cim/port_buffer_tick_wrapper.py +++ b/maro/data_lib/cim/port_buffer_tick_wrapper.py @@ -4,7 +4,8 @@ from math import ceil from .entities import CimBaseDataCollection, NoisedItem, PortSetting -from .utils import apply_noise, buffer_tick_rand +from .utils import apply_noise, BUFFER_TICK_RAND_KEY +from ...simulator.utils import random class PortBufferTickWrapper: @@ -29,4 +30,4 @@ def __getitem__(self, key): buffer_setting: NoisedItem = self._attribute_func(port) - return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, buffer_tick_rand)) + return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, random[BUFFER_TICK_RAND_KEY])) diff --git a/maro/data_lib/cim/utils.py b/maro/data_lib/cim/utils.py index eb56e1420..a6c639d8e 100644 --- a/maro/data_lib/cim/utils.py +++ b/maro/data_lib/cim/utils.py @@ -1,23 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from random import Random from typing import List, Union -from maro.simulator.utils.sim_random import SimRandom, random +from maro.simulator.utils.sim_random import SimRandom # we keep 4 random generator to make the result is reproduceable with same seed(s), no matter if agent passed actions -route_init_rand = random["route_init"] -order_init_rand = random["order_init"] -buffer_tick_rand = random["buffer_time"] -order_num_rand = random["order_number"] +ROUTE_INIT_RAND_KEY = "route_init" +ORDER_INIT_RAND_KEY = "order_init" +BUFFER_TICK_RAND_KEY = "buffer_time" +ORDER_NUM_RAND_KEY = "order_number" - -def get_buffer_tick_seed(): - return random.get_seed("buffer_time") - - -def get_order_num_seed(): - return random.get_seed("order_number") +DATA_CONTAINER_INIT_SEED_LIMIT = 4096 def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[int, float]) -> Union[int, float]: @@ -34,7 +29,7 @@ def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[in return max(min_val, min(max_val, value)) -def apply_noise(value: Union[int, float], noise: Union[int, float], rand: SimRandom) -> float: +def apply_noise(value: Union[int, float], noise: Union[int, float], rand: Random) -> float: """Apply noise with specified random generator Args: diff --git a/maro/simulator/utils/sim_random.py b/maro/simulator/utils/sim_random.py index 3916169d5..7609a681d 100644 --- a/maro/simulator/utils/sim_random.py +++ b/maro/simulator/utils/sim_random.py @@ -31,7 +31,6 @@ def __init__(self): self._rand_instances: Dict[str, Random] = OrderedDict() self._seed_dict: Dict[str, int] = {} self._seed = int(time.time()) - self._index = 0 def seed(self, seed_num: int): """Set seed for simulator random objects. @@ -46,27 +45,28 @@ def seed(self, seed_num: int): self._seed = seed_num - self._index = 0 - for key, rand in self._rand_instances.items(): + for index, (key, rand) in enumerate(self._rand_instances.items()): # we set seed for each random instance with 1 offset - seed = seed_num + self._index + seed = seed_num + index rand.seed(seed) self._seed_dict[key] = seed - self._index += 1 - - def __getitem__(self, key): + def _create_instance(self, key: str) -> None: assert type(key) is str if key not in self._rand_instances: + self._seed_dict[key] = self._seed + len(self._rand_instances) r = Random() - r.seed(self._seed + self._index) + r.seed(self._seed_dict[key]) + self._rand_instances[key] = r - self._index += 1 + def __getitem__(self, key): + assert type(key) is str - self._rand_instances[key] = r + if key not in self._rand_instances: + self._create_instance(key) return self._rand_instances[key] @@ -88,6 +88,27 @@ def get_seed(self, key: str = None) -> int: return self._seed + def reset_seed(self, key: str = None) -> None: + """Reset seed of current random generator. + + NOTE: + This will reset the seed to the value that specified by user (or default). + + Args: + key(str): Key of item to get. + """ + if key is not None: + if key not in self._seed_dict: + self._create_instance(key) + rand = self._rand_instances[key] + rand.seed(self._seed_dict[key]) + + def reset_all_seeds(self) -> None: + """Reset seed of all random generators + """ + for key in self._rand_instances: + self.reset_seed(key) + random = SimRandom() """Random utility for simulator, same with original random module.""" From 0d3886da275e56feec956e3637f124d8c60fd014 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 14:48:18 +0800 Subject: [PATCH 4/8] Minor refinement --- maro/data_lib/cim/cim_data_container.py | 7 +++---- maro/data_lib/cim/cim_data_container_helpers.py | 10 ++++++---- maro/data_lib/cim/cim_data_generator.py | 11 +++++++---- maro/data_lib/cim/global_order_proportion.py | 4 ++-- maro/data_lib/cim/port_buffer_tick_wrapper.py | 6 ++++-- maro/data_lib/cim/utils.py | 2 -- maro/simulator/utils/sim_random.py | 13 +++++++------ 7 files changed, 29 insertions(+), 24 deletions(-) diff --git a/maro/data_lib/cim/cim_data_container.py b/maro/data_lib/cim/cim_data_container.py index face972e1..ef4979eb7 100644 --- a/maro/data_lib/cim/cim_data_container.py +++ b/maro/data_lib/cim/cim_data_container.py @@ -6,20 +6,19 @@ from math import ceil from typing import Dict, List +from maro.simulator.utils import random + from .entities import ( CimBaseDataCollection, CimRealDataCollection, CimSyntheticDataCollection, NoisedItem, Order, OrderGenerateMode, PortSetting, VesselSetting ) from .port_buffer_tick_wrapper import PortBufferTickWrapper -from .utils import ( - apply_noise, list_sum_normalize, BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY -) +from .utils import apply_noise, list_sum_normalize, BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY from .vessel_future_stops_prediction import VesselFutureStopsPrediction from .vessel_past_stops_wrapper import VesselPastStopsWrapper from .vessel_reachable_stops_wrapper import VesselReachableStopsWrapper from .vessel_sailing_plan_wrapper import VesselSailingPlanWrapper from .vessel_stop_wrapper import VesselStopsWrapper -from ...simulator.utils import random class CimBaseDataContainer(ABC): diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index 803e9ebf9..e07eb26a1 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -5,11 +5,11 @@ import urllib.parse from maro.cli.data_pipeline.utils import StaticParameter -from maro.simulator.utils import seed, random +from maro.simulator.utils import random, seed from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer from .cim_data_generator import CimDataGenerator from .cim_data_loader import load_from_folder, load_real_data_from_folder -from .utils import ROUTE_INIT_RAND_KEY, DATA_CONTAINER_INIT_SEED_LIMIT +from .utils import DATA_CONTAINER_INIT_SEED_LIMIT, ROUTE_INIT_RAND_KEY class CimDataContainerWrapper: @@ -35,7 +35,8 @@ def _init_data_container(self, topology_seed: int = None): config_path = os.path.join(self._config_path, "config.yml") if os.path.exists(config_path): self._data_cntr = data_from_generator( - config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick, topology_seed=topology_seed + config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick, + topology_seed=topology_seed ) else: # Real Data Mode: read data from input data files, no need for any config.yml. @@ -71,7 +72,8 @@ def data_from_dumps(dumps_folder: str) -> CimSyntheticDataContainer: return CimSyntheticDataContainer(data_collection) -def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0, topology_seed: int = None) -> CimSyntheticDataContainer: +def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0, + topology_seed: int = None) -> CimSyntheticDataContainer: """Collect data from data generator with configurations. Args: diff --git a/maro/data_lib/cim/cim_data_generator.py b/maro/data_lib/cim/cim_data_generator.py index a0e331c67..83a94953b 100644 --- a/maro/data_lib/cim/cim_data_generator.py +++ b/maro/data_lib/cim/cim_data_generator.py @@ -6,13 +6,14 @@ from yaml import safe_load -from maro.simulator.utils import seed, random +from maro.simulator.utils import random, seed from maro.utils.exception.data_lib_exception import CimGeneratorInvalidParkingDuration + from .entities import CimSyntheticDataCollection, OrderGenerateMode, Stop from .global_order_proportion import GlobalOrderProportion from .port_parser import PortsParser from .route_parser import RoutesParser -from .utils import apply_noise, ROUTE_INIT_RAND_KEY +from .utils import ROUTE_INIT_RAND_KEY, apply_noise from .vessel_parser import VesselsParser CIM_GENERATOR_VERSION = 0x000001 @@ -28,14 +29,16 @@ def __init__(self): self._routes_parser = RoutesParser() self._global_order_proportion = GlobalOrderProportion() - def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0, topology_seed: int = None) -> CimSyntheticDataCollection: + def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0, + topology_seed: int = None) -> CimSyntheticDataCollection: """Generate data with specified configurations. Args: config_file(str): File of configuration (yaml). max_tick(int): Max tick to generate. start_tick(int): Start tick to generate. - topology_seed(int): Random seed for generating routes. 'None' means using the seed in the configuration file. + topology_seed(int): Random seed of the business engine. \ + 'None' means using the seed in the configuration file. Returns: CimSyntheticDataCollection: Data collection contains all cim data. diff --git a/maro/data_lib/cim/global_order_proportion.py b/maro/data_lib/cim/global_order_proportion.py index 9839fa4c3..802545e3b 100644 --- a/maro/data_lib/cim/global_order_proportion.py +++ b/maro/data_lib/cim/global_order_proportion.py @@ -6,8 +6,8 @@ import numpy as np -from .utils import apply_noise, clip, ORDER_INIT_RAND_KEY -from ...simulator.utils import random +from maro.simulator.utils import random +from .utils import ORDER_INIT_RAND_KEY, apply_noise, clip class GlobalOrderProportion: diff --git a/maro/data_lib/cim/port_buffer_tick_wrapper.py b/maro/data_lib/cim/port_buffer_tick_wrapper.py index c57a1c2d7..990892873 100644 --- a/maro/data_lib/cim/port_buffer_tick_wrapper.py +++ b/maro/data_lib/cim/port_buffer_tick_wrapper.py @@ -3,9 +3,11 @@ from math import ceil +from maro.simulator.utils import random + from .entities import CimBaseDataCollection, NoisedItem, PortSetting -from .utils import apply_noise, BUFFER_TICK_RAND_KEY -from ...simulator.utils import random +from .utils import BUFFER_TICK_RAND_KEY, apply_noise + class PortBufferTickWrapper: diff --git a/maro/data_lib/cim/utils.py b/maro/data_lib/cim/utils.py index a6c639d8e..7b28d6100 100644 --- a/maro/data_lib/cim/utils.py +++ b/maro/data_lib/cim/utils.py @@ -4,8 +4,6 @@ from random import Random from typing import List, Union -from maro.simulator.utils.sim_random import SimRandom - # we keep 4 random generator to make the result is reproduceable with same seed(s), no matter if agent passed actions ROUTE_INIT_RAND_KEY = "route_init" ORDER_INIT_RAND_KEY = "order_init" diff --git a/maro/simulator/utils/sim_random.py b/maro/simulator/utils/sim_random.py index 7609a681d..f766a6f97 100644 --- a/maro/simulator/utils/sim_random.py +++ b/maro/simulator/utils/sim_random.py @@ -88,7 +88,7 @@ def get_seed(self, key: str = None) -> int: return self._seed - def reset_seed(self, key: str = None) -> None: + def reset_seed(self, key: str) -> None: """Reset seed of current random generator. NOTE: @@ -97,11 +97,12 @@ def reset_seed(self, key: str = None) -> None: Args: key(str): Key of item to get. """ - if key is not None: - if key not in self._seed_dict: - self._create_instance(key) - rand = self._rand_instances[key] - rand.seed(self._seed_dict[key]) + assert type(key) is str + + if key not in self._seed_dict: + self._create_instance(key) + rand = self._rand_instances[key] + rand.seed(self._seed_dict[key]) def reset_all_seeds(self) -> None: """Reset seed of all random generators From cd4a156e87359fd31d851ebe40def785b8329302 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 15:00:30 +0800 Subject: [PATCH 5/8] Test check --- maro/data_lib/cim/cim_data_container.py | 2 +- maro/data_lib/cim/cim_data_generator.py | 7 +++++-- maro/data_lib/cim/global_order_proportion.py | 1 + maro/data_lib/cim/port_buffer_tick_wrapper.py | 1 - maro/simulator/scenarios/abs_business_engine.py | 2 +- maro/simulator/scenarios/cim/business_engine.py | 2 +- maro/simulator/scenarios/citi_bike/business_engine.py | 2 +- maro/simulator/scenarios/vm_scheduling/business_engine.py | 2 +- maro/simulator/utils/sim_random.py | 2 +- tests/dummy/dummy_business_engine.py | 2 +- 10 files changed, 13 insertions(+), 10 deletions(-) diff --git a/maro/data_lib/cim/cim_data_container.py b/maro/data_lib/cim/cim_data_container.py index ef4979eb7..701e766ad 100644 --- a/maro/data_lib/cim/cim_data_container.py +++ b/maro/data_lib/cim/cim_data_container.py @@ -13,7 +13,7 @@ PortSetting, VesselSetting ) from .port_buffer_tick_wrapper import PortBufferTickWrapper -from .utils import apply_noise, list_sum_normalize, BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY +from .utils import BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY, apply_noise, list_sum_normalize from .vessel_future_stops_prediction import VesselFutureStopsPrediction from .vessel_past_stops_wrapper import VesselPastStopsWrapper from .vessel_reachable_stops_wrapper import VesselReachableStopsWrapper diff --git a/maro/data_lib/cim/cim_data_generator.py b/maro/data_lib/cim/cim_data_generator.py index 83a94953b..056a5eca9 100644 --- a/maro/data_lib/cim/cim_data_generator.py +++ b/maro/data_lib/cim/cim_data_generator.py @@ -29,8 +29,11 @@ def __init__(self): self._routes_parser = RoutesParser() self._global_order_proportion = GlobalOrderProportion() - def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0, - topology_seed: int = None) -> CimSyntheticDataCollection: + def gen_data( + self, config_file: str, max_tick: int, + start_tick: int = 0, + topology_seed: int = None + ) -> CimSyntheticDataCollection: """Generate data with specified configurations. Args: diff --git a/maro/data_lib/cim/global_order_proportion.py b/maro/data_lib/cim/global_order_proportion.py index 802545e3b..63de50483 100644 --- a/maro/data_lib/cim/global_order_proportion.py +++ b/maro/data_lib/cim/global_order_proportion.py @@ -7,6 +7,7 @@ import numpy as np from maro.simulator.utils import random + from .utils import ORDER_INIT_RAND_KEY, apply_noise, clip diff --git a/maro/data_lib/cim/port_buffer_tick_wrapper.py b/maro/data_lib/cim/port_buffer_tick_wrapper.py index 990892873..12a2a3863 100644 --- a/maro/data_lib/cim/port_buffer_tick_wrapper.py +++ b/maro/data_lib/cim/port_buffer_tick_wrapper.py @@ -9,7 +9,6 @@ from .utils import BUFFER_TICK_RAND_KEY, apply_noise - class PortBufferTickWrapper: """Used to generate buffer ticks when empty/full become available. diff --git a/maro/simulator/scenarios/abs_business_engine.py b/maro/simulator/scenarios/abs_business_engine.py index 51ead0d65..442427094 100644 --- a/maro/simulator/scenarios/abs_business_engine.py +++ b/maro/simulator/scenarios/abs_business_engine.py @@ -139,7 +139,7 @@ def configs(self) -> dict: pass @abstractmethod - def reset(self, keep_seed): + def reset(self, keep_seed: bool = False): """Reset states business engine.""" pass diff --git a/maro/simulator/scenarios/cim/business_engine.py b/maro/simulator/scenarios/cim/business_engine.py index 56d63e6ee..14100ce27 100644 --- a/maro/simulator/scenarios/cim/business_engine.py +++ b/maro/simulator/scenarios/cim/business_engine.py @@ -196,7 +196,7 @@ def post_step(self, tick: int): return tick + 1 == self._max_tick - def reset(self, keep_seed): + def reset(self, keep_seed: bool = False): """Reset the business engine, it will reset frame value.""" self._snapshots.reset() diff --git a/maro/simulator/scenarios/citi_bike/business_engine.py b/maro/simulator/scenarios/citi_bike/business_engine.py index b301340a0..e0308ef8c 100644 --- a/maro/simulator/scenarios/citi_bike/business_engine.py +++ b/maro/simulator/scenarios/citi_bike/business_engine.py @@ -146,7 +146,7 @@ def get_event_payload_detail(self) -> dict: CitiBikeEvents.DeliverBike.name: BikeTransferPayload.summary_key } - def reset(self, keep_seed): + def reset(self, keep_seed: bool = False): """Reset internal states for episode.""" self._total_trips = 0 self._total_operate_num = 0 diff --git a/maro/simulator/scenarios/vm_scheduling/business_engine.py b/maro/simulator/scenarios/vm_scheduling/business_engine.py index 1507185b1..807ea8985 100644 --- a/maro/simulator/scenarios/vm_scheduling/business_engine.py +++ b/maro/simulator/scenarios/vm_scheduling/business_engine.py @@ -405,7 +405,7 @@ def _init_pms(self, pm_dict: dict): return start_pm_id - def reset(self, keep_seed): + def reset(self, keep_seed: bool = False): """Reset internal states for episode.""" self._init_metrics() diff --git a/maro/simulator/utils/sim_random.py b/maro/simulator/utils/sim_random.py index f766a6f97..7b7b1b0a7 100644 --- a/maro/simulator/utils/sim_random.py +++ b/maro/simulator/utils/sim_random.py @@ -98,7 +98,7 @@ def reset_seed(self, key: str) -> None: key(str): Key of item to get. """ assert type(key) is str - + if key not in self._seed_dict: self._create_instance(key) rand = self._rand_instances[key] diff --git a/tests/dummy/dummy_business_engine.py b/tests/dummy/dummy_business_engine.py index fca3e9c5e..fd635863b 100644 --- a/tests/dummy/dummy_business_engine.py +++ b/tests/dummy/dummy_business_engine.py @@ -45,7 +45,7 @@ def post_step(self, tick:int): return tick+1 == self._max_tick - def reset(self): + def reset(self, keep_seed: bool = False): self._frame.reset() self._frame.snapshots.reset() From ef91b637b02739c9984f846eaac7b1d9ba9fa333 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 15:05:48 +0800 Subject: [PATCH 6/8] Minor --- maro/data_lib/cim/cim_data_container_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index e07eb26a1..4430e157e 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -6,6 +6,7 @@ from maro.cli.data_pipeline.utils import StaticParameter from maro.simulator.utils import random, seed + from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer from .cim_data_generator import CimDataGenerator from .cim_data_loader import load_from_folder, load_real_data_from_folder From 34bb16c63e865d54c47c737e97b7507475ce0339 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 15:15:16 +0800 Subject: [PATCH 7/8] Remove unused functions so far --- maro/simulator/utils/sim_random.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/maro/simulator/utils/sim_random.py b/maro/simulator/utils/sim_random.py index 7b7b1b0a7..40e6b0206 100644 --- a/maro/simulator/utils/sim_random.py +++ b/maro/simulator/utils/sim_random.py @@ -70,24 +70,6 @@ def __getitem__(self, key): return self._rand_instances[key] - def get_seed(self, key: str = None) -> int: - """Get seed of current random generator. - - NOTE: - This will only return the seed of first random object that specified by user (or default). - - Args: - key(str): Key of item to get. - - Returns: - int: If key is None return seed for 1st instance (same as what passed to seed function), - else return seed for specified generator. - """ - if key is not None: - return self._seed_dict.get(key, None) - - return self._seed - def reset_seed(self, key: str) -> None: """Reset seed of current random generator. @@ -104,12 +86,6 @@ def reset_seed(self, key: str) -> None: rand = self._rand_instances[key] rand.seed(self._seed_dict[key]) - def reset_all_seeds(self) -> None: - """Reset seed of all random generators - """ - for key in self._rand_instances: - self.reset_seed(key) - random = SimRandom() """Random utility for simulator, same with original random module.""" From dc2d6079713ebbed1ec15ecbe78d1d6369165a78 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 19 Aug 2021 15:26:30 +0800 Subject: [PATCH 8/8] Minor --- maro/data_lib/cim/cim_data_container_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index 4430e157e..2be9ee671 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -81,7 +81,8 @@ def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0, config_path(str): Path of configuration file (yaml). max_tick (int): Max tick to generate data. start_tick(int): Start tick to generate data. - topology_seed(int): Random seed for generating routes. 'None' means using the seed in the configuration file. + topology_seed(int): Random seed of the business engine. \ + 'None' means using the seed in the configuration file. Returns: CimSyntheticDataContainer: Data container used to provide cim data related interfaces.