Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reset random seed bug #387

Merged
merged 8 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/apidoc/maro.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ maro.utils.exception.communication\_exception
:undoc-members:
:show-inheritance:

maro.utils.exception.data\_lib\_exeption
maro.utils.exception.data\_lib\_exception
--------------------------------------------------------------------------------

.. automodule:: maro.utils.exception.data_lib_exeption
.. automodule:: maro.utils.exception.data_lib_exception
:members:
:undoc-members:
:show-inheritance:
Expand Down
20 changes: 7 additions & 13 deletions maro/data_lib/cim/cim_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from math import ceil
from typing import Dict, List

from maro.simulator.utils import random

from .entities import (
CimBaseDataCollection, CimRealDataCollection, CimSyntheticDataCollection, NoisedItem, Order, OrderGenerateMode,
PortSetting, VesselSetting
)
from .port_buffer_tick_wrapper import PortBufferTickWrapper
from .utils import (
apply_noise, buffer_tick_rand, get_buffer_tick_seed, get_order_num_seed, list_sum_normalize, order_num_rand
)
from .utils import BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY, apply_noise, list_sum_normalize
from .vessel_future_stops_prediction import VesselFutureStopsPrediction
from .vessel_past_stops_wrapper import VesselPastStopsWrapper
from .vessel_reachable_stops_wrapper import VesselReachableStopsWrapper
Expand Down Expand Up @@ -60,9 +60,6 @@ def __init__(self, data_collection: CimBaseDataCollection):
self._vessel_plan_wrapper = VesselSailingPlanWrapper(self._data_collection)
self._reachable_stops_wrapper = VesselReachableStopsWrapper(self._data_collection)

# keep the seed so we can reproduce the sequence after reset
self._buffer_tick_seed: int = get_buffer_tick_seed()

# flag to tell if we need to reset seed, we need this flag as outside may set the seed after env.reset
self._is_need_reset_seed = False

Expand Down Expand Up @@ -245,7 +242,7 @@ def reset(self):

def _reset_seed(self):
"""Reset internal seed for generate reproduceable data"""
buffer_tick_rand.seed(self._buffer_tick_seed)
random.reset_seed(BUFFER_TICK_RAND_KEY)

@abstractmethod
def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
Expand All @@ -272,9 +269,6 @@ class CimSyntheticDataContainer(CimBaseDataContainer):
def __init__(self, data_collection: CimSyntheticDataCollection):
super().__init__(data_collection)

# keep the seed so we can reproduce the sequence after reset
self._order_num_seed: int = get_order_num_seed()

# TODO: get_events which composed with arrive, departure and order

def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
Expand Down Expand Up @@ -303,7 +297,7 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
def _reset_seed(self):
"""Reset internal seed for generate reproduceable data"""
super()._reset_seed()
order_num_rand.seed(self._order_num_seed)
random.reset_seed(ORDER_NUM_RAND_KEY)

def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
"""Generate order for specified tick.
Expand Down Expand Up @@ -339,7 +333,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
for port_idx in range(self.port_number):
source_dist: NoisedItem = self.ports[port_idx].source_proportion

noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, order_num_rand)
noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, random[ORDER_NUM_RAND_KEY])

noised_source_order_dist.append(noised_source_order_number)

Expand All @@ -356,7 +350,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:

# apply noise and normalize
noised_targets_dist = list_sum_normalize(
[apply_noise(target.base, target.noise, order_num_rand) for target in targets_dist])
[apply_noise(target.base, target.noise, random[ORDER_NUM_RAND_KEY]) for target in targets_dist])

# order for current ports
cur_port_order_num = ceil(orders_to_gen * noised_source_order_dist[port_idx])
Expand Down
22 changes: 15 additions & 7 deletions maro/data_lib/cim/cim_data_container_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import urllib.parse

from maro.cli.data_pipeline.utils import StaticParameter
from maro.simulator.utils import seed
from maro.simulator.utils import random, seed

from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer
from .cim_data_generator import CimDataGenerator
from .cim_data_loader import load_from_folder, load_real_data_from_folder
from .utils import DATA_CONTAINER_INIT_SEED_LIMIT, ROUTE_INIT_RAND_KEY


class CimDataContainerWrapper:
Expand All @@ -28,22 +29,26 @@ def __init__(self, config_path: str, max_tick: int, topology: str):

self._init_data_container()

def _init_data_container(self):
def _init_data_container(self, topology_seed: int = None):
if not os.path.exists(self._config_path):
raise FileNotFoundError
# Synthetic Data Mode: config.yml must exist.
config_path = os.path.join(self._config_path, "config.yml")
if os.path.exists(config_path):
self._data_cntr = data_from_generator(
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick,
topology_seed=topology_seed
)
else:
# Real Data Mode: read data from input data files, no need for any config.yml.
self._data_cntr = data_from_files(data_folder=self._config_path)

def reset(self):
def reset(self, keep_seed):
"""Reset data container internal state"""
self._data_cntr.reset()
if not keep_seed:
self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1))
else:
self._data_cntr.reset()

def __getattr__(self, name):
return getattr(self._data_cntr, name)
Expand All @@ -68,20 +73,23 @@ def data_from_dumps(dumps_folder: str) -> CimSyntheticDataContainer:
return CimSyntheticDataContainer(data_collection)


def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataContainer:
def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0,
topology_seed: int = None) -> CimSyntheticDataContainer:
"""Collect data from data generator with configurations.

Args:
config_path(str): Path of configuration file (yaml).
max_tick (int): Max tick to generate data.
start_tick(int): Start tick to generate data.
topology_seed(int): Random seed of the business engine. \
'None' means using the seed in the configuration file.

Returns:
CimSyntheticDataContainer: Data container used to provide cim data related interfaces.
"""
edg = CimDataGenerator()

data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick)
data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)

return CimSyntheticDataContainer(data_collection)

Expand Down
2 changes: 1 addition & 1 deletion maro/data_lib/cim/cim_data_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def dump_from_config(config_file: str, output_folder: str, max_tick: int):

generator = CimDataGenerator()

data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0)
data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)

dump_util = CimDataDumpUtil(data_collection)

Expand Down
21 changes: 14 additions & 7 deletions maro/data_lib/cim/cim_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from yaml import safe_load

from maro.simulator.utils import seed
from maro.utils.exception.data_lib_exeption import CimGeneratorInvalidParkingDuration
from maro.simulator.utils import random, seed
from maro.utils.exception.data_lib_exception import CimGeneratorInvalidParkingDuration

from .entities import CimSyntheticDataCollection, OrderGenerateMode, Stop
from .global_order_proportion import GlobalOrderProportion
from .port_parser import PortsParser
from .route_parser import RoutesParser
from .utils import apply_noise, route_init_rand
from .utils import ROUTE_INIT_RAND_KEY, apply_noise
from .vessel_parser import VesselsParser

CIM_GENERATOR_VERSION = 0x000001
Expand All @@ -29,13 +29,19 @@ def __init__(self):
self._routes_parser = RoutesParser()
self._global_order_proportion = GlobalOrderProportion()

def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimSyntheticDataCollection:
def gen_data(
self, config_file: str, max_tick: int,
start_tick: int = 0,
topology_seed: int = None
) -> CimSyntheticDataCollection:
"""Generate data with specified configurations.

Args:
config_file(str): File of configuration (yaml).
max_tick(int): Max tick to generate.
start_tick(int): Start tick to generate.
topology_seed(int): Random seed of the business engine. \
'None' means using the seed in the configuration file.

Returns:
CimSyntheticDataCollection: Data collection contains all cim data.
Expand All @@ -45,7 +51,8 @@ def gen_data(self, config_file: str, max_tick: int, start_tick: int = 0) -> CimS
with open(config_file, "r") as fp:
conf: dict = safe_load(fp)

topology_seed = conf["seed"]
if topology_seed is None:
topology_seed = conf["seed"]

# set seed to generate data
seed(topology_seed)
Expand Down Expand Up @@ -146,7 +153,7 @@ def _extend_route(
port_idx = port_mapping[cur_route_point.port_name]

# apply noise to parking duration
parking_duration = ceil(apply_noise(duration, duration_noise, route_init_rand))
parking_duration = ceil(apply_noise(duration, duration_noise, random[ROUTE_INIT_RAND_KEY]))

if parking_duration <= 0:
raise CimGeneratorInvalidParkingDuration()
Expand All @@ -165,7 +172,7 @@ def _extend_route(
distance_to_next_port = cur_route_point.distance_to_next_port

# apply noise to speed
noised_speed = apply_noise(speed, speed_noise, route_init_rand)
noised_speed = apply_noise(speed, speed_noise, random[ROUTE_INIT_RAND_KEY])
sailing_duration = ceil(distance_to_next_port / noised_speed)

# next tick
Expand Down
6 changes: 4 additions & 2 deletions maro/data_lib/cim/global_order_proportion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import numpy as np

from .utils import apply_noise, clip, order_init_rand
from maro.simulator.utils import random

from .utils import ORDER_INIT_RAND_KEY, apply_noise, clip


class GlobalOrderProportion:
Expand Down Expand Up @@ -59,7 +61,7 @@ def parse(self, conf: dict, total_container: int, max_tick: int, start_tick: int
# apply noise if the distribution not zero
if orders != 0:
if noise != 0:
orders = apply_noise(orders, noise, order_init_rand)
orders = apply_noise(orders, noise, random[ORDER_INIT_RAND_KEY])

# clip and gen order
orders = floor(clip(0, 1, orders) * total_container)
Expand Down
6 changes: 4 additions & 2 deletions maro/data_lib/cim/port_buffer_tick_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from math import ceil

from maro.simulator.utils import random

from .entities import CimBaseDataCollection, NoisedItem, PortSetting
from .utils import apply_noise, buffer_tick_rand
from .utils import BUFFER_TICK_RAND_KEY, apply_noise


class PortBufferTickWrapper:
Expand All @@ -29,4 +31,4 @@ def __getitem__(self, key):

buffer_setting: NoisedItem = self._attribute_func(port)

return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, buffer_tick_rand))
return ceil(apply_noise(buffer_setting.base, buffer_setting.noise, random[BUFFER_TICK_RAND_KEY]))
21 changes: 7 additions & 14 deletions maro/data_lib/cim/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from random import Random
from typing import List, Union

from maro.simulator.utils.sim_random import SimRandom, random

# we keep 4 random generator to make the result is reproduceable with same seed(s), no matter if agent passed actions
route_init_rand = random["route_init"]
order_init_rand = random["order_init"]
buffer_tick_rand = random["buffer_time"]
order_num_rand = random["order_number"]


def get_buffer_tick_seed():
return random.get_seed("buffer_time")

ROUTE_INIT_RAND_KEY = "route_init"
ORDER_INIT_RAND_KEY = "order_init"
BUFFER_TICK_RAND_KEY = "buffer_time"
ORDER_NUM_RAND_KEY = "order_number"

def get_order_num_seed():
return random.get_seed("order_number")
DATA_CONTAINER_INIT_SEED_LIMIT = 4096


def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[int, float]) -> Union[int, float]:
Expand All @@ -34,7 +27,7 @@ def clip(min_val: Union[int, float], max_val: Union[int, float], value: Union[in
return max(min_val, min(max_val, value))


def apply_noise(value: Union[int, float], noise: Union[int, float], rand: SimRandom) -> float:
def apply_noise(value: Union[int, float], noise: Union[int, float], rand: Random) -> float:
"""Apply noise with specified random generator

Args:
Expand Down
2 changes: 1 addition & 1 deletion maro/data_lib/item_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from yaml import SafeDumper, SafeLoader, YAMLObject, safe_dump, safe_load

from maro.data_lib.common import dtype_pack_map
from maro.utils.exception.data_lib_exeption import MetaTimestampNotExist
from maro.utils.exception.data_lib_exception import MetaTimestampNotExist


class EntityAttr(YAMLObject):
Expand Down
10 changes: 7 additions & 3 deletions maro/simulator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def dump(self):
"""
return

def reset(self):
"""Reset environment."""
def reset(self, keep_seed: bool = False):
"""Reset environment.

Args:
keep_seed (bool): Reset the random seed to the generate the same data sequence or not. Defaults to False.
"""
self._tick = self._start_tick

self._simulate_generator.close()
Expand All @@ -120,7 +124,7 @@ def reset(self):

self._decision_events.clear()

self._business_engine.reset()
self._business_engine.reset(keep_seed)

@property
def configs(self) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion maro/simulator/scenarios/abs_business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def configs(self) -> dict:
pass

@abstractmethod
def reset(self):
def reset(self, keep_seed: bool = False):
"""Reset states business engine."""
pass

Expand Down
4 changes: 2 additions & 2 deletions maro/simulator/scenarios/cim/business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def post_step(self, tick: int):

return tick + 1 == self._max_tick

def reset(self):
def reset(self, keep_seed: bool = False):
"""Reset the business engine, it will reset frame value."""

self._snapshots.reset()
Expand All @@ -205,7 +205,7 @@ def reset(self):

self._reset_nodes()

self._data_cntr.reset()
self._data_cntr.reset(keep_seed)

# Insert departure event again.
self._load_departure_events()
Expand Down
2 changes: 1 addition & 1 deletion maro/simulator/scenarios/citi_bike/business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_event_payload_detail(self) -> dict:
CitiBikeEvents.DeliverBike.name: BikeTransferPayload.summary_key
}

def reset(self):
def reset(self, keep_seed: bool = False):
"""Reset internal states for episode."""
self._total_trips = 0
self._total_operate_num = 0
Expand Down
2 changes: 1 addition & 1 deletion maro/simulator/scenarios/vm_scheduling/business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _init_pms(self, pm_dict: dict):

return start_pm_id

def reset(self):
def reset(self, keep_seed: bool = False):
"""Reset internal states for episode."""
self._init_metrics()

Expand Down
Loading