Skip to content

Commit

Permalink
fix T.Ptr[T.void] for packed api roundtrip (#12118)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored Jul 22, 2022
1 parent 8c42a83 commit 4a94a94
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def c2pyerror(err_msg):
message = []
for line in arr:
if trace_mode:
if line.startswith(" "):
if line.startswith(" ") and len(stack_trace) > 0:
stack_trace[-1] += "\n" + line
elif line.startswith(" "):
stack_trace.append(line)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""TVMScript for TIR"""

# Type system
from .ty import uint8, int8, int16, int32, int64, float16, float32, float64
from .ty import uint8, int8, int16, int32, int64, float16, float32, float64, void
from .ty import boolean, handle, Ptr, Tuple, Buffer

from .prim_func import prim_func
8 changes: 8 additions & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def evaluate(self):
return self.type


class VoidType(ConcreteType): # pylint: disable=too-few-public-methods, abstract-method
"""TVM script typing class for void type"""

def __init__(self):
super().__init__("")


class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
"""TVM script typing class generator for PtrType
Expand Down Expand Up @@ -202,6 +209,7 @@ def __getitem__(self, args):
float64 = ConcreteType("float64")
boolean = ConcreteType("bool")
handle = ConcreteType("handle")
void = VoidType()
Ptr = GenericPtrType()
Tuple = GenericTupleType()
# we don't have 'buffer' type on the cpp side
Expand Down
7 changes: 6 additions & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,12 @@ Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {

Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << tir_prefix_ << "." << runtime::DLDataType2String(node->dtype);
doc << tir_prefix_ << ".";
if (node->dtype.is_void()) {
doc << "void";
} else {
doc << runtime::DLDataType2String(node->dtype);
}
return doc;
}

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3297,6 +3297,14 @@ def func():
return func


def void_ptr():
@T.prim_func
def func(out_ret_value: T.Ptr[T.void]):
T.evaluate(out_ret_value)

return func


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3335,6 +3343,7 @@ def func():
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
let_expression,
void_ptr,
)


Expand Down

0 comments on commit 4a94a94

Please sign in to comment.