Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Data loading cuda device #4879

Merged
merged 18 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with
the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting.
- `ArrayField` is now called `TensorField`, and implemented in terms of torch tensors, not numpy.
- Improved `nn.util.move_to_device` function by avoiding an unnecessary recursive check for tensors and
adding a `non_blocking` optional argument, which is the same argument as in `torch.Tensor.to()`.
- If you are trying to create a heterogeneous batch, you now get a better error message.
- Readers using the new vision features now explicitly log how they are featurizing images.

### Removed

- Removed `nn.util.has_tensor`.

### Fixed

- The `build-vocab` command no longer crashes when the resulting vocab file is
Expand Down
4 changes: 3 additions & 1 deletion allennlp/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
return {**padding_lengths}

def as_tensor_dict(
self, padding_lengths: Dict[str, Dict[str, int]] = None, verbose: bool = False
self,
padding_lengths: Dict[str, Dict[str, int]] = None,
verbose: bool = False,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
# This complex return type is actually predefined elsewhere as a DataArray,
# but we can't use it because mypy doesn't like it.
Expand Down
12 changes: 9 additions & 3 deletions allennlp/data/data_loaders/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ class DataLoader(Registrable):
[`DatasetReader`](/api/data/dataset_readers/dataset_reader/#datasetreader),
or another source of data.

This class has three required methods:
This is purely an abstract base class. All concrete subclasses must provide
implementations of the following methods:

- [`__iter__()`](#__iter__) that creates an iterable of `TensorDict`s,
- [`iter_instances()`](#iter_instances) that creates an iterable of `Instance`s, and
- [`index_with()`](#index_with) that should index the data with a vocabulary.
- [`iter_instances()`](#iter_instances) that creates an iterable of `Instance`s,
- [`index_with()`](#index_with) that should index the data with a vocabulary, and
- [`set_target_device()`](#set_target_device), which updates the device that batch
tensors should be put it when they are generated in `__iter__()`.

Additionally, this class should also implement `__len__()` when possible.

Expand All @@ -54,3 +57,6 @@ def iter_instances(self) -> Iterator[Instance]:

def index_with(self, vocab: Vocabulary) -> None:
raise NotImplementedError

def set_target_device(self, device: torch.device) -> None:
raise NotImplementedError
75 changes: 56 additions & 19 deletions allennlp/data/data_loaders/multi_process_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import random
import sys
import traceback
from typing import List, Iterator, Optional, Callable, Iterable
from typing import List, Iterator, Optional, Iterable, Union

from overrides import overrides
import torch
import torch.multiprocessing as mp

from allennlp.common.util import lazy_groups_of, shuffle_iterable
Expand All @@ -16,6 +18,7 @@
from allennlp.data.fields import TextField
from allennlp.data.samplers import BatchSampler
from allennlp.data.vocabulary import Vocabulary
import allennlp.nn.util as nn_util


logger = logging.getLogger(__name__)
Expand All @@ -42,6 +45,12 @@ class MultiProcessDataLoader(DataLoader):
data_path: `str`, required
Passed to `DatasetReader.read()`.

!!! Note
In a typical AllenNLP configuration file, the `reader` and `data_path` parameters don't
get an entry under the `data_loader`. The `reader` is constructed separately from
the corresponding `dataset_reader` params, and the `data_path` is taken from the
`train_data_path`, `validation_data_path`, or `test_data_path`.

batch_size: `int`, optional (default = `None`)
When `batch_sampler` is unspecified, this option can be combined with `drop_last`
and `shuffle` to control automatic batch sampling.
Expand Down Expand Up @@ -75,30 +84,30 @@ class MultiProcessDataLoader(DataLoader):
the `reader` needs to implement
[`manual_multi_process_sharding`](/api/data/dataset_readers/dataset_reader/#datasetreader).

collate_fn: `Callable[[List[Instance]], TensorDict]`, optional (default = `allennlp_collate`)
The function used to turn `Instance`s into a `TensorDict` batch.

max_instances_in_memory: `int`, optional (default = `None`)
If not specified, all instances will be read and cached in memory for the duration
of the data loader's life. This is generally ideal when your data can fit in memory
during training. However, when your datasets are too big, using this option
will turn on lazy loading, where only `max_instances_in_memory` instances are processed
at a time.

Note that this setting will affect how a `batch_sampler` is applied. If
`max_instances_in_memory` is `None`, the sampler will be applied to all `Instance`s.
Otherwise the sampler will be applied to only `max_instances_in_memory` `Instance`s
at a time.
!!! Note
This setting will affect how a `batch_sampler` is applied. If
`max_instances_in_memory` is `None`, the sampler will be applied to all `Instance`s.
Otherwise the sampler will be applied to only `max_instances_in_memory` `Instance`s
at a time.

start_method: `str`, optional (default = `"fork"`)
The [start method](https://docs.python.org/3.7/library/multiprocessing.html#contexts-and-start-methods)
used to spin up workers.

!!! Note
In a typical AllenNLP configuration file, the `reader` and `data_path` parameters don't
get an entry under the "data_loader". The `reader` is constructed separately from
the corresponding `dataset_reader` params, and the `data_path` is taken from the
`train_data_path`, `validation_data_path`, or `test_data_path`.
cuda_device: `Optional[Union[int, str, torch.device]]`, optional (default = `None`)
If given, batches will automatically be put on this device.

!!! Note
This should typically not be set in an AllenNLP configuration file. The `Trainer`
will automatically call [`set_target_device()`](#set_target_device) before iterating
over batches.

!!! Warning
Multiprocessing code in Python is complicated! Especially code that involves lower-level libraries
Expand Down Expand Up @@ -131,15 +140,16 @@ def __init__(
self,
reader: DatasetReader,
data_path: str,
*,
batch_size: int = None,
drop_last: bool = False,
shuffle: bool = False,
batch_sampler: BatchSampler = None,
batches_per_epoch: int = None,
num_workers: int = 0,
collate_fn: Callable[[List[Instance]], TensorDict] = allennlp_collate,
max_instances_in_memory: int = None,
start_method: str = "fork",
cuda_device: Optional[Union[int, str, torch.device]] = None,
) -> None:
# Do some parameter validation.
if num_workers is not None and num_workers < 0:
Expand Down Expand Up @@ -177,9 +187,18 @@ def __init__(
self.batch_sampler = batch_sampler
self.batches_per_epoch = batches_per_epoch
self.num_workers = num_workers
self.collate_fn = collate_fn
self.collate_fn = allennlp_collate
self.max_instances_in_memory = max_instances_in_memory
self.start_method = start_method
self.cuda_device: Optional[torch.device] = None
if cuda_device is not None:
if not isinstance(cuda_device, torch.device):
self.cuda_device = torch.device(cuda_device)
else:
self.cuda_device = cuda_device

# Can only initialize CUDA in workers when these `start_methods` are used.
self._worker_cuda_safe = self.start_method in {"spawn", "forkserver"}

# To make sure we have some backpressure in the worker queues we try to set
# reasonable defaults for the maximum size of these queues.
Expand All @@ -206,12 +225,14 @@ def __init__(
# Load all instances right away.
deque(self.iter_instances(), maxlen=0)

@overrides
def index_with(self, vocab: Vocabulary) -> None:
self._vocab = vocab
if self._instances:
for instance in self._instances:
instance.index_fields(vocab)

@overrides
def __len__(self) -> int:
if self.batches_per_epoch is not None:
return self.batches_per_epoch
Expand All @@ -235,6 +256,7 @@ def __len__(self) -> int:
# is not specified.
raise TypeError

@overrides
def __iter__(self) -> Iterator[TensorDict]:
if self._vocab is None:
raise ValueError(
Expand All @@ -254,6 +276,7 @@ def __iter__(self) -> Iterator[TensorDict]:
self._batch_generator = self._iter_batches() # so refresh it
yield next(self._batch_generator)

@overrides
def iter_instances(self) -> Iterator[Instance]:
if self._instances:
yield from self._instances
Expand Down Expand Up @@ -293,9 +316,13 @@ def iter_instances(self) -> Iterator[Instance]:
queue.close() # type: ignore[attr-defined]
self._join_workers(workers, queue)

@overrides
def set_target_device(self, device: torch.device) -> None:
self.cuda_device = device

def _iter_batches(self) -> Iterator[TensorDict]:
if self._instances is not None or self.num_workers <= 0:
for batch in self._instances_to_batches(self.iter_instances()):
for batch in self._instances_to_batches(self.iter_instances(), move_to_device=True):
yield batch
else:
ctx = mp.get_context(self.start_method)
Expand All @@ -318,6 +345,9 @@ def _iter_batches(self) -> Iterator[TensorDict]:
sys.stderr.write("".join(tb))
raise WorkerError(e)

if not self._worker_cuda_safe and self.cuda_device is not None:
# Need to move batch to target device now.
batch = nn_util.move_to_device(batch, self.cuda_device)
yield batch
queue.task_done()
done_count += 1
Expand Down Expand Up @@ -398,7 +428,9 @@ def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
try:
self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
instances = self.reader.read(self.data_path)
for batch in self._instances_to_batches(instances):
for batch in self._instances_to_batches(
instances, move_to_device=self._worker_cuda_safe
):
queue.put((batch, None))
except Exception as e:
queue.put((None, (str(e), traceback.format_exc())))
Expand Down Expand Up @@ -431,7 +463,9 @@ def _index_instance(self, instance: Instance) -> Instance:
instance.index_fields(self._vocab)
return instance

def _instances_to_batches(self, instance_iterator: Iterable[Instance]) -> Iterator[TensorDict]:
def _instances_to_batches(
self, instance_iterator: Iterable[Instance], move_to_device
) -> Iterator[TensorDict]:
instance_iterator = (self._index_instance(instance) for instance in instance_iterator)

if self.max_instances_in_memory is not None:
Expand Down Expand Up @@ -480,7 +514,10 @@ def _instances_to_batches(self, instance_iterator: Iterable[Instance]) -> Iterat
and len(batch) < self.batch_size # type: ignore[operator]
):
break
yield self.collate_fn(batch)
tensor_dict = self.collate_fn(batch)
if move_to_device and self.cuda_device is not None:
tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device)
yield tensor_dict


class WorkerError(Exception):
Expand Down
38 changes: 35 additions & 3 deletions allennlp/data/data_loaders/multitask_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Dict, Iterable, Iterator, List
from typing import Any, Dict, Iterable, Iterator, List, Union, Optional
import itertools
import math

import torch
from overrides import overrides

from allennlp.common import util
from allennlp.data.batch import Batch
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
Expand All @@ -14,6 +17,7 @@
from allennlp.data.dataset_readers.multitask import MultiTaskDatasetReader
from allennlp.data.instance import Instance
from allennlp.data.vocabulary import Vocabulary
import allennlp.nn.util as nn_util


def maybe_shuffle_instances(loader: DataLoader, shuffle: bool) -> Iterable[Instance]:
Expand Down Expand Up @@ -109,6 +113,13 @@ class MultiTaskDataLoader(DataLoader):
shuffle: `bool`, optional (default = `True`)
If `False`, we will not shuffle the instances that come from each underlying data loader.
You almost certainly never want to use this except when debugging.
cuda_device: `Optional[Union[int, str, torch.device]]`, optional (default = `None`)
If given, batches will automatically be put on this device.

!!! Note
This should typically not be set in an AllenNLP configuration file. The `Trainer`
will automatically call [`set_target_device()`](#set_target_device) before iterating
over batches.
"""

def __init__(
Expand All @@ -128,11 +139,18 @@ def __init__(
instance_queue_size: Dict[str, int] = None,
instance_chunk_size: Dict[str, int] = None,
shuffle: bool = True,
cuda_device: Optional[Union[int, str, torch.device]] = None,
) -> None:
self.readers = reader.readers
self.data_paths = data_path
self.scheduler = scheduler or HomogeneousRoundRobinScheduler(batch_size=batch_size)
self.sampler = sampler
self.cuda_device: Optional[torch.device] = None
if cuda_device is not None:
if not isinstance(cuda_device, torch.device):
self.cuda_device = torch.device(cuda_device)
else:
self.cuda_device = cuda_device

self._batch_size = batch_size
self._instances_per_epoch = instances_per_epoch
Expand Down Expand Up @@ -186,6 +204,7 @@ def __init__(
for key, loader in self._loaders.items()
}

@overrides
def __len__(self) -> int:
if self._instances_per_epoch is not None:
return self._instances_per_epoch
Expand All @@ -204,6 +223,7 @@ def __len__(self) -> int:
else:
return int(1 + total_instances) // self._batch_size

@overrides
def __iter__(self) -> Iterator[TensorDict]:
# Basic outline: first we _sample_ the instances that we're going to be using for this
# epoch, which relies on the scheduler if `self._instances_per_epoch` is not None. This is
Expand All @@ -219,7 +239,10 @@ def __iter__(self) -> Iterator[TensorDict]:
current_batch_size += self._batch_size_multiplier.get(dataset, 1.0)
if current_batch_size > self._batch_size:
batch = Batch(batch_instances)
yield batch.as_tensor_dict()
tensor_dict = batch.as_tensor_dict()
if self.cuda_device is not None:
tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device)
yield tensor_dict
batch_instances = [instance]
current_batch_size = self._batch_size_multiplier.get(dataset, 1.0)
else:
Expand All @@ -229,8 +252,12 @@ def __iter__(self) -> Iterator[TensorDict]:
# so we don't need a check for that here.
if not self._drop_last or current_batch_size == self._batch_size:
batch = Batch(batch_instances)
yield batch.as_tensor_dict()
tensor_dict = batch.as_tensor_dict()
if self.cuda_device is not None:
tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device)
yield tensor_dict

@overrides
def iter_instances(self) -> Iterator[Instance]:
# The only external contract for this method is that it iterates over instances
# individually; it doesn't actually specify anything about batching or anything else. The
Expand All @@ -248,10 +275,15 @@ def iter_instances(self) -> Iterator[Instance]:
for loader in self._loaders.values():
yield from loader.iter_instances()

@overrides
def index_with(self, vocab: Vocabulary) -> None:
for loader in self._loaders.values():
loader.index_with(vocab)

@overrides
def set_target_device(self, device: torch.device) -> None:
self.cuda_device = device

def _get_instances_for_epoch(self) -> Dict[str, Iterable[Instance]]:
if self._instances_per_epoch is None:
return {
Expand Down
Loading