Skip to content

Commit

Permalink
[TIR] Add 'global_symbol' and 'tir.noalias' as default attributes in …
Browse files Browse the repository at this point in the history
…script auto completion (apache#9744)

* [TVMScript] Update default TIR prefix to T

* [TIR] Add 'global_symbol' and 'tir.noalias' as default attributes in script auto completion

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
  • Loading branch information
7 people authored and ylc committed Jan 13, 2022
1 parent 42b7b6d commit 7fd300f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintBody(const Stmt& body, bool indent = true);
};

String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false);
String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false);

String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate);
Expand Down
13 changes: 6 additions & 7 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,16 +302,15 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
ICHECK(it != info.tensor2buffers.end());
buffer_map.Set(arg, it->second);
}
PrimFunc func = PrimFunc(/*params=*/std::move(parameters),
/*body=*/SeqStmt::Flatten(root_stmts),
/*ret_type=*/VoidType(),
/*buffer_map=*/std::move(buffer_map));

PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters),
/*body=*/SeqStmt::Flatten(root_stmts),
/*ret_type=*/VoidType(),
/*buffer_map=*/std::move(buffer_map)),
{{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}});
const auto* complete = runtime::Registry::Get("script.Complete");
ICHECK(complete);

return (*complete)(func, info.root_alloc);
} // namespace tir
}

PrimFunc CreatePrimFuncFromOutputs(const Array<te::Tensor>& outputs) {
std::vector<te::Tensor> stack;
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def te_matmul():

@T.prim_func
def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
Expand All @@ -75,6 +76,7 @@ def te_element_wise():

@T.prim_func
def tir_element_wise(a: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (128, 128))
C = T.match_buffer(c, (128, 128))
B = T.alloc_buffer((128, 128))
Expand Down Expand Up @@ -126,6 +128,7 @@ def te_conv2d():

@T.prim_func
def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [16, 16, 14, 14])
W = T.match_buffer(w, [16, 3, 3, 32])
B = T.match_buffer(b, [16, 32, 14, 14])
Expand Down Expand Up @@ -163,6 +166,7 @@ def te_multi_output():

@T.prim_func
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
A0 = T.match_buffer(a0, (m, n))
Expand Down Expand Up @@ -199,6 +203,7 @@ def te_extern():

@T.prim_func
def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
Expand Down Expand Up @@ -257,6 +262,7 @@ def te_reordered_matmul():

@T.prim_func
def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
Expand Down

0 comments on commit 7fd300f

Please sign in to comment.