Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent behaviour and "AttributeError: _old_init" when using Pytorch Lightning with the ogb library. #14050

Closed
schlyah opened this issue Aug 5, 2022 · 15 comments · Fixed by #14117
Assignees
Labels
bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on
Milestone

Comments

@schlyah
Copy link

schlyah commented Aug 5, 2022

🐛 Bug

The simple example from the documentation page https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/mnist-hello-world.html works fine. However, when I add the import instruction from ogb.nodeproppred import *, the code has inconsistent behaviour: sometimes it works and sometimes it throws an "AttributeError: _old_init" exception:

 Traceback (most recent call last):
  File "plexample.py", line 54, in <module>
    trainer.fit(mnist_model, train_loader)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 700, in fit
    self._call_and_handle_interrupt(
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 654, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in _run_train
    self.fit_loop.run()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 195, in run
    self.on_run_start(*args, **kwargs)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 210, in on_run_start
    self.trainer.reset_train_dataloader(self.trainer.lightning_module)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1811, in reset_train_dataloader
    self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 430, in _request_dataloader
    dataloader = source.dataloader()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/contextlib.py", line 126, in __exit__
    next(self.gen)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py", line 527, in _replace_init_method
    del cls._old_init
AttributeError: _old_init

To Reproduce

The documentation example code with the import instruction:

import os
import pandas as pd
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import argparse
from ogb.nodeproppred import *

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


if __name__ == '__main__':
    # Init our model
    mnist_model = MNISTModel()

    # Init DataLoader from MNIST Dataset
    train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

    # Initialize a trainer
    trainer = Trainer(
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        max_epochs=3,
        callbacks=[TQDMProgressBar(refresh_rate=20)],
    )

    # Train the model ⚡
    trainer.fit(mnist_model, train_loader)

Environment

  • PyTorch Lightning Version: 1.7.0
  • PyTorch Version: 1.12.0
  • Python version: 3.9
  • OS: Linux
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration:
  • How you installed PyTorch: pip

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

@schlyah schlyah added the needs triage Waiting to be triaged by maintainers label Aug 5, 2022
@carmocca
Copy link
Contributor

carmocca commented Aug 5, 2022

This is super weird for 2 reasons:

I'm not familiar with this ogb library. Locally it's not giving me this error but the script hangs if I add the import.

And it works fine if I add it to our bug report model. Can you try to reproduce it using it?

@carmocca carmocca added bug Something isn't working help wanted Open to be worked on data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Aug 5, 2022
@carmocca carmocca added this to the pl:1.7.x milestone Aug 5, 2022
@krshrimali
Copy link
Contributor

@carmocca - If it helps, I've observed this recently in the Flash CI as well: https://github.com/Lightning-AI/lightning-flash/runs/7668034174?check_suite_focus=true. Didn't get a chance to take a closer look because of bad health, but wanted to dig deeper on Monday. I earlier thought it might be something with Flash, but will have to check with previous PL versions once.

@schlyah
Copy link
Author

schlyah commented Aug 5, 2022

I've run the bug report model with the import 5 times. I got the exception twice and it worked fine the other 3 times.

@carmocca
Copy link
Contributor

carmocca commented Aug 6, 2022

Thanks @krshrimali. The dependencies for that job do not include ogb so I guess it's partially an issue of ours.

@otaj
Copy link
Contributor

otaj commented Aug 8, 2022

Hi @krshrimali, @carmocca, @schlyah, if it's also affecting Flash CI, then it's almost definitely on us. @schlyah, would you mind sharing:

  1. ogb version
  2. Reproducible code using the bug report template? Reproducible meaning sometimes failing in this case 😂

@krshrimali, I will find all the required info in flash repo, I expect, is that correct?

@krshrimali
Copy link
Contributor

Hi, @otaj - Thank you for your message! Yes, you'll definitely find everything in the Flash repo - and if not, I'll be one message away to help you out. I tried reproducing it locally with Flash, but couldn't - this bug is really flaky. :/

@otaj
Copy link
Contributor

otaj commented Aug 8, 2022

Hi @schlyah, I spent a while on this and am unable to reproduce. Would you mind sharing your full environment (i.e. pip freeze) as well? Thanks a lot!

@PurpleSand123
Copy link

PurpleSand123 commented Aug 9, 2022

I am having the same problem when using torch_geometric.loader.DataLoader
For debugging, I changed "_replace_init_method" code like this

@contextmanager
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
    """This context manager is used to add support for re-instantiation of custom (subclasses) of "base_cls".
    It patches the "__init__" method.
    """
    classes = _get_all_subclasses(base_cls) | {base_cls}
    wrapped = set()
    print('before:', classes)
    for cls in classes:
        if cls.__init__ not in wrapped:
            print(cls.__name__)
            cls._old_init = cls.__init__
            cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
            wrapped.add(cls.__init__)
    yield
    print('after:', classes)
    for cls in classes:
        if hasattr(cls, "_old_init"):
            print("del", cls.__name__)
            cls.__init__ = cls._old_init
            del cls._old_init

It works well sometimes and get output as

before: {<class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>}
RandomNodeSampler
NeighborSampler
GraphSAINTNodeSampler
NeighborLoader
DenseDataLoader
GraphSAINTEdgeSampler
HGTLoader
DataLoader
ClusterLoader
BaseDataLoader
TemporalDataLoader
DataLoader
DataListLoader
GraphSAINTRandomWalkSampler
ShaDowKHopSampler
GraphSAINTSampler
before: {<class 'torch.utils.data.sampler.BatchSampler'>}
BatchSampler
after: {<class 'torch.utils.data.sampler.BatchSampler'>}
del BatchSampler
after: {<class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>}
del RandomNodeSampler
del NeighborSampler
del GraphSAINTNodeSampler
del NeighborLoader
del DenseDataLoader
del GraphSAINTEdgeSampler
del HGTLoader
del DataLoader
del ClusterLoader
del BaseDataLoader
del TemporalDataLoader
del DataLoader
del DataListLoader
del GraphSAINTRandomWalkSampler
del ShaDowKHopSampler
del GraphSAINTSampler

However, sometimes, the error occurs and the output is

before: {<class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>}
GraphSAINTNodeSampler
GraphSAINTSampler
TemporalDataLoader
ShaDowKHopSampler
NeighborSampler
NeighborLoader
ClusterLoader
RandomNodeSampler
DataListLoader
HGTLoader
DataLoader
BaseDataLoader
DataLoader
GraphSAINTRandomWalkSampler
DenseDataLoader
before: {<class 'torch.utils.data.sampler.BatchSampler'>}
BatchSampler
after: {<class 'torch.utils.data.sampler.BatchSampler'>}
del BatchSampler
after: {<class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>}
del GraphSAINTNodeSampler
del GraphSAINTSampler
del GraphSAINTEdgeSampler
Traceback (most recent call last):`

`File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 868, in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 654, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 915, in _test_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1249, in _run_stage
    return self._run_evaluate()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1288, in _run_evaluate
    self._evaluation_loop._reload_evaluation_dataloaders()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 234, in _reload_evaluation_dataloaders
    self.trainer.reset_test_dataloader()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1941, in reset_test_dataloader
    self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 344, in _reset_eval_dataloader
    dataloaders = self._request_dataloader(mode)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 427, in _request_dataloader
    with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
  File "/anaconda/envs/pytorch/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py", line 531, in _replace_init_method
    del cls._old_init
AttributeError: _old_init

@PurpleSand123
Copy link

PurpleSand123 commented Aug 9, 2022

I wrote the simple code to reproduce it.
This is the main code "main.py"

import torch
import numpy as np
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch.nn import functional as F
from torch_geometric.loader import DataLoader
from dataloader_1 import Dataloader_1

class my_model(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28, 1)

    def forward(self, x):
        return self.l1(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.mse_loss(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

if __name__ == '__main__':
    # Init our model
    mnist_model = my_model()

    # Init DataLoader from MNIST Dataset
    train_ds = [(torch.randint(100, size=[28],dtype=torch.float32), torch.randint(100, size=[1], dtype=torch.float32)) for i in range(1024)]
    train_loader = Dataloader_1(train_ds, batch_size=512)

    # Initialize a trainer
    trainer = Trainer(
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        max_epochs=3
    )

    # Train the model ⚡
    trainer.fit(mnist_model, train_loader)

and this is "dataloader_1.py" code

from torch_geometric.loader import DataLoader

class Dataloader_1(DataLoader):
    def __int__(self):
        super().__init__()
        self.name = "dataloader_1"

The error occurs almost 50% probability.

@schlyah
Copy link
Author

schlyah commented Aug 9, 2022

Hi @schlyah, I spent a while on this and am unable to reproduce. Would you mind sharing your full environment (i.e. pip freeze) as well? Thanks a lot!

This is the result of pip freeze:
absl-py==1.2.0
aiohttp==3.8.1
aiosignal==1.2.0
alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work
anaconda-client @ file:///tmp/build/80754af9/anaconda-client_1635342557008/work
anaconda-navigator==2.1.1
anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1626085644852/work
anyio @ file:///tmp/build/80754af9/anyio_1617783277988/work/dist
appdirs==1.4.4
argh==0.26.2
argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037499734/work
arrow @ file:///tmp/build/80754af9/arrow_1617737686940/work
asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
astroid @ file:///tmp/build/80754af9/astroid_1628063140030/work
astropy @ file:///tmp/build/80754af9/astropy_1629820331715/work
async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
async-timeout==4.0.2
atomicwrites==1.4.0
attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
autopep8 @ file:///tmp/build/80754af9/autopep8_1620866417880/work
Babel @ file:///tmp/build/80754af9/babel_1620871417480/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
backports.functools-lru-cache @ file:///tmp/build/80754af9/backports.functools_lru_cache_1618170165463/work
backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work
backports.tempfile @ file:///home/linux1/recipes/ci/backports.tempfile_1610991236607/work
backports.weakref==1.0.post1
beautifulsoup4 @ file:///tmp/build/80754af9/beautifulsoup4_1631874778482/work
binaryornot @ file:///tmp/build/80754af9/binaryornot_1617751525010/work
bitarray @ file:///tmp/build/80754af9/bitarray_1629132848682/work
bkcharts==0.2
black==19.10b0
bleach @ file:///tmp/build/80754af9/bleach_1628110601003/work
bokeh @ file:///tmp/build/80754af9/bokeh_1635312808503/work
boto==2.49.0
Bottleneck @ file:///tmp/build/80754af9/bottleneck_1607575130224/work
brotlipy==0.7.0
cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
cachetools==5.2.0
certifi==2021.10.8
cffi @ file:///tmp/build/80754af9/cffi_1625814692085/work
chardet @ file:///tmp/build/80754af9/chardet_1607706775000/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click==8.0.3
cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1632508026186/work
clyent==1.2.2
colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
conda==4.10.3
conda-build==3.21.5
conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work
conda-pack @ file:///tmp/build/80754af9/conda-pack_1611163042455/work
conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1618262147379/work
conda-repo-cli @ file:///tmp/build/80754af9/conda-repo-cli_1620168426516/work
conda-token @ file:///tmp/build/80754af9/conda-token_1620076980546/work
conda-verify==3.4.2
contextlib2 @ file:///Users/ktietz/demo/mc3/conda-bld/contextlib2_1630668244042/work
cookiecutter @ file:///tmp/build/80754af9/cookiecutter_1617748928239/work
cryptography @ file:///tmp/build/80754af9/cryptography_1633520369886/work
cycler==0.10.0
Cython @ file:///tmp/build/80754af9/cython_1636035166688/work
cytoolz==0.11.0
daal4py==2021.3.0
dask==2021.10.0
debugpy @ file:///tmp/build/80754af9/debugpy_1629214122703/work
decorator @ file:///tmp/build/80754af9/decorator_1632776554403/work
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
diff-match-patch @ file:///Users/ktietz/demo/mc3/conda-bld/diff-match-patch_1630511840874/work
distributed @ file:///tmp/build/80754af9/distributed_1635956221111/work
docker-pycreds==0.4.0
docutils @ file:///tmp/build/80754af9/docutils_1620827980776/work
entrypoints==0.3
et-xmlfile==1.1.0
fastcache @ file:///tmp/build/80754af9/fastcache_1607571273717/work
filelock @ file:///tmp/build/80754af9/filelock_1635402558181/work
flake8 @ file:///tmp/build/80754af9/flake8_1620776156532/work
Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work
fonttools==4.25.0
frozenlist==1.3.1
fsspec @ file:///tmp/build/80754af9/fsspec_1632413898837/work
future @ file:///tmp/build/80754af9/future_1607571303524/work
gevent @ file:///tmp/build/80754af9/gevent_1628273270105/work
gitdb==4.0.9
GitPython==3.1.27
glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
gmpy2==2.0.8
google-auth==2.9.1
google-auth-oauthlib==0.4.6
greenlet @ file:///tmp/build/80754af9/greenlet_1628888132713/work
grpcio==1.47.0
h5py @ file:///tmp/build/80754af9/h5py_1636040554610/work
HeapDict @ file:///Users/ktietz/demo/mc3/conda-bld/heapdict_1630598515714/work
html5lib @ file:///Users/ktietz/demo/mc3/conda-bld/html5lib_1629144453894/work
idna @ file:///tmp/build/80754af9/idna_1622654382723/work
imagecodecs @ file:///tmp/build/80754af9/imagecodecs_1635529108216/work
imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work
imagesize @ file:///Users/ktietz/demo/mc3/conda-bld/imagesize_1628863108022/work
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1631916692253/work
inflection==0.5.1
iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work
intervaltree @ file:///Users/ktietz/demo/mc3/conda-bld/intervaltree_1630511889664/work
ipykernel @ file:///tmp/build/80754af9/ipykernel_1633534655931/work/dist/ipykernel-6.4.1-py3-none-any.whl
ipython @ file:///tmp/build/80754af9/ipython_1635944169458/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1634143127070/work
isort @ file:///tmp/build/80754af9/isort_1628603791788/work
itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work
jdcal @ file:///Users/ktietz/demo/mc3/conda-bld/jdcal_1630584345063/work
jedi @ file:///tmp/build/80754af9/jedi_1611333729159/work
jeepney @ file:///tmp/build/80754af9/jeepney_1627537048313/work
Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
jinja2-time @ file:///tmp/build/80754af9/jinja2-time_1617751524098/work
joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
json5 @ file:///tmp/build/80754af9/json5_1624432770122/work
jsonschema @ file:///Users/ktietz/demo/mc3/conda-bld/jsonschema_1630511932244/work
jupyter @ file:///tmp/build/80754af9/jupyter_1607700846274/work
jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1633420104186/work
jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616084066671/work
jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1635799997693/work
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1633419203660/work
jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
keyring @ file:///tmp/build/80754af9/keyring_1629321552962/work
kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612281846643/work
lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616529027849/work
libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
littleutils==0.2.2
llvmlite==0.37.0
locket==0.2.1
lxml @ file:///tmp/build/80754af9/lxml_1616443220035/work
Markdown==3.4.1
MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1607027305082/work
matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1634667019719/work
matplotlib-inline @ file:///tmp/build/80754af9/matplotlib-inline_1628242447089/work
mccabe==0.6.1
mistune @ file:///tmp/build/80754af9/mistune_1607364877025/work
mkl-fft==1.3.1
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work
mkl-service==2.4.0
mock @ file:///tmp/build/80754af9/mock_1607622725907/work
more-itertools @ file:///tmp/build/80754af9/more-itertools_1635423142362/work
mpmath==1.2.1
msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287166301/work
multidict==6.0.2
multipledispatch @ file:///tmp/build/80754af9/multipledispatch_1607574243360/work
munkres==1.1.4
mypy-extensions==0.4.3
navigator-updater==0.2.1
nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work
nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
nbconvert @ file:///tmp/build/80754af9/nbconvert_1624472883256/work
nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
networkx @ file:///tmp/build/80754af9/networkx_1633639043937/work
nltk==3.6.5
nose @ file:///tmp/build/80754af9/nose_1606773131901/work
notebook @ file:///tmp/build/80754af9/notebook_1635411643303/work
numba @ file:///tmp/build/80754af9/numba_1635174338764/work
numexpr @ file:///tmp/build/80754af9/numexpr_1618856529730/work
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1626264704411/work
numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
oauthlib==3.2.0
ogb==1.3.3
olefile @ file:///Users/ktietz/demo/mc3/conda-bld/olefile_1629805411829/work
openpyxl @ file:///tmp/build/80754af9/openpyxl_1632777717936/work
outdated==0.2.1
packaging @ file:///tmp/build/80754af9/packaging_1625611678980/work
pandas==1.3.4
pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120906940/work
parso @ file:///tmp/build/80754af9/parso_1617223946239/work
partd @ file:///tmp/build/80754af9/partd_1618000087440/work
path @ file:///tmp/build/80754af9/path_1623603407737/work
pathlib2 @ file:///tmp/build/80754af9/pathlib2_1625585682511/work
pathspec==0.7.0
pathtools==0.1.2
patsy==0.5.2
pep8==1.7.1
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow==8.4.0
pkginfo==1.7.1
pluggy @ file:///tmp/build/80754af9/pluggy_1615976318790/work
ply==3.11
poyo @ file:///tmp/build/80754af9/poyo_1617751526755/work
prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
promise==2.3
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1633440160888/work
protobuf==3.19.4
psutil @ file:///tmp/build/80754af9/psutil_1612297992929/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
py @ file:///tmp/build/80754af9/py_1607971587848/work
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle @ file:///tmp/build/80754af9/pycodestyle_1615748559966/work
pycosat==0.6.3
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
pycurl==7.44.1
pyDeprecate==0.3.2
pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1621600989141/work
pyerfa @ file:///tmp/build/80754af9/pyerfa_1621556109336/work
pyflakes @ file:///tmp/build/80754af9/pyflakes_1617200973297/work
Pygments @ file:///tmp/build/80754af9/pygments_1629234116488/work
PyJWT @ file:///tmp/build/80754af9/pyjwt_1619682484438/work
pylint @ file:///tmp/build/80754af9/pylint_1627536788603/work
pyls-spyder==0.4.0
pyodbc===4.0.0-unsupported
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1635333100036/work
pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110951836/work
PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work
pytest==6.2.4
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
python-lsp-black @ file:///tmp/build/80754af9/python-lsp-black_1634232156041/work
python-lsp-jsonrpc==1.0.0
python-lsp-server==1.2.4
python-slugify @ file:///tmp/build/80754af9/python-slugify_1620405669636/work
pytorch-lightning==1.7.0
pytz==2021.3
PyWavelets @ file:///tmp/build/80754af9/pywavelets_1607645421828/work
pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
PyYAML==6.0
pyzmq @ file:///tmp/build/80754af9/pyzmq_1628275385016/work
QDarkStyle @ file:///tmp/build/80754af9/qdarkstyle_1617386714626/work
qstylizer @ file:///tmp/build/80754af9/qstylizer_1617713584600/work/dist/qstylizer-0.1.10-py2.py3-none-any.whl
QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work
qtconsole @ file:///tmp/build/80754af9/qtconsole_1632739723211/work
QtPy @ file:///tmp/build/80754af9/qtpy_1629397026935/work
regex @ file:///tmp/build/80754af9/regex_1628063347754/work
requests @ file:///tmp/build/80754af9/requests_1629994808627/work
requests-oauthlib==1.3.1
rope @ file:///tmp/build/80754af9/rope_1623703006312/work
rsa==4.9
Rtree @ file:///tmp/build/80754af9/rtree_1618420843093/work
ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016711199/work
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.6
scikit-image==0.18.3
scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1621370406642/work
scikit-learn-intelex==2021.20210714.170444
scipy @ file:///tmp/build/80754af9/scipy_1630606796912/work
seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work
SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022780358/work
Send2Trash @ file:///tmp/build/80754af9/send2trash_1632406701022/work
sentry-sdk==1.7.2
setproctitle==1.2.3
shortuuid==1.0.9
simplegeneric==0.8.1
singledispatch @ file:///tmp/build/80754af9/singledispatch_1629321204894/work
sip==4.19.13
six @ file:///tmp/build/80754af9/six_1623709665295/work
smmap==5.0.0
sniffio @ file:///tmp/build/80754af9/sniffio_1614030464178/work
snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work
sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work
sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1623949099177/work
soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work
Sphinx==4.2.0
sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work
sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work
sphinxcontrib-htmlhelp @ file:///tmp/build/80754af9/sphinxcontrib-htmlhelp_1623945626792/work
sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work
sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work
sphinxcontrib-serializinghtml @ file:///tmp/build/80754af9/sphinxcontrib-serializinghtml_1624451540180/work
sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
spyder @ file:///tmp/build/80754af9/spyder_1636479868270/work
spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1634236920897/work
SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1626947882125/work
statsmodels @ file:///tmp/build/80754af9/statsmodels_1614022848006/work
sympy @ file:///tmp/build/80754af9/sympy_1635237064765/work
tables @ file:///tmp/build/80754af9/pytables_1607975397488/work
TBB==0.2
tblib @ file:///Users/ktietz/demo/mc3/conda-bld/tblib_1629402031467/work
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
terminado==0.9.4
testpath @ file:///tmp/build/80754af9/testpath_1624638946665/work
text-unidecode @ file:///Users/ktietz/demo/mc3/conda-bld/text-unidecode_1629401354553/work
textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work
threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work
tifffile @ file:///tmp/build/80754af9/tifffile_1627275862826/work
tinycss @ file:///tmp/build/80754af9/tinycss_1617713798712/work
toml @ file:///tmp/build/80754af9/toml_1616166611790/work
toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work
torch==1.12.0
torch-geometric==2.0.4
torch-scatter==2.0.9
torch-sparse==0.6.14
torchmetrics==0.9.3
tornado @ file:///tmp/build/80754af9/tornado_1606942317143/work
tqdm @ file:///tmp/build/80754af9/tqdm_1635330843403/work
traitlets @ file:///tmp/build/80754af9/traitlets_1632522747050/work
typed-ast @ file:///tmp/build/80754af9/typed-ast_1624953673314/work
typing_extensions==4.3.0
ujson @ file:///tmp/build/80754af9/ujson_1611259520328/work
unicodecsv==0.14.1
Unidecode @ file:///tmp/build/80754af9/unidecode_1614712377438/work
urllib3==1.26.7
wandb==0.12.21
watchdog @ file:///tmp/build/80754af9/watchdog_1624954998138/work
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
webencodings==0.5.1
Werkzeug @ file:///tmp/build/80754af9/werkzeug_1635505089296/work
whichcraft @ file:///tmp/build/80754af9/whichcraft_1617751293875/work
widgetsnbextension @ file:///tmp/build/80754af9/widgetsnbextension_1607531506226/work
wrapt @ file:///tmp/build/80754af9/wrapt_1607574498026/work
wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1626947795674/work
xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work
XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1628603415431/work
xlwt==1.3.0
xmltodict @ file:///Users/ktietz/demo/mc3/conda-bld/xmltodict_1629301980723/work
yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
yarl==1.8.1
zict==2.0.0
zipp @ file:///tmp/build/80754af9/zipp_1633618647012/work
zope.event==4.5.0
zope.interface @ file:///tmp/build/80754af9/zope.interface_1625036153595/work

and the code:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from ogb.nodeproppred import *

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

@otaj
Copy link
Contributor

otaj commented Aug 9, 2022

@schlyah @PurpleSand123, thank you very much for the report, I was able to figure out what is the issue. I will send a PR with the fix shortly.

@otaj
Copy link
Contributor

otaj commented Aug 9, 2022

Basically the issue is due to improper handling of inheritance and non-deterministic order of removing our wrapper. This usually isn't a big problem because there's hardly ever situation with large inheritance chain, where not many classes override their __init__ method, but, here we are 😂

@otaj
Copy link
Contributor

otaj commented Aug 9, 2022

Hi, everyone. I used @PurpleSand123 example with one extra test script

import subprocess
import sys

import numpy as np


def main():
    i = 0
    while i < 100:
        i += 1
        try:
            _ = subprocess.run([sys.executable, "main.py"], check=True, capture_output=True)
        except subprocess.CalledProcessError:
            print(f"Broke on {i}th try")
            break
    if i == 100:
        print("Didn't break")
    return i

if __name__ == "__main__":
    results = np.array([main() for _ in range(10)])
    print(results.mean())

When running this with PL 1.7.0, the output on my machine was:

Broke on 5th try
Broke on 4th try
Broke on 6th try
Broke on 3th try
Broke on 18th try
Broke on 6th try
Broke on 1th try
Broke on 2th try
Broke on 2th try
Broke on 2th try
4.9

When running with the change in the linked PR, the output on my machine was:

Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
100

Please, try your examples with the changes in the linked PR and check if it fixes it for you.

@PurpleSand123
Copy link

Now, the problem is fixed. Thank you

@schlyah
Copy link
Author

schlyah commented Aug 10, 2022

It works 👍 Thanks for the quick response !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic help wanted Open to be worked on
Projects
None yet
5 participants