Skip to content

Commit

Permalink
Refactor dataloading (#955)
Browse files Browse the repository at this point in the history
* Refactor dataloading

* Refactor dataloading

* Refactor dataloading

* Add shuffle to test
  • Loading branch information
ethanwharris authored Feb 26, 2020
1 parent be24456 commit b2e9607
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 83 deletions.
108 changes: 41 additions & 67 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
import warnings
from abc import ABC

import torch.distributed as dist
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand Down Expand Up @@ -90,36 +78,19 @@ def call_prepare_data(self, model):
model.prepare_data()

def auto_add_sampler(self, dataloader, train):
# do nothing when user gives a sampler
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}

if train:
if self.use_ddp or self.use_ddp2:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
'num_workers': dataloader.num_workers,
'collate_fn': dataloader.collate_fn,
'pin_memory': dataloader.pin_memory,
'drop_last': dataloader.drop_last,
'timeout': dataloader.timeout,
'worker_init_fn': dataloader.worker_init_fn
}

elif self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal()
)
dl_args['shuffle'] = False
else:
sampler = RandomSampler(dataloader.dataset)

# on not train
else:
if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
Expand All @@ -128,12 +99,16 @@ def auto_add_sampler(self, dataloader, train):
)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)
if train:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
else:
sampler = SequentialSampler(dataloader.dataset)

dl_args['sampler'] = sampler
dl_args['sampler'] = sampler

new_dataloader = DataLoader(**dl_args)
return new_dataloader
dataloader = DataLoader(**dl_args)
return dataloader

def reset_train_dataloader(self, model):
"""
Expand All @@ -148,12 +123,12 @@ def reset_train_dataloader(self, model):
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

# determine number of training batches
if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset):
self._percent_range_check('train_percent_check')

if self.is_infinite_dataloader(self.train_dataloader):
self.num_training_batches = float('inf')
else:
self._percent_range_check('train_percent_check')

# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)

Expand All @@ -168,27 +143,26 @@ def reset_train_dataloader(self, model):
f"to the number of the training batches ({self.num_training_batches}). "
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
else:
if self.is_infinite_dataloader(self.train_dataloader):
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
must be an int. An int k specifies checking validation every k training batches.
'''
raise MisconfigurationException(m)

self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
)
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for `train_dataloader`,
`Trainer(val_check_interval)` must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException(m)

def is_iterable_dataloader(self, dataloader):
return (
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
)
def is_infinite_dataloader(self, dataloader):
try:
# try getting the length
_ = len(dataloader)
return False
except TypeError as e:
return True

def reset_val_dataloader(self, model):
"""
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,19 +1114,14 @@ def run_pretrain_routine(self, model: LightningModule):
self.run_evaluation(test_mode=True)
return

# load the dataloaders
self.reset_train_dataloader(ref_model)
self.reset_val_dataloader(ref_model)

# check if we should run validation during training
self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step')
self.disable_validation = self.disable_validation and not self.fast_dev_run
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run

# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
ref_model.on_train_start()
if not self.disable_validation and self.num_sanity_val_steps > 0:
self.reset_val_dataloader(ref_model)
# init progress bars for validation sanity check
pbar = tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.val_dataloaders),
Expand Down
26 changes: 17 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def is_function_implemented(self, m):
pass

@abstractmethod
def is_iterable_dataloader(self, dataloader):
def is_infinite_dataloader(self, dataloader):
# this is just empty shell for code from other class
pass

Expand Down Expand Up @@ -325,6 +325,11 @@ def reset_train_dataloader(self, model):
# this is just empty shell for code from other class
pass

@abstractmethod
def reset_val_dataloader(self, model):
# this is just empty shell for code from other class
pass

@abstractmethod
def has_arg(self, f_name, arg_name):
# this is just empty shell for code from other class
Expand All @@ -334,11 +339,17 @@ def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)

# get model
model = self.get_model()

# load data
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)

# Train begin callbacks
model.on_train_start()
self.on_train_start()

# get model
model = self.get_model()
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
Expand All @@ -347,9 +358,6 @@ def train(self):
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
self.train_dataloader.sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch
Expand All @@ -370,8 +378,8 @@ def train(self):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_dataloader(self.train_dataloader):
# for iterable train loader, the progress bar never ends
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
Expand All @@ -380,7 +388,7 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
Expand Down
1 change: 1 addition & 0 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _dataloader(self, train):
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True
)

return loader
Expand Down
54 changes: 54 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,60 @@ def test_model_freeze_unfreeze():
model.unfreeze()


def test_inf_train_dataloader(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
tutils.reset_seed()

class CurrentTestModel(LightningTestModel):
def train_dataloader(self):
dataloader = self._dataloader(train=True)

class CustomInfDataLoader:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count >= 5:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)

return CustomInfDataLoader(dataloader)

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=50,
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_multiple_val_dataloader(tmpdir):
"""Verify multiple val_dataloader."""
tutils.reset_seed()
Expand Down

0 comments on commit b2e9607

Please sign in to comment.