diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 569e8f042486..c2b4fdb2d00e 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -68,15 +68,28 @@ from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel +def _should_print_backtrace(): + in_pytest = "PYTEST_CURRENT_TEST" in os.environ + tvm_backtrace = os.environ.get("TVM_BACKTRACE", "0") + + try: + tvm_backtrace = bool(int(tvm_backtrace)) + except ValueError: + raise ValueError( + f"invalid value for TVM_BACKTRACE `{tvm_backtrace}`, please set to 0 or 1." + ) + + return in_pytest or tvm_backtrace + + def tvm_wrap_excepthook(exception_hook): """Wrap given excepthook with TVM additional work.""" def wrapper(exctype, value, trbk): """Clean subprocesses when TVM is interrupted.""" - in_pytest = "PYTEST_CURRENT_TEST" in os.environ - - if exctype is error.DiagnosticError and not in_pytest: - pass + if exctype is error.DiagnosticError and not _should_print_backtrace(): + # TODO(@jroesch): consider moving to C++? + print("note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.") else: exception_hook(exctype, value, trbk) diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 987a6e20ec38..afcf70737933 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -605,30 +605,43 @@ class Parser { return ast; } + struct MetaRef { + std::string type_key; + uint64_t node_index; + Span span; + MetaRef(std::string type_key, uint64_t node_index, Span span) + : type_key(type_key), node_index(node_index), span(span) {} + }; + + MetaRef MetaRefFromToken(const Token& tok) { + Call ref = Downcast(tok->data); + auto attrs = ref->attrs.as(); + auto type_key = attrs->node_type_key; + auto index = attrs->node_index; + return MetaRef(type_key, index, ref->span); + } + /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`. * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]` * the second, and so on. */ ObjectRef ParseMetaRef() { - auto meta_ref = Match(TokenType::kMetaReference); - Call ref = Downcast(meta_ref->data); - auto attrs = ref->attrs.as(); - auto type_key = attrs->node_type_key; - auto index = attrs->node_index; - auto it = this->meta_table.find(type_key); + auto meta_ref_tok = Match(TokenType::kMetaReference); + auto meta_ref = MetaRefFromToken(meta_ref_tok); + auto it = this->meta_table.find(meta_ref.type_key); if (it != this->meta_table.end()) { auto nodes = (*it).second; - if (index < nodes.size()) { - return nodes[index]; + if (meta_ref.node_index < nodes.size()) { + return nodes[meta_ref.node_index]; } else { - this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span) - << "the node index `" << index << "` is out of bounds for `" << type_key - << "`"); + this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) + << "the node index `" << meta_ref.node_index + << "` is out of bounds for `" << meta_ref.type_key << "`"); return ObjectRef(); } } else { - this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span) - << "no entry in the meta table for `" << type_key << "`"); + this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span) + << "no entry in the meta table for `" << meta_ref.type_key << "`"); return ObjectRef(); } } @@ -922,10 +935,7 @@ class Parser { exprs.push_back(ParseMatch(is_total)); break; } - case TokenType::kIf: { - exprs.push_back(ParseIf()); - break; - } + // %x ... case TokenType::kGraph: if (Lookahead(2)->token_type == TokenType::kEqual) { @@ -1344,6 +1354,10 @@ class Parser { Match(TokenType::kIdentifier); return ObjectRef(); } + if (id == "None") { + Match(TokenType::kIdentifier); + return Optional(); + } } } default: @@ -1372,7 +1386,7 @@ class Parser { ICHECK(op.defined()) << "the operator must be defined"; DLOG(INFO) << "Parser::ParseCallArgs"; - Map raw_attrs; + Attrs attrs; std::string op_key; bool is_op = false; @@ -1388,21 +1402,40 @@ class Parser { [&] { auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier; auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual; - - if (is_op && is_ident && next_is_equal) { - raw_attrs = ParseAttrs(); + auto is_pretty_attrs = is_ident && next_is_equal; + auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference; + // TODO(@jroesch): might not handle trailing comma + auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen; + auto is_meta_attrs = is_meta_next && last_meta; + + if (is_op && (is_pretty_attrs || is_meta_attrs)) { + if (is_meta_attrs) { + auto meta_ref = ParseMetaRef(); + if (meta_ref.as()) { + attrs = Downcast(meta_ref); + } else { + // Not awesome parsing code here. + this->pos--; + return false; + } + } else { + auto raw_attrs = ParseAttrs(); + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } return true; } return false; }); - Attrs attrs; - - if (is_op && op_key.size()) { - auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs); - ICHECK(attr_obj.defined()); - attrs = Downcast(attr_obj); + if (!attrs.defined()) { + if (is_op && op_key.size()) { + auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {}); + ICHECK(attr_obj.defined()); + attrs = Downcast(attr_obj); + } } // TODO(@jroesch): in a secondary pass adjust spans. @@ -1527,6 +1560,10 @@ class Parser { ICHECK(e->span.defined()) << "function spans must be defined.\n" << e; return e; } + case TokenType::kIf: { + Expr e = ParseIf(); + return e; + } case TokenType::kRef: { Consume(TokenType::kRef); Match(TokenType::kOpenParen); diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index c5217ba41bfd..162271756557 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -875,6 +875,20 @@ def @example() { parse_module(program) +def test_parse_if_in_binding(): + program = """ + def @example(%b: bool) { + %0 = if (%b) { + 1 + } else { + 0 + }; + %0 + } + """ + parse_module(program) + + def test_op_string_attr(): call = parse_text( """