From 8dd5e24325577a0b3d6c62152dfe962ed5fd14db Mon Sep 17 00:00:00 2001 From: yathindra kota Date: Wed, 7 Jun 2023 12:33:13 -0700 Subject: [PATCH] Use aimet FP16 quantization flow instead of TF flow for GPU Signed-off-by: yathindra kota --- .../tensorflow/src/AimetFp16OpUtils.h | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/TrainingExtensions/tensorflow/src/AimetFp16OpUtils.h b/TrainingExtensions/tensorflow/src/AimetFp16OpUtils.h index 3f097a89fc..178958e566 100644 --- a/TrainingExtensions/tensorflow/src/AimetFp16OpUtils.h +++ b/TrainingExtensions/tensorflow/src/AimetFp16OpUtils.h @@ -40,6 +40,7 @@ #define AIMET_FP16_OP_UTILS_H #include "AimetOpUtils.h" +#include "DlQuantization/Fp16Quantization.hpp" #define EIGEN_USE_THREADS @@ -86,18 +87,12 @@ class QuantizeDequantizeFp16Functor template <> class QuantizeDequantizeFp16Functor { - // truncate, if set to true would truncate the inputs before casting to fp16. If set to true, tensorflow backend - // calls LSBZeroSetter which does the truncate operation - bool _truncate = false; - public: void operator()(OpKernelContext* context, const Tensor& inTensor, Tensor* outTensor) { - Tensor tempTensorFp16; - OP_REQUIRES_OK(context, context->allocate_temp(DT_HALF, inTensor.shape(), &tempTensorFp16)); - - GetGpuCastFromFloat(DT_HALF)(context, inTensor, &tempTensorFp16, _truncate); - GetGpuCastFromHalf(DT_FLOAT)(context, tempTensorFp16, outTensor, _truncate); + DlQuantization::quantizeDequantizeFp16Gpu(inTensor.flat().data(), + inTensor.NumElements(), + outTensor->flat().data()); } }; #endif // GOOGLE_CUDA