-
Is there a way to make more torch functions work for TensorDicts? For example, I'm missing torch.repeat_interleave(my_tensor_dict, ...) . How would I go about implementing this myself, as in, what file to look in for examples how other torch-compatible functionality is implemented? |
Beta Was this translation helpful? Give feedback.
Answered by
vmoens
Jan 5, 2024
Replies: 1 comment 2 replies
-
Thanks for the question,
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We don't have recipies (yet!) but that's a great idea.
Here's how it generally works:
tensordict.apply(func)
will only be called on the tensors, which does not fit the purpose. (If you want to understand why this approach is better, think of aTensorDictBase
subclass -- call itMyTensorDict
-- with a dedicatedtorch.repeat_interleave
. If we call a plainapply
, we won't be treating this class accurately if it's nested within a regularTensorDict
).TensorDict._fast_apply(func, call_on_nested=True)
. This will call yourfunc
on tensors and tensordicts, re…