Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] RLModule API: SelfSupervisedLossAPI for RLModules that bring their own loss (algo independent). #47581

Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Sep 10, 2024

RLModule API: SelfSupervisedLossAPI for RLModules that bring their own loss (algo independent).

  • Learner now checks whether any RLModule (in MultiRLModule) implements this API and if yes, calls the Module's own compute_self_supervised_loss method (instead of the Learner's compute_loss_for_module() method).
  • Updated curiosity RLModule to implement this API.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Copy link
Collaborator

@simonsays1980 simonsays1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.


@abc.abstractmethod
def _update_module_kl_coeff(
self,
*,
module_id: ModuleID,
config: PPOConfig,
kl_loss: float,
) -> None:
"""Dynamically update the KL loss coefficients of each module with.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"module with"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

learner_config_dict={
# Intrinsic reward coefficient.
"intrinsic_reward_coeff": 0.05,
# Forward loss weight (vs inverse dynamics loss). Total ICM loss is:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice comment!



class DQNTorchLearnerWithCuriosity(DQNRainbowTorchLearner):
def build(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: Can't we just override AlgorithmConfig.build_learner_pipeline()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but again, this should be done inside the Learner, imo.

But you have a good point: How can we make this even easier for the user? Maybe offer a better way to customize the Learner pipeline? Currently, users can only prepend connector pieces to the beginning, then RLlib adds the default pieces to the end. But here, we need a (custom) connector piece to move all the way to the end, which is not possible with the config.learner_connector property.

],
dim=0,
)
obs = tree.map_structure(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

*,
learner: "TorchLearner",
module_id: ModuleID,
config: "AlgorithmConfig",
batch: Dict[str, Any],
fwd_out: Dict[str, Any],
) -> Dict[str, Any]:
module = learner.module[module_id]
module = learner.module[module_id].unwrapped()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need this for DDP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct

@staticmethod
def compute_loss_for_module(
@override(SelfSupervisedLossAPI)
def compute_self_supervised_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What somehow irritates me is that we are putting the loss function into the module, but still build a special learner to handle this. Instead we could directly override the learners compute_loss_for_module, couldn't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my answer to your comment above.

…odule_api_self_supervised_loss

Signed-off-by: sven1977 <svenmika1977@gmail.com>

# Conflicts:
#	rllib/core/rl_module/apis/__init__.py
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
@sven1977 sven1977 enabled auto-merge (squash) September 11, 2024 14:02
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Sep 11, 2024
@sven1977 sven1977 enabled auto-merge (squash) September 11, 2024 15:43
@sven1977 sven1977 merged commit f422376 into ray-project:master Sep 11, 2024
7 of 8 checks passed
@sven1977 sven1977 deleted the rl_module_api_self_supervised_loss branch September 12, 2024 06:38
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
ujjawal-khare pushed a commit to ujjawal-khare-27/ray that referenced this pull request Oct 15, 2024
…g their own loss (algo independent). (ray-project#47581)

Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants