Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable shuffle of slice data source at the end of epoch #1160

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/src/nnabla/utils/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def _next(self):
for i, v in enumerate(self._variables):
data[i].append(d[i])

if n_reset != 0:
self._data_source.apply_order()
self._queue.put((tuple([numpy.array(x) for x in data]), n_reset))

def next(self):
Expand Down
37 changes: 32 additions & 5 deletions python/src/nnabla/utils/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ class DataSource(object):
def _get_data(self, position):
pass

def _get_data_of_generation(self, position, generation):
return self._get_data(position)

def _generate_order(self, generation):
"""
When this function is called, caller expected this data source re-generate an order.
This order is labelled with a generation number. It means old order is still kept,
not overwrite the old order. Because it might be used by another slice. Of course,
this case never occurs in multi-node scenario, because multiple slice data source
is strictly synchronized.
"""
pass

def __init__(self, shuffle=False, rng=None):
'''
Init method for DataSource
Expand Down Expand Up @@ -96,6 +109,10 @@ def close(self):
atexit.unregister(self.close)
self._closed = True

@property
def order(self):
return self._order

@property
def variables(self):
'''variables
Expand Down Expand Up @@ -147,6 +164,13 @@ def shuffle(self, value):
def reset(self):
self._position = 0

def apply_order(self):
"""
This function is called when a batch is finished and an epoch is finished.
It is time for update the order if shuffle is true.
"""
pass


class DataSourceWithFileCacheError(Exception):
pass
Expand All @@ -168,7 +192,7 @@ class DataSourceWithFileCache(DataSource):
Default is None.
cache_file_name_prefix (str):
Beginning of the filenames of cache files.
Default is 'cache'.
Default is 'cache'.
shuffle (bool):
Indicates whether the dataset is shuffled or not.
rng (None or :obj:`numpy.random.RandomState`): Numpy random number
Expand Down Expand Up @@ -599,16 +623,19 @@ def __init__(self, data_source, shuffle=False, rng=None, slice_start=None, slice
self._slice_start = slice_start
self._slice_end = slice_end
self._size = self._slice_end - self._slice_start
self._generation = -1
self.reset()
self._generation = 0

def apply_order(self):
self._generation += 1
self._data_source._generate_order(self._generation)

def reset(self):
self._data_source.reset()
self._data_source._position = self._slice_start
self._generation += 1
self._position = 0

def _get_data(self, position):
self._position = position
data = self._data_source._get_data(self._slice_start + position)
data = self._data_source._get_data_of_generation(
self._slice_start + position, self._generation)
return data
44 changes: 34 additions & 10 deletions python/src/nnabla/utils/data_source_implements.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import csv
import os
import threading
from collections import OrderedDict
from collections import OrderedDict, deque
from time import sleep

import numpy
Expand All @@ -42,25 +42,49 @@ class SimpleDataSource(DataSource):
'''

def _get_data(self, position):
return self._load_func(self._order[position])
return self._get_data_of_generation(position, 0)

def _get_data_of_generation(self, position, generation):
index = None
for s in self._orders:
if s.generation == generation:
index = s.order
break
if index is None:
raise ValueError("Its client is not synchronized!")
return self._load_func(index[position])

def apply_order(self):
self._generation += 1
self._generate_order(self._generation)

def _generate_order(self, generation):
for s in self._orders:
if s.generation == generation:
return

class Snapshot:
pass
snapshot = Snapshot()
snapshot.order = self._rng.permutation(
self._size) if self._shuffle else numpy.arange(self._size)
snapshot.generation = generation
self._orders.append(snapshot)

def reset(self):
self._indexes = self._rng.permutation(
self._size) if self._shuffle else numpy.arange(self._size)
super(SimpleDataSource, self).reset()

def __init__(self, load_func, num_examples, shuffle=False, rng=None):
super(SimpleDataSource, self).__init__(shuffle=shuffle, rng=rng)
super(SimpleDataSource, self).reset()
self._load_func = load_func
self._size = num_examples
self._generation = 0
self._variables = ['x' + str(x)
for x in range(len(self._load_func(0)))]
if shuffle:
self._order = list(
self._rng.permutation(list(range(self._size))))
else:
self._order = list(range(self._size))
# We set maximal unsynchnonized generations.
self._orders = deque(maxlen=15)
self.reset()
self._generate_order(0)


class CachePrefetcher(object):
Expand Down
84 changes: 54 additions & 30 deletions python/test/utils/test_sliced_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def lcm(a, b):
max_epoch = lcm(batch_size, size) / size

def test_load_func(position):
return np.full((1), position, dtype=int)
return np.full((1), position, dtype=np.int32)

def simple_load_func(data_set, position):
return data_set[position]
Expand All @@ -65,7 +65,7 @@ def get_data(iter_list, iter_num):
sliced_di_list.append(sliced_di)

ref_di_list = []
all_data = [np.full((1), position, dtype=int)
all_data = [np.full((1), position, dtype=np.int32)
for position in range(size)]
slice_block_size = size // num_of_slices
if not drop_last:
Expand Down Expand Up @@ -109,46 +109,70 @@ def get_data(iter_list, iter_num):
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.parametrize("drop_last", [True])
def test_sliced_data_iterator_duplicated_unexpectedly(test_data_csv_png_10, num_of_slices, size, batch_size, shuffle, drop_last):

# As the discussion in this: https://app.slack.com/client/TAVSDRN92/CCWT3ENA2/thread/CCWT3ENA2-1653356233.062649
# Each slice might have overlapped items after the random order of each epoch is updated.
# It only assumes that each slice at each epoch have different dataset.
def test_load_func(position):
return np.full((1), position, dtype=np.float32)
return np.full((1), position, dtype=np.int32)

di = data_iterator_simple(test_load_func, size,
batch_size, shuffle=shuffle)

def lcm(a, b):
return abs(a * b) / math.gcd(a, b) if a and b else 0

max_epoch = lcm(batch_size, size) / size
max_epoch = int(lcm(batch_size, size) / size)

all_data = []
epoch = []
data_per_epoch = []

slice_di_list = []
for slice_pos in range(num_of_slices):
sliced_di = di.slice(
rng=None, num_of_slices=num_of_slices, slice_pos=slice_pos, drop_last=drop_last)
sliced_data = {}
while True:
current_epoch = sliced_di.epoch
if current_epoch > max_epoch + 1:
break
data = sliced_di.next()
if current_epoch not in sliced_data:
sliced_data[current_epoch] = []
for dat in data:
for d in dat:
sliced_data[current_epoch].append(d)
all_data.append(sliced_data)

epochs = {}
for slice_pos, sliced_data in enumerate(all_data):
for epoch in sorted(sliced_data.keys()):
if epoch not in epochs:
epochs[epoch] = []
epochs[epoch].append(set(sliced_data[epoch]))

for epoch in sorted(epochs.keys()):
x0 = epochs[epoch][0]
for dup in [x0 & x for x in epochs[epoch][1:]]:
assert len(dup) == 0
slice_di_list.append(sliced_di)

offset = 0
for _ in range(max_epoch):
data_per_batch = []
for sliced_di in slice_di_list:
batch = sliced_di.next()
data_per_batch.append([d for dat in batch for d in dat])
data_per_epoch.append(data_per_batch[-1][:])
for i in range(len(data_per_batch)-1):
for j in range(i + 1, len(data_per_batch)):
ok = len(set(data_per_batch[i]) & set(data_per_batch[j])) == 0
if not ok:
print("=" * 50)
for x in range(len(data_per_batch)):
print(
f"{len(epoch)}: {i} <--> {j}: {sorted(data_per_batch[x])}")
print(f"{set(data_per_batch[i]) & set(data_per_batch[j])}")
print("=" * 50)
assert ok, "data overlapped in same epoch!"
if offset + batch_size >= size:
epoch.append(data_per_epoch)
data_per_epoch = []
offset += batch_size
offset -= size
else:
offset += batch_size

if shuffle:
for i in range(len(epoch) - 1):
for j in range(i + 1, len(epoch)):
for k in range(len(epoch[i])):
ok = set(epoch[i][k]) != set(epoch[j][k])
if not ok:
print("-" * 50)
print(f"{k}: {i} {sorted(epoch[i][k])}")
print(f"{k}: {j} {sorted(epoch[j][k])}")
print("-" * 50)
assert ok, "Not allow duplicated data set occurs in shuffle mode."

for sliced_di in slice_di_list:
sliced_di.close()
di.close()


@pytest.mark.parametrize("num_of_slices", [2, 3, 5])
Expand Down