diff --git a/TrainingExtensions/tensorflow/test/python/eager/test_batch_norm_fold_keras.py b/TrainingExtensions/tensorflow/test/python/eager/test_batch_norm_fold_keras.py index a2ec6b629b..d17a546044 100644 --- a/TrainingExtensions/tensorflow/test/python/eager/test_batch_norm_fold_keras.py +++ b/TrainingExtensions/tensorflow/test/python/eager/test_batch_norm_fold_keras.py @@ -42,6 +42,7 @@ import copy import json import os +from packaging import version import unittest import tensorflow as tf import numpy as np @@ -54,12 +55,19 @@ from aimet_common.defs import QuantScheme from aimet_tensorflow.keras.utils.quantizer_utils import get_wrappers_weight_quantizer -np.random.seed(0) -tf.random.set_seed(0) - class TestBatchNormFold(unittest.TestCase): """ Test methods for BatchNormFold""" + @pytest.fixture(autouse=True) + def set_random_seed(self): + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(43) + if version.parse(tf.version.VERSION) >= version.parse("2.10"): + tf.keras.utils.set_random_seed(43) + else: + np.random.seed(43) + yield + def test_bn_replacement_model(self): @@ -569,7 +577,6 @@ def test_modify_bn_params_to_make_as_passthrough(self): var = tf.constant([[5.0]]) out = model(var) - self.assertTrue(out.numpy(), 15.0) def test_find_conv_bn_pairs_combined_model(self): @@ -1146,6 +1153,16 @@ class TestBatchNormFoldToScale: def clear_sessions(self): tf.keras.backend.clear_session() yield + + @pytest.fixture(autouse=True) + def set_random_seed(self): + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(43) + if version.parse(tf.version.VERSION) >= version.parse("2.10"): + tf.keras.utils.set_random_seed(43) + else: + np.random.seed(43) + yield @pytest.fixture(scope="session", autouse=True) def cleanup(request): @@ -1497,7 +1514,6 @@ def test_bn_fold_auto_mode_transposed_conv2d(self): folded_pairs = fold_all_batch_norms_to_scale(sim) model = sim.model output_after_fold = model(random_input) - # Check to make sure rebuild Quantization Sim Model working properly sim.export(path="/tmp", filename_prefix="temp_bn_fold_to_scale")