diff --git a/CHANGELOG.md b/CHANGELOG.md index 1211d8fbbb1..30c9b171459 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,9 +45,15 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting. - `ArrayField` is now called `TensorField`, and implemented in terms of torch tensors, not numpy. +- Improved `nn.util.move_to_device` function by avoiding an unnecessary recursive check for tensors and + adding a `non_blocking` optional argument, which is the same argument as in `torch.Tensor.to()`. - If you are trying to create a heterogeneous batch, you now get a better error message. - Readers using the new vision features now explicitly log how they are featurizing images. +### Removed + +- Removed `nn.util.has_tensor`. + ### Fixed - The `build-vocab` command no longer crashes when the resulting vocab file is diff --git a/allennlp/data/batch.py b/allennlp/data/batch.py index dc3bd01d540..0a98537c3d3 100644 --- a/allennlp/data/batch.py +++ b/allennlp/data/batch.py @@ -83,7 +83,9 @@ def get_padding_lengths(self) -> Dict[str, Dict[str, int]]: return {**padding_lengths} def as_tensor_dict( - self, padding_lengths: Dict[str, Dict[str, int]] = None, verbose: bool = False + self, + padding_lengths: Dict[str, Dict[str, int]] = None, + verbose: bool = False, ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: # This complex return type is actually predefined elsewhere as a DataArray, # but we can't use it because mypy doesn't like it. diff --git a/allennlp/data/data_loaders/data_loader.py b/allennlp/data/data_loaders/data_loader.py index 3724e3f849b..9895e1dd01e 100644 --- a/allennlp/data/data_loaders/data_loader.py +++ b/allennlp/data/data_loaders/data_loader.py @@ -29,11 +29,14 @@ class DataLoader(Registrable): [`DatasetReader`](/api/data/dataset_readers/dataset_reader/#datasetreader), or another source of data. - This class has three required methods: + This is purely an abstract base class. All concrete subclasses must provide + implementations of the following methods: - [`__iter__()`](#__iter__) that creates an iterable of `TensorDict`s, - - [`iter_instances()`](#iter_instances) that creates an iterable of `Instance`s, and - - [`index_with()`](#index_with) that should index the data with a vocabulary. + - [`iter_instances()`](#iter_instances) that creates an iterable of `Instance`s, + - [`index_with()`](#index_with) that should index the data with a vocabulary, and + - [`set_target_device()`](#set_target_device), which updates the device that batch + tensors should be put it when they are generated in `__iter__()`. Additionally, this class should also implement `__len__()` when possible. @@ -54,3 +57,6 @@ def iter_instances(self) -> Iterator[Instance]: def index_with(self, vocab: Vocabulary) -> None: raise NotImplementedError + + def set_target_device(self, device: torch.device) -> None: + raise NotImplementedError diff --git a/allennlp/data/data_loaders/multi_process_data_loader.py b/allennlp/data/data_loaders/multi_process_data_loader.py index d44b1b20f3e..fd8fe94f28f 100644 --- a/allennlp/data/data_loaders/multi_process_data_loader.py +++ b/allennlp/data/data_loaders/multi_process_data_loader.py @@ -4,8 +4,10 @@ import random import sys import traceback -from typing import List, Iterator, Optional, Callable, Iterable +from typing import List, Iterator, Optional, Iterable, Union +from overrides import overrides +import torch import torch.multiprocessing as mp from allennlp.common.util import lazy_groups_of, shuffle_iterable @@ -16,6 +18,7 @@ from allennlp.data.fields import TextField from allennlp.data.samplers import BatchSampler from allennlp.data.vocabulary import Vocabulary +import allennlp.nn.util as nn_util logger = logging.getLogger(__name__) @@ -42,6 +45,12 @@ class MultiProcessDataLoader(DataLoader): data_path: `str`, required Passed to `DatasetReader.read()`. + !!! Note + In a typical AllenNLP configuration file, the `reader` and `data_path` parameters don't + get an entry under the `data_loader`. The `reader` is constructed separately from + the corresponding `dataset_reader` params, and the `data_path` is taken from the + `train_data_path`, `validation_data_path`, or `test_data_path`. + batch_size: `int`, optional (default = `None`) When `batch_sampler` is unspecified, this option can be combined with `drop_last` and `shuffle` to control automatic batch sampling. @@ -75,9 +84,6 @@ class MultiProcessDataLoader(DataLoader): the `reader` needs to implement [`manual_multi_process_sharding`](/api/data/dataset_readers/dataset_reader/#datasetreader). - collate_fn: `Callable[[List[Instance]], TensorDict]`, optional (default = `allennlp_collate`) - The function used to turn `Instance`s into a `TensorDict` batch. - max_instances_in_memory: `int`, optional (default = `None`) If not specified, all instances will be read and cached in memory for the duration of the data loader's life. This is generally ideal when your data can fit in memory @@ -85,20 +91,23 @@ class MultiProcessDataLoader(DataLoader): will turn on lazy loading, where only `max_instances_in_memory` instances are processed at a time. - Note that this setting will affect how a `batch_sampler` is applied. If - `max_instances_in_memory` is `None`, the sampler will be applied to all `Instance`s. - Otherwise the sampler will be applied to only `max_instances_in_memory` `Instance`s - at a time. + !!! Note + This setting will affect how a `batch_sampler` is applied. If + `max_instances_in_memory` is `None`, the sampler will be applied to all `Instance`s. + Otherwise the sampler will be applied to only `max_instances_in_memory` `Instance`s + at a time. start_method: `str`, optional (default = `"fork"`) The [start method](https://docs.python.org/3.7/library/multiprocessing.html#contexts-and-start-methods) used to spin up workers. - !!! Note - In a typical AllenNLP configuration file, the `reader` and `data_path` parameters don't - get an entry under the "data_loader". The `reader` is constructed separately from - the corresponding `dataset_reader` params, and the `data_path` is taken from the - `train_data_path`, `validation_data_path`, or `test_data_path`. + cuda_device: `Optional[Union[int, str, torch.device]]`, optional (default = `None`) + If given, batches will automatically be put on this device. + + !!! Note + This should typically not be set in an AllenNLP configuration file. The `Trainer` + will automatically call [`set_target_device()`](#set_target_device) before iterating + over batches. !!! Warning Multiprocessing code in Python is complicated! Especially code that involves lower-level libraries @@ -131,15 +140,16 @@ def __init__( self, reader: DatasetReader, data_path: str, + *, batch_size: int = None, drop_last: bool = False, shuffle: bool = False, batch_sampler: BatchSampler = None, batches_per_epoch: int = None, num_workers: int = 0, - collate_fn: Callable[[List[Instance]], TensorDict] = allennlp_collate, max_instances_in_memory: int = None, start_method: str = "fork", + cuda_device: Optional[Union[int, str, torch.device]] = None, ) -> None: # Do some parameter validation. if num_workers is not None and num_workers < 0: @@ -177,9 +187,18 @@ def __init__( self.batch_sampler = batch_sampler self.batches_per_epoch = batches_per_epoch self.num_workers = num_workers - self.collate_fn = collate_fn + self.collate_fn = allennlp_collate self.max_instances_in_memory = max_instances_in_memory self.start_method = start_method + self.cuda_device: Optional[torch.device] = None + if cuda_device is not None: + if not isinstance(cuda_device, torch.device): + self.cuda_device = torch.device(cuda_device) + else: + self.cuda_device = cuda_device + + # Can only initialize CUDA in workers when these `start_methods` are used. + self._worker_cuda_safe = self.start_method in {"spawn", "forkserver"} # To make sure we have some backpressure in the worker queues we try to set # reasonable defaults for the maximum size of these queues. @@ -206,12 +225,14 @@ def __init__( # Load all instances right away. deque(self.iter_instances(), maxlen=0) + @overrides def index_with(self, vocab: Vocabulary) -> None: self._vocab = vocab if self._instances: for instance in self._instances: instance.index_fields(vocab) + @overrides def __len__(self) -> int: if self.batches_per_epoch is not None: return self.batches_per_epoch @@ -235,6 +256,7 @@ def __len__(self) -> int: # is not specified. raise TypeError + @overrides def __iter__(self) -> Iterator[TensorDict]: if self._vocab is None: raise ValueError( @@ -254,6 +276,7 @@ def __iter__(self) -> Iterator[TensorDict]: self._batch_generator = self._iter_batches() # so refresh it yield next(self._batch_generator) + @overrides def iter_instances(self) -> Iterator[Instance]: if self._instances: yield from self._instances @@ -293,9 +316,13 @@ def iter_instances(self) -> Iterator[Instance]: queue.close() # type: ignore[attr-defined] self._join_workers(workers, queue) + @overrides + def set_target_device(self, device: torch.device) -> None: + self.cuda_device = device + def _iter_batches(self) -> Iterator[TensorDict]: if self._instances is not None or self.num_workers <= 0: - for batch in self._instances_to_batches(self.iter_instances()): + for batch in self._instances_to_batches(self.iter_instances(), move_to_device=True): yield batch else: ctx = mp.get_context(self.start_method) @@ -318,6 +345,9 @@ def _iter_batches(self) -> Iterator[TensorDict]: sys.stderr.write("".join(tb)) raise WorkerError(e) + if not self._worker_cuda_safe and self.cuda_device is not None: + # Need to move batch to target device now. + batch = nn_util.move_to_device(batch, self.cuda_device) yield batch queue.task_done() done_count += 1 @@ -398,7 +428,9 @@ def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None: try: self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path) - for batch in self._instances_to_batches(instances): + for batch in self._instances_to_batches( + instances, move_to_device=self._worker_cuda_safe + ): queue.put((batch, None)) except Exception as e: queue.put((None, (str(e), traceback.format_exc()))) @@ -431,7 +463,9 @@ def _index_instance(self, instance: Instance) -> Instance: instance.index_fields(self._vocab) return instance - def _instances_to_batches(self, instance_iterator: Iterable[Instance]) -> Iterator[TensorDict]: + def _instances_to_batches( + self, instance_iterator: Iterable[Instance], move_to_device + ) -> Iterator[TensorDict]: instance_iterator = (self._index_instance(instance) for instance in instance_iterator) if self.max_instances_in_memory is not None: @@ -480,7 +514,10 @@ def _instances_to_batches(self, instance_iterator: Iterable[Instance]) -> Iterat and len(batch) < self.batch_size # type: ignore[operator] ): break - yield self.collate_fn(batch) + tensor_dict = self.collate_fn(batch) + if move_to_device and self.cuda_device is not None: + tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device) + yield tensor_dict class WorkerError(Exception): diff --git a/allennlp/data/data_loaders/multitask_data_loader.py b/allennlp/data/data_loaders/multitask_data_loader.py index ac853186515..f4b4869e875 100644 --- a/allennlp/data/data_loaders/multitask_data_loader.py +++ b/allennlp/data/data_loaders/multitask_data_loader.py @@ -1,7 +1,10 @@ -from typing import Any, Dict, Iterable, Iterator, List +from typing import Any, Dict, Iterable, Iterator, List, Union, Optional import itertools import math +import torch +from overrides import overrides + from allennlp.common import util from allennlp.data.batch import Batch from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict @@ -14,6 +17,7 @@ from allennlp.data.dataset_readers.multitask import MultiTaskDatasetReader from allennlp.data.instance import Instance from allennlp.data.vocabulary import Vocabulary +import allennlp.nn.util as nn_util def maybe_shuffle_instances(loader: DataLoader, shuffle: bool) -> Iterable[Instance]: @@ -109,6 +113,13 @@ class MultiTaskDataLoader(DataLoader): shuffle: `bool`, optional (default = `True`) If `False`, we will not shuffle the instances that come from each underlying data loader. You almost certainly never want to use this except when debugging. + cuda_device: `Optional[Union[int, str, torch.device]]`, optional (default = `None`) + If given, batches will automatically be put on this device. + + !!! Note + This should typically not be set in an AllenNLP configuration file. The `Trainer` + will automatically call [`set_target_device()`](#set_target_device) before iterating + over batches. """ def __init__( @@ -128,11 +139,18 @@ def __init__( instance_queue_size: Dict[str, int] = None, instance_chunk_size: Dict[str, int] = None, shuffle: bool = True, + cuda_device: Optional[Union[int, str, torch.device]] = None, ) -> None: self.readers = reader.readers self.data_paths = data_path self.scheduler = scheduler or HomogeneousRoundRobinScheduler(batch_size=batch_size) self.sampler = sampler + self.cuda_device: Optional[torch.device] = None + if cuda_device is not None: + if not isinstance(cuda_device, torch.device): + self.cuda_device = torch.device(cuda_device) + else: + self.cuda_device = cuda_device self._batch_size = batch_size self._instances_per_epoch = instances_per_epoch @@ -186,6 +204,7 @@ def __init__( for key, loader in self._loaders.items() } + @overrides def __len__(self) -> int: if self._instances_per_epoch is not None: return self._instances_per_epoch @@ -204,6 +223,7 @@ def __len__(self) -> int: else: return int(1 + total_instances) // self._batch_size + @overrides def __iter__(self) -> Iterator[TensorDict]: # Basic outline: first we _sample_ the instances that we're going to be using for this # epoch, which relies on the scheduler if `self._instances_per_epoch` is not None. This is @@ -219,7 +239,10 @@ def __iter__(self) -> Iterator[TensorDict]: current_batch_size += self._batch_size_multiplier.get(dataset, 1.0) if current_batch_size > self._batch_size: batch = Batch(batch_instances) - yield batch.as_tensor_dict() + tensor_dict = batch.as_tensor_dict() + if self.cuda_device is not None: + tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device) + yield tensor_dict batch_instances = [instance] current_batch_size = self._batch_size_multiplier.get(dataset, 1.0) else: @@ -229,8 +252,12 @@ def __iter__(self) -> Iterator[TensorDict]: # so we don't need a check for that here. if not self._drop_last or current_batch_size == self._batch_size: batch = Batch(batch_instances) - yield batch.as_tensor_dict() + tensor_dict = batch.as_tensor_dict() + if self.cuda_device is not None: + tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device) + yield tensor_dict + @overrides def iter_instances(self) -> Iterator[Instance]: # The only external contract for this method is that it iterates over instances # individually; it doesn't actually specify anything about batching or anything else. The @@ -248,10 +275,15 @@ def iter_instances(self) -> Iterator[Instance]: for loader in self._loaders.values(): yield from loader.iter_instances() + @overrides def index_with(self, vocab: Vocabulary) -> None: for loader in self._loaders.values(): loader.index_with(vocab) + @overrides + def set_target_device(self, device: torch.device) -> None: + self.cuda_device = device + def _get_instances_for_epoch(self) -> Dict[str, Iterable[Instance]]: if self._instances_per_epoch is None: return { diff --git a/allennlp/data/data_loaders/pytorch_data_loader.py b/allennlp/data/data_loaders/pytorch_data_loader.py index 9bbc7eb4595..690b38b6adc 100644 --- a/allennlp/data/data_loaders/pytorch_data_loader.py +++ b/allennlp/data/data_loaders/pytorch_data_loader.py @@ -1,5 +1,7 @@ -from typing import List, Iterator, Optional +from typing import List, Iterator, Optional, Union +from overrides import overrides +import torch from torch.utils import data from allennlp.common.lazy import Lazy @@ -9,6 +11,7 @@ from allennlp.data.samplers import PyTorchSampler, PyTorchBatchSampler from allennlp.data.vocabulary import Vocabulary from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate +import allennlp.nn.util as nn_util class AllennlpDataset(data.Dataset): @@ -162,6 +165,7 @@ def __init__( timeout: int = 0, multiprocessing_context: str = None, batches_per_epoch: int = None, + cuda_device: Optional[Union[int, str, torch.device]] = None, ): super().__init__( dataset, @@ -179,32 +183,52 @@ def __init__( ) self._data_generator = super().__iter__() self._batches_per_epoch = batches_per_epoch - + self.cuda_device: Optional[torch.device] = None + if cuda_device is not None: + if not isinstance(cuda_device, torch.device): + self.cuda_device = torch.device(cuda_device) + else: + self.cuda_device = cuda_device + + @overrides def __len__(self): if self._batches_per_epoch is not None: return self._batches_per_epoch return super().__len__() + @overrides def __iter__(self): if self._batches_per_epoch is None: # NOTE: since torch's DataLoader is listed as the first super class of this class, # super().__iter__() will resolve to the __iter__ method from torch's DataLoader, # which is what we want. - yield from super().__iter__() + for tensor_dict in super().__iter__(): + if self.cuda_device is not None: + tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device) + yield tensor_dict else: for i in range(self._batches_per_epoch): try: - yield next(self._data_generator) + tensor_dict = next(self._data_generator) except StopIteration: # data_generator is exhausted self._data_generator = super().__iter__() # so refresh it - yield next(self._data_generator) # and yield required instance + tensor_dict = next(self._data_generator) # and yield required instance + if self.cuda_device is not None: + tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device) + yield tensor_dict + @overrides def iter_instances(self) -> Iterator[Instance]: yield from self.dataset + @overrides def index_with(self, vocab: Vocabulary): self.dataset.index_with(vocab) # type: ignore + @overrides + def set_target_device(self, device: torch.device) -> None: + self.cuda_device = device + @classmethod def from_partial_objects( cls, diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index c593694aaed..a5c7574ef70 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -19,43 +19,35 @@ T = TypeVar("T") -def has_tensor(obj) -> bool: +def move_to_device(obj, device: Union[torch.device, int]): """ - Given a possibly complex data structure, - check if it has any torch.Tensors in it. - """ - if isinstance(obj, torch.Tensor): - return True - elif isinstance(obj, dict): - return any(has_tensor(value) for value in obj.values()) - elif isinstance(obj, (list, tuple)): - return any(has_tensor(item) for item in obj) - else: - return False - - -def move_to_device(obj, cuda_device: Union[torch.device, int]): - """ - Given a structure (possibly) containing Tensors on the CPU, - move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU). + Given a structure (possibly) containing Tensors, + move all the Tensors to the specified device (or do nothing, if they are already on + the target device). """ from allennlp.common.util import int_to_device - cuda_device = int_to_device(cuda_device) + device = int_to_device(device) - if cuda_device == torch.device("cpu") or not has_tensor(obj): - return obj - elif isinstance(obj, torch.Tensor): - return obj.cuda(cuda_device) + if isinstance(obj, torch.Tensor): + # You may be wondering why we don't just always call `obj.to(device)` since that would + # be a no-op anyway if `obj` is already on `device`. Well that works fine except + # when PyTorch is not compiled with CUDA support, in which case even calling + # `obj.to(torch.device("cpu"))` would result in an error. + return obj if obj.device == device else obj.to(device=device) elif isinstance(obj, dict): - return {key: move_to_device(value, cuda_device) for key, value in obj.items()} + for key, value in obj.items(): + obj[key] = move_to_device(value, device) + return obj elif isinstance(obj, list): - return [move_to_device(item, cuda_device) for item in obj] + for i, item in enumerate(obj): + obj[i] = move_to_device(item, device) + return obj elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # This is the best way to detect a NamedTuple, it turns out. - return obj.__class__(*(move_to_device(item, cuda_device) for item in obj)) + return obj.__class__(*(move_to_device(item, device) for item in obj)) elif isinstance(obj, tuple): - return tuple(move_to_device(item, cuda_device) for item in obj) + return tuple(move_to_device(item, device) for item in obj) else: return obj diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 7dda0a9e267..a4726de9405 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -22,7 +22,6 @@ from allennlp.common.checks import ConfigurationError, check_for_gpu from allennlp.data import DataLoader, TensorDict from allennlp.models.model import Model -from allennlp.nn import util as nn_util from allennlp.training import util as training_util from allennlp.training.checkpointer import Checkpointer from allennlp.training.learning_rate_schedulers import LearningRateScheduler @@ -499,7 +498,10 @@ def __init__( self.model = model self.data_loader = data_loader + self.data_loader.set_target_device(self.cuda_device) self._validation_data_loader = validation_data_loader + if self._validation_data_loader is not None: + self._validation_data_loader.set_target_device(self.cuda_device) self.optimizer = optimizer if patience is None: # no early stopping @@ -599,7 +601,6 @@ def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torc Does a forward pass on the given batch and returns the output dictionary that the model returns, after adding any specified regularization penalty to the loss (if training). """ - batch = nn_util.move_to_device(batch, self.cuda_device) output_dict = self._pytorch_model(**batch) if for_training: diff --git a/allennlp/training/util.py b/allennlp/training/util.py index ad2adac2685..09feb3481f3 100644 --- a/allennlp/training/util.py +++ b/allennlp/training/util.py @@ -16,7 +16,7 @@ from allennlp.common.checks import check_for_gpu, ConfigurationError from allennlp.common.params import Params from allennlp.common.tqdm import Tqdm -from allennlp.common.util import dump_metrics, sanitize +from allennlp.common.util import dump_metrics, sanitize, int_to_device from allennlp.data import Instance, Vocabulary, Batch, DataLoader from allennlp.data.dataset_readers import DatasetReader from allennlp.models.archival import CONFIG_NAME @@ -315,6 +315,7 @@ def evaluate( The final metrics. """ check_for_gpu(cuda_device) + data_loader.set_target_device(int_to_device(cuda_device)) predictions_file = ( None if predictions_output_file is None else open(predictions_output_file, "w") ) diff --git a/tests/commands/evaluate_test.py b/tests/commands/evaluate_test.py index 971816b1411..8ad4e624df5 100644 --- a/tests/commands/evaluate_test.py +++ b/tests/commands/evaluate_test.py @@ -23,6 +23,9 @@ def __iter__(self) -> Iterator[TensorDict]: def __len__(self): return len(self._outputs) + def set_target_device(self, _): + pass + class DummyModel(Model): def __init__(self) -> None: diff --git a/tests/data/data_loaders/multi_process_data_loader_test.py b/tests/data/data_loaders/multi_process_data_loader_test.py index dc7875e1295..8ae2000c317 100644 --- a/tests/data/data_loaders/multi_process_data_loader_test.py +++ b/tests/data/data_loaders/multi_process_data_loader_test.py @@ -1,11 +1,12 @@ -from typing import List, Iterable +from typing import List, Iterable, Dict +import torch import pytest - +from allennlp.common.testing import requires_gpu from allennlp.data.instance import Instance from allennlp.data.dataset_readers import DatasetReader from allennlp.data.data_loaders import MultiProcessDataLoader, WorkerError -from allennlp.data.fields import TextField, MetadataField +from allennlp.data.fields import Field, TextField, MetadataField, TensorField from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.data.token_indexers import PretrainedTransformerIndexer from allennlp.data.vocabulary import Vocabulary @@ -40,9 +41,12 @@ def _read(self, file_path: str): yield self.text_to_instance(i, source, target) def text_to_instance(self, index: int, source: str, target: str = None) -> Instance: # type: ignore - fields = {} + fields: Dict[str, Field] = {} fields["source"] = TextField(self.tokenizer.tokenize(source)) fields["index"] = MetadataField(index) # type: ignore + # It's important to have tests that use a tensor field since sending tensors + # between processes has a lot of pitfalls. + fields["tensor"] = TensorField(torch.tensor([1, 2, 3])) if target is not None: fields["target"] = TextField(self.tokenizer.tokenize(target)) return Instance(fields) # type: ignore @@ -171,3 +175,27 @@ def test_batches_per_epoch(): assert len(loader) == 10 assert len(list(loader)) == 10 + + +@pytest.mark.parametrize( + "options", + [ + dict(num_workers=0, batch_size=2), + dict(num_workers=1, batch_size=2), + dict(num_workers=1, batch_size=2, start_method="spawn"), + ], + ids=str, +) +@requires_gpu +def test_load_to_cuda(options): + reader = MockDatasetReader() + loader = MultiProcessDataLoader( + reader=reader, + data_path="this doens't matter", + cuda_device=0, + **options, + ) + vocab = Vocabulary.from_instances(loader.iter_instances()) + loader.index_with(vocab) + for batch in loader: + assert batch["tensor"].device == torch.device("cuda:0") diff --git a/tests/data/samplers/bucket_batch_sampler_test.py b/tests/data/samplers/bucket_batch_sampler_test.py index 0d2e814838b..3a972facdc2 100644 --- a/tests/data/samplers/bucket_batch_sampler_test.py +++ b/tests/data/samplers/bucket_batch_sampler_test.py @@ -91,14 +91,18 @@ def test_drop_last_works(self): sorting_keys=["text"], drop_last=True, ) + # We use a custom collate_fn for testing, which doesn't actually create tensors, # just the allennlp Batches. + def collate_fn(x, **kwargs): + return Batch(x) + data_loader = MultiProcessDataLoader( self.get_mock_reader(), "fake_path", batch_sampler=sampler, - collate_fn=lambda x: Batch(x), ) + data_loader.collate_fn = collate_fn data_loader.index_with(self.vocab) batches = [batch for batch in iter(data_loader)] stats = self.get_batches_stats(batches) diff --git a/tests/nn/util_test.py b/tests/nn/util_test.py index 705f3f7ab74..d98439534ff 100644 --- a/tests/nn/util_test.py +++ b/tests/nn/util_test.py @@ -1427,25 +1427,6 @@ def test_combine_tensors_and_multiply_with_batch_size_one_and_seq_len_one(self): assert_almost_equal(result.size(), [1, seq_len_1, seq_len_2]) - def test_has_tensor(self): - - has_tensor = util.has_tensor - tensor = torch.tensor([1, 2, 3]) - - assert has_tensor(["a", 10, tensor]) - assert not has_tensor(["a", 10]) - - assert has_tensor(("a", 10, tensor)) - assert not has_tensor(("a", 10)) - - assert has_tensor({"a": tensor, "b": 1}) - assert not has_tensor({"a": 10, "b": 1}) - - assert has_tensor(tensor) - assert not has_tensor(3) - - assert has_tensor({"x": [0, {"inside": {"double_inside": [3, [10, tensor]]}}]}) - def test_combine_initial_dims(self): tensor = torch.randn(4, 10, 20, 17, 5) @@ -1471,13 +1452,13 @@ def test_inspect_model_parameters(self): assert parameters_inspection_dict == util.inspect_parameters(model) def test_move_to_device(self): - # We're faking the tensor here so that we can test the calls to .cuda() without actually + # We're faking the tensor here so that we can test the calls to .to() without actually # needing a GPU. class FakeTensor(torch.Tensor): def __init__(self): self._device = None - def cuda(self, device): + def to(self, device, **kwargs): self._device = device return self diff --git a/tests/training/trainer_test.py b/tests/training/trainer_test.py index e64ed6f8c1a..8d46b8af40c 100644 --- a/tests/training/trainer_test.py +++ b/tests/training/trainer_test.py @@ -669,6 +669,9 @@ def __iter__(self): def __len__(self): return len(self.data_loader) + def set_target_device(self, _): + pass + trainer = GradientDescentTrainer( self.model, self.optimizer,