Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【 HELP ❗️ 】Error about PyG2.5.0 : NameError: name 'OptPairTensor' is not defined #8968

Closed
StefanIsSmart opened this issue Feb 26, 2024 · 9 comments · Fixed by #8973
Closed
Labels

Comments

@StefanIsSmart
Copy link

StefanIsSmart commented Feb 26, 2024

🐛 Describe the bug

When I run the code, the Error raise as:

Traceback (most recent call last):
  File "/script/ComENet.py", line 824, in <listcomp>
    SimpleInteractionBlock(
  File "/script/ComENet.py", line 151, in __init__
    self.conv1 = EdgeGraphConv(hidden_channels, hidden_channels)
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/nn/conv/graph_conv.py", line 55, in __init__
    super().__init__(aggr=aggr, **kwargs)
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 177, in __init__
    signature=self._get_propagate_signature(),
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 936, in _get_propagate_signature
    param_dict = self.inspector.get_params_from_method_call(
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/inspector.py", line 382, in get_params_from_method_call
    type=self.eval_type(type_repr),
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/inspector.py", line 44, in eval_type
    return eval_type(value, self._globals)
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/site-packages/torch_geometric/inspector.py", line 422, in eval_type
    return typing._eval_type(value, _globals, None)  # type: ignore
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/typing.py", line 270, in _eval_type
    return t._evaluate(globalns, localns)
  File "/export/disk3/why/software/Miniforge3/envs/PyG250/lib/python3.8/typing.py", line 518, in _evaluate
    eval(self.__forward_code__, globalns, localns),
  File "<string>", line 1, in <module>
NameError: name 'OptPairTensor' is not defined

The coding is running well with PyG2.2.0 & Python=3.7 & Torch = 1.13 & cudatoolkit =11.3

The SimpleInteractionBlock is:

from torch_cluster import radius_graph
from torch_geometric.nn import GraphConv
from torch_geometric.nn import inits
from torch_geometric.nn.norm import GraphNorm
from torch_scatter import scatter, scatter_min

from torch.nn import Embedding

from torch import nn
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

import math
from math import sqrt

try:
    import sympy as sym
except ImportError:
    sym = None

class SimpleInteractionBlock(torch.nn.Module):
    def __init__(
            self,
            hidden_channels,
            middle_channels,
            num_radial,
            num_spherical,
            num_layers,
            output_channels,
            act=swish,
            norm=None
    ):
        super(SimpleInteractionBlock, self).__init__()
        self.act = act

        self.conv1 = EdgeGraphConv(hidden_channels, hidden_channels)

        self.conv2 = EdgeGraphConv(hidden_channels, hidden_channels)

        self.lin1 = Linear(hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.lin_cat = Linear(2 * hidden_channels, hidden_channels)

        self.norm = norm

        if self.norm == 'layer':
            self.norm_layer = nn.LayerNorm(hidden_channels)
        elif self.norm =='batch':
            self.norm_layer = nn.BatchNorm1d(hidden_channels)
        elif self.norm == 'graph':
            self.norm_layer = GraphNorm(hidden_channels)
        else:
            pass

        # Transformations of Bessel and spherical basis representations.
        self.lin_feature1 = TwoLayerLinear(num_radial * num_spherical ** 2, middle_channels, hidden_channels)
        self.lin_feature2 = TwoLayerLinear(num_radial * num_spherical, middle_channels, hidden_channels)

        # Dense transformations of input messages.
        self.lin = Linear(hidden_channels, hidden_channels)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.final = Linear(hidden_channels, output_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

        self.norm_layer.reset_parameters()

        self.lin_feature1.reset_parameters()
        self.lin_feature2.reset_parameters()

        self.lin.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

        self.lin_cat.reset_parameters()

        for lin in self.lins:
            lin.reset_parameters()

        self.final.reset_parameters()

    def forward(self, x, feature1, feature2, edge_index, batch):
        x = self.act(self.lin(x))

        feature1 = self.lin_feature1(feature1)
        h1 = self.conv1(x, edge_index, feature1)
        h1 = self.lin1(h1)
        h1 = self.act(h1)

        feature2 = self.lin_feature2(feature2)
        h2 = self.conv2(x, edge_index, feature2)
        h2 = self.lin2(h2)
        h2 = self.act(h2)

        h = self.lin_cat(torch.cat([h1, h2], 1))

        h = h + x
        for lin in self.lins:
            h = self.act(lin(h)) + h

        if 'graph' in self.norm:
            h = self.norm_layer(h, batch)
        else:
            h = self.norm_layer(h)

        h = self.final(h)
        return h



class EdgeGraphConv(GraphConv):
    def message(self, x_j, edge_weight) -> Tensor:
        return x_j if edge_weight is None else edge_weight * x_j

Versions

Collecting environment information...
PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.7 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: version 3.24.0
Libc version: glibc-2.23

Python version: 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28)  [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-4.15.0-142-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX
GPU 4: TITAN RTX
GPU 5: TITAN RTX
GPU 6: TITAN RTX
GPU 7: TITAN RTX

Nvidia driver version: 450.51.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                40
On-line CPU(s) list:   0-39
Thread(s) per core:    2
Core(s) per socket:    10
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 79
Model name:            Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz
Stepping:              1
CPU MHz:               2405.052
CPU max MHz:           3100.0000
CPU min MHz:           1200.0000
BogoMIPS:              4403.61
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              25600K
NUMA node0 CPU(s):     0-9,20-29
NUMA node1 CPU(s):     10-19,30-39
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.0.0
[pip3] torch-cluster==1.6.3
[pip3] torch_geometric==2.5.0
[pip3] torch-scatter==2.1.2
[pip3] torch-sparse==0.6.18
[pip3] triton==2.0.0
[conda] blas                      2.121                       mkl    conda-forge
[conda] blas-devel                3.9.0            21_linux64_mkl    conda-forge
[conda] cudatoolkit               11.7.0              hd8887f6_10    nvidia
[conda] libblas                   3.9.0            21_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            21_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            21_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            21_linux64_mkl    conda-forge
[conda] mkl                       2024.0.0         ha957f24_49657    conda-forge
[conda] mkl-devel                 2024.0.0         ha770c72_49657    conda-forge
[conda] mkl-include               2024.0.0         ha957f24_49657    conda-forge
[conda] numpy                     1.24.4           py38h59b608b_0    conda-forge
[conda] pyg                       2.5.0           py38_torch_2.0.0_cu117    pyg
[conda] pytorch                   2.0.0           py3.8_cuda11.7_cudnn8.5.0_0    pytorch
[conda] pytorch-cluster           1.6.3           py38_torch_2.0.0_cu117    pyg
[conda] pytorch-cuda              11.7                 h778d358_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] pytorch-scatter           2.1.2           py38_torch_2.0.0_cu117    pyg
[conda] pytorch-sparse            0.6.18          py38_torch_2.0.0_cu117    pyg
[conda] torchtriton               2.0.0                      py38    pytorch
@StefanIsSmart
Copy link
Author

I think maybe the typing.py from torch_geometric and the typing.py from python has some conflict?

@StefanIsSmart StefanIsSmart changed the title NameError: name 'OptPairTensor' is not defined Error about PyG2.5.0 : NameError: name 'OptPairTensor' is not defined Feb 26, 2024
@StefanIsSmart
Copy link
Author

I tried the python3.10 instead of 3.8, but the Error still exists.
Could you give me a hand?

@StefanIsSmart
Copy link
Author

Anyone could help me, please?
I don‘t know how to fix it.

@StefanIsSmart StefanIsSmart changed the title Error about PyG2.5.0 : NameError: name 'OptPairTensor' is not defined 【 HELP ❗️ 】Error about PyG2.5.0 : NameError: name 'OptPairTensor' is not defined Feb 26, 2024
@rusty1s
Copy link
Member

rusty1s commented Feb 26, 2024

You can temporally fix this by doing

from torch_geometric.typing import OptPairTensor, OptTensor

I will look into it.

@StefanIsSmart
Copy link
Author

StefanIsSmart commented Feb 26, 2024

You can temporally fix this by doing

from torch_geometric.typing import OptPairTensor, OptTensor

I will look into it.

Thank you for your reply very much!!!
I will try it.

@StefanIsSmart
Copy link
Author

You can temporally fix this by doing

from torch_geometric.typing import OptPairTensor, OptTensor

I will look into it.

Thank you for your reply very much!!! I will try it.

It works, I don't know why I need to add from torch_geometric.typing import OptPairTensor, OptTensor by hands.

But now another questions raise, just as #8959

@StefanIsSmart
Copy link
Author

You can temporally fix this by doing

from torch_geometric.typing import OptPairTensor, OptTensor

I will look into it.

Thank you for your reply very much!!! I will try it.

It works, I don't know why I need to add from torch_geometric.typing import OptPairTensor, OptTensor by hands.

But now another questions raise, just as #8959

This has been solved!
Maybe the old version generated processed data is Incompatible with the new version of PyG.
So I regenerate the processed data and the error is gone!

@StefanIsSmart
Copy link
Author

You can temporally fix this by doing

from torch_geometric.typing import OptPairTensor, OptTensor

I will look into it.

Now the only question is why I need to add from torch_geometric.typing import OptPairTensor, OptTensor by hands.

@rusty1s
Copy link
Member

rusty1s commented Feb 26, 2024

Because it's a bug in the code :) Will be fixed in #8973.

rusty1s added a commit that referenced this issue Feb 26, 2024
Fixes #8968

Accidentally merged
c22bb89.

Follow-up.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants