-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Stack
transform
#2567
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2567
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 6 New Failures, 18 Unrelated FailuresAs of commit f443812 with merge base 408cf7d (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Looks like there is a minor bug if I try to use this on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this, long awaited feature!
Just left a couple of comments on the default dim and test set
dc5cceb
to
23f7e1b
Compare
23f7e1b
to
f443812
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks this is superb
I'd like to discuss the inverse transform:
Would it make sense in the inverse to get an entry (from the input_spec
) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action")
, ("agent1", "action")
. The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?
forward = _call | ||
|
||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
values = torch.split(tensordict[self.in_keys_inv[0]], 1, dim=self.dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work too if the key isn't there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use unbind
instead, not to call squeeze
afterwards?
values = torch.split(tensordict[self.in_keys_inv[0]], 1, dim=self.dim) | ||
for value, out_key_inv in _zip_strict(values, self.out_keys_inv): | ||
tensordict.set(out_key_inv, value.squeeze(self.dim)) | ||
tensordict.exclude(self.in_keys_inv[0], inplace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we don't want to do that inplace? Is there a specific reason to use that?
Description
Adds a transform that stacks tensors and specs from different keys of a tensordict into a common key.
Motivation and Context
close #2566
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!