Skip to content

Commit

Permalink
minor fix for qm9_pretrained_schnet.py (#7228)
Browse files Browse the repository at this point in the history
```
root@9135585df2bb:/workspace# python3 examples/qm9_pretrained_schnet.py
Traceback (most recent call last):
  File "examples/qm9_pretrained_schnet.py", line 17, in <module>
    dataset = QM9(osp.join())
TypeError: join() missing 1 required positional argument: 'a'
```

after this minor fix and pip installing ase and schnet==1.0.0
i get:
```
Traceback (most recent call last):
  File "examples/qm9_pretrained_schnet.py", line 22, in <module>
    model, datasets = SchNet.from_qm9_pretrained(path, dataset, target)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/models/schnet.py", line 252, in from_qm9_pretrained
    net.readout = 'mean' if mean is True else 'add'
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1685, in __setattr__
    raise TypeError("cannot assign '{}' as child module '{}' "
TypeError: cannot assign 'str' as child module 'readout' (torch.nn.Module or None expected)
```
if i set net.readout = None then it works and I get to the part of the
schnet that relies on torch cluster (which i will eventually make a part
of pyg-lib along with a bunch of other torch-* functionalities that are
needed to be moved to pyg-lib)

```
  File "examples/qm9_pretrained_schnet.py", line 32, in <module>
    pred = model(data.z, data.pos, data.batch)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1533, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/models/schnet.py", line 284, in forward
    edge_index, edge_weight = self.interaction_graph(pos, batch)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1533, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/models/schnet.py", line 352, in forward
    edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/pool/__init__.py", line 210, in radius_graph
    return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,
AttributeError: 'NoneType' object has no attribute 'radius_graph'
```

note that w/ schnet >= 2.0:
```
Traceback (most recent call last):
  File "examples/qm9_pretrained_schnet.py", line 22, in <module>
    model, datasets = SchNet.from_qm9_pretrained(path, dataset, target)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/models/schnet.py", line 219, in from_qm9_pretrained
    state = torch.load(path, map_location='cpu')
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 817, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 1045, in _legacy_load
    result = unpickler.load()
  File "/usr/lib/python3.8/pickle.py", line 1212, in load
    dispatch[key[0]](self)
  File "/usr/lib/python3.8/pickle.py", line 1528, in load_global
    klass = self.find_class(module, name)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 850, in find_class
    return super().find_class(mod_name, name)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/_graveyard/legacy_import_unpickler.py", line 24, in find_class
    return super().find_class(new_module, name)
  File "/usr/lib/python3.8/pickle.py", line 1579, in find_class
    __import__(module, level=0)
ModuleNotFoundError: No module named 'schnetpack.atomistic.model'
```
  • Loading branch information
puririshi98 authored Apr 25, 2023
1 parent c9fef62 commit 0c4ea3a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/qm9_pretrained_schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
args = parser.parse_args()

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')
dataset = QM9(osp.join())
dataset = QM9(path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down

0 comments on commit 0c4ea3a

Please sign in to comment.