Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Remove 2nd Learner ConnectorV2 pass from PPO (add new GAE Connector piece). Fix: "State-connector" would use seq_len=20. #47401

Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Aug 29, 2024

This PR contains a bug fix:
The AddStatesFromEpisodeToBatch ConnectorV2 (only in the Learner pipeline) would always use max_seq_len=20 and ignore the user's configured max_seq_len in the model_config_dict.

Remove 2nd Learner ConnectorV2 pass from PPO (add new GAE Connector piece).

  • Simplifies the PPO learner update step by removing the need for 2 Learner connector passes.
  • A new GeneralAdvantageEstimation connectorV2 piece (added by PPOLearner to the end of the pipeline) now performs the vf forward pass (inside the connector!) and directly computes the advantages from these results and adds them back to the train batch (to be used by forward_train and compute_loss_for_module).
  • Simplify/cleanup PPOLearner and other, related utility functions.

A very quick performance comparison between master and this PR on CartPole-v1 (5 seeds) hints at a ~10% performance increase. But keep in mind that the particular setup (how many workers, how many learners, speed of env/sampling vs model complexity, etc..) matters gravely:

master
+-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------+
| Trial name                  | status     | loc             |   iter |   total time (s) |   num_env_steps_sample |   num_episodes_lifetim |   num_env_steps_traine |
|                             |            |                 |        |                  |             d_lifetime |                      e |             d_lifetime |
|-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------|
| PPO_CartPole-v1_f0a12_00000 | TERMINATED | 127.0.0.1:66641 |     50 |          34.7085 |                 200000 |                   1375 |                 200000 |
| PPO_CartPole-v1_f0a12_00001 | TERMINATED | 127.0.0.1:66642 |     41 |          28.7064 |                 164000 |                   1262 |                 164000 |
| PPO_CartPole-v1_f0a12_00002 | TERMINATED | 127.0.0.1:66643 |     33 |          23.1132 |                 132000 |                   1072 |                 132000 |
| PPO_CartPole-v1_f0a12_00003 | TERMINATED | 127.0.0.1:66644 |     48 |          34.1065 |                 192000 |                   1157 |                 192000 |
| PPO_CartPole-v1_f0a12_00004 | TERMINATED | 127.0.0.1:66757 |     44 |          31.3689 |                 176000 |                   1284 |                 176000 |
+-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------+
5760 env_ts/sec

This PR:
+-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------+
| Trial name                  | status     | loc             |   iter |   total time (s) |   num_env_steps_sample |   num_episodes_lifetim |   num_env_steps_traine |
|                             |            |                 |        |                  |             d_lifetime |                      e |             d_lifetime |
|-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------|
| PPO_CartPole-v1_8c285_00000 | TERMINATED | 127.0.0.1:68696 |     45 |          29.585  |                 180000 |                   1050 |                 181140 |
| PPO_CartPole-v1_8c285_00001 | TERMINATED | 127.0.0.1:68697 |     33 |          21.3128 |                 132000 |                   1061 |                 133127 |
| PPO_CartPole-v1_8c285_00002 | TERMINATED | 127.0.0.1:68698 |     33 |          21.527  |                 132000 |                   1002 |                 133068 |
| PPO_CartPole-v1_8c285_00003 | TERMINATED | 127.0.0.1:68699 |     50 |          32.4258 |                 200000 |                   1274 |                 201374 |
| PPO_CartPole-v1_8c285_00004 | TERMINATED | 127.0.0.1:68809 |     36 |          23.8069 |                 144000 |                   1130 |                 145202 |
+-----------------------------+------------+-----------------+--------+------------------+------------------------+------------------------+------------------------+
6253 env_ts/sec


Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
…remove_extra_learner_connector_pass

Signed-off-by: sven1977 <svenmika1977@gmail.com>

# Conflicts:
#	rllib/algorithms/algorithm.py
#	rllib/algorithms/algorithm_config.py
#	rllib/algorithms/ppo/ppo_learner.py
#	rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py
#	rllib/algorithms/ppo/torch/ppo_torch_learner.py
#	rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
#	rllib/connectors/common/batch_individual_items.py
#	rllib/connectors/common/numpy_to_tensor.py
#	rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py
#	rllib/core/learner/learner.py
#	rllib/core/learner/learner_group.py
#	rllib/core/learner/torch/torch_learner.py
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some small nits here and there.

as_learner_connector=True, max_seq_len=self.model.get("max_seq_len")
)
)
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHy don't we need the max_seq_len anymore ... now used internally by getting it from the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question. We might have different RLModules with different max_seq_len, so we now get it directly from the individual RLModule's config.model_config.dict inside the connector's __call__. Safer.

Defines the exponential weight used between actually measured rewards
vs value function estimates over multiple time steps. Specifically,
`lambda_` balances short-term, low-variance estimates with longer-term,
high-variance returns. A `lambda_` or 0.0 makes the GAE rely only on
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: "or" -> "of" ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if self.config.add_default_connectors_to_learner_pipeline:
self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate())
self._learner_connector.append(
GeneralAdvantageEstimation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! A connector now! We can move this also now into the connector pipeline of the OfflinePreLearner and remove the MARWILOfflinePreLearner - no need anymore for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will implement this as soon as this is merged :)

"your model config dict. You can set this dict and/or override "
"keys in it via `config.training(model={'max_seq_len': x})`."
)
max_seq_len = sa_module.config.model_config_dict["max_seq_len"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using here _get_max_seq_len, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know :D Let me check, whether this can be beautified ...

`column` -> [item, item, ...]
2) If `single_agent_episode`'s `agent_id` and `module_id` properties are None
(`single_agent_episode` is not part of a multi-agent episode), will append
`item_to_add` to a list under a `([episodeID],)` key under `column`:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use ´<>brackets, to not mislead the user that this is a list containingepisodeID`?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, lemme fix ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""Calls the given function with each (module_id, module).

Args:
func: The function to call with each (module_id, module) tuple.
return_dict: Whether to return a dict mapping ModuleID to the individual
module's return values of calling `func`. If False (default), return
a list.

Returns:
The lsit of return values of all calls to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: "lsit" -> "list" ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) fixed

]
sub_chunk = np.pad(sub_chunk, pad_width=padding_shape, mode="constant")
# Simple case: `item_list` contains individual floats.
check(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great example!

else:
current_t += t
item_list.appendleft(item[t:])
# `item` is a single item (no batch axis): Append and continue with next item.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it happen that item has no batch axis, but a time axis (for example with max_seq_len time axis length)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is confusing. So if item has a "batch" axis, then that axis (axis=0) is assumed to be the time axis (of a coherent trajectory). So item is either a single item (without batch/time axes) OR a sequence of items (with a time-axis, which we'll call "batch" axis here).

You are right, maybe we should rename this here in this function.
On the other hand, this would take the generality away from our batch utility function, which is used for both actual batch and time contexts.

Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 changed the title [RLlib] Remove 2nd Learner ConnectorV2 pass from PPO (add new GAE Connector piece). [RLlib] Remove 2nd Learner ConnectorV2 pass from PPO (add new GAE Connector piece). Fix: "State-connector" would use seq_len=20. Aug 29, 2024
@sven1977 sven1977 enabled auto-merge (squash) August 29, 2024 15:47
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Aug 29, 2024
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@github-actions github-actions bot disabled auto-merge August 29, 2024 18:39
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) August 30, 2024 16:58
@sven1977 sven1977 merged commit d1f21a5 into ray-project:master Aug 30, 2024
6 checks passed
@sven1977 sven1977 deleted the ppo_remove_extra_learner_connector_pass branch September 4, 2024 08:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants