diff --git a/.nengobones.yml b/.nengobones.yml index 92730e7..0be4120 100644 --- a/.nengobones.yml +++ b/.nengobones.yml @@ -70,7 +70,7 @@ ci_scripts: pre_commands: # We run this ahead of time, otherwise the download progress bar causes # problems in the notebook rendering - - python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()" + - python -c "import keras; keras.datasets.mnist.load_data()" - template: examples - template: test coverage: true diff --git a/docs/basic-usage.rst b/docs/basic-usage.rst index 62c7dcb..968fb64 100644 --- a/docs/basic-usage.rst +++ b/docs/basic-usage.rst @@ -41,14 +41,11 @@ a 10-dimensional input and a 20-dimensional output. .. testcode:: - from tensorflow.keras import Input, Model - from tensorflow.keras.layers import Dense - - inputs = Input((None, 10)) + inputs = keras.Input((None, 10)) lmus = lmu_layer(inputs) - outputs = Dense(20)(lmus) + outputs = keras.layers.Dense(20)(lmus) - model = Model(inputs=inputs, outputs=outputs) + model = keras.Model(inputs=inputs, outputs=outputs) Other parameters diff --git a/docs/examples/psMNIST.ipynb b/docs/examples/psMNIST.ipynb index 9167c77..f241820 100644 --- a/docs/examples/psMNIST.ipynb +++ b/docs/examples/psMNIST.ipynb @@ -82,9 +82,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "We now obtain the standard MNIST dataset of handwritten digits from `tf.keras.datasets`." - ] + "source": "We now obtain the standard MNIST dataset of handwritten digits from `keras.datasets`." }, { "cell_type": "code", @@ -95,7 +93,7 @@ "(train_images, train_labels), (\n", " test_images,\n", " test_labels,\n", - ") = tf.keras.datasets.mnist.load_data()" + ") = keras.datasets.mnist.load_data()" ] }, { diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index 22b2003..982a4f2 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -29,8 +29,13 @@ else: from keras.layers import Layer as BaseRandomLayer +if tf_version < version.parse("2.16.0rc0"): + from tensorflow.keras.utils import register_keras_serializable +else: + from keras.saving import register_keras_serializable + -@tf.keras.utils.register_keras_serializable("keras-lmu") +@register_keras_serializable("keras-lmu") class LMUCell( DropoutRNNCellMixin, BaseRandomLayer ): # pylint: disable=too-many-ancestors @@ -524,7 +529,7 @@ def from_config(cls, config): return super().from_config(config) -@tf.keras.utils.register_keras_serializable("keras-lmu") +@register_keras_serializable("keras-lmu") class LMU(keras.layers.Layer): # pylint: disable=too-many-ancestors,abstract-method """ A layer of trainable low-dimensional delay systems. @@ -792,7 +797,7 @@ def from_config(cls, config): return super().from_config(config) -@tf.keras.utils.register_keras_serializable("keras-lmu") +@register_keras_serializable("keras-lmu") class LMUFeedforward( keras.layers.Layer ): # pylint: disable=too-many-ancestors,abstract-method