-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Add APPO/IMPALA multi-agent StatelessCartPole learning tests to CI (+ fix some bugs related to this). #47245
[RLlib] Add APPO/IMPALA multi-agent StatelessCartPole learning tests to CI (+ fix some bugs related to this). #47245
Conversation
@@ -228,7 +228,7 @@ def __call__( | |||
# Also, let module-to-env pipeline know that we had added a single timestep | |||
# time rank to the data (to remove it again). | |||
if not self._as_learner_connector: | |||
for column, column_data in data.copy().items(): | |||
for column in data.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify
item_list, T=self.max_seq_len | ||
) | ||
# Multi-agent case AND RLModule is not stateful -> Do not zero-pad | ||
# for this model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bug fix: For multi-agent with some RLModules NOT stateful, we should NOT zero-pad anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually work already when using it on full length episodes coming from OfflineData
?
normalized_sa_obs = self._filters[sa_episode.agent_id]( | ||
sa_obs, update=self._update_stats | ||
) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the error better, that shows up when multi_agent=True
c'tor arg is forgotten.
…appo_multi_agent_stateless_cartpole_tests
lambda p, s: s if Columns.STATE_OUT in p else np.squeeze(s, axis=0), | ||
data, | ||
) | ||
def _remove_single_ts(item, eps_id, aid, mid): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bug fix: For mixed-MultiRLModules where some RLModules are NOT stateful, the old code would crash.
clear_on_reduce=True, | ||
) | ||
# Log all timesteps (env, agent, modules) based on given episodes/batch. | ||
self._log_steps_trained_metrics(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify
@@ -1581,49 +1566,26 @@ def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None: | |||
def _get_clip_function() -> Callable: | |||
"""Returns the gradient clipping function to use, given the framework.""" | |||
|
|||
def _log_steps_trained_metrics(self, episodes, batch, shared_data): | |||
def _log_steps_trained_metrics(self, batch: MultiAgentBatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. We complete the stack more and more :)
py_test( | ||
name = "learning_tests_stateless_cartpole_appo", | ||
main = "tuned_examples/appo/stateless_cartpole_appo.py", | ||
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is here missing a "gpu" while num_gpus=1
or do we want to test here simply a remote learner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, we test here the simple case of: 1 (remote) Learner on 1 CPU.
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], | ||
size = "large", | ||
srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], | ||
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], | ||
size = "large", | ||
srcs = ["tuned_examples/impala/stateless_cartpole_impala.py"], | ||
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here. I guess this brings us a num_learners=1
, doesn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually tries to put the 1 (remote) Learner on 1 GPU.
Sorry, you are right in that these command line options are very confusing:
On a CPU machine:
--num-gpus=1
-> 1 (remote) Learner (on CPU!)
--num-gpus=2
-> 2 (remote) Learners (on CPUs!)
On a GPU machine:
--num-gpus=1
-> 1 (remote) Learner (on GPU)
--num-gpus=2
-> 2 (remote) Learners (on GPUs)
We should probably rename these args.
item_list, T=self.max_seq_len | ||
) | ||
# Multi-agent case AND RLModule is not stateful -> Do not zero-pad | ||
# for this model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually work already when using it on full length episodes coming from OfflineData
?
@@ -101,10 +102,23 @@ def __call__( | |||
# batch: - - - - - - - T B0- - - - - R Bx- - - - R Bx | |||
# mask : t t t t t t t t f t t t t t t f t t t t t f | |||
|
|||
# TODO (sven): Same situation as in TODO below, but for multi-agent episode. | |||
# Maybe add a dedicated connector piece for this task? | |||
# We extend the MultiAgentEpisode's ID by a running number here to make sure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah tricky. This kind of trick needs to also go into the connector docs. This can solve problems, but we need to know how.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, it's getting to a point, where the default pipelines do become quite complex. We should spend some time soon to maybe simplify these again or to make the ConnectorV2 helper methods even better, e.g. self.foreach_batch_item_change_in_place
.
@@ -1294,24 +1295,8 @@ def _update_from_batch_or_episodes( | |||
if not self.should_module_be_updated(module_id, batch): | |||
del batch.policy_batches[module_id] | |||
|
|||
# Log all timesteps (env, agent, modules) based on given episodes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FInally, this goes away haha.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need to remvoe this also from learn_from_iterator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch. Will check ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"multi_agent_cartpole", | ||
lambda _: MultiAgentCartPole({"num_agents": args.num_agents}), | ||
) | ||
register_env("multi_agent_cartpole", lambda cfg: MultiAgentCartPole(config=cfg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For DQN and SAC we have not stateful modules enables, yet. What do we need for it? The buffers need to collect time sequences, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is the huge advantage of the "episodes-until-the-last-second" design :) Everything now behaves the same and we can simply pass in a list of episodes (from offline data) into any Learner and its Learner connector pipelines behave the exact same.
…appo_multi_agent_stateless_cartpole_tests
Add APPO/IMPALA multi-agent StatelessCartPole learning tests to CI (+ fix some bugs related to this).
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.