Skip to content

Commit

Permalink
seeding for eager testing in tf 2.10
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored and quic-bharathr committed Jul 11, 2023
1 parent 7f5cb4d commit 2a8b8e1
Showing 1 changed file with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import copy
import json
import os
from packaging import version
import unittest
import tensorflow as tf
import numpy as np
Expand All @@ -54,12 +55,19 @@
from aimet_common.defs import QuantScheme
from aimet_tensorflow.keras.utils.quantizer_utils import get_wrappers_weight_quantizer

np.random.seed(0)
tf.random.set_seed(0)


class TestBatchNormFold(unittest.TestCase):
""" Test methods for BatchNormFold"""
@pytest.fixture(autouse=True)
def set_random_seed(self):
tf.compat.v1.reset_default_graph()
tf.compat.v1.set_random_seed(43)
if version.parse(tf.version.VERSION) >= version.parse("2.10"):
tf.keras.utils.set_random_seed(43)
else:
np.random.seed(43)
yield


def test_bn_replacement_model(self):

Expand Down Expand Up @@ -569,7 +577,6 @@ def test_modify_bn_params_to_make_as_passthrough(self):

var = tf.constant([[5.0]])
out = model(var)

self.assertTrue(out.numpy(), 15.0)

def test_find_conv_bn_pairs_combined_model(self):
Expand Down Expand Up @@ -1146,6 +1153,16 @@ class TestBatchNormFoldToScale:
def clear_sessions(self):
tf.keras.backend.clear_session()
yield

@pytest.fixture(autouse=True)
def set_random_seed(self):
tf.compat.v1.reset_default_graph()
tf.compat.v1.set_random_seed(43)
if version.parse(tf.version.VERSION) >= version.parse("2.10"):
tf.keras.utils.set_random_seed(43)
else:
np.random.seed(43)
yield

@pytest.fixture(scope="session", autouse=True)
def cleanup(request):
Expand Down Expand Up @@ -1497,7 +1514,6 @@ def test_bn_fold_auto_mode_transposed_conv2d(self):
folded_pairs = fold_all_batch_norms_to_scale(sim)
model = sim.model
output_after_fold = model(random_input)

# Check to make sure rebuild Quantization Sim Model working properly
sim.export(path="/tmp", filename_prefix="temp_bn_fold_to_scale")

Expand Down

0 comments on commit 2a8b8e1

Please sign in to comment.