diff --git a/tests/models/base.py b/tests/models/base.py index 073b2ca2d629d..dce2ef624e97a 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -16,9 +16,15 @@ # TODO: this should be discussed and moved out of this package raise ImportError('Missing test-tube package.') -from pytorch_lightning.core.decorators import data_loader from pytorch_lightning.core.lightning import LightningModule +# TODO: remove after getting own MNIST +# TEMPORAL FIX, https://github.com/pytorch/vision/issues/1938 +import urllib.request +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) + class TestingMNIST(MNIST): @@ -54,8 +60,8 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def train_dataloader(self): - return DataLoader(MNIST(os.getcwd(), train=True, download=True, - transform=transforms.ToTensor()), batch_size=32) + return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True, + transform=transforms.ToTensor()), batch_size=32) class TestModelBase(LightningModule):