diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index e01052bc57e..f2ff70396fb 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -19,6 +19,7 @@ from keras.src.ops.function import make_node_key from keras.src.ops.node import KerasHistory from keras.src.ops.node import Node +from keras.src.ops.operation import Operation from keras.src.saving import serialization_lib from keras.src.utils import tracking @@ -523,6 +524,11 @@ def process_layer(layer_data): layer = serialization_lib.deserialize_keras_object( layer_data, custom_objects=custom_objects ) + if not isinstance(layer, Operation): + raise ValueError( + "Unexpected object from deserialization, expected a layer or " + f"operation, got a {type(layer)}" + ) created_layers[layer_name] = layer # Gather layer inputs. diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index cf8eb327fb4..535478b62bb 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -783,22 +783,18 @@ def _retrieve_class_or_fn( # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. - try: - mod = importlib.import_module(module) - except ModuleNotFoundError: - raise TypeError( - f"Could not deserialize {obj_type} '{name}' because " - f"its parent module {module} cannot be imported. " - f"Full object config: {full_config}" - ) - obj = vars(mod).get(name, None) - - # Special case for keras.metrics.metrics - if obj is None and registered_name is not None: - obj = vars(mod).get(registered_name, None) - - if obj is not None: - return obj + if module == "keras.src" or module.startswith("keras.src."): + try: + mod = importlib.import_module(module) + obj = vars(mod).get(name, None) + if obj is not None: + return obj + except ModuleNotFoundError: + raise TypeError( + f"Could not deserialize {obj_type} '{name}' because " + f"its parent module {module} cannot be imported. " + f"Full object config: {full_config}" + ) raise TypeError( f"Could not locate {obj_type} '{name}'. "