Skip to content

Commit

Permalink
[RLlib; Offline RL] Add docstrings to 'MARWIL'. (#47157)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Sep 26, 2024
1 parent 63233ec commit d8d9f6b
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 37 deletions.
1 change: 1 addition & 0 deletions doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ doctest(
"source/rllib/rllib-sample-collection.rst",
],
),
data = ["//rllib:cartpole-v1_large"],
tags = ["team:rllib"],
)

Expand Down
9 changes: 8 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
load("//bazel:python.bzl", "py_test_module_list")
load("//bazel:python.bzl", "doctest")

filegroup(
name = "cartpole-v1_large",
data = glob(["tests/data/cartpole/cartpole-v1_large/*.parquet"]),
visibility = ["//visibility:public"],
)

doctest(
files = glob(
["**/*.py"],
Expand Down Expand Up @@ -112,7 +118,8 @@ doctest(
]
),
tags = ["team:rllib"],
size = "enormous"
data = glob(["tests/data/cartpole/cartpole-v1_large/*.parquet"]),
size = "enormous",
)

# --------------------------------------------------------------------
Expand Down
129 changes: 93 additions & 36 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,92 @@
class MARWILConfig(AlgorithmConfig):
"""Defines a configuration class from which a MARWIL Algorithm can be built.
.. testcode::
Example:
>>> from ray.rllib.algorithms.marwil import MARWILConfig
>>> # Run this from the ray directory root.
>>> config = MARWILConfig() # doctest: +SKIP
>>> config = config.training(beta=1.0, lr=0.00001, gamma=0.99) # doctest: +SKIP
>>> config = config.offline_data( # doctest: +SKIP
... input_=["./rllib/tests/data/cartpole/large.json"])
>>> print(config.to_dict()) # doctest: +SKIP
...
>>> # Build an Algorithm object from the config and run 1 training iteration.
>>> algo = config.build() # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
Example:
>>> from ray.rllib.algorithms.marwil import MARWILConfig
>>> from ray import tune
>>> config = MARWILConfig()
>>> # Print out some default values.
>>> print(config.beta) # doctest: +SKIP
>>> # Update the config object.
>>> config.training(lr=tune.grid_search( # doctest: +SKIP
... [0.001, 0.0001]), beta=0.75)
>>> # Set the config object's data path.
>>> # Run this from the ray directory root.
>>> config.offline_data( # doctest: +SKIP
... input_=["./rllib/tests/data/cartpole/large.json"])
>>> # Set the config object's env, used for evaluation.
>>> config.environment(env="CartPole-v1") # doctest: +SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest: +SKIP
... "MARWIL",
... param_space=config.to_dict(),
... ).fit()
from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig
# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Define the environment for which to learn a policy
# from offline data.
config.environment("CartPole-v1")
# Set the training parameters.
config.training(
beta=1.0,
lr=1e-5,
gamma=0.99,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Define the data source for offline data.
config.offline_data(
input_=[data_path.as_posix()],
# Run exactly one update per training iteration.
dataset_num_iters_per_learner=1,
)
# Build an `Algorithm` object from the config and run 1 training
# iteration.
algo = config.build()
algo.train()
.. testcode::
from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig
from ray import train, tune
# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Print out some default values
print(f"beta: {config.beta}")
# Update the config object.
config.training(
lr=tune.grid_search([1e-3, 1e-4]),
beta=0.75,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Set the config's data path.
config.offline_data(
input_=[data_path.as_posix()],
# Set the number of updates to be run per learner
# per training step.
dataset_num_iters_per_learner=1,
)
# Set the config's environment for evalaution.
config.environment(env="CartPole-v1")
# Set up a tuner to run the experiment.
tuner = tune.Tuner(
"MARWIL",
param_space=config,
run_config=train.RunConfig(
stop={"training_iteration": 1},
),
)
# Run the experiment.
tuner.fit()
"""

def __init__(self, algo_class=None):
Expand Down Expand Up @@ -162,11 +213,12 @@ def training(
see bc.py algorithm in this same directory.
bc_logstd_coeff: A coefficient to encourage higher action distribution
entropy for exploration.
moving_average_sqd_adv_norm_update_rate: The rate for updating the
squared moving average advantage norm (c^2). A higher rate leads
to faster updates of this moving avergage.
moving_average_sqd_adv_norm_start: Starting value for the
squared moving average advantage norm (c^2).
vf_coeff: Balancing value estimation loss and policy optimization loss.
moving_average_sqd_adv_norm_update_rate: Update rate for the
squared moving average advantage norm (c^2).
grad_clip: If specified, clip the global norm of gradients by this amount.
Returns:
Expand Down Expand Up @@ -458,6 +510,11 @@ class (multi-/single-learner setup) and evaluation on
return self.metrics.reduce()

def _training_step_old_api_stack(self) -> ResultDict:
"""Implements training step for the old stack.
Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
use the new api stack.
"""
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
Expand Down
5 changes: 5 additions & 0 deletions rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@


class MARWILTorchLearner(MARWILLearner, TorchLearner):
"""Implements torch-specific MARWIL loss on top of MARWILLearner.
This class implements the MARWIL loss under `self.compute_loss_for_module()`.
"""

def compute_loss_for_module(
self,
*,
Expand Down
1 change: 1 addition & 0 deletions rllib/tuned_examples/bc/pendulum_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
.offline_data(
input_=[data_path],
input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 1)},
dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None,
)
.training(
# To increase learning speed with multiple learners,
Expand Down

0 comments on commit d8d9f6b

Please sign in to comment.