Skip to content

Commit

Permalink
[Feature] Better detection of non-tensor data (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 26, 2024
1 parent 9601868 commit 551331d
Show file tree
Hide file tree
Showing 13 changed files with 513 additions and 139 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-rl-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
repository: pytorch/tensordict
gpu-arch-type: cuda
gpu-arch-version: ${{ matrix.cuda_arch_version }}
timeout: 120
script: |
# Set env vars from matrix
export PYTHON_VERSION=${{ matrix.python_version }}
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,4 @@ Here is an example:

tensorclass
NonTensorData
NonTensorStack
3 changes: 2 additions & 1 deletion tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import NonTensorData, tensorclass
from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass
from tensordict.utils import (
assert_allclose_td,
is_batchedtensor,
Expand All @@ -43,6 +43,7 @@
"TensorDict",
"TensorDictBase",
"merge_tensordicts",
"NonTensorStack",
"set_transfer_ownership",
"pad_sequence",
"is_memmap",
Expand Down
Loading

0 comments on commit 551331d

Please sign in to comment.