diff --git a/keras_core/ops/numpy.py b/keras_core/ops/numpy.py index 2d6b4ff28..1fa63924a 100644 --- a/keras_core/ops/numpy.py +++ b/keras_core/ops/numpy.py @@ -2547,6 +2547,8 @@ def full_like(x, fill_value, dtype=None): class GetItem(Operation): def call(self, x, key): + if isinstance(key, list): + key = tuple(key) return x[key] def compute_output_spec(self, x, key): diff --git a/keras_core/saving/serialization_lib.py b/keras_core/saving/serialization_lib.py index b829f1bb6..6c1435af9 100644 --- a/keras_core/saving/serialization_lib.py +++ b/keras_core/saving/serialization_lib.py @@ -153,6 +153,15 @@ def serialize_keras_object(obj): "class_name": "__bytes__", "config": {"value": obj.decode("utf-8")}, } + if isinstance(obj, slice): + return { + "class_name": "__slice__", + "config": { + "start": serialize_keras_object(obj.start), + "stop": serialize_keras_object(obj.stop), + "step": serialize_keras_object(obj.step), + }, + } if isinstance(obj, backend.KerasTensor): history = getattr(obj, "_keras_history", None) if history: @@ -602,6 +611,24 @@ class ModifiedMeanSquaredError(keras_core.losses.MeanSquaredError): return np.array(inner_config["value"], dtype=inner_config["dtype"]) if config["class_name"] == "__bytes__": return inner_config["value"].encode("utf-8") + if config["class_name"] == "__slice__": + return slice( + deserialize_keras_object( + inner_config["start"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["stop"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["step"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + ) if config["class_name"] == "__lambda__": if safe_mode: raise ValueError( diff --git a/keras_core/saving/serialization_lib_test.py b/keras_core/saving/serialization_lib_test.py index 8e3b03089..26daa33c2 100644 --- a/keras_core/saving/serialization_lib_test.py +++ b/keras_core/saving/serialization_lib_test.py @@ -75,6 +75,8 @@ def test_simple_objects(self): ["hello", 0, "world", 1.0, True], {"1": "hello", "2": 0, "3": True}, {"1": "hello", "2": [True, False]}, + slice(None, 20, 1), + slice(None, np.array([0, 1]), 1), ]: serialized, _, reserialized = self.roundtrip(obj) self.assertEqual(serialized, reserialized)