Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rl config doc #467

Merged
merged 8 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/rl/cim/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
def post_collect(info_list: list, ep: int, segment: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
print(info)
print(f"env summary (episode {ep}, segment {segment}): {info['env_metric']}")

# print the average env metric
Expand Down
11 changes: 10 additions & 1 deletion maro/rl/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ class TrainerParams:
automatically determined according to GPU availability.
replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory.
batch_size (int, default=128): Training batch size.
data_parallelism (int, default=1): Degree of data parallelism.
data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when
a model is large and computing gradients with respect to a batch becomes expensive. In this case, the
batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set
of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets
updated only after collecting all the gradients from the remote nodes. Note that this value is the desired
parallelism and the actual parallelism in a distributed experiment may be smaller depending on the
availability of compute resources. For details on distributed deep learning and data parallelism, see
https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance
of resources available on the internet.

"""
device: str = None
replay_memory_capacity: int = 10000
Expand Down
93 changes: 65 additions & 28 deletions maro/rl/workflows/config/template.yml
Original file line number Diff line number Diff line change
@@ -1,39 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# This is a configuration template for running reinforcement learning workflows with MARO's CLI tools. The workflows
# are scenario agnostic, meaning that this template can be applied to any scenario as long as the necessary components
# are provided (see examples/rl/README.md for details about these components). Your scenario should be placed in a
# folder and its path should be specified in the "scenario_path" field. Note that all fields with a "null" value are
# optional and will be converted to None by the parser unless a non-null value is specified. Note that commenting them
# out or leaving them blank are equivalent to using "null".


job: your_job_name
# Path to a directory that defines a business scenario and contains the necessary components to execute reinforcement
# learning workflows in single-threaded, multi-process and distributed modes.
scenario_path: "/path/to/your/scenario"
log_path: "/path/to/your/log/folder"
log_path: "/path/to/your/log/folder" # All logs are written to a single file for ease of viewing.
main:
num_episodes: 100
num_steps: -1
num_episodes: 100 # Number of episodes to run. Each episode is one cycle of roll-out and training.
# Number of environment steps to collect environment samples over. If null, samples are collected until the
# environments reach the terminal state, i.e., for a full episode. Otherwise, samples are collected until the
# specified number of steps or the terminal state is reached, whichever comes first.
num_steps: null
# This can be an integer or a list of integers. An integer indicates the interval at which policies are evaluated.
# A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are
# 1-based.
eval_schedule: 10
logging:
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG
logging: # log levels for the main loop
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
rollout:
# Optional section to specify roll-out parallelism settings. If absent, a single environment instance will be created
# locally for training and evaluation.
parallelism:
sampling: 10
eval: 1 # defaults to 1 if not provided
min_env_samples: 8 # ignored if parallelism.training == 1
grace_factor: 0.2 # ignored if parallelism.training == 1
controller: # ignored if parallelism.training == 1
host: "127.0.0.1" # ignored if run in containerized environments
port: 20000
logging:
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG
sampling: 10 # Number of parallel roll-outs to collecting training data from.
# Number of parallel roll-outs to evaluate policies on. If not specified, one roll-out worker is chosen to perform
# evaluation.
eval: null
# Minimum number of environment samples to collect from the parallel roll-outs per episode / segment before moving
# on to the training phase. The actual number of env samples collected may be more than this value if we allow a
# grace period (see the comment for rollout.parallelism.grace_factor for details), but never less. This value should
# not exceed rollout.parallelism.sampling.
min_env_samples: 8
# Factor that determines the additional wait time after the required number of environment samples as indicated by
# "min_env_samples" are received. For example, if T seconds elapsed after receiving "min_env_samples" environment
# samples, it will wait an additional T * grace_factor seconds to try to collect the remaining results.
grace_factor: 0.2
controller: # Parallel roll-out controller settings. Ignored if rollout.parallelism section is absent.
host: "127.0.0.1" # Controller's IP address. Ignored if run in containerized environments.
port: 20000 # Controller's network port for remote roll-out workers to connect to.
logging: # log levels for roll-out workers
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
training:
mode: simple # simple, parallel
load_path: "/path/to/your/models" # (If not None) path to load previously saved trainer snapshots.
# Must be "simple" or "parallel". In simple mode, all underlying models are trained locally. In parallel mode,
# all trainers send gradient-related tasks to a proxy service where they get dispatched to a set of workers.
mode: simple
# Path to load previously saved trainer snapshots from. A policy trainer's snapshot includes the states of all
# the policies it manages as well as the states of auxillary models (e.g., critics in the Actor-Critic paradigm).
# If the path corresponds to an existing directory, the program will look under the directory for snapshot files
# that match the trainer names specified in the scenario and attempt to load from them.
load_path: "/path/to/your/models"
# Optional section to specify model checkpointing settings.
checkpointing:
path: "/path/to/your/checkpoint/folder" # Path to save checkpoints and trainer snapshots.
interval: 10
proxy: # ignored under simple mode
host: "127.0.0.1" # ignored if run in containerized environments
frontend: 10000
backend: 10001
num_workers: 10
logging:
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG
# Directory to save trainer snapshots under. Snapshot files created at different episodes will be saved under
# separate folders named using episode numbers. For example, if a snapshot is created for a trainer named "dqn"
# at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt".
path: "/path/to/your/checkpoint/folder"
interval: 10 # Interval at which trained policies / models are persisted to disk.
proxy: # Proxy settings. Ignored if training.mode is "simple".
host: "127.0.0.1" # Proxy service host's IP address. Ignored if run in containerized environments.
frontend: 10000 # Proxy service's network port for trainers to send tasks to.
backend: 10001 # Proxy service's network port for remote workers to connect to.
num_workers: 10 # Number of workers to execute trainers' tasks.
logging: # log levels for training task workers
stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS
file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS