Skip to content

Commit

Permalink
[TIR][Schedule] Annotate allows array as annotaton value (apache#9920)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
  • Loading branch information
7 people authored and crazydemo committed Jan 27, 2022
1 parent 037b9ee commit 7233afb
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
4 changes: 2 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ def annotate(
self,
block_or_loop: Union[BlockRV, LoopRV],
ann_key: str,
ann_val: Union[str, int, float, ExprRV],
ann_val: Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]],
) -> None:
"""Annotate a block/loop with a key value pair
Expand All @@ -1750,7 +1750,7 @@ def annotate(
The block/loop to be annotated
ann_key : str
The annotation key
ann_val : Union[str, int, float, ExprRV]
ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]]
The annotation value
Examples
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,14 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_
<< "TypeError: runtime::String is expected, but gets StringImm";
return this->Get(GetRef<PrimExpr>(expr));
}
if (const auto* arr = ann_val.as<ArrayNode>()) {
Array<ObjectRef> result;
result.reserve(arr->size());
for (size_t i = 0; i < arr->size(); i++) {
result.push_back(CheckAndGetAnnotationValue(arr->at(i)));
}
return std::move(result);
}
LOG(FATAL)
<< "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
<< "gets: " << ann_val->GetTypeKey();
Expand Down
13 changes: 1 addition & 12 deletions src/tir/schedule/primitive/annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,7 @@ struct AnnotateTraits : public UnpackedInstTraits<AnnotateTraits> {
PythonAPICall py("annotate");
py.Input("block_or_loop", block_or_loop_rv);
py.Input("ann_key", ann_key);
if (const auto* int_imm = ann_val.as<IntImmNode>()) {
py.Input("ann_val", std::to_string(int_imm->value));
} else if (const auto* str_imm = ann_val.as<StringObj>()) {
py.Input("ann_val", GetRef<String>(str_imm));
} else if (const auto* expr = ann_val.as<PrimExprNode>()) {
std::ostringstream os;
os << GetRef<PrimExpr>(expr);
py.Input("ann_val", os.str());
} else {
LOG(FATAL) << "TypeError: Cannot handle type: " << ann_val->GetTypeKey();
throw;
}
py.Input("ann_val", ann_val);
return py.Str();
}

Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ Array<ObjectRef> TranslateInputRVs(
} else if (input->IsInstance<IntImmNode>() || input->IsInstance<FloatImmNode>()) {
// Case 3. integer or floating-point number
results.push_back(input);
} else if (input->IsInstance<ArrayNode>()) {
// Case 4: array
results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_names));
} else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() ||
inputs->IsInstance<VarNode>()) {
LOG(FATAL) << "IndexError: Random variable is not defined " << input;
Expand All @@ -136,6 +139,11 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
results.push_back(input);
continue;
}
// Case 4. array
if (input->IsInstance<ArrayNode>()) {
results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), named_rvs));
continue;
}
const auto* str = input.as<StringObj>();
CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey();
CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names";
Expand Down
8 changes: 6 additions & 2 deletions tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None:
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
for i in T.serial(0, 1024, annotations={"test1": "aaa"}):
for j in T.serial(0, 1024, annotations={"test2": 612}):
for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}):
for k in T.serial(0, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
Expand Down Expand Up @@ -97,7 +97,7 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None:
for i, j in T.grid(1024, 1024):
with T.block("relu"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"test2": 0.22})
T.block_attr({"test2": 0.22, "test3": ["aa", 1]})
D[vi, vj] = T.max(C[vi, vj], 0.0)


Expand Down Expand Up @@ -245,10 +245,12 @@ def test_annotate_unannotate_loop():
relu = sch.get_block("relu")
sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa")
sch.annotate(sch.get_loops(matmul)[1], "test2", 612)
sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1])
tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(sch.get_loops(matmul)[0], "test1")
sch.unannotate(sch.get_loops(matmul)[1], "test2")
sch.unannotate(sch.get_loops(matmul)[1], "test3")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)


Expand All @@ -258,10 +260,12 @@ def test_annotate_unannotate_block():
relu = sch.get_block("relu")
sch.annotate(matmul, "test1", "aaa")
sch.annotate(relu, "test2", 0.22)
sch.annotate(relu, "test3", ["aa", 1])
tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(matmul, "test1")
sch.unannotate(relu, "test2")
sch.unannotate(relu, "test3")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)


Expand Down

0 comments on commit 7233afb

Please sign in to comment.