Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Stack transform #2567

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kurtamohler
Copy link
Collaborator

Description

Adds a transform that stacks tensors and specs from different keys of a tensordict into a common key.

Motivation and Context

close #2566

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2567

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 6 New Failures, 18 Unrelated Failures

As of commit f443812 with merge base 408cf7d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
@kurtamohler
Copy link
Collaborator Author

Looks like there is a minor bug if I try to use this on UnityMLAgentsEnv and then do a rollout. I'll fix that and add a test

@vmoens vmoens added the enhancement New feature or request label Nov 15, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, long awaited feature!
Just left a couple of comments on the default dim and test set

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
test/test_transforms.py Show resolved Hide resolved
@kurtamohler kurtamohler force-pushed the Stack-Transform-0 branch 2 times, most recently from dc5cceb to 23f7e1b Compare November 19, 2024 05:16
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this is superb
I'd like to discuss the inverse transform:
Would it make sense in the inverse to get an entry (from the input_spec) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action"), ("agent1", "action"). The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?

forward = _call

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
values = torch.split(tensordict[self.in_keys_inv[0]], 1, dim=self.dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work too if the key isn't there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use unbind instead, not to call squeeze afterwards?

values = torch.split(tensordict[self.in_keys_inv[0]], 1, dim=self.dim)
for value, out_key_inv in _zip_strict(values, self.out_keys_inv):
tensordict.set(out_key_inv, value.squeeze(self.dim))
tensordict.exclude(self.in_keys_inv[0], inplace=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't want to do that inplace? Is there a specific reason to use that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Transform that stacks data for agents with identical specs
3 participants