Skip to content

Commit

Permalink
Add some additional TF SavedModel tests (#569)
Browse files Browse the repository at this point in the history
* Add saved_model_test

* Add extra saved model tests

* Fix formatting
  • Loading branch information
nkovela1 authored Jul 21, 2023
1 parent 42cad2b commit 4be412e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
81 changes: 81 additions & 0 deletions keras_core/backend/tensorflow/saved_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,84 @@ def test_subclassed(self):
rtol=1e-4,
atol=1e-4,
)

def test_custom_model_and_layer(self):
@object_registration.register_keras_serializable(package="my_package")
class CustomLayer(layers.Layer):
def __call__(self, inputs):
return inputs

@object_registration.register_keras_serializable(package="my_package")
class Model(models.Model):
def __init__(self):
super().__init__()
self.layer = CustomLayer()

@tf.function(input_signature=[tf.TensorSpec([None, 1])])
def call(self, inputs):
return self.layer(inputs)

model = Model()
inp = np.array([[1.0]])
result = model(inp)
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
self.assertAllClose(
result,
restored_model.call(inp),
rtol=1e-4,
atol=1e-4,
)

def test_multi_input_model(self):
input_1 = layers.Input(shape=(3,))
input_2 = layers.Input(shape=(5,))
model = models.Model([input_1, input_2], [input_1, input_2])
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")

tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
input_arr_1 = np.random.random((1, 3)).astype("float32")
input_arr_2 = np.random.random((1, 5)).astype("float32")

outputs = restored_model.signatures["serving_default"](
inputs=tf.convert_to_tensor(input_arr_1, dtype=tf.float32),
inputs_1=tf.convert_to_tensor(input_arr_2, dtype=tf.float32),
)

self.assertAllClose(
input_arr_1, outputs["output_0"], rtol=1e-4, atol=1e-4
)
self.assertAllClose(
input_arr_2, outputs["output_1"], rtol=1e-4, atol=1e-4
)

def test_multi_input_custom_model_and_layer(self):
@object_registration.register_keras_serializable(package="my_package")
class CustomLayer(layers.Layer):
def __call__(self, *input_list):
self.add_loss(input_list[-2] * 2)
return sum(input_list)

@object_registration.register_keras_serializable(package="my_package")
class CustomModel(models.Model):
def build(self, input_shape):
super().build(input_shape)
self.layer = CustomLayer()

@tf.function
def call(self, *inputs):
inputs = list(inputs)
return self.layer(*inputs)

model = CustomModel()
inp = [
tf.constant(i, shape=[1, 1], dtype=tf.float32) for i in range(1, 4)
]
expected = model(*inp)
path = os.path.join(self.get_temp_dir(), "my_keras_core_model")
tf.saved_model.save(model, path)
restored_model = tf.saved_model.load(path)
output = restored_model.call(*inp)
self.assertAllClose(expected, output, rtol=1e-4, atol=1e-4)
5 changes: 3 additions & 2 deletions keras_core/layers/merging/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def build(self, input_shape):
# in case self.axis is a negative number
concat_axis = self.axis % len(reduced_inputs_shapes[i])
# Skip batch axis.
for axis, axis_value in enumerate(reduced_inputs_shapes[i][1:],
start=1):
for axis, axis_value in enumerate(
reduced_inputs_shapes[i][1:], start=1
):
# Remove squeezable axes (axes with value of 1)
# if not in the axis that will be used for concatenation
# otherwise leave it.
Expand Down

0 comments on commit 4be412e

Please sign in to comment.