From 1af5869f19916f3a9f65934d64c95c7488ebc48f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 12 Jan 2022 19:56:06 -0500 Subject: [PATCH] [TIR][Schedule] Annotate allows array as annotaton value Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou --- python/tvm/tir/schedule/schedule.py | 4 ++-- src/tir/schedule/concrete_schedule.cc | 8 ++++++++ src/tir/schedule/primitive/annotate.cc | 13 +------------ src/tir/schedule/trace.cc | 8 ++++++++ .../python/unittest/test_tir_schedule_utilities.py | 8 ++++++-- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 50905eed9169..b261fd0a7518 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1740,7 +1740,7 @@ def annotate( self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str, - ann_val: Union[str, int, float, ExprRV], + ann_val: Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]], ) -> None: """Annotate a block/loop with a key value pair @@ -1750,7 +1750,7 @@ def annotate( The block/loop to be annotated ann_key : str The annotation key - ann_val : Union[str, int, float, ExprRV] + ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]] The annotation value Examples diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 37d896a71196..9e5b6f949feb 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -608,6 +608,14 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ << "TypeError: runtime::String is expected, but gets StringImm"; return this->Get(GetRef(expr)); } + if (const auto* arr = ann_val.as()) { + Array result; + result.reserve(arr->size()); + for (size_t i = 0; i < arr->size(); i++) { + result.push_back(CheckAndGetAnnotationValue(arr->at(i))); + } + return std::move(result); + } LOG(FATAL) << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " << "gets: " << ann_val->GetTypeKey(); diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 0c79d55fcd86..f5c1978a1b25 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -112,18 +112,7 @@ struct AnnotateTraits : public UnpackedInstTraits { PythonAPICall py("annotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); - if (const auto* int_imm = ann_val.as()) { - py.Input("ann_val", std::to_string(int_imm->value)); - } else if (const auto* str_imm = ann_val.as()) { - py.Input("ann_val", GetRef(str_imm)); - } else if (const auto* expr = ann_val.as()) { - std::ostringstream os; - os << GetRef(expr); - py.Input("ann_val", os.str()); - } else { - LOG(FATAL) << "TypeError: Cannot handle type: " << ann_val->GetTypeKey(); - throw; - } + py.Input("ann_val", ann_val); return py.Str(); } diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index d8c18f0de0d6..dc05f10cc4f8 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -114,6 +114,9 @@ Array TranslateInputRVs( } else if (input->IsInstance() || input->IsInstance()) { // Case 3. integer or floating-point number results.push_back(input); + } else if (input->IsInstance()) { + // Case 4: array + results.push_back(TranslateInputRVs(Downcast>(input), rv_names)); } else if (input->IsInstance() || inputs->IsInstance() || inputs->IsInstance()) { LOG(FATAL) << "IndexError: Random variable is not defined " << input; @@ -136,6 +139,11 @@ Array TranslateInputRVs(const Array& inputs, results.push_back(input); continue; } + // Case 4. array + if (input->IsInstance()) { + results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); + continue; + } const auto* str = input.as(); CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names"; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index e01d469d8ec5..5ec4a1120923 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -68,7 +68,7 @@ def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i in T.serial(0, 1024, annotations={"test1": "aaa"}): - for j in T.serial(0, 1024, annotations={"test2": 612}): + for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}): for k in T.serial(0, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) @@ -97,7 +97,7 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"test2": 0.22}) + T.block_attr({"test2": 0.22, "test3": ["aa", 1]}) D[vi, vj] = T.max(C[vi, vj], 0.0) @@ -245,10 +245,12 @@ def test_annotate_unannotate_loop(): relu = sch.get_block("relu") sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa") sch.annotate(sch.get_loops(matmul)[1], "test2", 612) + sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1]) tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1) verify_trace_roundtrip(sch=sch, mod=matmul_relu) sch.unannotate(sch.get_loops(matmul)[0], "test1") sch.unannotate(sch.get_loops(matmul)[1], "test2") + sch.unannotate(sch.get_loops(matmul)[1], "test3") verify_trace_roundtrip(sch=sch, mod=matmul_relu) @@ -258,10 +260,12 @@ def test_annotate_unannotate_block(): relu = sch.get_block("relu") sch.annotate(matmul, "test1", "aaa") sch.annotate(relu, "test2", 0.22) + sch.annotate(relu, "test3", ["aa", 1]) tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2) verify_trace_roundtrip(sch=sch, mod=matmul_relu) sch.unannotate(matmul, "test1") sch.unannotate(relu, "test2") + sch.unannotate(relu, "test3") verify_trace_roundtrip(sch=sch, mod=matmul_relu)