diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index da7bc12619bd..59cbfb914649 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -226,7 +226,7 @@ class DictAttrsNode : public BaseAttrsNode { class DictAttrs : public Attrs { public: /*! - * \brief Consruct a Attrs backed by DictAttrsNode. + * \brief Construct a Attrs backed by DictAttrsNode. * \param dict The attributes. * \return The dict attributes. */ diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index aa5f271c20c2..9152b8d170dd 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -66,11 +66,11 @@ class PrimFuncSpecializer : public StmtExprMutator { } } - // Updating parmeters + // Updating parameters Array params; bool param_updated = false; for (const auto& var : f->params) { - // Remove parmeters which has been specialized. + // Remove parameters which have been specialized. if (var_map.find(var) == var_map.end()) { params.push_back(var); } else { @@ -78,13 +78,34 @@ class PrimFuncSpecializer : public StmtExprMutator { } } + // Updating the `attrs` dictionary + DictAttrs attrs{nullptr}; + bool attrs_updated = false; + if (f->attrs.defined()) { + Map dict; + for (const std::pair& kv : f->attrs->dict) { + const auto* expr = kv.second.as(); + if (expr == nullptr) { + dict.Set(kv.first, kv.second); + continue; + } + PrimExpr result = Substitute(GetRef(expr), var_map); + dict.Set(kv.first, result); + if (result.get() != expr) { + attrs_updated = true; + } + } + attrs = DictAttrs(dict); + } + // Updating function body Stmt body = specializer(f->body); - if (param_updated || buffer_map_updated || !f->body.same_as(body)) { + if (param_updated || buffer_map_updated || attrs_updated || !f->body.same_as(body)) { PrimFuncNode* f_ptr = f.CopyOnWrite(); f_ptr->params = std::move(params); f_ptr->buffer_map = std::move(buffer_map); + f_ptr->attrs = std::move(attrs); f_ptr->body = std::move(body); } return f; @@ -248,9 +269,9 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer // build var mapping using specific_buf's parameters auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { if (!equal(new_expr, old_expr)) { - CHECK(old_expr->IsInstance()) - << "TypeError: The signature of target buffer exprected an independent Var, but got " - << old_expr << "."; + CHECK(old_expr->IsInstance()) << "TypeError: The signature of target buffer is " + "expected to be an independent Var, but it is " + << old_expr << "."; const Var& var = Downcast(old_expr); auto it = var_map->find(var); if (it != var_map->end()) { @@ -282,7 +303,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); // Check data_alignment and offset_factor. - // These two signatures are int, so we do not need map them. + // These two signatures are int, so we do not need to map them. CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment << " vs. " << specific_buf->data_alignment << "."; diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 2e9f1110732a..ef5f9b5d9b80 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -116,6 +116,38 @@ def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@tvm.script.tir +def element_wise_func_attr(a: ty.handle, c: ty.handle) -> None: + m = tir.var("int32") + n = tir.var("int32") + tir.func_attr({"test_attr": m * n}) + + A = tir.match_buffer(a, (m, n), "float32") + C = tir.match_buffer(c, (m, n), "float32") + + B = tir.alloc_buffer((m, n), "float32") + + with tir.block([m, n], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([m, n], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_func_attr_128_64(a: ty.handle, c: ty.handle) -> None: + tir.func_attr({"test_attr": 128 * 64}) + A = tir.match_buffer(a, (128, 64), "float32") + C = tir.match_buffer(c, (128, 64), "float32") + B = tir.alloc_buffer((128, 64), "float32") + + with tir.block([128, 64], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + with tir.block([128, 64], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + @tvm.script.tir def mem_copy( a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32 @@ -172,6 +204,10 @@ def test_specialize_elemwise(): # partially specialized func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))}) tvm.ir.assert_structural_equal(func, element_wise_128_n) + # specialization also updates the attributes dictionary + a_attr, _ = element_wise_func_attr.params + func = element_wise_func_attr.specialize({a_attr: tir.decl_buffer((128, 64))}) + tvm.ir.assert_structural_equal(func, element_wise_func_attr_128_64) def test_specialize_mem_copy():