diff --git a/keras_lmu/tests/conftest.py b/keras_lmu/tests/conftest.py index 9f85b33..9c957a6 100644 --- a/keras_lmu/tests/conftest.py +++ b/keras_lmu/tests/conftest.py @@ -7,5 +7,7 @@ def pytest_configure(config): tf.debugging.disable_traceback_filtering() + tf.config.experimental.enable_op_determinism() + tf.keras.utils.set_random_seed(0) if version.parse(tf.__version__) >= version.parse("2.16.0"): keras.config.disable_traceback_filtering()