Skip to content

Commit

Permalink
Core & Business Engine code refinement (#392)
Browse files Browse the repository at this point in the history
* First version

* Optimize imports

* Add typehint

* Lint check

* Lint check
  • Loading branch information
lihuoran authored Sep 9, 2021
1 parent d02f4fe commit eb0527f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 57 deletions.
42 changes: 23 additions & 19 deletions maro/simulator/abs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from enum import IntEnum
from typing import List
from typing import List, Optional, Tuple

from maro.backends.frame import SnapshotList
from maro.event_buffer import EventBuffer
Expand Down Expand Up @@ -46,26 +46,27 @@ def __init__(
disable_finished_events: bool,
options: dict
):
self._tick = start_tick
self._scenario = scenario
self._topology = topology
self._start_tick = start_tick
self._durations = durations
self._snapshot_resolution = snapshot_resolution
self._max_snapshots = max_snapshots
self._decision_mode = decision_mode
self._business_engine_cls = business_engine_cls
self._additional_options = options

self._business_engine: AbsBusinessEngine = None
self._event_buffer: EventBuffer = None
self._tick: int = start_tick
self._scenario: str = scenario
self._topology: str = topology
self._start_tick: int = start_tick
self._durations: int = durations
self._snapshot_resolution: int = snapshot_resolution
self._max_snapshots: int = max_snapshots
self._decision_mode: DecisionMode = decision_mode
self._business_engine_cls: type = business_engine_cls
self._disable_finished_events: bool = disable_finished_events
self._additional_options: dict = options

self._business_engine: Optional[AbsBusinessEngine] = None
self._event_buffer: Optional[EventBuffer] = None

@property
def business_engine(self):
def business_engine(self) -> AbsBusinessEngine:
return self._business_engine

@abstractmethod
def step(self, action):
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
"""Push the environment to next step with action.
Args:
Expand All @@ -77,12 +78,12 @@ def step(self, action):
pass

@abstractmethod
def dump(self):
def dump(self) -> None:
"""Dump environment for restore."""
pass

@abstractmethod
def reset(self):
def reset(self) -> None:
"""Reset environment."""
pass

Expand Down Expand Up @@ -111,6 +112,7 @@ def tick(self) -> int:
pass

@property
@abstractmethod
def frame_index(self) -> int:
"""int: Frame index in snapshot list for current tick, USE this for snapshot querying."""
pass
Expand All @@ -127,7 +129,7 @@ def snapshot_list(self) -> SnapshotList:
"""SnapshotList: Current snapshot list, a snapshot list contains all the snapshots of frame at each tick."""
pass

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
"""Set random seed used by simulator.
NOTE:
Expand All @@ -147,10 +149,12 @@ def metrics(self) -> dict:
"""
return {}

@abstractmethod
def get_finished_events(self) -> list:
"""list: All events finished so far."""
pass

@abstractmethod
def get_pending_events(self, tick: int) -> list:
"""list: Pending events at certain tick.
Expand Down
48 changes: 24 additions & 24 deletions maro/simulator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from collections import Iterable
from importlib import import_module
from inspect import getmembers, isclass
from typing import List
from typing import Generator, List, Optional, Tuple

from maro.backends.frame import FrameBase, SnapshotList
from maro.data_lib.dump_csv_converter import DumpConverter
from maro.event_buffer import EventBuffer, EventState
from maro.event_buffer import ActualEvent, CascadeEvent, EventBuffer, EventState
from maro.streamit import streamit
from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError

from .abs_core import AbsEnv, DecisionMode
from .scenarios.abs_business_engine import AbsBusinessEngine
from .utils import seed as sim_seed
from .utils import random
from .utils.common import tick_to_frame_index


Expand Down Expand Up @@ -47,17 +47,16 @@ def __init__(
business_engine_cls: type = None, disable_finished_events: bool = False,
record_finished_events: bool = False,
record_file_path: str = None,
options: dict = {}
):
options: Optional[dict] = None
) -> None:
super().__init__(
scenario, topology, start_tick, durations,
snapshot_resolution, max_snapshots, decision_mode, business_engine_cls,
disable_finished_events, options
disable_finished_events, options if options is not None else {}
)

self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \
else business_engine_cls.__name__
self._business_engine: AbsBusinessEngine = None

self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path)

Expand All @@ -72,12 +71,12 @@ def __init__(

if "enable-dump-snapshot" in self._additional_options:
parent_path = self._additional_options["enable-dump-snapshot"]
self._converter = DumpConverter(parent_path, self._business_engine._scenario_name)
self._converter = DumpConverter(parent_path, self._business_engine.scenario_name)
self._converter.reset_folder_path()

self._streamit_episode = 0

def step(self, action):
def step(self, action) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
"""Push the environment to next step with action.
Args:
Expand All @@ -93,15 +92,15 @@ def step(self, action):

return metrics, decision_event, _is_done

def dump(self):
def dump(self) -> None:
"""Dump environment for restore.
NOTE:
Not implemented.
"""
return

def reset(self, keep_seed: bool = False):
def reset(self, keep_seed: bool = False) -> None:
"""Reset environment.
Args:
Expand All @@ -114,10 +113,10 @@ def reset(self, keep_seed: bool = False):

self._event_buffer.reset()

if ("enable-dump-snapshot" in self._additional_options) and (self._business_engine._frame is not None):
if "enable-dump-snapshot" in self._additional_options and self._business_engine.frame is not None:
dump_folder = self._converter.get_new_snapshot_folder()

self._business_engine._frame.dump(dump_folder)
self._business_engine.frame.dump(dump_folder)
self._converter.start_processing(self.configs)
self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution)
self._business_engine.dump(dump_folder)
Expand Down Expand Up @@ -173,7 +172,7 @@ def agent_idx_list(self) -> List[int]:
"""List[int]: Agent index list that related to this environment."""
return self._business_engine.get_agent_idx_list()

def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
"""Set random seed used by simulator.
NOTE:
Expand All @@ -184,7 +183,7 @@ def set_seed(self, seed: int):
"""

if seed is not None:
sim_seed(seed)
random.seed(seed)

@property
def metrics(self) -> dict:
Expand All @@ -196,19 +195,19 @@ def metrics(self) -> dict:

return self._business_engine.get_metrics()

def get_finished_events(self):
def get_finished_events(self) -> List[ActualEvent]:
"""List[Event]: All events finished so far."""
return self._event_buffer.get_finished_events()

def get_pending_events(self, tick):
def get_pending_events(self, tick) -> List[ActualEvent]:
"""Pending events at certain tick.
Args:
tick (int): Specified tick to query.
"""
return self._event_buffer.get_pending_events(tick)

def _init_business_engine(self):
def _init_business_engine(self) -> None:
"""Initialize business engine object.
NOTE:
Expand Down Expand Up @@ -238,7 +237,7 @@ def _init_business_engine(self):
if business_class is None:
raise BusinessEngineNotFoundError()

self._business_engine = business_class(
self._business_engine: AbsBusinessEngine = business_class(
event_buffer=self._event_buffer,
topology=self._topology,
start_tick=self._start_tick,
Expand All @@ -248,10 +247,8 @@ def _init_business_engine(self):
additional_options=self._additional_options
)

def _simulate(self):
def _simulate(self) -> Generator[Tuple[dict, List[object], bool], object, None]:
"""This is the generator to wrap each episode process."""
is_end_tick = False

self._streamit_episode += 1

streamit.episode(self._streamit_episode)
Expand Down Expand Up @@ -297,8 +294,10 @@ def _simulate(self):

# NOTE: decision event always be a CascadeEvent
# We just append the action into sub event of first pending cascade event.
pending_events[0].state = EventState.EXECUTING
pending_events[0].add_immediate_event(action_event, is_head=True)
event = pending_events[0]
assert isinstance(event, CascadeEvent)
event.state = EventState.EXECUTING
event.add_immediate_event(action_event, is_head=True)
else:
# For joint mode, we will assign actions from beginning to end.
# Then mark others pending events to finished if not sequential action mode.
Expand All @@ -314,6 +313,7 @@ def _simulate(self):
pending_event.state = EventState.EXECUTING
action_event = self._event_buffer.gen_action_event(self._tick, action)

assert isinstance(pending_event, CascadeEvent)
pending_event.add_immediate_event(action_event, is_head=True)

# Check the end tick of the simulation to decide if we should end the simulation.
Expand Down
39 changes: 25 additions & 14 deletions maro/simulator/scenarios/abs_business_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional

from maro.backends.frame import FrameBase, SnapshotList
from maro.event_buffer import EventBuffer
Expand Down Expand Up @@ -31,23 +32,23 @@ class AbsBusinessEngine(ABC):
max_tick (int): Max tick of this business engine.
snapshot_resolution (int): Frequency to take a snapshot.
max_snapshots(int): Max number of in-memory snapshots, default is None that means max number of snapshots.
addition_options (dict): Additional options for this business engine from outside.
additional_options (dict): Additional options for this business engine from outside.
"""

def __init__(
self, scenario_name: str, event_buffer: EventBuffer, topology: str,
start_tick: int, max_tick: int, snapshot_resolution: int, max_snapshots: int,
additional_options: dict = None
):
self._scenario_name = scenario_name
self._topology = topology
self._event_buffer = event_buffer
self._start_tick = start_tick
self._max_tick = max_tick
self._snapshot_resolution = snapshot_resolution
self._max_snapshots = max_snapshots
self._additional_options = additional_options
self._config_path = None
self._scenario_name: str = scenario_name
self._topology: str = topology
self._event_buffer: EventBuffer = event_buffer
self._start_tick: int = start_tick
self._max_tick: int = max_tick
self._snapshot_resolution: int = snapshot_resolution
self._max_snapshots: int = max_snapshots
self._additional_options: dict = additional_options
self._config_path: Optional[str] = None

assert start_tick >= 0
assert max_tick > start_tick
Expand All @@ -65,6 +66,15 @@ def snapshots(self) -> SnapshotList:
"""SnapshotList: Snapshot list of current frame, this is used to expose querying interface for outside."""
pass

@property
def scenario_name(self) -> str:
return self._scenario_name

@abstractmethod
def get_agent_idx_list(self) -> List[int]:
"""Get a list of agent index."""
pass

def frame_index(self, tick: int) -> int:
"""Helper method for child class, used to get index of frame in snapshot list for specified tick.
Expand All @@ -89,7 +99,7 @@ def calc_max_snapshots(self) -> int:
return self._max_snapshots if self._max_snapshots is not None \
else total_frames(self._start_tick, self._max_tick, self._snapshot_resolution)

def update_config_root_path(self, business_engine_file_path: str):
def update_config_root_path(self, business_engine_file_path: str) -> None:
"""Helper method used to update the config path with business engine path if you
follow the way to load configuration file as built-in scenarios.
Expand Down Expand Up @@ -125,7 +135,7 @@ def __init__(self, *args, **kwargs):
self._config_path = os.path.join(be_file_path, "topologies", self._topology)

@abstractmethod
def step(self, tick: int):
def step(self, tick: int) -> None:
"""Method that is called at each tick, usually used to trigger business logic at current tick.
Args:
Expand All @@ -134,12 +144,13 @@ def step(self, tick: int):
pass

@property
@abstractmethod
def configs(self) -> dict:
"""dict: Configurations of this business engine."""
pass

@abstractmethod
def reset(self, keep_seed: bool = False):
def reset(self, keep_seed: bool = False) -> None:
"""Reset states business engine."""
pass

Expand Down Expand Up @@ -183,7 +194,7 @@ def get_metrics(self) -> dict:
"""
return {}

def dump(self, folder: str):
def dump(self, folder: str) -> None:
"""Dump something from business engine.
Args:
Expand Down

0 comments on commit eb0527f

Please sign in to comment.