-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numpy trainer #633
Merged
Merged
Add numpy trainer #633
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,313 @@ | ||
class NumpyTrainer: | ||
def fit(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
import numpy as np | ||
import tree | ||
|
||
def predict(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
from keras_core import backend | ||
from keras_core import callbacks as callbacks_module | ||
from keras_core.backend.common import standardize_dtype | ||
from keras_core.backend.common.keras_tensor import KerasTensor | ||
from keras_core.backend.numpy.core import is_tensor | ||
from keras_core.trainers import trainer as base_trainer | ||
from keras_core.trainers.data_adapters import data_adapter_utils | ||
from keras_core.trainers.epoch_iterator import EpochIterator | ||
from keras_core.utils import traceback_utils | ||
|
||
def evaluate(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
|
||
def train_on_batch(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
class NumpyTrainer(base_trainer.Trainer): | ||
def __init__(self): | ||
super().__init__() | ||
self.test_function = None | ||
self.predict_function = None | ||
|
||
def test_on_batch(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
def test_step(self, data): | ||
( | ||
x, | ||
y, | ||
sample_weight, | ||
) = data_adapter_utils.unpack_x_y_sample_weight(data) | ||
if self._call_has_training_arg: | ||
y_pred = self(x, training=False) | ||
else: | ||
y_pred = self(x) | ||
loss = self.compute_loss( | ||
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight | ||
) | ||
self._loss_tracker.update_state(loss) | ||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) | ||
|
||
def predict_on_batch(self): | ||
raise NotImplementedError("Trainer not implemented for NumPy backend.") | ||
def predict_step(self, data): | ||
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) | ||
if self._call_has_training_arg: | ||
y_pred = self(x, training=False) | ||
else: | ||
y_pred = self(x) | ||
return y_pred | ||
|
||
def make_test_function(self, force=False): | ||
if self.test_function is not None and not force: | ||
return self.test_function | ||
|
||
def one_test_step(data): | ||
data = data[0] | ||
return self.test_step(data) | ||
|
||
def multi_test_steps(data): | ||
for single_step_data in data: | ||
logs = one_test_step([single_step_data]) | ||
return logs | ||
|
||
if self.steps_per_execution > 1: | ||
test_step = multi_test_steps | ||
else: | ||
test_step = one_test_step | ||
|
||
self.test_function = test_step | ||
|
||
def make_predict_function(self, force=False): | ||
if self.predict_function is not None and not force: | ||
return self.predict_function | ||
|
||
def one_predict_step(data): | ||
data = data[0] | ||
return self.predict_step(data) | ||
|
||
def multi_predict_steps(data): | ||
outputs = one_predict_step(data[:1]) | ||
|
||
for single_step_data in data[1:]: | ||
step_outputs = one_predict_step([single_step_data]) | ||
outputs = tree.map_structure( | ||
lambda t1, t2: np.concatenate([t1, t2]), | ||
outputs, | ||
step_outputs, | ||
) | ||
return outputs | ||
|
||
if self.steps_per_execution > 1: | ||
predict_step = multi_predict_steps | ||
else: | ||
predict_step = one_predict_step | ||
|
||
self.predict_function = predict_step | ||
|
||
def _symbolic_build(self, data_batch): | ||
model_unbuilt = not all(layer.built for layer in self._flatten_layers()) | ||
compile_metrics_unbuilt = ( | ||
self._compile_metrics is not None | ||
and not self._compile_metrics.built | ||
) | ||
if model_unbuilt or compile_metrics_unbuilt: | ||
# Create symbolic tensors matching an input batch. | ||
|
||
def to_symbolic_input(v): | ||
if is_tensor(v): | ||
return KerasTensor(v.shape, standardize_dtype(v.dtype)) | ||
return v | ||
|
||
data_batch = tree.map_structure(to_symbolic_input, data_batch) | ||
( | ||
x, | ||
y, | ||
sample_weight, | ||
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) | ||
# Build all model state with `backend.compute_output_spec`. | ||
try: | ||
y_pred = backend.compute_output_spec(self, x) | ||
except: | ||
raise RuntimeError( | ||
"Unable to automatically build the model. " | ||
"Please build it yourself before calling " | ||
"fit/evaluate/predict. " | ||
"A model is 'built' when its variables have " | ||
"been created and its `self.built` attribute " | ||
"is True. Usually, calling the model on a batch " | ||
"of data is the right way to build it." | ||
) | ||
if compile_metrics_unbuilt: | ||
# Build all metric state with `backend.compute_output_spec`. | ||
backend.compute_output_spec( | ||
self.compute_metrics, | ||
x, | ||
y, | ||
y_pred, | ||
sample_weight=sample_weight, | ||
) | ||
self._post_build() | ||
|
||
def fit( | ||
self, | ||
x=None, | ||
y=None, | ||
batch_size=None, | ||
epochs=1, | ||
verbose="auto", | ||
callbacks=None, | ||
validation_split=0.0, | ||
validation_data=None, | ||
shuffle=True, | ||
class_weight=None, | ||
sample_weight=None, | ||
initial_epoch=0, | ||
steps_per_epoch=None, | ||
validation_steps=None, | ||
validation_batch_size=None, | ||
validation_freq=1, | ||
): | ||
raise NotImplementedError("fit not implemented for NumPy backend.") | ||
|
||
@traceback_utils.filter_traceback | ||
def predict( | ||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None | ||
): | ||
# Create an iterator that yields batches of input data. | ||
epoch_iterator = EpochIterator( | ||
x=x, | ||
batch_size=batch_size, | ||
steps_per_epoch=steps, | ||
shuffle=False, | ||
steps_per_execution=self.steps_per_execution, | ||
) | ||
|
||
# Container that configures and calls callbacks. | ||
if not isinstance(callbacks, callbacks_module.CallbackList): | ||
callbacks = callbacks_module.CallbackList( | ||
callbacks, | ||
add_history=True, | ||
add_progbar=verbose != 0, | ||
verbose=verbose, | ||
epochs=1, | ||
steps=epoch_iterator.num_batches, | ||
model=self, | ||
) | ||
|
||
def append_to_outputs(batch_outputs, outputs): | ||
if outputs is None: | ||
outputs = tree.map_structure( | ||
lambda batch_output: [batch_output], | ||
batch_outputs, | ||
) | ||
else: | ||
tree.map_structure_up_to( | ||
batch_outputs, | ||
lambda output, batch_output: output.append(batch_output), | ||
outputs, | ||
batch_outputs, | ||
) | ||
return outputs | ||
|
||
self.make_predict_function() | ||
callbacks.on_predict_begin() | ||
outputs = None | ||
for step, data in epoch_iterator.enumerate_epoch(return_type="np"): | ||
callbacks.on_predict_batch_begin(step) | ||
batch_outputs = self.predict_function(data) | ||
outputs = append_to_outputs(batch_outputs, outputs) | ||
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) | ||
callbacks.on_predict_end() | ||
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) | ||
|
||
@traceback_utils.filter_traceback | ||
def evaluate( | ||
self, | ||
x=None, | ||
y=None, | ||
batch_size=None, | ||
verbose="auto", | ||
sample_weight=None, | ||
steps=None, | ||
callbacks=None, | ||
return_dict=False, | ||
**kwargs, | ||
): | ||
# TODO: respect compiled trainable state | ||
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) | ||
if kwargs: | ||
raise ValueError(f"Arguments not recognized: {kwargs}") | ||
|
||
if use_cached_eval_dataset: | ||
epoch_iterator = self._eval_epoch_iterator | ||
else: | ||
# Create an iterator that yields batches of input/target data. | ||
epoch_iterator = EpochIterator( | ||
x=x, | ||
y=y, | ||
sample_weight=sample_weight, | ||
batch_size=batch_size, | ||
steps_per_epoch=steps, | ||
shuffle=False, | ||
steps_per_execution=self.steps_per_execution, | ||
) | ||
|
||
if not all(layer.built for layer in self._flatten_layers()): | ||
# Build the model on one batch of data. | ||
for _, data in epoch_iterator.enumerate_epoch(return_type="np"): | ||
data_batch = data[0] | ||
self._symbolic_build(data_batch) | ||
break | ||
|
||
# Container that configures and calls callbacks. | ||
if not isinstance(callbacks, callbacks_module.CallbackList): | ||
callbacks = callbacks_module.CallbackList( | ||
callbacks, | ||
add_history=True, | ||
add_progbar=verbose != 0, | ||
verbose=verbose, | ||
epochs=1, | ||
steps=epoch_iterator.num_batches, | ||
model=self, | ||
) | ||
|
||
self.make_test_function() | ||
callbacks.on_test_begin() | ||
logs = None | ||
self.reset_metrics() | ||
for step, data in epoch_iterator.enumerate_epoch(return_type="np"): | ||
callbacks.on_test_batch_begin(step) | ||
logs = self.test_function(data) | ||
callbacks.on_test_batch_end(step, self._pythonify_logs(logs)) | ||
logs = self.get_metrics_result() | ||
callbacks.on_test_end(logs) | ||
|
||
if return_dict: | ||
return logs | ||
return self._flatten_metrics_in_order(logs) | ||
|
||
def train_on_batch( | ||
self, | ||
x, | ||
y=None, | ||
sample_weight=None, | ||
class_weight=None, | ||
return_dict=False, | ||
): | ||
raise NotImplementedError( | ||
"train_on_batch not implemented for NumPy backend." | ||
) | ||
|
||
def test_on_batch( | ||
self, | ||
x, | ||
y=None, | ||
sample_weight=None, | ||
return_dict=False, | ||
): | ||
self._assert_compile_called("test_on_batch") | ||
|
||
data = (x, y, sample_weight) | ||
|
||
# Maybe build model | ||
self._symbolic_build(data) | ||
self.make_test_function() | ||
|
||
logs = self.test_function([data]) | ||
logs = tree.map_structure(lambda x: np.array(x), logs) | ||
if return_dict: | ||
return logs | ||
return self._flatten_metrics_in_order(logs) | ||
|
||
def predict_on_batch(self, x): | ||
self.make_predict_function() | ||
batch_outputs = self.predict_function((x,)) | ||
batch_outputs = tree.map_structure( | ||
backend.convert_to_numpy, batch_outputs | ||
) | ||
return batch_outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, this is something that tf.keras always does (use the
trainable
values of layers as they existed at the timecompile()
was called), which may not always be verified in some cases in Keras Core.