From abd23fa4cf50f02784f9b7336f7710c696a5f286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Mon, 21 Aug 2017 14:40:06 -0700 Subject: [PATCH] [RELNOTES] Add step-wise prediction and evaluation. (#7703) * Add step-wise prediction and evaluation. * Fix bug. --- keras/engine/training.py | 297 +++++++++++++++++----------- tests/keras/engine/test_training.py | 26 ++- 2 files changed, 193 insertions(+), 130 deletions(-) diff --git a/keras/engine/training.py b/keras/engine/training.py index 2eea3c1e0e1..0b3e6574203 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -573,15 +573,15 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None, """Configures the model for training. # Arguments - optimizer: str (name of optimizer) or optimizer object. + optimizer: String (name of optimizer) or optimizer object. See [optimizers](/optimizers). - loss: str (name of objective function) or objective function. + loss: String (name of objective function) or objective function. See [losses](/losses). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses. - metrics: list of metrics to be evaluated by the model + metrics: List of metrics to be evaluated by the model during training and testing. Typically you will use `metrics=['accuracy']`. To specify different metrics for different outputs of a @@ -596,13 +596,13 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None, If a list, it is expected to have a 1:1 mapping to the model's outputs. If a tensor, it is expected to map output names (strings) to scalar coefficients. - sample_weight_mode: if you need to do timestep-wise + sample_weight_mode: If you need to do timestep-wise sample weighting (2D weights), set this to `"temporal"`. `None` defaults to sample-wise weights (1D). If the model has multiple outputs, you can use a different `sample_weight_mode` on each output by passing a dictionary or a list of modes. - target_tensors: by default, Keras will create placeholders for the + target_tensors: By default, Keras will create placeholders for the model's target, which will be fed with the target data during training. If instead you would like to use your own target tensors (in turn, Keras will not expect external @@ -610,9 +610,9 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None, can specify them via the `target_tensors` argument. It can be a single tensor (for a single-output model), a list of tensors, or a dict mapping output names to target tensors. - weighted_metrics: list of metrics to be evaluated and weighted + weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing - **kwargs: when using the Theano/CNTK backends, these arguments + **kwargs: When using the Theano/CNTK backends, these arguments are passed into K.function. When using the TensorFlow backend, these arguments are passed into `tf.Session.run`. @@ -1005,8 +1005,8 @@ def _check_num_samples(self, ins, batch_size=None, steps=None, steps_name='steps in which case the number of samples is set to `None`. # Arguments - ins: list of tensors to be fed to the Keras function. - batch_size: integer batch size or `None` if not defined. + ins: List of tensors to be fed to the Keras function. + batch_size: Integer batch size or `None` if not defined. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. Ignored with the default value of `None`. @@ -1028,12 +1028,13 @@ def _check_num_samples(self, ins, batch_size=None, steps=None, steps_name='steps num_samples = None if batch_size is not None: raise ValueError('If ' + steps_name + - ' is set the batch_size must be None.') + ' is set, the `batch_size` must be None.') elif ins and hasattr(ins[0], 'shape'): num_samples = ins[0].shape[0] else: - raise ValueError('The input data should have shape, or ' - 'please specify ' + steps_name + '.') + raise ValueError('Either the input data should have ' + 'a defined shape, or ' + steps_name + + ' should be specified.') return num_samples def _fit_loop(self, f, ins, out_labels=None, batch_size=None, @@ -1047,21 +1048,21 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None, # Arguments f: Keras function returning a list of tensors - ins: list of tensors to be fed to `f` - out_labels: list of strings, display names of + ins: List of tensors to be fed to `f` + out_labels: List of strings, display names of the outputs of `f` - batch_size: integer batch size or None if unknown. - epochs: number of times to iterate over the data - verbose: verbosity mode, 0, 1 or 2 - callbacks: list of callbacks to be called during training + batch_size: Integer batch size or None if unknown. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training val_f: Keras function to call for validation - val_ins: list of tensors to be fed to `val_f` - shuffle: whether to shuffle the data at the beginning of each epoch - callback_metrics: list of strings, the display names of the metrics + val_ins: List of tensors to be fed to `val_f` + shuffle: Whether to shuffle the data at the beginning of each epoch + callback_metrics: List of strings, the display names of the metrics passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. - initial_epoch: epoch at which to start training + initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the @@ -1129,11 +1130,11 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None, callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: - for step_num in range(steps_per_epoch): + for step_index in range(steps_per_epoch): batch_logs = {} - batch_logs['batch'] = step_num + batch_logs['batch'] = step_index batch_logs['size'] = 1 - callbacks.on_batch_begin(step_num, batch_logs) + callbacks.on_batch_begin(step_index, batch_logs) outs = f(ins) if not isinstance(outs, list): @@ -1141,7 +1142,7 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None, for l, o in zip(out_labels, outs): batch_logs[l] = o - callbacks.on_batch_end(step_num, batch_logs) + callbacks.on_batch_end(step_index, batch_logs) if callback_model.stop_training: break @@ -1204,7 +1205,7 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None, callbacks.on_train_end() return self.history - def _predict_loop(self, f, ins, batch_size=32, verbose=0): + def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None): """Abstract method to loop over some data in batches. # Arguments @@ -1221,44 +1222,64 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0): or list of arrays of predictions (if the model has multiple outputs). """ - if ins and hasattr(ins[0], 'shape'): - samples = ins[0].shape[0] - else: - # May happen if we are running `predict` without Numpy input data, - # i.e. if all inputs to the models are data tensors - # instead of placeholders. - # In that case we will run `predict` over a single batch. - samples = batch_size - verbose = 2 - - outs = [] + num_samples = self._check_num_samples(ins, batch_size, + steps, + 'steps') if verbose == 1: - progbar = Progbar(target=samples) - batches = _make_batches(samples, batch_size) - index_array = np.arange(samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - if ins and isinstance(ins[-1], float): - # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + if steps is not None: + progbar = Progbar(target=steps) else: - ins_batch = _slice_arrays(ins, batch_ids) - - batch_outs = f(ins_batch) - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if batch_index == 0: - for batch_out in batch_outs: - shape = (samples,) + batch_out.shape[1:] - outs.append(np.zeros(shape, dtype=batch_out.dtype)) - - for i, batch_out in enumerate(batch_outs): - outs[i][batch_start:batch_end] = batch_out - if verbose == 1: - progbar.update(batch_end) - if len(outs) == 1: - return outs[0] - return outs + progbar = Progbar(target=num_samples) + if steps is not None: + # Step-based predictions. + # Since we do not know how many samples + # we will see, we cannot pre-allocate + # the returned Numpy arrays. + # Instead, we store one array per batch seen + # and concatenate them upon returning. + unconcatenated_outs = [] + for step in range(steps): + batch_outs = f(ins) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if step == 0: + for batch_out in batch_outs: + unconcatenated_outs.append([]) + for i, batch_out in enumerate(batch_outs): + unconcatenated_outs[i].append(batch_out) + if verbose == 1: + progbar.update(step) + if len(unconcatenated_outs) == 1: + return np.concatenate(unconcatenated_outs[0], axis=0) + return [np.concatenate(unconcatenated_outs[i], axis=0) + for i in range(len(unconcatenated_outs))] + else: + # Sample-based predictions. + outs = [] + batches = _make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + if ins and isinstance(ins[-1], float): + # Do not slice the training phase flag. + ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + else: + ins_batch = _slice_arrays(ins, batch_ids) + batch_outs = f(ins_batch) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + shape = (num_samples,) + batch_out.shape[1:] + outs.append(np.zeros(shape, dtype=batch_out.dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + if len(outs) == 1: + return outs[0] + return outs def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None): """Abstract method to loop over some data in batches. @@ -1278,30 +1299,34 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ + num_samples = self._check_num_samples(ins, batch_size, + steps, + 'steps') outs = [] - if batch_size is None and steps is not None: + if steps is not None: if verbose == 1: progbar = Progbar(target=steps) - for step_num in range(steps): - f(ins) + for step in range(steps): + batch_outs = f(ins) + if isinstance(batch_outs, list): + if step == 0: + for _ in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out + else: + if step == 0: + outs.append(0.) + outs[0] += batch_outs if verbose == 1: - progbar.update(step_num) - + progbar.update(step) + for i in range(len(outs)): + outs[i] /= steps else: - if ins and hasattr(ins[0], 'shape'): - samples = ins[0].shape[0] - else: - # May happen if we are running `predict` without Numpy input data, - # i.e. if all inputs to the models are data tensors - # instead of placeholders. - # In that case we will run `predict` over a single batch. - samples = batch_size - verbose = 2 - if verbose == 1: - progbar = Progbar(target=samples) - batches = _make_batches(samples, batch_size) - index_array = np.arange(samples) + progbar = Progbar(target=num_samples) + batches = _make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] if isinstance(ins[-1], float): @@ -1324,8 +1349,8 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None): if verbose == 1: progbar.update(batch_end) - for i in range(len(outs)): - outs[i] /= samples + for i in range(len(outs)): + outs[i] /= num_samples if len(outs) == 1: return outs[0] return outs @@ -1416,44 +1441,44 @@ def fit(self, x=None, If all outputs in the model are named, you can also pass a dictionary mapping output names to Numpy arrays. - batch_size: integer or `None`. Number of samples per gradient update. - Defaults to 32 if training numpy arrays and no batch - size is specified. Defaults to `None` when `steps_per_epoch` is set. - epochs: integer, the number of times to iterate + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, it will default to 32. + epochs: Integer, the number of times to iterate over the training data arrays. verbose: 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch. - callbacks: list of callbacks to be called during training. + callbacks: List of callbacks to be called during training. See [callbacks](/callbacks). - validation_split: float between 0 and 1: + validation_split: Float between 0 and 1: fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. - validation_data: data on which to evaluate + validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. This could be a tuple (x_val, y_val) or a tuple (x_val, y_val, val_sample_weights). - shuffle: boolean, whether to shuffle the training data + shuffle: Boolean, whether to shuffle the training data before each epoch. Has no effect when `steps_per_epoch` is not `None`. - class_weight: optional dictionary mapping + class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. - sample_weight: optional array of the same length as x, containing + sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). - initial_epoch: epoch at which to start training + initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the @@ -1473,7 +1498,7 @@ class indices (integers) to ValueError: In case of mismatch between the provided input data and what the model expects. """ - # backwards compatibility + # Backwards compatibility if batch_size is None and steps_per_epoch is None: batch_size = 32 # Legacy support @@ -1483,7 +1508,10 @@ class indices (integers) to epochs = kwargs.pop('nb_epoch') if kwargs: raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) - + if x is None and y is None and steps_per_epoch is None: + raise ValueError('If fitting from data tensors, ' + 'you should specify the `steps_per_epoch` ' + 'argument.') # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, @@ -1566,7 +1594,11 @@ class indices (integers) to steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) - def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): + def evaluate(self, x, y, + batch_size=None, + verbose=1, + sample_weight=None, + steps=None): """Returns the loss value & metrics values for the model in test mode. Computation is done in batches. @@ -1582,10 +1614,13 @@ def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): If all outputs in the model are named, you can also pass a dictionary mapping output names to Numpy arrays. - batch_size: integer. Number of samples per gradient update. - verbose: verbosity mode, 0 or 1. + batch_size: Integer. If unspecified, it will default to 32. + verbose: Verbosity mode, 0 or 1. sample_weight: Array of weights to weight the contribution of different samples to the loss and metrics. + steps: Total number of steps (batches of samples) + before declaring the evaluation round finished. + Ignored with the default value of `None`. # Returns Scalar test loss (if the model has a single output and no metrics) @@ -1593,6 +1628,13 @@ def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ + # Backwards compatibility. + if batch_size is None and steps is None: + batch_size = 32 + if x is None and y is None and steps is None: + raise ValueError('If evaluating from data tensors, ' + 'you should specify the `steps` ' + 'argument.') # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, @@ -1608,18 +1650,25 @@ def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): f = self.test_function return self._test_loop(f, ins, batch_size=batch_size, - verbose=verbose) + verbose=verbose, + steps=steps) - def predict(self, x, batch_size=32, verbose=0): + def predict(self, x, + batch_size=None, + verbose=0, + steps=None): """Generates output predictions for the input samples. Computation is done in batches. # Arguments - x: the input data, as a Numpy array + x: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple outputs). - batch_size: integer. - verbose: verbosity mode, 0 or 1. + batch_size: Integer. If unspecified, it will default to 32. + verbose: Verbosity mode, 0 or 1. + steps: Total number of steps (batches of samples) + before declaring the prediction round finished. + Ignored with the default value of `None`. # Returns Numpy array(s) of predictions. @@ -1630,6 +1679,13 @@ def predict(self, x, batch_size=32, verbose=0): or in case a stateful model receives a number of samples that is not a multiple of the batch size. """ + # Backwards compatibility. + if batch_size is None and steps is None: + batch_size = 32 + if x is None and steps is None: + raise ValueError('If predicting from data tensors, ' + 'you should specify the `steps` ' + 'argument.') # Validate user data. x = _standardize_input_data(x, self._feed_input_names, self._feed_input_shapes, @@ -1651,10 +1707,11 @@ def predict(self, x, batch_size=32, verbose=0): self._make_predict_function() f = self.predict_function return self._predict_loop(f, ins, batch_size=batch_size, - verbose=verbose) + verbose=verbose, steps=steps) def train_on_batch(self, x, y, - sample_weight=None, class_weight=None): + sample_weight=None, + class_weight=None): """Runs a single gradient update on a single batch of data. # Arguments @@ -1668,14 +1725,14 @@ def train_on_batch(self, x, y, If all outputs in the model are named, you can also pass a dictionary mapping output names to Numpy arrays. - sample_weight: optional array of the same length as x, containing + sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). - class_weight: optional dictionary mapping + class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. @@ -1718,7 +1775,7 @@ def test_on_batch(self, x, y, sample_weight=None): If all outputs in the model are named, you can also pass a dictionary mapping output names to Numpy arrays. - sample_weight: optional array of the same length as x, containing + sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), @@ -1792,7 +1849,7 @@ def fit_generator(self, generator, using `use_multiprocessing=True`. # Arguments - generator: a generator or an instance of Sequence (keras.utils.Sequence) + generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either @@ -1807,32 +1864,32 @@ def fit_generator(self, generator, finished and starting the next epoch. It should typically be equal to the number of unique samples if your dataset divided by the batch size. - epochs: integer, total number of iterations on the data. - verbose: verbosity mode, 0, 1, or 2. - callbacks: list of callbacks to be called during training. - validation_data: this can be either + epochs: Integer, total number of iterations on the data. + verbose: Verbosity mode, 0, 1, or 2. + callbacks: List of callbacks to be called during training. + validation_data: This can be either - a generator for the validation data - a tuple (inputs, targets) - a tuple (inputs, targets, sample_weights). validation_steps: Only relevant if `validation_data` is a generator. Total number of steps (batches of samples) to yield from `generator` before stopping. - class_weight: dictionary mapping class indices to a weight + class_weight: Dictionary mapping class indices to a weight for the class. - max_queue_size: maximum size for the generator queue - workers: maximum number of processes to spin up + max_queue_size: Maximum size for the generator queue + workers: Maximum number of processes to spin up when using process based threading - use_multiprocessing: if True, use process based threading. + use_multiprocessing: If True, use process based threading. Note that because this implementation relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes. - shuffle: whether to shuffle the data at the beginning of each + shuffle: Whether to shuffle the data at the beginning of each epoch. Only used with instances of `Sequence` ( keras.utils.Sequence). - initial_epoch: epoch at which to start training + initial_epoch: Epoch at which to start training (useful for resuming a previous training run) # Returns diff --git a/tests/keras/engine/test_training.py b/tests/keras/engine/test_training.py index e94591b3005..7658eab4649 100644 --- a/tests/keras/engine/test_training.py +++ b/tests/keras/engine/test_training.py @@ -583,9 +583,9 @@ def test_model_with_input_feed_tensor(): output_a_np, batch_size=10) # test predict - out = model.predict(None, batch_size=10) - out = model.predict(None, batch_size=10) - assert out.shape == (10, 4) + out = model.predict(None, steps=3) + out = model.predict(None, steps=3) + assert out.shape == (10 * 3, 4) # Same, without learning phase # i.e. we don't pass any data to fit the model. @@ -624,9 +624,9 @@ def test_model_with_input_feed_tensor(): output_a_np, batch_size=10) # test predict - out = model.predict(None, batch_size=10) - out = model.predict(None, batch_size=10) - assert out.shape == (10, 4) + out = model.predict(None, steps=3) + out = model.predict(None, steps=3) + assert out.shape == (10 * 3, 4) @keras_test @@ -739,14 +739,20 @@ def test_model_with_external_loss(): out = model.predict_on_batch(None) # test fit - out = model.fit(None, None, epochs=1, batch_size=None, steps_per_epoch=1) + with pytest.raises(ValueError): + out = model.fit(None, None, epochs=1, batch_size=10) + out = model.fit(None, None, epochs=1, steps_per_epoch=1) # test evaluate - out = model.evaluate(None, None, batch_size=10) + with pytest.raises(ValueError): + out = model.evaluate(None, None, batch_size=10) + out = model.evaluate(None, None, steps=3) # test predict - out = model.predict(None, batch_size=10) - assert out.shape == (10, 4) + with pytest.raises(ValueError): + out = model.predict(None, batch_size=10) + out = model.predict(None, steps=3) + assert out.shape == (10 * 3, 4) @keras_test