From 8ec1b7a1369edddd49c143389d2272c76bfb6e58 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 2 Feb 2021 06:39:41 +0900 Subject: [PATCH] [Parser] Fix tokenizing inf (#7370) * fix tokenizing inf * use ParseNumber to parse inf, handle -inf * fix neg handling * fixed multi negation * refactor * use while loop * simplyfing * fix lint * simpler implementation per altan's suggestion * disable flaky test --- src/parser/tokenizer.h | 61 +++++++++++++++------------- tests/python/contrib/test_cudnn.py | 3 +- tests/python/relay/test_ir_parser.py | 14 ++++++- 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index c6fb3e09f4d1..5e71794cc7fb 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -212,6 +212,25 @@ struct Tokenizer { } } + Token ParseNumber(bool is_pos) { + std::stringstream ss; + while (More() && IsNumeric(Peek())) { + ss << Next(); + } + + bool is_float = false; + + // Remove trailing floating point prefix. + if (More() && Peek() == 'f') { + ss << Next(); + while (More() && IsNumeric(Peek())) { + ss << Next(); + } + is_float = true; + } + return ParseNumber(is_pos, is_float, ss.str()); + } + bool MatchString(const std::string& string) { int start = this->pos; @@ -340,38 +359,28 @@ struct Tokenizer { auto token = NewToken(TokenType::kWhitespace); Next(); return token; - } else if (IsDigit(next) || next == '-') { + } else if (next == '-') { int negs = 0; while (More() && Peek() == '-') { Next(); negs++; } - // If there isn't a number right after either, - // this is really slow for lexing, should replace - // with multi-token return or something. - if (negs && !IsDigit(Peek())) { + bool is_neg = negs % 2 == 1; + if (More() && IsDigit(Peek())) { + return ParseNumber(!is_neg); + } else if (More() && MatchString("inff")) { + return ParseNumber(!is_neg, true, "inff"); + } else { + // If there isn't a number right after either, + // this is really slow for lexing, should replace + // with multi-token return or something. pos = pos - (negs - 1); return NewToken(TokenType::kMinus); } - - bool is_neg = negs % 2 == 1; - std::stringstream ss; - while (More() && IsNumeric(Peek())) { - ss << Next(); - } - - bool is_float = false; - - // Remove trailing floating point prefix. - if (More() && Peek() == 'f') { - ss << Next(); - while (More() && IsNumeric(Peek())) { - ss << Next(); - } - is_float = true; - } - - return ParseNumber(!is_neg, is_float, ss.str()); + } else if (IsDigit(next)) { + return ParseNumber(true); + } else if (MatchString("inff")) { + return ParseNumber(true, true, "inff"); } else if (next == '.') { auto token = NewToken(TokenType::kPeriod); Next(); @@ -404,10 +413,6 @@ struct Tokenizer { auto token = NewToken(TokenType::kPlus); Next(); return token; - } else if (next == '-') { - auto token = NewToken(TokenType::kMinus); - Next(); - return token; } else if (next == '*') { auto token = NewToken(TokenType::kStar); Next(); diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index a17776028647..b6e6284f5893 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -94,7 +94,8 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) - verify_conv2d("float16", "float16", tensor_format=0) + # This test is flaky, disable for now + # verify_conv2d("float16", "float16", tensor_format=0) verify_conv2d("int8", "int32", tensor_format=1) verify_conv2d("float32", "float32", tensor_format=0, groups=2) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 162271756557..70fb56049873 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm -from tvm import te from tvm import relay import tvm.relay.testing import pytest from numpy import isclose from typing import Union -from functools import wraps SEMVER = '#[version = "0.0.5"]\n' @@ -910,6 +910,16 @@ def test_load_prelude(): tvm.parser.parse(mod.astext()) +def test_tokenize_inf(): + x = relay.var("x", shape=(3, 4), dtype="float32") + y = relay.clip(x, -np.inf, np.inf) + + f = relay.Function([x], y) + mod = tvm.IRModule.from_expr(f) + + mod = relay.transform.AnnotateSpans()(mod) + + if __name__ == "__main__": import sys