From 7e023954caf564d769c8ad472e9b0bd069042862 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 28 Jul 2023 03:19:33 +0000 Subject: [PATCH 1/4] Add numpy trainer --- keras_core/backend/numpy/trainer.py | 321 ++++++++++++++++++++++++++-- 1 file changed, 308 insertions(+), 13 deletions(-) diff --git a/keras_core/backend/numpy/trainer.py b/keras_core/backend/numpy/trainer.py index 875d3468f..53ec9daa2 100644 --- a/keras_core/backend/numpy/trainer.py +++ b/keras_core/backend/numpy/trainer.py @@ -1,18 +1,313 @@ -class NumpyTrainer: - def fit(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") +import numpy as np +import tree - def predict(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") +from keras_core import backend +from keras_core import callbacks as callbacks_module +from keras_core.backend.common import standardize_dtype +from keras_core.backend.common.keras_tensor import KerasTensor +from keras_core.backend.numpy.core import is_tensor +from keras_core.trainers import trainer as base_trainer +from keras_core.trainers.data_adapters import data_adapter_utils +from keras_core.trainers.epoch_iterator import EpochIterator +from keras_core.utils import traceback_utils - def evaluate(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") - def train_on_batch(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") +class NumpyTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.test_function = None + self.predict_function = None - def test_on_batch(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") + def test_step(self, data): + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self.compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + self._loss_tracker.update_state(loss) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) - def predict_on_batch(self): - raise NotImplementedError("Trainer not implemented for NumPy backend.") + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + def one_test_step(data): + data = data[0] + return self.test_step(data) + + def multi_test_steps(data): + for single_step_data in data: + logs = one_test_step([single_step_data]) + return logs + + if self.steps_per_execution > 1: + test_step = multi_test_steps + else: + test_step = one_test_step + + self.test_function = test_step + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def one_predict_step(data): + data = data[0] + return self.predict_step(data) + + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, + ) + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step + + def _symbolic_build(self, data_batch): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + if model_unbuilt or compile_metrics_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if is_tensor(v): + return KerasTensor(v.shape, standardize_dtype(v.dtype)) + return v + + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x) + except: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it." + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + self._post_build() + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + raise NotImplementedError("fit not implemented for NumPy backend.") + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self.make_predict_function() + callbacks.on_predict_begin() + outputs = None + for step, data in epoch_iterator.enumerate_epoch(return_type="np"): + callbacks.on_predict_batch_begin(step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_end() + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = EpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + if not all(layer.built for layer in self._flatten_layers()): + # Build the model on one batch of data. + for _, data in epoch_iterator.enumerate_epoch(return_type="np"): + data_batch = data[0] + self._symbolic_build(data_batch) + break + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + callbacks.on_test_begin() + logs = None + self.reset_metrics() + for step, data in epoch_iterator.enumerate_epoch(return_type="np"): + callbacks.on_test_batch_begin(step) + logs = self.test_function(data) + callbacks.on_test_batch_end(step, self._pythonify_logs(logs)) + logs = self.get_metrics_result() + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "train_on_batch not implemented for NumPy backend." + ) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + data = (x, y, sample_weight) + + # Maybe build model + self._symbolic_build(data) + self.make_test_function() + + logs = self.test_function([data]) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function((x,)) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs From 55e7c40431423842e35e404fd0ecd446e933d6f6 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 28 Jul 2023 03:19:46 +0000 Subject: [PATCH 2/4] Improve test coverage --- keras_core/trainers/trainer_test.py | 73 ++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 040458a19..469dcecf5 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -72,8 +72,8 @@ def call(self, x, training=False): return x * 0 -@pytest.mark.requires_trainable_backend class TestTrainer(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_metric_tracking(self): class ModelWithMetric(layers.Dense, Trainer): def __init__(self, units): @@ -138,6 +138,7 @@ def __init__(self, units): ("steps_per_epoch_jit", False, True, True), ] ) + @pytest.mark.requires_trainable_backend def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch): if not run_eagerly and not jit_compile and use_steps_per_epoch: if backend.backend() == "tensorflow": @@ -239,6 +240,7 @@ def test_predict_flow(self, run_eagerly, jit_compile): self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3))) self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3))) + @pytest.mark.requires_trainable_backend @pytest.mark.skipif( backend.backend() == "torch", reason="`steps_per_execution` not implemented for torch yet", @@ -278,6 +280,38 @@ def on_batch_begin(self, batch, logs=None): ) self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_count_without_training(self): + class StepCount(Callback): + def __init__(self): + super().__init__() + self.test_count = 0 + self.predict_count = 0 + self.batches = [0, 3, 6] + + def on_test_batch_begin(self, batch, logs=None): + assert batch == self.batches[self.test_count] + self.test_count += 1 + + def on_predict_batch_begin(self, batch, logs=None): + assert batch == self.batches[self.predict_count] + self.predict_count += 1 + + x = np.ones((100, 4)) + y = np.ones((100, 1)) + batch_size = 16 + model = ExampleModel(units=1) + model.compile(loss="mse", steps_per_execution=3) + step_count = StepCount() + model.predict(x, batch_size=batch_size, callbacks=[step_count]) + self.assertEqual(step_count.predict_count, 3) + model.evaluate(x, y, batch_size=batch_size, callbacks=[step_count]) + self.assertEqual(step_count.test_count, 3) + + @pytest.mark.requires_trainable_backend def test_training_arg(self): model = TrainingTestingLayer() model.compile(optimizer="rmsprop", loss="mse") @@ -297,6 +331,7 @@ def test_training_arg(self): ("jit", False, True), ] ) + @pytest.mark.requires_trainable_backend def test_on_batch_methods(self, run_eagerly, jit_compile): model = ExampleModel(units=3) x = np.ones((100, 4)) @@ -346,6 +381,38 @@ def test_on_batch_methods(self, run_eagerly, jit_compile): logs = model.train_on_batch(x, y, class_weight={1: 0.3, 0: 0.2}) self.assertAlmostEqual(logs[0], 12.899) + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_on_batch_methods_without_training(self, run_eagerly, jit_compile): + model = ExampleModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + + model.compile( + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + logs = model.test_on_batch(x, y) + self.assertTrue(isinstance(logs, list)) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs[0], 16.0) + + logs = model.test_on_batch(x, y, return_dict=True) + self.assertTrue(isinstance(logs, dict)) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs["loss"], 16.0) + + output = model.predict_on_batch(x) + self.assertTrue(isinstance(output, np.ndarray)) + self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0])) + def test_nested_input_predict(self): # https://github.com/keras-team/keras-core/issues/325 @@ -368,6 +435,7 @@ def call(self, inputs): out = model.predict({"a": x1, "b": x2}) self.assertEqual(out.shape, (3, 4)) + @pytest.mark.requires_trainable_backend def test_callback_methods_keys(self): class CustomCallback(Callback): def on_train_begin(self, logs=None): @@ -452,6 +520,7 @@ def on_predict_batch_end(self, batch, logs=None): model.evaluate(x_test, y_test, batch_size=4) model.predict(x_test, batch_size=4) + @pytest.mark.requires_trainable_backend def test_internal_only_loss(self): class LossLayer(layers.Layer): def call(self, x): @@ -511,6 +580,7 @@ def __init__(self, input_shape=(None,)): }, ] ) + @pytest.mark.requires_trainable_backend @pytest.mark.skipif( keras_core.backend.backend() != "tensorflow", reason="Only tensorflow supports raggeds", @@ -556,6 +626,7 @@ def test_predict_dropout(self): out3 = model.predict_on_batch(np.ones((2, 20))) self.assertGreater(5, np.sum(np.abs(out2 - out3))) + @pytest.mark.requires_trainable_backend def test_recompile(self): inputs = layers.Input((2,)) outputs = layers.Dense(3)(inputs) From 02afcc1635359b9709a556236c73dc751689bbe8 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 28 Jul 2023 03:20:02 +0000 Subject: [PATCH 3/4] Fix steps_per_execution bug in tf trainer --- keras_core/backend/tensorflow/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index df15266f1..7458a031f 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -337,6 +337,7 @@ def fit( sample_weight=val_sample_weight, batch_size=validation_batch_size or batch_size, distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, ) val_logs = self.evaluate( x=val_x, @@ -401,6 +402,7 @@ def evaluate( steps_per_epoch=steps, shuffle=False, distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, ) # Container that configures and calls callbacks. @@ -442,6 +444,7 @@ def predict( steps_per_epoch=steps, shuffle=False, distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, ) # Container that configures and calls callbacks. From b7d5c5a3853daded3d2ee0583d1ff5065e3c886b Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 28 Jul 2023 03:32:40 +0000 Subject: [PATCH 4/4] Fix test --- keras_core/trainers/trainer_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 469dcecf5..cac924e1c 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -399,6 +399,10 @@ def test_on_batch_methods_without_training(self, run_eagerly, jit_compile): run_eagerly=run_eagerly, jit_compile=jit_compile, ) + output = model.predict_on_batch(x) + self.assertTrue(isinstance(output, np.ndarray)) + self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0])) + logs = model.test_on_batch(x, y) self.assertTrue(isinstance(logs, list)) self.assertEqual(len(logs), 2) @@ -409,10 +413,6 @@ def test_on_batch_methods_without_training(self, run_eagerly, jit_compile): self.assertEqual(len(logs), 2) self.assertAlmostEqual(logs["loss"], 16.0) - output = model.predict_on_batch(x) - self.assertTrue(isinstance(output, np.ndarray)) - self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0])) - def test_nested_input_predict(self): # https://github.com/keras-team/keras-core/issues/325