From eb0527f21920cdd3010c8f65b778d1404cfc516f Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 9 Sep 2021 13:35:35 +0800 Subject: [PATCH] Core & Business Engine code refinement (#392) * First version * Optimize imports * Add typehint * Lint check * Lint check --- maro/simulator/abs_core.py | 42 ++++++++-------- maro/simulator/core.py | 48 +++++++++---------- .../scenarios/abs_business_engine.py | 39 +++++++++------ 3 files changed, 72 insertions(+), 57 deletions(-) diff --git a/maro/simulator/abs_core.py b/maro/simulator/abs_core.py index 6a3b97bbb..cdbe0362f 100644 --- a/maro/simulator/abs_core.py +++ b/maro/simulator/abs_core.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import IntEnum -from typing import List +from typing import List, Optional, Tuple from maro.backends.frame import SnapshotList from maro.event_buffer import EventBuffer @@ -46,26 +46,27 @@ def __init__( disable_finished_events: bool, options: dict ): - self._tick = start_tick - self._scenario = scenario - self._topology = topology - self._start_tick = start_tick - self._durations = durations - self._snapshot_resolution = snapshot_resolution - self._max_snapshots = max_snapshots - self._decision_mode = decision_mode - self._business_engine_cls = business_engine_cls - self._additional_options = options - - self._business_engine: AbsBusinessEngine = None - self._event_buffer: EventBuffer = None + self._tick: int = start_tick + self._scenario: str = scenario + self._topology: str = topology + self._start_tick: int = start_tick + self._durations: int = durations + self._snapshot_resolution: int = snapshot_resolution + self._max_snapshots: int = max_snapshots + self._decision_mode: DecisionMode = decision_mode + self._business_engine_cls: type = business_engine_cls + self._disable_finished_events: bool = disable_finished_events + self._additional_options: dict = options + + self._business_engine: Optional[AbsBusinessEngine] = None + self._event_buffer: Optional[EventBuffer] = None @property - def business_engine(self): + def business_engine(self) -> AbsBusinessEngine: return self._business_engine @abstractmethod - def step(self, action): + def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: """Push the environment to next step with action. Args: @@ -77,12 +78,12 @@ def step(self, action): pass @abstractmethod - def dump(self): + def dump(self) -> None: """Dump environment for restore.""" pass @abstractmethod - def reset(self): + def reset(self) -> None: """Reset environment.""" pass @@ -111,6 +112,7 @@ def tick(self) -> int: pass @property + @abstractmethod def frame_index(self) -> int: """int: Frame index in snapshot list for current tick, USE this for snapshot querying.""" pass @@ -127,7 +129,7 @@ def snapshot_list(self) -> SnapshotList: """SnapshotList: Current snapshot list, a snapshot list contains all the snapshots of frame at each tick.""" pass - def set_seed(self, seed: int): + def set_seed(self, seed: int) -> None: """Set random seed used by simulator. NOTE: @@ -147,10 +149,12 @@ def metrics(self) -> dict: """ return {} + @abstractmethod def get_finished_events(self) -> list: """list: All events finished so far.""" pass + @abstractmethod def get_pending_events(self, tick: int) -> list: """list: Pending events at certain tick. diff --git a/maro/simulator/core.py b/maro/simulator/core.py index d7b992530..fa7de129f 100644 --- a/maro/simulator/core.py +++ b/maro/simulator/core.py @@ -4,17 +4,17 @@ from collections import Iterable from importlib import import_module from inspect import getmembers, isclass -from typing import List +from typing import Generator, List, Optional, Tuple from maro.backends.frame import FrameBase, SnapshotList from maro.data_lib.dump_csv_converter import DumpConverter -from maro.event_buffer import EventBuffer, EventState +from maro.event_buffer import ActualEvent, CascadeEvent, EventBuffer, EventState from maro.streamit import streamit from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError from .abs_core import AbsEnv, DecisionMode from .scenarios.abs_business_engine import AbsBusinessEngine -from .utils import seed as sim_seed +from .utils import random from .utils.common import tick_to_frame_index @@ -47,17 +47,16 @@ def __init__( business_engine_cls: type = None, disable_finished_events: bool = False, record_finished_events: bool = False, record_file_path: str = None, - options: dict = {} - ): + options: Optional[dict] = None + ) -> None: super().__init__( scenario, topology, start_tick, durations, snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, - disable_finished_events, options + disable_finished_events, options if options is not None else {} ) self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \ else business_engine_cls.__name__ - self._business_engine: AbsBusinessEngine = None self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path) @@ -72,12 +71,12 @@ def __init__( if "enable-dump-snapshot" in self._additional_options: parent_path = self._additional_options["enable-dump-snapshot"] - self._converter = DumpConverter(parent_path, self._business_engine._scenario_name) + self._converter = DumpConverter(parent_path, self._business_engine.scenario_name) self._converter.reset_folder_path() self._streamit_episode = 0 - def step(self, action): + def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: """Push the environment to next step with action. Args: @@ -93,7 +92,7 @@ def step(self, action): return metrics, decision_event, _is_done - def dump(self): + def dump(self) -> None: """Dump environment for restore. NOTE: @@ -101,7 +100,7 @@ def dump(self): """ return - def reset(self, keep_seed: bool = False): + def reset(self, keep_seed: bool = False) -> None: """Reset environment. Args: @@ -114,10 +113,10 @@ def reset(self, keep_seed: bool = False): self._event_buffer.reset() - if ("enable-dump-snapshot" in self._additional_options) and (self._business_engine._frame is not None): + if "enable-dump-snapshot" in self._additional_options and self._business_engine.frame is not None: dump_folder = self._converter.get_new_snapshot_folder() - self._business_engine._frame.dump(dump_folder) + self._business_engine.frame.dump(dump_folder) self._converter.start_processing(self.configs) self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution) self._business_engine.dump(dump_folder) @@ -173,7 +172,7 @@ def agent_idx_list(self) -> List[int]: """List[int]: Agent index list that related to this environment.""" return self._business_engine.get_agent_idx_list() - def set_seed(self, seed: int): + def set_seed(self, seed: int) -> None: """Set random seed used by simulator. NOTE: @@ -184,7 +183,7 @@ def set_seed(self, seed: int): """ if seed is not None: - sim_seed(seed) + random.seed(seed) @property def metrics(self) -> dict: @@ -196,11 +195,11 @@ def metrics(self) -> dict: return self._business_engine.get_metrics() - def get_finished_events(self): + def get_finished_events(self) -> List[ActualEvent]: """List[Event]: All events finished so far.""" return self._event_buffer.get_finished_events() - def get_pending_events(self, tick): + def get_pending_events(self, tick) -> List[ActualEvent]: """Pending events at certain tick. Args: @@ -208,7 +207,7 @@ def get_pending_events(self, tick): """ return self._event_buffer.get_pending_events(tick) - def _init_business_engine(self): + def _init_business_engine(self) -> None: """Initialize business engine object. NOTE: @@ -238,7 +237,7 @@ def _init_business_engine(self): if business_class is None: raise BusinessEngineNotFoundError() - self._business_engine = business_class( + self._business_engine: AbsBusinessEngine = business_class( event_buffer=self._event_buffer, topology=self._topology, start_tick=self._start_tick, @@ -248,10 +247,8 @@ def _init_business_engine(self): additional_options=self._additional_options ) - def _simulate(self): + def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]: """This is the generator to wrap each episode process.""" - is_end_tick = False - self._streamit_episode += 1 streamit.episode(self._streamit_episode) @@ -297,8 +294,10 @@ def _simulate(self): # NOTE: decision event always be a CascadeEvent # We just append the action into sub event of first pending cascade event. - pending_events[0].state = EventState.EXECUTING - pending_events[0].add_immediate_event(action_event, is_head=True) + event = pending_events[0] + assert isinstance(event, CascadeEvent) + event.state = EventState.EXECUTING + event.add_immediate_event(action_event, is_head=True) else: # For joint mode, we will assign actions from beginning to end. # Then mark others pending events to finished if not sequential action mode. @@ -314,6 +313,7 @@ def _simulate(self): pending_event.state = EventState.EXECUTING action_event = self._event_buffer.gen_action_event(self._tick, action) + assert isinstance(pending_event, CascadeEvent) pending_event.add_immediate_event(action_event, is_head=True) # Check the end tick of the simulation to decide if we should end the simulation. diff --git a/maro/simulator/scenarios/abs_business_engine.py b/maro/simulator/scenarios/abs_business_engine.py index 442427094..ed72365f9 100644 --- a/maro/simulator/scenarios/abs_business_engine.py +++ b/maro/simulator/scenarios/abs_business_engine.py @@ -4,6 +4,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path +from typing import List, Optional from maro.backends.frame import FrameBase, SnapshotList from maro.event_buffer import EventBuffer @@ -31,7 +32,7 @@ class AbsBusinessEngine(ABC): max_tick (int): Max tick of this business engine. snapshot_resolution (int): Frequency to take a snapshot. max_snapshots(int): Max number of in-memory snapshots, default is None that means max number of snapshots. - addition_options (dict): Additional options for this business engine from outside. + additional_options (dict): Additional options for this business engine from outside. """ def __init__( @@ -39,15 +40,15 @@ def __init__( start_tick: int, max_tick: int, snapshot_resolution: int, max_snapshots: int, additional_options: dict = None ): - self._scenario_name = scenario_name - self._topology = topology - self._event_buffer = event_buffer - self._start_tick = start_tick - self._max_tick = max_tick - self._snapshot_resolution = snapshot_resolution - self._max_snapshots = max_snapshots - self._additional_options = additional_options - self._config_path = None + self._scenario_name: str = scenario_name + self._topology: str = topology + self._event_buffer: EventBuffer = event_buffer + self._start_tick: int = start_tick + self._max_tick: int = max_tick + self._snapshot_resolution: int = snapshot_resolution + self._max_snapshots: int = max_snapshots + self._additional_options: dict = additional_options + self._config_path: Optional[str] = None assert start_tick >= 0 assert max_tick > start_tick @@ -65,6 +66,15 @@ def snapshots(self) -> SnapshotList: """SnapshotList: Snapshot list of current frame, this is used to expose querying interface for outside.""" pass + @property + def scenario_name(self) -> str: + return self._scenario_name + + @abstractmethod + def get_agent_idx_list(self) -> List[int]: + """Get a list of agent index.""" + pass + def frame_index(self, tick: int) -> int: """Helper method for child class, used to get index of frame in snapshot list for specified tick. @@ -89,7 +99,7 @@ def calc_max_snapshots(self) -> int: return self._max_snapshots if self._max_snapshots is not None \ else total_frames(self._start_tick, self._max_tick, self._snapshot_resolution) - def update_config_root_path(self, business_engine_file_path: str): + def update_config_root_path(self, business_engine_file_path: str) -> None: """Helper method used to update the config path with business engine path if you follow the way to load configuration file as built-in scenarios. @@ -125,7 +135,7 @@ def __init__(self, *args, **kwargs): self._config_path = os.path.join(be_file_path, "topologies", self._topology) @abstractmethod - def step(self, tick: int): + def step(self, tick: int) -> None: """Method that is called at each tick, usually used to trigger business logic at current tick. Args: @@ -134,12 +144,13 @@ def step(self, tick: int): pass @property + @abstractmethod def configs(self) -> dict: """dict: Configurations of this business engine.""" pass @abstractmethod - def reset(self, keep_seed: bool = False): + def reset(self, keep_seed: bool = False) -> None: """Reset states business engine.""" pass @@ -183,7 +194,7 @@ def get_metrics(self) -> dict: """ return {} - def dump(self, folder: str): + def dump(self, folder: str) -> None: """Dump something from business engine. Args: