From 97a4ed923677c6dfd545fd654c55c424cf490a19 Mon Sep 17 00:00:00 2001 From: liuzhe Date: Thu, 10 Oct 2019 13:14:32 +0800 Subject: [PATCH] use tensorflow.compat.v1 --- src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py | 2 +- .../pynni/nni/compression/tensorflow/builtin_quantizers.py | 2 +- src/sdk/pynni/nni/compression/tensorflow/compressor.py | 2 +- src/sdk/pynni/tests/test_compressor.py | 4 +++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py index c2b7e4453d..0a977bf218 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_pruners.py @@ -1,5 +1,5 @@ import logging -import tensorflow as tf +import tensorflow.compat.v1 as tf from .compressor import Pruner __all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] diff --git a/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py b/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py index a7ed2b9338..7cda8bf254 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py +++ b/src/sdk/pynni/nni/compression/tensorflow/builtin_quantizers.py @@ -1,5 +1,5 @@ import logging -import tensorflow as tf +import tensorflow.compat.v1 as tf from .compressor import Quantizer __all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] diff --git a/src/sdk/pynni/nni/compression/tensorflow/compressor.py b/src/sdk/pynni/nni/compression/tensorflow/compressor.py index 3c0cf04d1c..76fd95ad43 100644 --- a/src/sdk/pynni/nni/compression/tensorflow/compressor.py +++ b/src/sdk/pynni/nni/compression/tensorflow/compressor.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf import logging from . import default_layers diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 83735a20a2..34d3caa196 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -1,10 +1,12 @@ from unittest import TestCase, main -import tensorflow as tf +import tensorflow.compat.v1 as tf import torch import torch.nn.functional as F import nni.compression.tensorflow as tf_compressor import nni.compression.torch as torch_compressor +tf.disable_v2_behavior() + def weight_variable(shape): return tf.Variable(tf.truncated_normal(shape, stddev = 0.1))