Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 30, 2024
1 parent 78f1c71 commit 34c4536
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion keras/src/legacy/saving/legacy_h5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import pytest
import tf_keras

import keras
from keras.src import layers
Expand All @@ -17,6 +16,11 @@
# on exact weight ordering for each layer, so we need
# to test across all types of layers.

try:
import tf_keras
except:
tf_keras = None


def get_sequential_model(keras):
return keras.Sequential(
Expand Down Expand Up @@ -73,6 +77,7 @@ def call(self, x):


@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(tf_keras is None, reason="Test requires tf_keras")
class LegacyH5WeightsTest(testing.TestCase):
def _check_reloading_weights(self, ref_input, model, tf_keras_model):
ref_output = tf_keras_model(ref_input)
Expand Down Expand Up @@ -276,6 +281,7 @@ class RegisteredSubLayer(layers.Layer):


@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(tf_keras is None, reason="Test requires tf_keras")
class LegacyH5BackwardsCompatTest(testing.TestCase):
def _check_reloading_model(self, ref_input, model, tf_keras_model):
# Whole model file
Expand Down

0 comments on commit 34c4536

Please sign in to comment.