-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] - Enable buffering episodes. #47501
[RLlib; Offline RL] - Enable buffering episodes. #47501
Conversation
…'SampleBatch ' data and with stateful modules. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…thod into algorithms. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…into offline-rl-enable-buffering-episodes
…AlgorithmConfig'. Furthermore added tests for 'OfflinePreLearner' and moved tests from 'OfflineData' over. Added further tests top 'test_offline_data.py'. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
…sode data. Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -844,6 +846,8 @@ def validate(self) -> None: | |||
self._validate_input_settings() | |||
# Check evaluation specific settings. | |||
self._validate_evaluation_settings() | |||
# Check offline specific settings (new API stack). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!!
prelearner_module_synch_period: The period (number of batches converted) | ||
after which the `RLModule` held by the `PreLearner` should sync weights. | ||
The `PreLearner` is used to preprocess batches for the learners. The | ||
higher this value the more off-policy the `PreLearner`'s module will be. | ||
Values too small will force the `PreLearner` to sync more frequently | ||
and thus might slow down the data pipeline. The default value chosen | ||
by the `OfflinePreLearner` is 10. | ||
dataset_num_iters_per_learner: Number of iterations to run in each learner | ||
dataset_num_iters_per_learner: Number of updates to run in each learner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this up to you to decide: Would dataset_num_batches_per_learner
be more accurate? Or would it add more confusion?
Or: dataset_num_batches_per_learner_update
🤔 maybe too long ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I admit, dataset_num_iters_per_learner
is not straight to the point here. dataset_num_updates_per_learner
is not better in my opinion. Even though dataset_num_batches_per_learner
describes better that these many different! batches are pulled per learner, but does not point out that it is actually deeply related to an DataIterator
that iterates these many times. I am not sure, yet. As long as I am not sure, I will leave it as is ;)
rllib/algorithms/dqn/dqn.py
Outdated
@@ -719,6 +719,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: | |||
n_step=self.config.n_step, | |||
gamma=self.config.gamma, | |||
beta=self.config.replay_buffer_config.get("beta"), | |||
sample_episodes=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!!
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
@@ -0,0 +1,232 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, thanks for adding all these tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR @simonsays1980 !
Thanks for adding all these tests as well. Offline RL getting stronger by the day.
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Why are these changes needed?
Sampling exactly
train_batch_size_per_learner
when using offline data with old API stackSampleBatch
or new stackEpisode
records is not possible, yet, because each sample batch or episode could contain more than a single timestep. This PR proposes a way to enable sampling with exactly the requested batch size using replay buffers.The user can define the replay buffer class to use and its
kwargs
. TheOfflinePreLearner
keeps a replay buffer that buffers episodes and samples from this buffer. The replay buffer serves multiple functions:SampleBatch
orEpisode
data batches and ensures that the requested batch size is sampled.n_step
sampling, if needed.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.