Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix shared memory with gluon dataloader, add option pin_memory #11908

Merged
merged 13 commits into from
Aug 10, 2018
1 change: 1 addition & 0 deletions ci/windows/test_py2_cpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

7z x -y windows_package.7z
$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
$env:PYTHONPATH=join-path $pwd.Path windows_package\python
$env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
Expand Down
1 change: 1 addition & 0 deletions ci/windows/test_py2_gpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

7z x -y windows_package.7z
$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
$env:PYTHONPATH=join-path $pwd.Path windows_package\python
$env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
c:\Anaconda3\envs\py2\Scripts\pip install -r tests\requirements.txt
Expand Down
1 change: 1 addition & 0 deletions ci/windows/test_py3_cpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

7z x -y windows_package.7z
$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
$env:PYTHONPATH=join-path $pwd.Path windows_package\python
$env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
Expand Down
1 change: 1 addition & 0 deletions ci/windows/test_py3_gpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

7z x -y windows_package.7z
$env:MXNET_LIBRARY_PATH=join-path $pwd.Path windows_package\lib\libmxnet.dll
$env:PYTHONPATH=join-path $pwd.Path windows_package\python
$env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
c:\Anaconda3\envs\py3\Scripts\pip install -r tests\requirements.txt
Expand Down
61 changes: 45 additions & 16 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=
# pylint: disable=ungrouped-imports
"""Dataset generator."""
__all__ = ['DataLoader']

Expand All @@ -26,6 +26,7 @@
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
import threading
import numpy as np

try:
Expand Down Expand Up @@ -149,6 +150,14 @@ def default_mp_batchify_fn(data):
ctx=context.Context('cpu_shared', 0))


def _as_in_context(data, ctx):
"""Move data into new context."""
if isinstance(data, nd.NDArray):
return data.as_in_context(ctx)
elif isinstance(data, (list, tuple)):
return [_as_in_context(d, ctx) for d in data]
return data

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
dataset._fork()
Expand All @@ -159,9 +168,21 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))

def fetcher_loop(data_queue, data_buffer, pin_memory=False):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
if idx is None:
break
if pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
else:
batch = _as_in_context(batch, context.cpu())
data_buffer[idx] = batch

class _MultiWorkerIter(object):
"""Interal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
Expand All @@ -184,6 +205,12 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
worker.start()
workers.append(worker)

self._fetcher = threading.Thread(
target=fetcher_loop,
args=(self._data_queue, self._data_buffer, pin_memory))
self._fetcher.daemon = True
self._fetcher.start()

# pre-fetch
for _ in range(2 * self._num_workers):
self._push_next()
Expand All @@ -210,13 +237,11 @@ def __next__(self):
raise StopIteration

while True:
self._push_next()
if self._rcvd_idx in self._data_buffer:
batch = self._data_buffer.pop(self._rcvd_idx)
self._rcvd_idx += 1
self._push_next()
return batch
idx, batch = self._data_queue.get()
self._data_buffer[idx] = batch

def next(self):
return self.__next__()
Expand All @@ -229,11 +254,7 @@ def shutdown(self):
if not self._shutdown:
for _ in range(self._num_workers):
self._key_queue.put((None, None))
try:
while not self._data_queue.empty():
self._data_queue.get()
except IOError:
pass
self._data_queue.put((None, None))
self._shutdown = True


Expand Down Expand Up @@ -277,12 +298,16 @@ def default_batchify_fn(data):

num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
`num_workers > 0` is not supported on Windows yet.
pin_memory : boolean, default False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why default to False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary for non-gpu instances, and using non-pagable memory for everything can screw up the stability

If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0):
num_workers=0, pin_memory=False):
self._dataset = dataset
self._pin_memory = pin_memory

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -315,13 +340,17 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,

def __iter__(self):
if self._num_workers == 0:
generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
for batch in self._batch_sampler]
return generator()
def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler)
self._batchify_fn, self._batch_sampler, self._pin_memory)

def __len__(self):
return len(self._batch_sampler)
168 changes: 79 additions & 89 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,99 +130,89 @@ def test_multi_worker():
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()

@with_seed()
def test_multi_worker_forked_data_loader():
class _Dummy(Dataset):
"""Dummy dataset for randomized shape arrays."""
def __init__(self, random_shape):
self.random_shape = random_shape

def __getitem__(self, idx):
key = idx
if self.random_shape:
out = np.random.uniform(size=(random.randint(1000, 1100), 40))
labels = np.random.uniform(size=(random.randint(10, 15)))
else:
out = np.random.uniform(size=(1000, 40))
labels = np.random.uniform(size=(10))
return key, out, labels

def __len__(self):
return 50

def _batchify_list(data):
"""
return list of ndarray without stack/concat/pad
"""
Test should successfully run its course of multi-process/forked data loader without errors
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def _batchify(data):
"""
Collate data into batch. Use shared memory for stacking.
:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""
class Dummy(Dataset):
def __init__(self, random_shape):
self.random_shape = random_shape

def __getitem__(self, idx):
key = idx
if self.random_shape:
out = np.random.uniform(size=(random.randint(1000, 1100), 40))
labels = np.random.uniform(size=(random.randint(10, 15)))
else:
out = np.random.uniform(size=(1000, 40))
labels = np.random.uniform(size=(10))
return key, out, labels

def __len__(self):
return 50

def batchify_list(self, data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def batchify(self, data):
"""
Collate data into batch. Use shared memory for stacking.

:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))

# This test is pointless on Windows because Windows doesn't fork
if platform.system() != 'Windows':
data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
if i % 100 == 0:
print(data)
print('{}:{}'.format(epoch, i))

data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
if i % 100 == 0:
print(data)
print('{}:{}'.format(epoch, i))
@with_seed()
def test_multi_worker_forked_data_loader():
data = _Dummy(False)
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass

data = _Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass

if __name__ == '__main__':
import nose
Expand Down