Skip to content

Commit

Permalink
Merge pull request #1160 from sony/feature/20221222-fix-data-source-s…
Browse files Browse the repository at this point in the history
…imple-shuffle

enable shuffle of slice data source at the end of epoch
  • Loading branch information
YukioOobuchi authored Jan 24, 2023
2 parents a460e10 + 124b03b commit 76268f2
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 45 deletions.
2 changes: 2 additions & 0 deletions python/src/nnabla/utils/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def _next(self):
for i, v in enumerate(self._variables):
data[i].append(d[i])

if n_reset != 0:
self._data_source.apply_order()
self._queue.put((tuple([numpy.array(x) for x in data]), n_reset))

def next(self):
Expand Down
37 changes: 32 additions & 5 deletions python/src/nnabla/utils/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ class DataSource(object):
def _get_data(self, position):
pass

def _get_data_of_generation(self, position, generation):
return self._get_data(position)

def _generate_order(self, generation):
"""
When this function is called, caller expected this data source re-generate an order.
This order is labelled with a generation number. It means old order is still kept,
not overwrite the old order. Because it might be used by another slice. Of course,
this case never occurs in multi-node scenario, because multiple slice data source
is strictly synchronized.
"""
pass

def __init__(self, shuffle=False, rng=None):
'''
Init method for DataSource
Expand Down Expand Up @@ -96,6 +109,10 @@ def close(self):
atexit.unregister(self.close)
self._closed = True

@property
def order(self):
return self._order

@property
def variables(self):
'''variables
Expand Down Expand Up @@ -147,6 +164,13 @@ def shuffle(self, value):
def reset(self):
self._position = 0

def apply_order(self):
"""
This function is called when a batch is finished and an epoch is finished.
It is time for update the order if shuffle is true.
"""
pass


class DataSourceWithFileCacheError(Exception):
pass
Expand All @@ -168,7 +192,7 @@ class DataSourceWithFileCache(DataSource):
Default is None.
cache_file_name_prefix (str):
Beginning of the filenames of cache files.
Default is 'cache'.
Default is 'cache'.
shuffle (bool):
Indicates whether the dataset is shuffled or not.
rng (None or :obj:`numpy.random.RandomState`): Numpy random number
Expand Down Expand Up @@ -599,16 +623,19 @@ def __init__(self, data_source, shuffle=False, rng=None, slice_start=None, slice
self._slice_start = slice_start
self._slice_end = slice_end
self._size = self._slice_end - self._slice_start
self._generation = -1
self.reset()
self._generation = 0

def apply_order(self):
self._generation += 1
self._data_source._generate_order(self._generation)

def reset(self):
self._data_source.reset()
self._data_source._position = self._slice_start
self._generation += 1
self._position = 0

def _get_data(self, position):
self._position = position
data = self._data_source._get_data(self._slice_start + position)
data = self._data_source._get_data_of_generation(
self._slice_start + position, self._generation)
return data
44 changes: 34 additions & 10 deletions python/src/nnabla/utils/data_source_implements.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import csv
import os
import threading
from collections import OrderedDict
from collections import OrderedDict, deque
from time import sleep

import numpy
Expand All @@ -42,25 +42,49 @@ class SimpleDataSource(DataSource):
'''

def _get_data(self, position):
return self._load_func(self._order[position])
return self._get_data_of_generation(position, 0)

def _get_data_of_generation(self, position, generation):
index = None
for s in self._orders:
if s.generation == generation:
index = s.order
break
if index is None:
raise ValueError("Its client is not synchronized!")
return self._load_func(index[position])

def apply_order(self):
self._generation += 1
self._generate_order(self._generation)

def _generate_order(self, generation):
for s in self._orders:
if s.generation == generation:
return

class Snapshot:
pass
snapshot = Snapshot()
snapshot.order = self._rng.permutation(
self._size) if self._shuffle else numpy.arange(self._size)
snapshot.generation = generation
self._orders.append(snapshot)

def reset(self):
self._indexes = self._rng.permutation(
self._size) if self._shuffle else numpy.arange(self._size)
super(SimpleDataSource, self).reset()

def __init__(self, load_func, num_examples, shuffle=False, rng=None):
super(SimpleDataSource, self).__init__(shuffle=shuffle, rng=rng)
super(SimpleDataSource, self).reset()
self._load_func = load_func
self._size = num_examples
self._generation = 0
self._variables = ['x' + str(x)
for x in range(len(self._load_func(0)))]
if shuffle:
self._order = list(
self._rng.permutation(list(range(self._size))))
else:
self._order = list(range(self._size))
# We set maximal unsynchnonized generations.
self._orders = deque(maxlen=15)
self.reset()
self._generate_order(0)


class CachePrefetcher(object):
Expand Down
84 changes: 54 additions & 30 deletions python/test/utils/test_sliced_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def lcm(a, b):
max_epoch = lcm(batch_size, size) / size

def test_load_func(position):
return np.full((1), position, dtype=int)
return np.full((1), position, dtype=np.int32)

def simple_load_func(data_set, position):
return data_set[position]
Expand All @@ -65,7 +65,7 @@ def get_data(iter_list, iter_num):
sliced_di_list.append(sliced_di)

ref_di_list = []
all_data = [np.full((1), position, dtype=int)
all_data = [np.full((1), position, dtype=np.int32)
for position in range(size)]
slice_block_size = size // num_of_slices
if not drop_last:
Expand Down Expand Up @@ -109,46 +109,70 @@ def get_data(iter_list, iter_num):
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.parametrize("drop_last", [True])
def test_sliced_data_iterator_duplicated_unexpectedly(test_data_csv_png_10, num_of_slices, size, batch_size, shuffle, drop_last):

# As the discussion in this: https://app.slack.com/client/TAVSDRN92/CCWT3ENA2/thread/CCWT3ENA2-1653356233.062649
# Each slice might have overlapped items after the random order of each epoch is updated.
# It only assumes that each slice at each epoch have different dataset.
def test_load_func(position):
return np.full((1), position, dtype=np.float32)
return np.full((1), position, dtype=np.int32)

di = data_iterator_simple(test_load_func, size,
batch_size, shuffle=shuffle)

def lcm(a, b):
return abs(a * b) / math.gcd(a, b) if a and b else 0

max_epoch = lcm(batch_size, size) / size
max_epoch = int(lcm(batch_size, size) / size)

all_data = []
epoch = []
data_per_epoch = []

slice_di_list = []
for slice_pos in range(num_of_slices):
sliced_di = di.slice(
rng=None, num_of_slices=num_of_slices, slice_pos=slice_pos, drop_last=drop_last)
sliced_data = {}
while True:
current_epoch = sliced_di.epoch
if current_epoch > max_epoch + 1:
break
data = sliced_di.next()
if current_epoch not in sliced_data:
sliced_data[current_epoch] = []
for dat in data:
for d in dat:
sliced_data[current_epoch].append(d)
all_data.append(sliced_data)

epochs = {}
for slice_pos, sliced_data in enumerate(all_data):
for epoch in sorted(sliced_data.keys()):
if epoch not in epochs:
epochs[epoch] = []
epochs[epoch].append(set(sliced_data[epoch]))

for epoch in sorted(epochs.keys()):
x0 = epochs[epoch][0]
for dup in [x0 & x for x in epochs[epoch][1:]]:
assert len(dup) == 0
slice_di_list.append(sliced_di)

offset = 0
for _ in range(max_epoch):
data_per_batch = []
for sliced_di in slice_di_list:
batch = sliced_di.next()
data_per_batch.append([d for dat in batch for d in dat])
data_per_epoch.append(data_per_batch[-1][:])
for i in range(len(data_per_batch)-1):
for j in range(i + 1, len(data_per_batch)):
ok = len(set(data_per_batch[i]) & set(data_per_batch[j])) == 0
if not ok:
print("=" * 50)
for x in range(len(data_per_batch)):
print(
f"{len(epoch)}: {i} <--> {j}: {sorted(data_per_batch[x])}")
print(f"{set(data_per_batch[i]) & set(data_per_batch[j])}")
print("=" * 50)
assert ok, "data overlapped in same epoch!"
if offset + batch_size >= size:
epoch.append(data_per_epoch)
data_per_epoch = []
offset += batch_size
offset -= size
else:
offset += batch_size

if shuffle:
for i in range(len(epoch) - 1):
for j in range(i + 1, len(epoch)):
for k in range(len(epoch[i])):
ok = set(epoch[i][k]) != set(epoch[j][k])
if not ok:
print("-" * 50)
print(f"{k}: {i} {sorted(epoch[i][k])}")
print(f"{k}: {j} {sorted(epoch[j][k])}")
print("-" * 50)
assert ok, "Not allow duplicated data set occurs in shuffle mode."

for sliced_di in slice_di_list:
sliced_di.close()
di.close()


@pytest.mark.parametrize("num_of_slices", [2, 3, 5])
Expand Down

0 comments on commit 76268f2

Please sign in to comment.