Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature(PyG data): support nested tensor's dtype and device transformation #174

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions dptb/data/use_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,81 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda/envs/deeptb/lib/python3.8/site-packages/torch/nested/__init__.py:47: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:175.)\n",
" nt = torch._nested_tensor_from_tensor_list(new_data, dtype, None, device, pin_memory)\n"
]
}
],
"source": [
"from dptb.utils.torch_geometric import Data\n",
"import torch\n",
"\n",
"data = Data(x=torch.randn(10,3), edge_index=torch.randint(0, 10, (2,10)), fe=torch.nested.nested_tensor([torch.randn(10,3), torch.randn(10,3)]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"nested_tensor([\n",
" tensor([[-0.8980, 1.5517, 0.6173],\n",
" [-0.9263, 1.7326, 0.8377],\n",
" [-1.7981, -1.9792, 0.5199],\n",
" [ 0.2150, -1.4216, -0.1682],\n",
" [-0.2299, 0.2367, -0.6701],\n",
" [-0.5820, 0.8229, -0.4034],\n",
" [ 0.2771, 0.2464, 2.1399],\n",
" [-1.8328, 0.0133, -0.9239],\n",
" [-0.8021, -0.2262, -0.2930],\n",
" [ 1.7474, -1.1398, -1.2048]], device='cuda:0'),\n",
" tensor([[ 0.2287, -0.2875, -1.0089],\n",
" [ 1.4052, -0.2078, -0.4727],\n",
" [-0.8960, 1.9116, -0.2225],\n",
" [ 0.1758, 1.4902, 0.6408],\n",
" [-1.6969, -0.8203, -1.1533],\n",
" [-0.9147, 0.5500, 1.5237],\n",
" [-0.5706, -0.0517, -0.6109],\n",
" [-0.8387, -0.1820, -1.1708],\n",
" [-2.4404, 1.1044, -1.0515],\n",
" [ 0.6899, 0.8555, -0.9393]], device='cuda:0')\n",
"], device='cuda:0')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 8, 7, 9, 3, 8, 5, 3, 7, 0],\n",
" [2, 8, 5, 0, 2, 0, 9, 2, 0, 4]], device='cuda:0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from build import dataset_from_config\n",
"from dptb.utils.config import Config"
"data[\"edge_index\"]"
]
},
{
Expand Down
22 changes: 21 additions & 1 deletion dptb/utils/torch_geometric/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def keys(self):
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys

@property
def nested_keys(self):
keys = self.keys
keys = [key for key in keys if torch.is_tensor(self[key])]
keys = [key for key in keys if self[key].is_nested]

return keys

def __len__(self):
r"""Returns the number of all present attributes."""
Expand Down Expand Up @@ -286,8 +294,17 @@ def apply(self, func, *keys):
:obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to
all present attributes.
"""
nested_keys = self.nested_keys
if len(nested_keys) > 0:
for key, item in self(*nested_keys):
self[key] = self.__apply__(item, lambda x: list(x.unbind()))
for key, item in self(*keys):
QG-phy marked this conversation as resolved.
Show resolved Hide resolved
self[key] = self.__apply__(item, func)

if len(nested_keys) > 0:
for key, item in self(*nested_keys):
self[key] = torch.nested.as_nested_tensor(item)

return self

def contiguous(self, *keys):
Expand All @@ -301,7 +318,10 @@ def to(self, device, *keys, **kwargs):
:obj:`*keys`.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.to(device, **kwargs), *keys)

self.apply(lambda x: x.to(device, **kwargs), *keys)

return self

def cpu(self, *keys):
r"""Copies all attributes :obj:`*keys` to CPU memory.
Expand Down