Skip to content

Commit

Permalink
Refactor training part of engine module. (keras-team#10029)
Browse files Browse the repository at this point in the history
* Refactor topological part of Keras engine.

* Fix imports

* Fix merge mixup.

* Refactor training part of the Keras engine.

* Fix unit tests.
  • Loading branch information
fchollet authored and Vijayabhaskar96 committed May 1, 2018
1 parent f38492e commit 78b3405
Show file tree
Hide file tree
Showing 8 changed files with 1,839 additions and 1,595 deletions.
1,770 changes: 253 additions & 1,517 deletions keras/engine/training.py

Large diffs are not rendered by default.

416 changes: 416 additions & 0 deletions keras/engine/training_arrays.py

Large diffs are not rendered by default.

461 changes: 461 additions & 0 deletions keras/engine/training_generator.py

Large diffs are not rendered by default.

546 changes: 546 additions & 0 deletions keras/engine/training_utils.py

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,44 @@ def is_all_none(iterable_or_element):
if element is not None:
return False
return True


def slice_arrays(arrays, start=None, stop=None):
"""Slices an array or list of arrays.
This takes an array-like, or a list of
array-likes, and outputs:
- arrays[start:stop] if `arrays` is an array-like
- [x[start:stop] for x in arrays] if `arrays` is a list
Can also work on list/array of indices: `_slice_arrays(x, indices)`
# Arguments
arrays: Single array or list of arrays.
start: can be an integer index (start index)
or a list/array of indices
stop: integer (stop index); should be None if
`start` was a list.
# Returns
A slice of the array(s).
"""
if arrays is None:
return [None]
elif isinstance(arrays, list):
if hasattr(start, '__len__'):
# hdf5 datasets only support list objects as indices
if hasattr(start, 'shape'):
start = start.tolist()
return [None if x is None else x[start] for x in arrays]
else:
return [None if x is None else x[start:stop] for x in arrays]
else:
if hasattr(start, '__len__'):
if hasattr(start, 'shape'):
start = start.tolist()
return arrays[start]
elif hasattr(start, '__getitem__'):
return arrays[start:stop]
else:
return [None]
192 changes: 118 additions & 74 deletions tests/keras/engine/test_training.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from keras.utils.test_utils import get_test_data, keras_test
from keras.models import model_from_json, model_from_yaml
from keras import losses
from keras.engine.training import _make_batches
from keras.engine.training_utils import make_batches


input_dim = 16
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_sequential(in_tmpdir):
def data_generator(x, y, batch_size=50):
index_array = np.arange(len(x))
while 1:
batches = _make_batches(len(x_test), batch_size)
batches = make_batches(len(x_test), batch_size)
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
x_batch = x[batch_ids]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_loss_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from keras.models import Sequential
from keras.engine.training import _weighted_masked_objective
from keras.engine.training_utils import weighted_masked_objective
from keras.layers import TimeDistributed, Masking, Dense
from keras.utils.test_utils import keras_test
from keras import losses
Expand All @@ -26,7 +26,7 @@ def test_masking():

@keras_test
def test_loss_masking():
weighted_loss = _weighted_masked_objective(losses.get('mae'))
weighted_loss = weighted_masked_objective(losses.get('mae'))
shape = (3, 4, 2)
x = np.arange(24).reshape(shape)
y = 2 * x
Expand Down

0 comments on commit 78b3405

Please sign in to comment.