Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Data loading cuda device #4879

Merged
merged 18 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with
the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting.
- `ArrayField` is now called `TensorField`, and implemented in terms of torch tensors, not numpy.
- Improved `nn.util.move_to_device` function by avoiding an unnecessary recursive check for tensors and
adding a `non_blocking` optional argument, which is the same argument as in `torch.Tensor.to()`.

### Removed

- Removed `nn.util.has_tensor`.


## Unreleased (1.x branch)
Expand Down
34 changes: 29 additions & 5 deletions allennlp/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
return {**padding_lengths}

def as_tensor_dict(
self, padding_lengths: Dict[str, Dict[str, int]] = None, verbose: bool = False
self,
padding_lengths: Dict[str, Dict[str, int]] = None,
verbose: bool = False,
pin_memory: bool = False,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
# This complex return type is actually predefined elsewhere as a DataArray,
# but we can't use it because mypy doesn't like it.
Expand All @@ -98,6 +101,10 @@ def as_tensor_dict(
But if you're doing this inside of a data generator, having all of this output per
batch is a bit obnoxious (and really slow).

pin_memory : `bool`, optional (default = `False`)
When `True`, tensors will be put into pinned (page-locked) memory, allowing for faster
(and potentially asyncronous) transfer to GPU.

# Returns

tensors : `Dict[str, DataArray]`
Expand Down Expand Up @@ -146,10 +153,27 @@ def as_tensor_dict(
# tensors together, so we grab a dictionary of field_name -> field class from the first
# instance in the batch.
field_classes = self.instances[0].fields
return {
field_name: field_classes[field_name].batch_tensors(field_tensor_list)
for field_name, field_tensor_list in field_tensors.items()
}

if pin_memory:
return {
field_name: self.maybe_pin(
field_classes[field_name].batch_tensors(field_tensor_list)
)
for field_name, field_tensor_list in field_tensors.items()
}
else:
return {
field_name: field_classes[field_name].batch_tensors(field_tensor_list)
for field_name, field_tensor_list in field_tensors.items()
}

def maybe_pin(self, maybe_tensor):
if isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor.pin_memory()
elif isinstance(maybe_tensor, dict):
return {k: self.maybe_pin(v) for k, v in maybe_tensor.items()}
else:
return maybe_tensor

def __iter__(self) -> Iterator[Instance]:
return iter(self.instances)
Expand Down
4 changes: 2 additions & 2 deletions allennlp/data/data_loaders/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
"""


def allennlp_collate(instances: List[Instance]) -> TensorDict:
def allennlp_collate(instances: List[Instance], *, pin_memory: bool = False) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()
return batch.as_tensor_dict(pin_memory=pin_memory)


class DataLoader(Registrable):
Expand Down
26 changes: 19 additions & 7 deletions allennlp/data/data_loaders/multi_process_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import sys
import traceback
from typing import List, Iterator, Optional, Callable, Iterable
from typing import List, Iterator, Optional, Iterable

import torch.multiprocessing as mp

Expand Down Expand Up @@ -75,9 +75,6 @@ class MultiProcessDataLoader(DataLoader):
the `reader` needs to implement
[`manual_multi_process_sharding`](/api/data/dataset_readers/dataset_reader/#datasetreader).

collate_fn: `Callable[[List[Instance]], TensorDict]`, optional (default = `allennlp_collate`)
The function used to turn `Instance`s into a `TensorDict` batch.

max_instances_in_memory: `int`, optional (default = `None`)
If not specified, all instances will be read and cached in memory for the duration
of the data loader's life. This is generally ideal when your data can fit in memory
Expand All @@ -94,6 +91,14 @@ class MultiProcessDataLoader(DataLoader):
The [start method](https://docs.python.org/3.7/library/multiprocessing.html#contexts-and-start-methods)
used to spin up workers.

pin_memory: `bool`, optional (default = `False`)
When `True`, CPU tensors will be put into pinned (page-locked) memory, which results in faster copies to GPU.
It also lets you make asyncronous copies to GPU by passing the `non_blocking=True` argument to
`.to()` or `.cuda()`.

See [the PyTorch docs](https://pytorch.org/docs/stable/notes/cuda.html#use-pinned-memory-buffers)
for more info.

!!! Note
In a typical AllenNLP configuration file, the `reader` and `data_path` parameters don't
get an entry under the "data_loader". The `reader` is constructed separately from
Expand Down Expand Up @@ -131,15 +136,16 @@ def __init__(
self,
reader: DatasetReader,
data_path: str,
*,
batch_size: int = None,
drop_last: bool = False,
shuffle: bool = False,
batch_sampler: BatchSampler = None,
batches_per_epoch: int = None,
num_workers: int = 0,
collate_fn: Callable[[List[Instance]], TensorDict] = allennlp_collate,
max_instances_in_memory: int = None,
start_method: str = "fork",
pin_memory: bool = False,
) -> None:
# Do some parameter validation.
if num_workers is not None and num_workers < 0:
Expand Down Expand Up @@ -169,6 +175,11 @@ def __init__(
elif max_instances_in_memory < 1:
raise ValueError("max_instances_in_memory must be at least 1")

if pin_memory and max_instances_in_memory is not None and start_method != "spawn":
raise ValueError(
"start_method must be set to 'spawn' when memory_pinning=True and max_instances_in_memory!=None"
)

self.reader = reader
self.data_path = data_path
self.batch_size = batch_size
Expand All @@ -177,9 +188,10 @@ def __init__(
self.batch_sampler = batch_sampler
self.batches_per_epoch = batches_per_epoch
self.num_workers = num_workers
self.collate_fn = collate_fn
self.collate_fn = allennlp_collate
self.max_instances_in_memory = max_instances_in_memory
self.start_method = start_method
self.pin_memory = pin_memory

# To make sure we have some backpressure in the worker queues we try to set
# reasonable defaults for the maximum size of these queues.
Expand Down Expand Up @@ -480,7 +492,7 @@ def _instances_to_batches(self, instance_iterator: Iterable[Instance]) -> Iterat
and len(batch) < self.batch_size # type: ignore[operator]
):
break
yield self.collate_fn(batch)
yield self.collate_fn(batch, pin_memory=self.pin_memory)


class WorkerError(Exception):
Expand Down
48 changes: 21 additions & 27 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,37 @@
T = TypeVar("T")


def has_tensor(obj) -> bool:
def move_to_device(obj, device: Union[torch.device, int], *, non_blocking: bool = False):
"""
Given a possibly complex data structure,
check if it has any torch.Tensors in it.
"""
if isinstance(obj, torch.Tensor):
return True
elif isinstance(obj, dict):
return any(has_tensor(value) for value in obj.values())
elif isinstance(obj, (list, tuple)):
return any(has_tensor(item) for item in obj)
else:
return False


def move_to_device(obj, cuda_device: Union[torch.device, int]):
"""
Given a structure (possibly) containing Tensors on the CPU,
move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
Given a structure (possibly) containing Tensors,
move all the Tensors to the specified device (or do nothing, if they are already on
the target device).
"""
from allennlp.common.util import int_to_device

cuda_device = int_to_device(cuda_device)
device = int_to_device(device)

if cuda_device == torch.device("cpu") or not has_tensor(obj):
return obj
elif isinstance(obj, torch.Tensor):
return obj.cuda(cuda_device)
if isinstance(obj, torch.Tensor):
# You may be wondering why we don't just always call `obj.to(device)` since that would
# be a no-op anyway if `obj` is already on `device`. Well that works fine except
# when PyTorch is not compiled with CUDA support, in which case even calling
# `obj.to(torch.device("cpu"))` would result in an error.
return obj if obj.device == device else obj.to(device=device, non_blocking=non_blocking)
elif isinstance(obj, dict):
return {key: move_to_device(value, cuda_device) for key, value in obj.items()}
for key, value in obj.items():
obj[key] = move_to_device(value, device, non_blocking=non_blocking)
return obj
elif isinstance(obj, list):
return [move_to_device(item, cuda_device) for item in obj]
for i, item in enumerate(obj):
obj[i] = move_to_device(item, device, non_blocking=non_blocking)
return obj
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# This is the best way to detect a NamedTuple, it turns out.
return obj.__class__(*(move_to_device(item, cuda_device) for item in obj))
return obj.__class__(
*(move_to_device(item, device, non_blocking=non_blocking) for item in obj)
)
elif isinstance(obj, tuple):
return tuple(move_to_device(item, cuda_device) for item in obj)
return tuple(move_to_device(item, device, non_blocking=non_blocking) for item in obj)
else:
return obj

Expand Down
38 changes: 34 additions & 4 deletions tests/data/data_loaders/multi_process_data_loader_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import List, Iterable
from typing import List, Iterable, Dict

import torch
import pytest

from allennlp.common.testing import requires_gpu
from allennlp.data.instance import Instance
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.data_loaders import MultiProcessDataLoader, WorkerError
from allennlp.data.fields import TextField, MetadataField
from allennlp.data.fields import Field, TextField, MetadataField, TensorField
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn.util import move_to_device


class MockDatasetReader(DatasetReader):
Expand Down Expand Up @@ -40,9 +42,12 @@ def _read(self, file_path: str):
yield self.text_to_instance(i, source, target)

def text_to_instance(self, index: int, source: str, target: str = None) -> Instance: # type: ignore
fields = {}
fields: Dict[str, Field] = {}
fields["source"] = TextField(self.tokenizer.tokenize(source))
fields["index"] = MetadataField(index) # type: ignore
# It's important to have tests that use a tensor field since sending tensors
# between processes has a lot of pitfalls.
fields["tensor"] = TensorField(torch.tensor([1, 2, 3]))
if target is not None:
fields["target"] = TextField(self.tokenizer.tokenize(target))
return Instance(fields) # type: ignore
Expand Down Expand Up @@ -171,3 +176,28 @@ def test_batches_per_epoch():

assert len(loader) == 10
assert len(list(loader)) == 10


@pytest.mark.parametrize(
"options",
[
dict(num_workers=0, batch_size=2),
dict(num_workers=1, batch_size=2),
dict(num_workers=1, batch_size=2, max_instances_in_memory=10),
],
ids=str,
)
@requires_gpu
def test_pin_memory(options):
reader = MockDatasetReader()
loader = MultiProcessDataLoader(
reader=reader,
data_path="this doens't matter",
pin_memory=True,
start_method="spawn",
**options,
)
vocab = Vocabulary.from_instances(loader.iter_instances())
loader.index_with(vocab)
for batch in loader:
batch = move_to_device(batch, 0, non_blocking=True)
6 changes: 5 additions & 1 deletion tests/data/samplers/bucket_batch_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,18 @@ def test_drop_last_works(self):
sorting_keys=["text"],
drop_last=True,
)

# We use a custom collate_fn for testing, which doesn't actually create tensors,
# just the allennlp Batches.
def collate_fn(x, **kwargs):
return Batch(x)

data_loader = MultiProcessDataLoader(
self.get_mock_reader(),
"fake_path",
batch_sampler=sampler,
collate_fn=lambda x: Batch(x),
)
data_loader.collate_fn = collate_fn
data_loader.index_with(self.vocab)
batches = [batch for batch in iter(data_loader)]
stats = self.get_batches_stats(batches)
Expand Down
23 changes: 2 additions & 21 deletions tests/nn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,25 +1427,6 @@ def test_combine_tensors_and_multiply_with_batch_size_one_and_seq_len_one(self):

assert_almost_equal(result.size(), [1, seq_len_1, seq_len_2])

def test_has_tensor(self):

has_tensor = util.has_tensor
tensor = torch.tensor([1, 2, 3])

assert has_tensor(["a", 10, tensor])
assert not has_tensor(["a", 10])

assert has_tensor(("a", 10, tensor))
assert not has_tensor(("a", 10))

assert has_tensor({"a": tensor, "b": 1})
assert not has_tensor({"a": 10, "b": 1})

assert has_tensor(tensor)
assert not has_tensor(3)

assert has_tensor({"x": [0, {"inside": {"double_inside": [3, [10, tensor]]}}]})

def test_combine_initial_dims(self):
tensor = torch.randn(4, 10, 20, 17, 5)

Expand All @@ -1471,13 +1452,13 @@ def test_inspect_model_parameters(self):
assert parameters_inspection_dict == util.inspect_parameters(model)

def test_move_to_device(self):
# We're faking the tensor here so that we can test the calls to .cuda() without actually
# We're faking the tensor here so that we can test the calls to .to() without actually
# needing a GPU.
class FakeTensor(torch.Tensor):
def __init__(self):
self._device = None

def cuda(self, device):
def to(self, device, **kwargs):
self._device = device
return self

Expand Down