Skip to content

Commit

Permalink
Add support for TensorFlow 2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jul 13, 2023
1 parent 6b35ef5 commit 4825e71
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
4 changes: 1 addition & 3 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ Release history
0.6.1 (unreleased)
==================

*Compatible with TensorFlow 2.4 - 2.11*


*Compatible with TensorFlow 2.4 - 2.13*

0.6.0 (May 5, 2023)
===================
Expand Down
26 changes: 21 additions & 5 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
elif version.parse(tf.__version__) < version.parse("2.9.0rc0"):
from keras.layers.recurrent import DropoutRNNCellMixin
else:
elif version.parse(tf.__version__) < version.parse("2.13.0rc0"):
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
else:
from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin

if version.parse(tf.__version__) < version.parse("2.8.0rc0"):
from tensorflow.keras.layers import Layer as BaseRandomLayer
else:
elif version.parse(tf.__version__) < version.parse("2.13.0rc0"):
from keras.engine.base_layer import BaseRandomLayer
else:
from keras.src.engine.base_layer import BaseRandomLayer


@tf.keras.utils.register_keras_serializable("keras-lmu")
Expand Down Expand Up @@ -450,7 +454,11 @@ def get_config(self):
def from_config(cls, config):
"""Load model from serialized config."""

config["hidden_cell"] = tf.keras.layers.deserialize(config["hidden_cell"])
config["hidden_cell"] = (
None
if config["hidden_cell"] is None
else tf.keras.layers.deserialize(config["hidden_cell"])
)
return super().from_config(config)


Expand Down Expand Up @@ -714,7 +722,11 @@ def get_config(self):
def from_config(cls, config):
"""Load model from serialized config."""

config["hidden_cell"] = tf.keras.layers.deserialize(config["hidden_cell"])
config["hidden_cell"] = (
None
if config["hidden_cell"] is None
else tf.keras.layers.deserialize(config["hidden_cell"])
)
return super().from_config(config)


Expand Down Expand Up @@ -1065,5 +1077,9 @@ def get_config(self):
def from_config(cls, config):
"""Load model from serialized config."""

config["hidden_cell"] = tf.keras.layers.deserialize(config["hidden_cell"])
config["hidden_cell"] = (
None
if config["hidden_cell"] is None
else tf.keras.layers.deserialize(config["hidden_cell"])
)
return super().from_config(config)

0 comments on commit 4825e71

Please sign in to comment.