Skip to content

Commit

Permalink
Update for Keras 3.3
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jun 18, 2024
1 parent ae66e91 commit 7264265
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ ci_scripts:
pre_commands:
# We run this ahead of time, otherwise the download progress bar causes
# problems in the notebook rendering
- python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()"
- python -c "import keras; keras.datasets.mnist.load_data()"
- template: examples
- template: test
coverage: true
Expand Down
9 changes: 3 additions & 6 deletions docs/basic-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ a 10-dimensional input and a 20-dimensional output.

.. testcode::

from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense

inputs = Input((None, 10))
inputs = keras.Input((None, 10))
lmus = lmu_layer(inputs)
outputs = Dense(20)(lmus)
outputs = keras.layers.Dense(20)(lmus)

model = Model(inputs=inputs, outputs=outputs)
model = keras.Model(inputs=inputs, outputs=outputs)


Other parameters
Expand Down
6 changes: 2 additions & 4 deletions docs/examples/psMNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now obtain the standard MNIST dataset of handwritten digits from `tf.keras.datasets`."
]
"source": "We now obtain the standard MNIST dataset of handwritten digits from `keras.datasets`."
},
{
"cell_type": "code",
Expand All @@ -95,7 +93,7 @@
"(train_images, train_labels), (\n",
" test_images,\n",
" test_labels,\n",
") = tf.keras.datasets.mnist.load_data()"
") = keras.datasets.mnist.load_data()"
]
},
{
Expand Down
11 changes: 8 additions & 3 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@
else:
from keras.layers import Layer as BaseRandomLayer

if tf_version < version.parse("2.16.0rc0"):
from tensorflow.keras.utils import register_keras_serializable
else:
from keras.saving import register_keras_serializable


@tf.keras.utils.register_keras_serializable("keras-lmu")
@register_keras_serializable("keras-lmu")
class LMUCell(
DropoutRNNCellMixin, BaseRandomLayer
): # pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -524,7 +529,7 @@ def from_config(cls, config):
return super().from_config(config)


@tf.keras.utils.register_keras_serializable("keras-lmu")
@register_keras_serializable("keras-lmu")
class LMU(keras.layers.Layer): # pylint: disable=too-many-ancestors,abstract-method
"""
A layer of trainable low-dimensional delay systems.
Expand Down Expand Up @@ -792,7 +797,7 @@ def from_config(cls, config):
return super().from_config(config)


@tf.keras.utils.register_keras_serializable("keras-lmu")
@register_keras_serializable("keras-lmu")
class LMUFeedforward(
keras.layers.Layer
): # pylint: disable=too-many-ancestors,abstract-method
Expand Down

0 comments on commit 7264265

Please sign in to comment.