From 4a94a94dfc9ba6e265e0847228272562989072c7 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 22 Jul 2022 12:40:01 +0800 Subject: [PATCH] fix T.Ptr[T.void] for packed api roundtrip (#12118) --- python/tvm/_ffi/base.py | 2 +- python/tvm/script/tir/__init__.py | 2 +- python/tvm/script/tir/ty.py | 8 ++++++++ src/printer/tvmscript_printer.cc | 7 ++++++- tests/python/unittest/test_tvmscript_roundtrip.py | 9 +++++++++ 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index e4e1fb1bb863..744e4c93e181 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -255,7 +255,7 @@ def c2pyerror(err_msg): message = [] for line in arr: if trace_mode: - if line.startswith(" "): + if line.startswith(" ") and len(stack_trace) > 0: stack_trace[-1] += "\n" + line elif line.startswith(" "): stack_trace.append(line) diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index de4045913102..2655f5bb3362 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -17,7 +17,7 @@ """TVMScript for TIR""" # Type system -from .ty import uint8, int8, int16, int32, int64, float16, float32, float64 +from .ty import uint8, int8, int16, int32, int64, float16, float32, float64, void from .ty import boolean, handle, Ptr, Tuple, Buffer from .prim_func import prim_func diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 878f029e55dd..a64485b215f8 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -69,6 +69,13 @@ def evaluate(self): return self.type +class VoidType(ConcreteType): # pylint: disable=too-few-public-methods, abstract-method + """TVM script typing class for void type""" + + def __init__(self): + super().__init__("") + + class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for PtrType @@ -202,6 +209,7 @@ def __getitem__(self, args): float64 = ConcreteType("float64") boolean = ConcreteType("bool") handle = ConcreteType("handle") +void = VoidType() Ptr = GenericPtrType() Tuple = GenericTupleType() # we don't have 'buffer' type on the cpp side diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 725e105c016a..aaebc7409f29 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1236,7 +1236,12 @@ Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) { Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { Doc doc; - doc << tir_prefix_ << "." << runtime::DLDataType2String(node->dtype); + doc << tir_prefix_ << "."; + if (node->dtype.is_void()) { + doc << "void"; + } else { + doc << runtime::DLDataType2String(node->dtype); + } return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 306f60f1b1ba..8e0561bb19f9 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3297,6 +3297,14 @@ def func(): return func +def void_ptr(): + @T.prim_func + def func(out_ret_value: T.Ptr[T.void]): + T.evaluate(out_ret_value) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3335,6 +3343,7 @@ def func(): buffer_axis_separator, buffer_ramp_access_as_slice_index, let_expression, + void_ptr, )