Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passed the scheduling argument through the *_generator function. #7236

Merged
merged 13 commits into from
Jul 19, 2017
7 changes: 6 additions & 1 deletion keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,7 @@ def fit_generator(self, generator,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0):
"""Fits the model on data yielded batch-by-batch by a Python generator.

Expand Down Expand Up @@ -1680,6 +1681,9 @@ def fit_generator(self, generator,
non picklable arguments to the generator
as they can't be passed
easily to children processes.
shuffle: whether to shuffle the data at the beginning of each
epoch. Only used with instances of `Sequence` (
keras.utils.Sequence).
initial_epoch: epoch at which to start training
(useful for resuming a previous training run)

Expand Down Expand Up @@ -1781,7 +1785,8 @@ def generate_arrays_from_file(path):
try:
if is_sequence:
enqueuer = OrderedEnqueuer(generator,
use_multiprocessing=use_multiprocessing)
use_multiprocessing=use_multiprocessing,
shuffle=shuffle)
else:
enqueuer = GeneratorEnqueuer(generator,
use_multiprocessing=use_multiprocessing,
Expand Down
16 changes: 12 additions & 4 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ def __len__(self):
"""
raise NotImplementedError

@abstractmethod
def on_epoch_end(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of this method appears unrelated to the shuffle arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is and it isn't.
The shuffle arg relates to the shuffling of the batch indices in the sequences, however, you may also wish to implement file path shuffling in the Sequence method. This would allow that by placing that shuffle operator into the on_epoch_end method.
The path shuffling would ensure that the files inside the batches differed as well as the batch ordering.

"""Method called at the end of every epoch.
"""
raise NotImplementedError


def get_index(ds, i):
"""Quick fix for Python2, otherwise, it cannot be pickled.
Expand Down Expand Up @@ -434,15 +440,15 @@ class OrderedEnqueuer(SequenceEnqueuer):
# Arguments
sequence: A `keras.utils.data_utils.Sequence` object.
use_multiprocessing: use multiprocessing if True, otherwise threading
scheduling: Sequential querying of datas if 'sequential', random otherwise.
shuffle: whether to shuffle the data at the beginning of each epoch
"""

def __init__(self, sequence,
use_multiprocessing=False,
scheduling='sequential'):
shuffle=False):
self.sequence = sequence
self.use_multiprocessing = use_multiprocessing
self.scheduling = scheduling
self.shuffle = shuffle
self.workers = 0
self.executor = None
self.queue = None
Expand Down Expand Up @@ -474,14 +480,16 @@ def _run(self):
"""Function to submit request to the executor and queue the `Future` objects."""
sequence = list(range(len(self.sequence)))
while True:
if self.scheduling is not 'sequential':
if self.shuffle:
random.shuffle(sequence)
for i in sequence:
if self.stop_signal.is_set():
return
self.queue.put(
self.executor.apply_async(get_index,
(self.sequence, i)), block=True)
# Call the internal on epoch end.
self.sequence.on_epoch_end()

def get(self):
"""Creates a generator to extract data from the queue.
Expand Down
3 changes: 3 additions & 0 deletions tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __getitem__(self, idx):
np.random.random((self.batch_size, 4)),
np.random.random((self.batch_size, 3))]

def on_epoch_end(self):
pass


@keras_test
def test_check_array_lengths():
Expand Down
19 changes: 19 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def __getitem__(self, item):
def __len__(self):
return 100

def on_epoch_end(self):
pass


class FaultSequence(Sequence):
def __getitem__(self, item):
Expand All @@ -129,6 +132,9 @@ def __getitem__(self, item):
def __len__(self):
return 100

def on_epoch_end(self):
pass


@threadsafe_generator
def create_generator_from_sequence_threads(ds):
Expand Down Expand Up @@ -199,6 +205,19 @@ def test_ordered_enqueuer_threads():
enqueuer.stop()


def test_ordered_enqueuer_threads_not_ordered():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]),
use_multiprocessing=False,
shuffle=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
for i in range(100):
acc.append(next(gen_output)[0, 0, 0, 0])
assert acc != list(range(100)), "Order was not keep in GeneratorEnqueuer with threads"
enqueuer.stop()


def test_ordered_enqueuer_processes():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
enqueuer.start(3, 10)
Expand Down