Skip to content

Commit

Permalink
[Type Hints] utils.train_test_split_edges (#5737)
Browse files Browse the repository at this point in the history
I have a question regarding jit script support for
utils.train_test_split_edges. I am getting an issue related to circular
imports. My full test is also failing on local (Removing this as
TorchScript does not support Data). I would appreciate it if anyone has
any suggestions to fix this issue.
```
Traceback (most recent call last):

File "/opt/hostedtoolcache/Python/3.8.14/x64/lib/python3.8/site-packages/sphinx/config.py", line 347, in eval_config_file

exec(code, namespace)

File "/home/runner/work/pytorch_geometric/pytorch_geometric/docs/source/conf.py", line 3, in <module>

import torch_geometric

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/__init__.py", line 4, in <module>

import torch_geometric.data

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/data/__init__.py", line 1, in <module>

from .data import Data

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/data/data.py", line 22, in <module>

from torch_geometric.data.feature_store import (

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/data/feature_store.py", line 33, in <module>

from torch_geometric.utils.mixin import CastMixin

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/utils/__init__.py", line 37, in <module>

from .train_test_split_edges import train_test_split_edges

File "/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/utils/train_test_split_edges.py", line 6, in <module>

from torch_geometric.data import Data

ImportError: cannot import name 'Data' from partially initialized module 'torch_geometric.data' (most likely due to a circular import) (/home/runner/work/pytorch_geometric/pytorch_geometric/torch_geometric/data/__init__.py)

```
```

RuntimeError:
E           Arguments for call are not valid.
E           The following variants are available:
E
E             aten::__contains__.int_list(int[] l, int item) -> (bool):
E             Expected a value of type 'List[int]' for argument 'l' but instead found type 'Tensor (inferred)'.
E             Inferred the value for argument 'l' to be of type 'Tensor' because it was not annotated with an explicit type.
E
E             aten::__contains__.str_list(str[] l, str item) -> (bool):
E             Expected a value of type 'List[str]' for argument 'l' but instead found type 'Tensor (inferred)'.
E             Inferred the value for argument 'l' to be of type 'Tensor' because it was not annotated with an explicit type.
E
E             aten::__contains__.str(Dict(str, t) dict, str key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[str, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.int(Dict(int, t) dict, int key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[int, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.bool(Dict(bool, t) dict, bool key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[bool, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.float(Dict(float, t) dict, float key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[float, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.complex(Dict(complex, t) dict, complex key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[complex, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.Tensor(Dict(Tensor, t) dict, Tensor key) -> (bool):
E             Could not match type Tensor (inferred) to Dict[Tensor, t] in argument 'dict': Cannot match a dict to Tensor (inferred).
E
E             aten::__contains__.float_list(float[] l, float item) -> (bool):
E             Expected a value of type 'List[float]' for argument 'l' but instead found type 'Tensor (inferred)'.
E             Inferred the value for argument 'l' to be of type 'Tensor' because it was not annotated with an explicit type.
E
E             __contains__(str self, str key) -> (bool):
E             Expected a value of type 'str' for argument 'self' but instead found type 'Tensor (inferred)'.
E             Inferred the value for argument 'self' to be of type 'Tensor' because it was not annotated with an explicit type.
E
E           The original call is:
E             File "/Users/shwetajacob/Documents/pyg/pytorch_geometric/torch_geometric/deprecation.py", line 38
E               """
E
E               assert 'batch' not in data  # No batch-mode.
E                      ~~~~~~~~~~~~~~~~~~~ <--- HERE
E
E               num_nodes = data.num_nodes

/opt/homebrew/Caskroom/miniforge/base/envs/pyg_dev/lib/python3.9/site-packages/torch/jit/_script.py:1343: RuntimeError
```

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
shhs29 and rusty1s authored Oct 17, 2022
1 parent ee09efe commit 18b0f0d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
8 changes: 6 additions & 2 deletions torch_geometric/utils/train_test_split_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import torch

import torch_geometric
from torch_geometric.deprecation import deprecated
from torch_geometric.utils import to_undirected


@deprecated("use 'transforms.RandomLinkSplit' instead")
def train_test_split_edges(data, val_ratio: float = 0.05,
test_ratio: float = 0.1):
def train_test_split_edges(
data: 'torch_geometric.data.Data',
val_ratio: float = 0.05,
test_ratio: float = 0.1,
) -> 'torch_geometric.data.Data':
r"""Splits the edges of a :class:`torch_geometric.data.Data` object
into positive and negative train/val/test edges.
As such, it will replace the :obj:`edge_index` attribute with
Expand Down

0 comments on commit 18b0f0d

Please sign in to comment.