Skip to content

Commit

Permalink
Add saving of loaded/trained compatibility models in test and fix a c…
Browse files Browse the repository at this point in the history
…ompatibility bug.

PiperOrigin-RevId: 273455709
  • Loading branch information
vbardiovskyg authored and goldiegadde committed Oct 11, 2019
1 parent 38ea9bb commit 8d71a87
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags

Expand Down Expand Up @@ -57,6 +58,10 @@ def train(fine_tuning):

model.fit_generator(generator=dataset.batch(1), epochs=5)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())


def main(argv):
del argv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags
import numpy as np
Expand All @@ -39,6 +40,10 @@ def main(argv):
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
initial_state)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(cell, tempfile.mkdtemp())


if __name__ == "__main__":
app.run(main)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags

Expand Down Expand Up @@ -55,6 +56,10 @@ def _map_fn(features, labels):

model.fit_generator(generator=dataset.batch(10), epochs=5)

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())


def main(argv):
del argv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import tempfile
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
Expand All @@ -40,6 +41,9 @@ def main(argv):
sequence_length=10, first_word=tf.constant("<S>"))
_ = [d.numpy() for d in decoded]

# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())

if __name__ == "__main__":
app.run(main)
6 changes: 4 additions & 2 deletions tensorflow/python/saved_model/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,13 @@ def _list_functions_for_serialization(self, unused_serialization_cache):
# Overwrite this method to avoid the implementation of
# base class to re-wrap the polymorphic functions into
# another layer of `tf.function`.
return {
functions = {
"_create_resource": self._create_resource,
"_initialize": self._initialize,
"_destroy_resource": self._destroy_resource,
}
if self._destroy_resource:
functions.update(_destroy_resource=self._destroy_resource)
return functions


def _call_attribute(instance, *args, **kwargs):
Expand Down

0 comments on commit 8d71a87

Please sign in to comment.