diff --git a/autokeras/blocks/__init__.py b/autokeras/blocks/__init__.py index a73020ea3..51a1e2389 100644 --- a/autokeras/blocks/__init__.py +++ b/autokeras/blocks/__init__.py @@ -41,14 +41,15 @@ from autokeras.blocks.wrapper import StructuredDataBlock from autokeras.blocks.wrapper import TextBlock from autokeras.blocks.wrapper import TimeseriesBlock +from autokeras.utils import utils def serialize(obj): - return keras.utils.serialize_keras_object(obj) + return utils.serialize_keras_object(obj) def deserialize(config, custom_objects=None): - return keras.utils.deserialize_keras_object( + return utils.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, diff --git a/autokeras/hyper_preprocessors.py b/autokeras/hyper_preprocessors.py index 10252181d..6b239e744 100644 --- a/autokeras/hyper_preprocessors.py +++ b/autokeras/hyper_preprocessors.py @@ -11,18 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tensorflow import keras from autokeras import preprocessors from autokeras.engine import hyper_preprocessor +from autokeras.utils import utils def serialize(encoder): - return keras.utils.serialize_keras_object(encoder) + return utils.serialize_keras_object(encoder) def deserialize(config, custom_objects=None): - return keras.utils.deserialize_keras_object( + return utils.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, diff --git a/autokeras/nodes.py b/autokeras/nodes.py index 7b86b59c9..99ebc1bf9 100644 --- a/autokeras/nodes.py +++ b/autokeras/nodes.py @@ -28,14 +28,15 @@ from autokeras import preprocessors from autokeras.engine import io_hypermodel from autokeras.engine import node as node_module +from autokeras.utils import utils def serialize(obj): - return keras.utils.serialize_keras_object(obj) + return utils.serialize_keras_object(obj) def deserialize(config, custom_objects=None): - return keras.utils.deserialize_keras_object( + return utils.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, diff --git a/autokeras/preprocessors/__init__.py b/autokeras/preprocessors/__init__.py index 8691e36df..db6d88faf 100644 --- a/autokeras/preprocessors/__init__.py +++ b/autokeras/preprocessors/__init__.py @@ -24,14 +24,15 @@ from autokeras.preprocessors.encoders import OneHotEncoder from autokeras.preprocessors.postprocessors import SigmoidPostprocessor from autokeras.preprocessors.postprocessors import SoftmaxPostprocessor +from autokeras.utils import utils def serialize(preprocessor): - return keras.utils.serialize_keras_object(preprocessor) + return utils.serialize_keras_object(preprocessor) def deserialize(config, custom_objects=None): - return keras.utils.deserialize_keras_object( + return utils.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, diff --git a/autokeras/utils/utils.py b/autokeras/utils/utils.py index 9552b9e54..d707f1f97 100644 --- a/autokeras/utils/utils.py +++ b/autokeras/utils/utils.py @@ -138,3 +138,23 @@ def add_to_hp(hp, hps, name=None): class_name = hp.__class__.__name__ func = getattr(hps, class_name) return func(name=name, **kwargs) + + +def serialize_keras_object(obj): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.serialize_keras_object(obj) # pragma: no cover + else: + return tf.keras.utils.serialize_keras_object(obj) + + +def deserialize_keras_object( + config, module_objects=None, custom_objects=None, printable_module_name=None +): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.deserialize_keras_object( # pragma: no cover + config, custom_objects, module_objects, printable_module_name + ) + else: + return tf.keras.utils.deserialize_keras_object( + config, custom_objects, module_objects, printable_module_name + )