From 309dfef4890ca823cf6e32dea2d8066566e0a689 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 15 Apr 2021 00:55:14 +0800 Subject: [PATCH] [BugFix] Print doubles with precision 17 in SaveJSON and TVM script printer (#7846) * [BugFix] SaveJSON type double with precision 17 * [BugFix] Fix for TVM script printer --- src/node/serialization.cc | 4 ++-- src/printer/tvmscript_printer.cc | 3 +++ tests/python/unittest/test_node_reflection.py | 11 +++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index ad42799b55e5d..75f03fbc79546 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -214,8 +214,8 @@ class JSONAttrGetter : public AttrVisitor { void Visit(const char* key, double* value) final { std::ostringstream s; - // Type have approximately 16 decimal digits - s.precision(16); + // Save 17 decimal digits for type to avoid precision loss during loading JSON + s.precision(17); s << (*value); node_->attrs[key] = s.str(); } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 7afdcab371daa..f6586ce7000a5 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -232,6 +232,9 @@ class TVMScriptPrinter : public StmtFunctor, static Doc PrintConstScalar(DataType dtype, const T* data) { Doc doc; std::ostringstream os; + if (dtype.is_float() || dtype.is_float16() || dtype.is_bfloat16()) { + os.precision(17); + } os << data[0]; if (dtype == DataType::Int(32)) { doc << Doc::Text(os.str()); diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index 67c8283b641bd..c1298b56f7fbd 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -43,6 +43,16 @@ def test_infinity_value(): _test_infinity_value(float("-inf"), "float32") +def _test_minmax_value(value): + json_str = tvm.ir.save_json(value) + tvm.ir.assert_structural_equal(value, tvm.ir.load_json(json_str)) + + +def test_minmax_value(): + _test_minmax_value(tvm.tir.min_value("float32")) + _test_minmax_value(tvm.tir.max_value("float32")) + + def test_make_smap(): # save load json x = tvm.tir.const(1, "int32") @@ -160,3 +170,4 @@ def test_dict(): test_pass_config() test_dict() test_infinity_value() + test_minmax_value()