Skip to content

Commit

Permalink
Add checks to deserialization. (#20751)
Browse files Browse the repository at this point in the history
In particular for functional models.
  • Loading branch information
hertschuh authored Jan 12, 2025
1 parent 3d20616 commit e67ac8f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
6 changes: 6 additions & 0 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.ops.function import make_node_key
from keras.src.ops.node import KerasHistory
from keras.src.ops.node import Node
from keras.src.ops.operation import Operation
from keras.src.saving import serialization_lib
from keras.src.utils import tracking

Expand Down Expand Up @@ -523,6 +524,11 @@ def process_layer(layer_data):
layer = serialization_lib.deserialize_keras_object(
layer_data, custom_objects=custom_objects
)
if not isinstance(layer, Operation):
raise ValueError(
"Unexpected object from deserialization, expected a layer or "
f"operation, got a {type(layer)}"
)
created_layers[layer_name] = layer

# Gather layer inputs.
Expand Down
28 changes: 12 additions & 16 deletions keras/src/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,22 +783,18 @@ def _retrieve_class_or_fn(

# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
try:
mod = importlib.import_module(module)
except ModuleNotFoundError:
raise TypeError(
f"Could not deserialize {obj_type} '{name}' because "
f"its parent module {module} cannot be imported. "
f"Full object config: {full_config}"
)
obj = vars(mod).get(name, None)

# Special case for keras.metrics.metrics
if obj is None and registered_name is not None:
obj = vars(mod).get(registered_name, None)

if obj is not None:
return obj
if module == "keras.src" or module.startswith("keras.src."):
try:
mod = importlib.import_module(module)
obj = vars(mod).get(name, None)
if obj is not None:
return obj
except ModuleNotFoundError:
raise TypeError(
f"Could not deserialize {obj_type} '{name}' because "
f"its parent module {module} cannot be imported. "
f"Full object config: {full_config}"
)

raise TypeError(
f"Could not locate {obj_type} '{name}'. "
Expand Down

0 comments on commit e67ac8f

Please sign in to comment.