Skip to content

Commit

Permalink
[BugFix] Print doubles with precision 17 in SaveJSON and TVM script p…
Browse files Browse the repository at this point in the history
…rinter (apache#7846)

* [BugFix] SaveJSON type double with precision 17

* [BugFix] Fix for TVM script printer
  • Loading branch information
MasterJH5574 authored and Trevor Morris committed May 6, 2021
1 parent 1f051c2 commit 1b9280c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ class JSONAttrGetter : public AttrVisitor {

void Visit(const char* key, double* value) final {
std::ostringstream s;
// Type <double> have approximately 16 decimal digits
s.precision(16);
// Save 17 decimal digits for type <double> to avoid precision loss during loading JSON
s.precision(17);
s << (*value);
node_->attrs[key] = s.str();
}
Expand Down
3 changes: 3 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
static Doc PrintConstScalar(DataType dtype, const T* data) {
Doc doc;
std::ostringstream os;
if (dtype.is_float() || dtype.is_float16() || dtype.is_bfloat16()) {
os.precision(17);
}
os << data[0];
if (dtype == DataType::Int(32)) {
doc << Doc::Text(os.str());
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_node_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def test_infinity_value():
_test_infinity_value(float("-inf"), "float32")


def _test_minmax_value(value):
json_str = tvm.ir.save_json(value)
tvm.ir.assert_structural_equal(value, tvm.ir.load_json(json_str))


def test_minmax_value():
_test_minmax_value(tvm.tir.min_value("float32"))
_test_minmax_value(tvm.tir.max_value("float32"))


def test_make_smap():
# save load json
x = tvm.tir.const(1, "int32")
Expand Down Expand Up @@ -160,3 +170,4 @@ def test_dict():
test_pass_config()
test_dict()
test_infinity_value()
test_minmax_value()

0 comments on commit 1b9280c

Please sign in to comment.