Skip to content

Commit

Permalink
[Parser] Fix tokenizing inf (apache#7370)
Browse files Browse the repository at this point in the history
* fix tokenizing inf

* use ParseNumber to parse inf, handle -inf

* fix neg handling

* fixed multi negation

* refactor

* use while loop

* simplyfing

* fix lint

* simpler implementation per altan's suggestion

* disable flaky test
  • Loading branch information
masahi authored and alexwong committed Feb 11, 2021
1 parent f9d2196 commit 8ec1b7a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
61 changes: 33 additions & 28 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,25 @@ struct Tokenizer {
}
}

Token ParseNumber(bool is_pos) {
std::stringstream ss;
while (More() && IsNumeric(Peek())) {
ss << Next();
}

bool is_float = false;

// Remove trailing floating point prefix.
if (More() && Peek() == 'f') {
ss << Next();
while (More() && IsNumeric(Peek())) {
ss << Next();
}
is_float = true;
}
return ParseNumber(is_pos, is_float, ss.str());
}

bool MatchString(const std::string& string) {
int start = this->pos;

Expand Down Expand Up @@ -340,38 +359,28 @@ struct Tokenizer {
auto token = NewToken(TokenType::kWhitespace);
Next();
return token;
} else if (IsDigit(next) || next == '-') {
} else if (next == '-') {
int negs = 0;
while (More() && Peek() == '-') {
Next();
negs++;
}
// If there isn't a number right after either,
// this is really slow for lexing, should replace
// with multi-token return or something.
if (negs && !IsDigit(Peek())) {
bool is_neg = negs % 2 == 1;
if (More() && IsDigit(Peek())) {
return ParseNumber(!is_neg);
} else if (More() && MatchString("inff")) {
return ParseNumber(!is_neg, true, "inff");
} else {
// If there isn't a number right after either,
// this is really slow for lexing, should replace
// with multi-token return or something.
pos = pos - (negs - 1);
return NewToken(TokenType::kMinus);
}

bool is_neg = negs % 2 == 1;
std::stringstream ss;
while (More() && IsNumeric(Peek())) {
ss << Next();
}

bool is_float = false;

// Remove trailing floating point prefix.
if (More() && Peek() == 'f') {
ss << Next();
while (More() && IsNumeric(Peek())) {
ss << Next();
}
is_float = true;
}

return ParseNumber(!is_neg, is_float, ss.str());
} else if (IsDigit(next)) {
return ParseNumber(true);
} else if (MatchString("inff")) {
return ParseNumber(true, true, "inff");
} else if (next == '.') {
auto token = NewToken(TokenType::kPeriod);
Next();
Expand Down Expand Up @@ -404,10 +413,6 @@ struct Tokenizer {
auto token = NewToken(TokenType::kPlus);
Next();
return token;
} else if (next == '-') {
auto token = NewToken(TokenType::kMinus);
Next();
return token;
} else if (next == '*') {
auto token = NewToken(TokenType::kStar);
Next();
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1):
def test_conv2d():
verify_conv2d("float32", "float32", tensor_format=0)
verify_conv2d("float16", "float32", tensor_format=1)
verify_conv2d("float16", "float16", tensor_format=0)
# This test is flaky, disable for now
# verify_conv2d("float16", "float16", tensor_format=0)
verify_conv2d("int8", "int32", tensor_format=1)

verify_conv2d("float32", "float32", tensor_format=0, groups=2)
Expand Down
14 changes: 12 additions & 2 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm import te
from tvm import relay
import tvm.relay.testing
import pytest
from numpy import isclose
from typing import Union
from functools import wraps


SEMVER = '#[version = "0.0.5"]\n'
Expand Down Expand Up @@ -910,6 +910,16 @@ def test_load_prelude():
tvm.parser.parse(mod.astext())


def test_tokenize_inf():
x = relay.var("x", shape=(3, 4), dtype="float32")
y = relay.clip(x, -np.inf, np.inf)

f = relay.Function([x], y)
mod = tvm.IRModule.from_expr(f)

mod = relay.transform.AnnotateSpans()(mod)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 8ec1b7a

Please sign in to comment.